@@ -1,7 +1,11 @@ | |||
[中文](#支持批并行的LatticeLSTM) | |||
[English](#Batch-Parallel-LatticeLSTM) | |||
# 支持批并行的LatticeLSTM | |||
+ 原论文:https://arxiv.org/abs/1805.02023 | |||
+ 在batch=10时,计算速度已明显超过[原版代码](https://github.com/jiesutd/LatticeLSTM)。 | |||
+ 在main.py中添加三个embedding的文件路径以及对应数据集的路径即可运行 | |||
+ 此代码集合已加入fastNLP | |||
## 运行环境: | |||
+ python >= 3.7.3 | |||
@@ -18,7 +22,7 @@ | |||
## 性能: | |||
|数据集| 目前达到的F1分数(test)|原文中的F1分数(test)| | |||
|:----:|:----:|:----:| | |||
|Weibo|62.73|58.79| | |||
|Weibo|58.66|58.79| | |||
|Resume|95.18|94.46| | |||
|Ontonote|73.62|73.88| | |||
@@ -26,3 +30,36 @@ | |||
## 如有任何疑问请联系: | |||
+ lixiaonan_xdu@outlook.com | |||
--- | |||
# Batch Parallel LatticeLSTM | |||
+ paper:https://arxiv.org/abs/1805.02023 | |||
+ when batch is 10,the computation efficiency exceeds that of [original code](https://github.com/jiesutd/LatticeLSTM)。 | |||
+ set the path of embeddings and corpus before you run main.py | |||
+ this code set has been added to fastNLP | |||
## Environment: | |||
+ python >= 3.7.3 | |||
+ fastNLP >= dev.0.5.0 | |||
+ pytorch >= 1.1.0 | |||
+ numpy >= 1.16.4 | |||
+ fitlog >= 0.2.0 | |||
## Dataset: | |||
+ Resume,downloaded from [here](https://github.com/jiesutd/LatticeLSTM) | |||
+ Ontonote | |||
+ [Weibo](https://github.com/hltcoe/golden-horse) | |||
to those unincluded dataset, you can write the interface function whose output form is like *load_ontonotes4ner* in load_data.py | |||
## Performance: | |||
|Dataset|F1 of my code(test)|F1 in paper(test)| | |||
|:----:|:----:|:----:| | |||
|Weibo|58.66|58.79| | |||
|Resume|95.18|94.46| | |||
|Ontonote|73.62|73.88| | |||
PS:The Weibo dataset I use is V2, namely revised version. | |||
## If any confusion, please contact: | |||
+ lixiaonan_xdu@outlook.com |
@@ -1,252 +0,0 @@ | |||
import torch.nn as nn | |||
from pathes import * | |||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner | |||
from fastNLP.embeddings import StaticEmbedding | |||
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 | |||
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward | |||
import torch.optim as optim | |||
import argparse | |||
import torch | |||
import sys | |||
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ | |||
from fastNLP import Tester | |||
import fitlog | |||
from fastNLP.core.callback import FitlogCallback | |||
from utils import set_seed | |||
import os | |||
from fastNLP import LRScheduler | |||
from torch.optim.lr_scheduler import LambdaLR | |||
# sys.path.append('.') | |||
# sys.path.append('..') | |||
# for p in sys.path: | |||
# print(p) | |||
# fitlog.add_hyper_in_file (__file__) # record your hyperparameters | |||
########hyper | |||
########hyper | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--device',default='cpu') | |||
parser.add_argument('--debug',default=True) | |||
parser.add_argument('--batch',default=1) | |||
parser.add_argument('--test_batch',default=1024) | |||
parser.add_argument('--optim',default='sgd',help='adam|sgd') | |||
parser.add_argument('--lr',default=0.015) | |||
parser.add_argument('--model',default='lattice',help='lattice|lstm') | |||
parser.add_argument('--skip_before_head',default=False)#in paper it's false | |||
parser.add_argument('--hidden',default=100) | |||
parser.add_argument('--momentum',default=0) | |||
parser.add_argument('--bi',default=True) | |||
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') | |||
parser.add_argument('--use_bigram',default=False) | |||
parser.add_argument('--embed_dropout',default=0) | |||
parser.add_argument('--output_dropout',default=0) | |||
parser.add_argument('--epoch',default=100) | |||
parser.add_argument('--seed',default=100) | |||
args = parser.parse_args() | |||
set_seed(args.seed) | |||
fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)] | |||
if args.model == 'lattice': | |||
fit_msg_list.append(str(args.skip_before_head)) | |||
fit_msg = ' '.join(fit_msg_list) | |||
# fitlog.commit(__file__,fit_msg=fit_msg) | |||
# fitlog.add_hyper(args) | |||
device = torch.device(args.device) | |||
for k,v in args.__dict__.items(): | |||
print(k,v) | |||
refresh_data = False | |||
if args.dataset == 'ontonote': | |||
datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||
_refresh=refresh_data,index_token=False) | |||
elif args.dataset == 'resume': | |||
datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||
_refresh=refresh_data,index_token=False) | |||
# exit() | |||
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | |||
_refresh=refresh_data) | |||
cache_name = os.path.join('cache',args.dataset+'_lattice') | |||
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, | |||
_refresh=refresh_data,_cache_fp=cache_name) | |||
print('中:embedding:{}'.format(embeddings['char'](24))) | |||
print('embed lookup dropout:{}'.format(embeddings['word'].word_dropout)) | |||
# for k, v in datasets.items(): | |||
# # v.apply_field(lambda x: list(map(len, x)), 'skips_l2r_word', 'lexicon_count') | |||
# # v.apply_field(lambda x: | |||
# # list(map(lambda y: | |||
# # list(map(lambda z: vocabs['word'].to_index(z), y)), x)), | |||
# # 'skips_l2r_word') | |||
print(datasets['train'][0]) | |||
print('vocab info:') | |||
for k,v in vocabs.items(): | |||
print('{}:{}'.format(k,len(v))) | |||
# print(datasets['dev'][0]) | |||
# print(datasets['test'][0]) | |||
# print(datasets['train'].get_all_fields().keys()) | |||
for k,v in datasets.items(): | |||
if args.model == 'lattice': | |||
v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source') | |||
if args.skip_before_head: | |||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||
v.set_padder('skips_r2l_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True)) | |||
else: | |||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_r2l_word', LatticeLexiconPadder()) | |||
v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1)) | |||
v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1)) | |||
if args.bi: | |||
v.set_input('chars','bigrams','seq_len', | |||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||
'skips_r2l_word', 'skips_r2l_source','lexicon_count_back', | |||
'target', | |||
use_1st_ins_infer_dim_type=True) | |||
else: | |||
v.set_input('chars','bigrams','seq_len', | |||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||
'target', | |||
use_1st_ins_infer_dim_type=True) | |||
v.set_target('target','seq_len') | |||
v['target'].set_pad_val(0) | |||
elif args.model == 'lstm': | |||
v.set_ignore_type('skips_l2r_word','skips_l2r_source') | |||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||
v.set_input('chars','bigrams','seq_len','target', | |||
use_1st_ins_infer_dim_type=True) | |||
v.set_target('target','seq_len') | |||
v['target'].set_pad_val(0) | |||
print(datasets['dev']['skips_l2r_word'][100]) | |||
if args.model =='lattice': | |||
model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||
skip_batch_first=True,bidirectional=args.bi,debug=args.debug, | |||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram, | |||
vocabs=vocabs | |||
) | |||
elif args.model == 'lstm': | |||
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||
bidirectional=args.bi, | |||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||
use_bigram=args.use_bigram) | |||
for k,v in model.state_dict().items(): | |||
print('{}:{}'.format(k,v.size())) | |||
# exit() | |||
weight_dict = torch.load(open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb')) | |||
# print(weight_dict.keys()) | |||
# for k,v in weight_dict.items(): | |||
# print('{}:{}'.format(k,v.size())) | |||
def state_dict_param(model): | |||
param_list = list(model.named_parameters()) | |||
print(len(param_list)) | |||
param_dict = {} | |||
for i in range(len(param_list)): | |||
param_dict[param_list[i][0]] = param_list[i][1] | |||
return param_dict | |||
def copy_yangjie_lattice_weight(target,source_dict): | |||
t = state_dict_param(target) | |||
with torch.no_grad(): | |||
t['encoder.char_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.weight_ih']) | |||
t['encoder.char_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.weight_hh']) | |||
t['encoder.char_cell.alpha_weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_ih']) | |||
t['encoder.char_cell.alpha_weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_hh']) | |||
t['encoder.char_cell.bias'].set_(source_dict['lstm.forward_lstm.rnn.bias']) | |||
t['encoder.char_cell.alpha_bias'].set_(source_dict['lstm.forward_lstm.rnn.alpha_bias']) | |||
t['encoder.word_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_ih']) | |||
t['encoder.word_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_hh']) | |||
t['encoder.word_cell.bias'].set_(source_dict['lstm.forward_lstm.word_rnn.bias']) | |||
t['encoder_back.char_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.weight_ih']) | |||
t['encoder_back.char_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.weight_hh']) | |||
t['encoder_back.char_cell.alpha_weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_ih']) | |||
t['encoder_back.char_cell.alpha_weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_hh']) | |||
t['encoder_back.char_cell.bias'].set_(source_dict['lstm.backward_lstm.rnn.bias']) | |||
t['encoder_back.char_cell.alpha_bias'].set_(source_dict['lstm.backward_lstm.rnn.alpha_bias']) | |||
t['encoder_back.word_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_ih']) | |||
t['encoder_back.word_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_hh']) | |||
t['encoder_back.word_cell.bias'].set_(source_dict['lstm.backward_lstm.word_rnn.bias']) | |||
for k,v in t.items(): | |||
print('{}:{}'.format(k,v)) | |||
copy_yangjie_lattice_weight(model,weight_dict) | |||
# print(vocabs['label'].word2idx.keys()) | |||
loss = LossInForward() | |||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') | |||
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') | |||
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') | |||
metrics = [f1_metric,f1_metric_yj,acc_metric] | |||
if args.optim == 'adam': | |||
optimizer = optim.Adam(model.parameters(),lr=args.lr) | |||
elif args.optim == 'sgd': | |||
optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) | |||
# tester = Tester(datasets['dev'],model,metrics=metrics,batch_size=args.test_batch,device=device) | |||
# test_result = tester.test() | |||
# print(test_result) | |||
callbacks = [ | |||
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05)**ep)) | |||
] | |||
print(datasets['train'][:2]) | |||
print(vocabs['char'].to_index(':')) | |||
# StaticEmbedding | |||
# datasets['train'] = datasets['train'][1:] | |||
from fastNLP import SequentialSampler | |||
trainer = Trainer(datasets['train'],model, | |||
optimizer=optimizer, | |||
loss=loss, | |||
metrics=metrics, | |||
dev_data=datasets['dev'], | |||
device=device, | |||
batch_size=args.batch, | |||
n_epochs=args.epoch, | |||
dev_batch_size=args.test_batch, | |||
callbacks=callbacks, | |||
check_code_level=-1, | |||
sampler=SequentialSampler()) | |||
trainer.train() |
@@ -416,10 +416,92 @@ def load_conllized_ontonote_pkl_yf(path): | |||
return task_lst, vocabs | |||
@cache_results(_cache_fp='weiboNER old uni+bi', _refresh=False) | |||
def load_weibo_ner_old(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||
normlize={'char':True,'bigram':True,'word':False}): | |||
from fastNLP.io.data_loader import ConllLoader | |||
from utils import get_bigrams | |||
loader = ConllLoader(['chars','target']) | |||
# from fastNLP.io.file_reader import _read_conll | |||
# from fastNLP.core import Instance,DataSet | |||
# def _load(path): | |||
# ds = DataSet() | |||
# for idx, data in _read_conll(path, indexes=loader.indexes, dropna=loader.dropna, | |||
# encoding='ISO-8859-1'): | |||
# ins = {h: data[i] for i, h in enumerate(loader.headers)} | |||
# ds.append(Instance(**ins)) | |||
# return ds | |||
# from fastNLP.io.utils import check_loader_paths | |||
# paths = check_loader_paths(path) | |||
# datasets = {name: _load(path) for name, path in paths.items()} | |||
datasets = {} | |||
train_path = os.path.join(path,'train.all.bmes') | |||
dev_path = os.path.join(path,'dev.all.bmes') | |||
test_path = os.path.join(path,'test.all.bmes') | |||
datasets['train'] = loader.load(train_path).datasets['train'] | |||
datasets['dev'] = loader.load(dev_path).datasets['train'] | |||
datasets['test'] = loader.load(test_path).datasets['train'] | |||
for k,v in datasets.items(): | |||
print('{}:{}'.format(k,len(v))) | |||
vocabs = {} | |||
word_vocab = Vocabulary() | |||
bigram_vocab = Vocabulary() | |||
label_vocab = Vocabulary(padding=None,unknown=None) | |||
for k,v in datasets.items(): | |||
# ignore the word segmentation tag | |||
v.apply_field(lambda x: [w[0] for w in x],'chars','chars') | |||
v.apply_field(get_bigrams,'chars','bigrams') | |||
word_vocab.from_dataset(datasets['train'],field_name='chars',no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||
label_vocab.from_dataset(datasets['train'],field_name='target') | |||
print('label_vocab:{}\n{}'.format(len(label_vocab),label_vocab.idx2word)) | |||
for k,v in datasets.items(): | |||
# v.set_pad_val('target',-100) | |||
v.add_seq_len('chars',new_field_name='seq_len') | |||
vocabs['char'] = word_vocab | |||
vocabs['label'] = label_vocab | |||
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||
if index_token: | |||
word_vocab.index_dataset(*list(datasets.values()), field_name='raw_words', new_field_name='words') | |||
bigram_vocab.index_dataset(*list(datasets.values()),field_name='raw_bigrams',new_field_name='bigrams') | |||
label_vocab.index_dataset(*list(datasets.values()), field_name='raw_target', new_field_name='target') | |||
# for k,v in datasets.items(): | |||
# v.set_input('chars','bigrams','seq_len','target') | |||
# v.set_target('target','seq_len') | |||
vocabs['bigram'] = bigram_vocab | |||
embeddings = {} | |||
if unigram_embedding_path is not None: | |||
unigram_embedding = StaticEmbedding(word_vocab, model_dir_or_name=unigram_embedding_path, | |||
word_dropout=0.01,normalize=normlize['char']) | |||
embeddings['char'] = unigram_embedding | |||
if bigram_embedding_path is not None: | |||
bigram_embedding = StaticEmbedding(bigram_vocab, model_dir_or_name=bigram_embedding_path, | |||
word_dropout=0.01,normalize=normlize['bigram']) | |||
embeddings['bigram'] = bigram_embedding | |||
return datasets, vocabs, embeddings | |||
@cache_results(_cache_fp='weiboNER uni+bi', _refresh=False) | |||
def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||
normlize={'char':True,'bigram':True,'word':False}): | |||
from fastNLP.io.data_loader import ConllLoader | |||
from fastNLP.io.loader import ConllLoader | |||
from utils import get_bigrams | |||
loader = ConllLoader(['chars','target']) | |||
@@ -492,7 +574,7 @@ def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,i | |||
@cache_results(_cache_fp='cache/ontonotes4ner',_refresh=False) | |||
def load_ontonotes4ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||
normalize={'char':True,'bigram':True,'word':False}): | |||
from fastNLP.io.data_loader import ConllLoader | |||
from fastNLP.io.loader import ConllLoader | |||
from utils import get_bigrams | |||
train_path = os.path.join(path,'train.char.bmes') | |||
@@ -1,6 +1,8 @@ | |||
import torch.nn as nn | |||
# print(1111111111) | |||
# from pathes import * | |||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner,load_weibo_ner | |||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,\ | |||
load_resume_ner,load_weibo_ner,load_weibo_ner_old | |||
from fastNLP.embeddings import StaticEmbedding | |||
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 | |||
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward | |||
@@ -18,23 +20,24 @@ from fastNLP import LRScheduler | |||
from torch.optim.lr_scheduler import LambdaLR | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--device',default='cuda:4') | |||
parser.add_argument('--device',default='cuda:1') | |||
parser.add_argument('--debug',default=False) | |||
parser.add_argument('--norm_embed',default=True) | |||
parser.add_argument('--batch',default=10) | |||
parser.add_argument('--norm_embed',default=False) | |||
parser.add_argument('--batch',default=1) | |||
parser.add_argument('--test_batch',default=1024) | |||
parser.add_argument('--optim',default='sgd',help='adam|sgd') | |||
parser.add_argument('--lr',default=0.045) | |||
parser.add_argument('--model',default='lattice',help='lattice|lstm') | |||
parser.add_argument('--skip_before_head',default=False)#in paper it's false | |||
parser.add_argument('--hidden',default=100) | |||
parser.add_argument('--hidden',default=113) | |||
parser.add_argument('--momentum',default=0) | |||
parser.add_argument('--bi',default=True) | |||
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') | |||
parser.add_argument('--dataset',default='weibo',help='resume|ontonote|weibo|msra') | |||
parser.add_argument('--use_bigram',default=True) | |||
parser.add_argument('--embed_dropout',default=0.5) | |||
parser.add_argument('--gaz_dropout',default=-1) | |||
parser.add_argument('--output_dropout',default=0.5) | |||
parser.add_argument('--epoch',default=100) | |||
parser.add_argument('--seed',default=100) | |||
@@ -49,8 +52,6 @@ if args.model == 'lattice': | |||
fit_msg = ' '.join(fit_msg_list) | |||
fitlog.commit(__file__,fit_msg=fit_msg) | |||
fitlog.add_hyper(args) | |||
device = torch.device(args.device) | |||
for k,v in args.__dict__.items(): | |||
print(k,v) | |||
@@ -78,6 +79,10 @@ elif args.dataset == 'weibo': | |||
_refresh=refresh_data,index_token=False, | |||
) | |||
elif args.dataset == 'weibo_old': | |||
datasets,vocabs,embeddings = load_weibo_ner_old(weibo_ner_old_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||
_refresh=refresh_data,index_token=False, | |||
) | |||
if args.dataset == 'ontonote': | |||
args.batch = 10 | |||
args.lr = 0.045 | |||
@@ -85,9 +90,18 @@ elif args.dataset == 'resume': | |||
args.batch = 1 | |||
args.lr = 0.015 | |||
elif args.dataset == 'weibo': | |||
args.batch = 10 | |||
args.gaz_dropout = 0.1 | |||
args.embed_dropout = 0.1 | |||
args.output_dropout = 0.1 | |||
elif args.dataset == 'weibo_old': | |||
args.embed_dropout = 0.1 | |||
args.output_dropout = 0.1 | |||
if args.gaz_dropout < 0: | |||
args.gaz_dropout = args.embed_dropout | |||
fitlog.add_hyper(args) | |||
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | |||
_refresh=refresh_data) | |||
@@ -145,7 +159,8 @@ if args.model =='lattice': | |||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||
skip_batch_first=True,bidirectional=args.bi,debug=args.debug, | |||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram | |||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram, | |||
gaz_dropout=args.gaz_dropout | |||
) | |||
elif args.model == 'lstm': | |||
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||
@@ -156,11 +171,12 @@ elif args.model == 'lstm': | |||
loss = LossInForward() | |||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') | |||
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') | |||
encoding_type = 'bmeso' | |||
if args.dataset == 'weibo': | |||
encoding_type = 'bio' | |||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type=encoding_type) | |||
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') | |||
metrics = [f1_metric,f1_metric_yj,acc_metric] | |||
metrics = [f1_metric,acc_metric] | |||
if args.optim == 'adam': | |||
optimizer = optim.Adam(model.parameters(),lr=args.lr) | |||
@@ -174,7 +190,7 @@ callbacks = [ | |||
FitlogCallback({'test':datasets['test'],'train':datasets['train']}), | |||
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep)) | |||
] | |||
print('label_vocab:{}\n{}'.format(len(vocabs['label']),vocabs['label'].idx2word)) | |||
trainer = Trainer(datasets['train'],model, | |||
optimizer=optimizer, | |||
loss=loss, | |||
@@ -3,7 +3,7 @@ from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP.modules import LSTM, ConditionalRandomField | |||
import torch | |||
from fastNLP import seq_len_to_mask | |||
from utils import better_init_rnn | |||
from utils import better_init_rnn,print_info | |||
class LatticeLSTM_SeqLabel(nn.Module): | |||
@@ -120,7 +120,7 @@ class LatticeLSTM_SeqLabel(nn.Module): | |||
class LatticeLSTM_SeqLabel_V1(nn.Module): | |||
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False, | |||
device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False, | |||
skip_before_head=False,use_bigram=True,vocabs=None): | |||
skip_before_head=False,use_bigram=True,vocabs=None,gaz_dropout=0): | |||
if device is None: | |||
self.device = torch.device('cpu') | |||
else: | |||
@@ -173,6 +173,7 @@ class LatticeLSTM_SeqLabel_V1(nn.Module): | |||
self.loss_func = nn.CrossEntropyLoss() | |||
self.embed_dropout = nn.Dropout(embed_dropout) | |||
self.gaz_dropout = nn.Dropout(gaz_dropout) | |||
self.output_dropout = nn.Dropout(output_dropout) | |||
def forward(self, chars, bigrams, seq_len, target, | |||
@@ -257,15 +258,22 @@ class LSTM_SeqLabel(nn.Module): | |||
better_init_rnn(self.encoder.lstm) | |||
self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) | |||
self.debug = False | |||
self.debug = True | |||
self.loss_func = nn.CrossEntropyLoss() | |||
self.embed_dropout = nn.Dropout(embed_dropout) | |||
self.output_dropout = nn.Dropout(output_dropout) | |||
self.crf = ConditionalRandomField(label_size, True) | |||
def forward(self, chars, bigrams, seq_len, target): | |||
if self.debug: | |||
print_info('chars:{}'.format(chars.size())) | |||
print_info('bigrams:{}'.format(bigrams.size())) | |||
print_info('seq_len:{}'.format(seq_len.size())) | |||
print_info('target:{}'.format(target.size())) | |||
embed_char = self.char_embed(chars) | |||
if self.use_bigram: | |||
@@ -291,6 +299,9 @@ class LSTM_SeqLabel(nn.Module): | |||
# batch_size, sent_len = pred.shape[0], pred.shape[1] | |||
# loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) | |||
if self.debug: | |||
print('debug mode:finish') | |||
exit(1208) | |||
if self.training: | |||
loss = self.crf(pred, target, mask) | |||
return {'loss': loss} | |||
@@ -326,7 +326,7 @@ class MultiInputLSTMCell_V1(nn.Module): | |||
alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) | |||
skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]) | |||
skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]).float() | |||
skip_mask = 1 - skip_mask | |||
@@ -622,8 +622,8 @@ class LatticeLSTMLayer_sup_back_V1(nn.Module): | |||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | |||
h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0) | |||
c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0) | |||
h_1_mask = h_1.masked_fill(~ mask_for_seq_len[:,i].unsqueeze(-1),0) | |||
c_1_mask = c_1.masked_fill(~ mask_for_seq_len[:, i].unsqueeze(-1), 0) | |||
h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1) | |||
@@ -20,4 +20,5 @@ sst2_path = '/remote-home/xnli/data/corpus/text_classification/SST-2/' | |||
ontonote4ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/OntoNote4NER' | |||
msra_ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/MSRANER' | |||
resume_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/ResumeNER' | |||
weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER' | |||
weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER' | |||
weibo_ner_old_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER_old' |