@@ -1,7 +1,11 @@ | |||||
[中文](#支持批并行的LatticeLSTM) | |||||
[English](#Batch-Parallel-LatticeLSTM) | |||||
# 支持批并行的LatticeLSTM | # 支持批并行的LatticeLSTM | ||||
+ 原论文:https://arxiv.org/abs/1805.02023 | + 原论文:https://arxiv.org/abs/1805.02023 | ||||
+ 在batch=10时,计算速度已明显超过[原版代码](https://github.com/jiesutd/LatticeLSTM)。 | + 在batch=10时,计算速度已明显超过[原版代码](https://github.com/jiesutd/LatticeLSTM)。 | ||||
+ 在main.py中添加三个embedding的文件路径以及对应数据集的路径即可运行 | + 在main.py中添加三个embedding的文件路径以及对应数据集的路径即可运行 | ||||
+ 此代码集合已加入fastNLP | |||||
## 运行环境: | ## 运行环境: | ||||
+ python >= 3.7.3 | + python >= 3.7.3 | ||||
@@ -18,7 +22,7 @@ | |||||
## 性能: | ## 性能: | ||||
|数据集| 目前达到的F1分数(test)|原文中的F1分数(test)| | |数据集| 目前达到的F1分数(test)|原文中的F1分数(test)| | ||||
|:----:|:----:|:----:| | |:----:|:----:|:----:| | ||||
|Weibo|62.73|58.79| | |||||
|Weibo|58.66|58.79| | |||||
|Resume|95.18|94.46| | |Resume|95.18|94.46| | ||||
|Ontonote|73.62|73.88| | |Ontonote|73.62|73.88| | ||||
@@ -26,3 +30,36 @@ | |||||
## 如有任何疑问请联系: | ## 如有任何疑问请联系: | ||||
+ lixiaonan_xdu@outlook.com | + lixiaonan_xdu@outlook.com | ||||
--- | |||||
# Batch Parallel LatticeLSTM | |||||
+ paper:https://arxiv.org/abs/1805.02023 | |||||
+ when batch is 10,the computation efficiency exceeds that of [original code](https://github.com/jiesutd/LatticeLSTM)。 | |||||
+ set the path of embeddings and corpus before you run main.py | |||||
+ this code set has been added to fastNLP | |||||
## Environment: | |||||
+ python >= 3.7.3 | |||||
+ fastNLP >= dev.0.5.0 | |||||
+ pytorch >= 1.1.0 | |||||
+ numpy >= 1.16.4 | |||||
+ fitlog >= 0.2.0 | |||||
## Dataset: | |||||
+ Resume,downloaded from [here](https://github.com/jiesutd/LatticeLSTM) | |||||
+ Ontonote | |||||
+ [Weibo](https://github.com/hltcoe/golden-horse) | |||||
to those unincluded dataset, you can write the interface function whose output form is like *load_ontonotes4ner* in load_data.py | |||||
## Performance: | |||||
|Dataset|F1 of my code(test)|F1 in paper(test)| | |||||
|:----:|:----:|:----:| | |||||
|Weibo|58.66|58.79| | |||||
|Resume|95.18|94.46| | |||||
|Ontonote|73.62|73.88| | |||||
PS:The Weibo dataset I use is V2, namely revised version. | |||||
## If any confusion, please contact: | |||||
+ lixiaonan_xdu@outlook.com |
@@ -1,252 +0,0 @@ | |||||
import torch.nn as nn | |||||
from pathes import * | |||||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 | |||||
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward | |||||
import torch.optim as optim | |||||
import argparse | |||||
import torch | |||||
import sys | |||||
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ | |||||
from fastNLP import Tester | |||||
import fitlog | |||||
from fastNLP.core.callback import FitlogCallback | |||||
from utils import set_seed | |||||
import os | |||||
from fastNLP import LRScheduler | |||||
from torch.optim.lr_scheduler import LambdaLR | |||||
# sys.path.append('.') | |||||
# sys.path.append('..') | |||||
# for p in sys.path: | |||||
# print(p) | |||||
# fitlog.add_hyper_in_file (__file__) # record your hyperparameters | |||||
########hyper | |||||
########hyper | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--device',default='cpu') | |||||
parser.add_argument('--debug',default=True) | |||||
parser.add_argument('--batch',default=1) | |||||
parser.add_argument('--test_batch',default=1024) | |||||
parser.add_argument('--optim',default='sgd',help='adam|sgd') | |||||
parser.add_argument('--lr',default=0.015) | |||||
parser.add_argument('--model',default='lattice',help='lattice|lstm') | |||||
parser.add_argument('--skip_before_head',default=False)#in paper it's false | |||||
parser.add_argument('--hidden',default=100) | |||||
parser.add_argument('--momentum',default=0) | |||||
parser.add_argument('--bi',default=True) | |||||
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') | |||||
parser.add_argument('--use_bigram',default=False) | |||||
parser.add_argument('--embed_dropout',default=0) | |||||
parser.add_argument('--output_dropout',default=0) | |||||
parser.add_argument('--epoch',default=100) | |||||
parser.add_argument('--seed',default=100) | |||||
args = parser.parse_args() | |||||
set_seed(args.seed) | |||||
fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)] | |||||
if args.model == 'lattice': | |||||
fit_msg_list.append(str(args.skip_before_head)) | |||||
fit_msg = ' '.join(fit_msg_list) | |||||
# fitlog.commit(__file__,fit_msg=fit_msg) | |||||
# fitlog.add_hyper(args) | |||||
device = torch.device(args.device) | |||||
for k,v in args.__dict__.items(): | |||||
print(k,v) | |||||
refresh_data = False | |||||
if args.dataset == 'ontonote': | |||||
datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||||
_refresh=refresh_data,index_token=False) | |||||
elif args.dataset == 'resume': | |||||
datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||||
_refresh=refresh_data,index_token=False) | |||||
# exit() | |||||
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | |||||
_refresh=refresh_data) | |||||
cache_name = os.path.join('cache',args.dataset+'_lattice') | |||||
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, | |||||
_refresh=refresh_data,_cache_fp=cache_name) | |||||
print('中:embedding:{}'.format(embeddings['char'](24))) | |||||
print('embed lookup dropout:{}'.format(embeddings['word'].word_dropout)) | |||||
# for k, v in datasets.items(): | |||||
# # v.apply_field(lambda x: list(map(len, x)), 'skips_l2r_word', 'lexicon_count') | |||||
# # v.apply_field(lambda x: | |||||
# # list(map(lambda y: | |||||
# # list(map(lambda z: vocabs['word'].to_index(z), y)), x)), | |||||
# # 'skips_l2r_word') | |||||
print(datasets['train'][0]) | |||||
print('vocab info:') | |||||
for k,v in vocabs.items(): | |||||
print('{}:{}'.format(k,len(v))) | |||||
# print(datasets['dev'][0]) | |||||
# print(datasets['test'][0]) | |||||
# print(datasets['train'].get_all_fields().keys()) | |||||
for k,v in datasets.items(): | |||||
if args.model == 'lattice': | |||||
v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source') | |||||
if args.skip_before_head: | |||||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||||
v.set_padder('skips_r2l_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True)) | |||||
else: | |||||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_r2l_word', LatticeLexiconPadder()) | |||||
v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1)) | |||||
v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1)) | |||||
if args.bi: | |||||
v.set_input('chars','bigrams','seq_len', | |||||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||||
'skips_r2l_word', 'skips_r2l_source','lexicon_count_back', | |||||
'target', | |||||
use_1st_ins_infer_dim_type=True) | |||||
else: | |||||
v.set_input('chars','bigrams','seq_len', | |||||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||||
'target', | |||||
use_1st_ins_infer_dim_type=True) | |||||
v.set_target('target','seq_len') | |||||
v['target'].set_pad_val(0) | |||||
elif args.model == 'lstm': | |||||
v.set_ignore_type('skips_l2r_word','skips_l2r_source') | |||||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||||
v.set_input('chars','bigrams','seq_len','target', | |||||
use_1st_ins_infer_dim_type=True) | |||||
v.set_target('target','seq_len') | |||||
v['target'].set_pad_val(0) | |||||
print(datasets['dev']['skips_l2r_word'][100]) | |||||
if args.model =='lattice': | |||||
model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||||
skip_batch_first=True,bidirectional=args.bi,debug=args.debug, | |||||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram, | |||||
vocabs=vocabs | |||||
) | |||||
elif args.model == 'lstm': | |||||
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||||
bidirectional=args.bi, | |||||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||||
use_bigram=args.use_bigram) | |||||
for k,v in model.state_dict().items(): | |||||
print('{}:{}'.format(k,v.size())) | |||||
# exit() | |||||
weight_dict = torch.load(open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb')) | |||||
# print(weight_dict.keys()) | |||||
# for k,v in weight_dict.items(): | |||||
# print('{}:{}'.format(k,v.size())) | |||||
def state_dict_param(model): | |||||
param_list = list(model.named_parameters()) | |||||
print(len(param_list)) | |||||
param_dict = {} | |||||
for i in range(len(param_list)): | |||||
param_dict[param_list[i][0]] = param_list[i][1] | |||||
return param_dict | |||||
def copy_yangjie_lattice_weight(target,source_dict): | |||||
t = state_dict_param(target) | |||||
with torch.no_grad(): | |||||
t['encoder.char_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.weight_ih']) | |||||
t['encoder.char_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.weight_hh']) | |||||
t['encoder.char_cell.alpha_weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_ih']) | |||||
t['encoder.char_cell.alpha_weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_hh']) | |||||
t['encoder.char_cell.bias'].set_(source_dict['lstm.forward_lstm.rnn.bias']) | |||||
t['encoder.char_cell.alpha_bias'].set_(source_dict['lstm.forward_lstm.rnn.alpha_bias']) | |||||
t['encoder.word_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_ih']) | |||||
t['encoder.word_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_hh']) | |||||
t['encoder.word_cell.bias'].set_(source_dict['lstm.forward_lstm.word_rnn.bias']) | |||||
t['encoder_back.char_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.weight_ih']) | |||||
t['encoder_back.char_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.weight_hh']) | |||||
t['encoder_back.char_cell.alpha_weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_ih']) | |||||
t['encoder_back.char_cell.alpha_weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_hh']) | |||||
t['encoder_back.char_cell.bias'].set_(source_dict['lstm.backward_lstm.rnn.bias']) | |||||
t['encoder_back.char_cell.alpha_bias'].set_(source_dict['lstm.backward_lstm.rnn.alpha_bias']) | |||||
t['encoder_back.word_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_ih']) | |||||
t['encoder_back.word_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_hh']) | |||||
t['encoder_back.word_cell.bias'].set_(source_dict['lstm.backward_lstm.word_rnn.bias']) | |||||
for k,v in t.items(): | |||||
print('{}:{}'.format(k,v)) | |||||
copy_yangjie_lattice_weight(model,weight_dict) | |||||
# print(vocabs['label'].word2idx.keys()) | |||||
loss = LossInForward() | |||||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') | |||||
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') | |||||
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') | |||||
metrics = [f1_metric,f1_metric_yj,acc_metric] | |||||
if args.optim == 'adam': | |||||
optimizer = optim.Adam(model.parameters(),lr=args.lr) | |||||
elif args.optim == 'sgd': | |||||
optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) | |||||
# tester = Tester(datasets['dev'],model,metrics=metrics,batch_size=args.test_batch,device=device) | |||||
# test_result = tester.test() | |||||
# print(test_result) | |||||
callbacks = [ | |||||
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05)**ep)) | |||||
] | |||||
print(datasets['train'][:2]) | |||||
print(vocabs['char'].to_index(':')) | |||||
# StaticEmbedding | |||||
# datasets['train'] = datasets['train'][1:] | |||||
from fastNLP import SequentialSampler | |||||
trainer = Trainer(datasets['train'],model, | |||||
optimizer=optimizer, | |||||
loss=loss, | |||||
metrics=metrics, | |||||
dev_data=datasets['dev'], | |||||
device=device, | |||||
batch_size=args.batch, | |||||
n_epochs=args.epoch, | |||||
dev_batch_size=args.test_batch, | |||||
callbacks=callbacks, | |||||
check_code_level=-1, | |||||
sampler=SequentialSampler()) | |||||
trainer.train() |
@@ -416,10 +416,92 @@ def load_conllized_ontonote_pkl_yf(path): | |||||
return task_lst, vocabs | return task_lst, vocabs | ||||
@cache_results(_cache_fp='weiboNER old uni+bi', _refresh=False) | |||||
def load_weibo_ner_old(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||||
normlize={'char':True,'bigram':True,'word':False}): | |||||
from fastNLP.io.data_loader import ConllLoader | |||||
from utils import get_bigrams | |||||
loader = ConllLoader(['chars','target']) | |||||
# from fastNLP.io.file_reader import _read_conll | |||||
# from fastNLP.core import Instance,DataSet | |||||
# def _load(path): | |||||
# ds = DataSet() | |||||
# for idx, data in _read_conll(path, indexes=loader.indexes, dropna=loader.dropna, | |||||
# encoding='ISO-8859-1'): | |||||
# ins = {h: data[i] for i, h in enumerate(loader.headers)} | |||||
# ds.append(Instance(**ins)) | |||||
# return ds | |||||
# from fastNLP.io.utils import check_loader_paths | |||||
# paths = check_loader_paths(path) | |||||
# datasets = {name: _load(path) for name, path in paths.items()} | |||||
datasets = {} | |||||
train_path = os.path.join(path,'train.all.bmes') | |||||
dev_path = os.path.join(path,'dev.all.bmes') | |||||
test_path = os.path.join(path,'test.all.bmes') | |||||
datasets['train'] = loader.load(train_path).datasets['train'] | |||||
datasets['dev'] = loader.load(dev_path).datasets['train'] | |||||
datasets['test'] = loader.load(test_path).datasets['train'] | |||||
for k,v in datasets.items(): | |||||
print('{}:{}'.format(k,len(v))) | |||||
vocabs = {} | |||||
word_vocab = Vocabulary() | |||||
bigram_vocab = Vocabulary() | |||||
label_vocab = Vocabulary(padding=None,unknown=None) | |||||
for k,v in datasets.items(): | |||||
# ignore the word segmentation tag | |||||
v.apply_field(lambda x: [w[0] for w in x],'chars','chars') | |||||
v.apply_field(get_bigrams,'chars','bigrams') | |||||
word_vocab.from_dataset(datasets['train'],field_name='chars',no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||||
label_vocab.from_dataset(datasets['train'],field_name='target') | |||||
print('label_vocab:{}\n{}'.format(len(label_vocab),label_vocab.idx2word)) | |||||
for k,v in datasets.items(): | |||||
# v.set_pad_val('target',-100) | |||||
v.add_seq_len('chars',new_field_name='seq_len') | |||||
vocabs['char'] = word_vocab | |||||
vocabs['label'] = label_vocab | |||||
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||||
if index_token: | |||||
word_vocab.index_dataset(*list(datasets.values()), field_name='raw_words', new_field_name='words') | |||||
bigram_vocab.index_dataset(*list(datasets.values()),field_name='raw_bigrams',new_field_name='bigrams') | |||||
label_vocab.index_dataset(*list(datasets.values()), field_name='raw_target', new_field_name='target') | |||||
# for k,v in datasets.items(): | |||||
# v.set_input('chars','bigrams','seq_len','target') | |||||
# v.set_target('target','seq_len') | |||||
vocabs['bigram'] = bigram_vocab | |||||
embeddings = {} | |||||
if unigram_embedding_path is not None: | |||||
unigram_embedding = StaticEmbedding(word_vocab, model_dir_or_name=unigram_embedding_path, | |||||
word_dropout=0.01,normalize=normlize['char']) | |||||
embeddings['char'] = unigram_embedding | |||||
if bigram_embedding_path is not None: | |||||
bigram_embedding = StaticEmbedding(bigram_vocab, model_dir_or_name=bigram_embedding_path, | |||||
word_dropout=0.01,normalize=normlize['bigram']) | |||||
embeddings['bigram'] = bigram_embedding | |||||
return datasets, vocabs, embeddings | |||||
@cache_results(_cache_fp='weiboNER uni+bi', _refresh=False) | @cache_results(_cache_fp='weiboNER uni+bi', _refresh=False) | ||||
def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True, | def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True, | ||||
normlize={'char':True,'bigram':True,'word':False}): | normlize={'char':True,'bigram':True,'word':False}): | ||||
from fastNLP.io.data_loader import ConllLoader | |||||
from fastNLP.io.loader import ConllLoader | |||||
from utils import get_bigrams | from utils import get_bigrams | ||||
loader = ConllLoader(['chars','target']) | loader = ConllLoader(['chars','target']) | ||||
@@ -492,7 +574,7 @@ def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,i | |||||
@cache_results(_cache_fp='cache/ontonotes4ner',_refresh=False) | @cache_results(_cache_fp='cache/ontonotes4ner',_refresh=False) | ||||
def load_ontonotes4ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True, | def load_ontonotes4ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True, | ||||
normalize={'char':True,'bigram':True,'word':False}): | normalize={'char':True,'bigram':True,'word':False}): | ||||
from fastNLP.io.data_loader import ConllLoader | |||||
from fastNLP.io.loader import ConllLoader | |||||
from utils import get_bigrams | from utils import get_bigrams | ||||
train_path = os.path.join(path,'train.char.bmes') | train_path = os.path.join(path,'train.char.bmes') | ||||
@@ -1,6 +1,8 @@ | |||||
import torch.nn as nn | import torch.nn as nn | ||||
# print(1111111111) | |||||
# from pathes import * | # from pathes import * | ||||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner,load_weibo_ner | |||||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,\ | |||||
load_resume_ner,load_weibo_ner,load_weibo_ner_old | |||||
from fastNLP.embeddings import StaticEmbedding | from fastNLP.embeddings import StaticEmbedding | ||||
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 | from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 | ||||
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward | from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward | ||||
@@ -18,23 +20,24 @@ from fastNLP import LRScheduler | |||||
from torch.optim.lr_scheduler import LambdaLR | from torch.optim.lr_scheduler import LambdaLR | ||||
parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
parser.add_argument('--device',default='cuda:4') | |||||
parser.add_argument('--device',default='cuda:1') | |||||
parser.add_argument('--debug',default=False) | parser.add_argument('--debug',default=False) | ||||
parser.add_argument('--norm_embed',default=True) | |||||
parser.add_argument('--batch',default=10) | |||||
parser.add_argument('--norm_embed',default=False) | |||||
parser.add_argument('--batch',default=1) | |||||
parser.add_argument('--test_batch',default=1024) | parser.add_argument('--test_batch',default=1024) | ||||
parser.add_argument('--optim',default='sgd',help='adam|sgd') | parser.add_argument('--optim',default='sgd',help='adam|sgd') | ||||
parser.add_argument('--lr',default=0.045) | parser.add_argument('--lr',default=0.045) | ||||
parser.add_argument('--model',default='lattice',help='lattice|lstm') | parser.add_argument('--model',default='lattice',help='lattice|lstm') | ||||
parser.add_argument('--skip_before_head',default=False)#in paper it's false | parser.add_argument('--skip_before_head',default=False)#in paper it's false | ||||
parser.add_argument('--hidden',default=100) | |||||
parser.add_argument('--hidden',default=113) | |||||
parser.add_argument('--momentum',default=0) | parser.add_argument('--momentum',default=0) | ||||
parser.add_argument('--bi',default=True) | parser.add_argument('--bi',default=True) | ||||
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') | |||||
parser.add_argument('--dataset',default='weibo',help='resume|ontonote|weibo|msra') | |||||
parser.add_argument('--use_bigram',default=True) | parser.add_argument('--use_bigram',default=True) | ||||
parser.add_argument('--embed_dropout',default=0.5) | parser.add_argument('--embed_dropout',default=0.5) | ||||
parser.add_argument('--gaz_dropout',default=-1) | |||||
parser.add_argument('--output_dropout',default=0.5) | parser.add_argument('--output_dropout',default=0.5) | ||||
parser.add_argument('--epoch',default=100) | parser.add_argument('--epoch',default=100) | ||||
parser.add_argument('--seed',default=100) | parser.add_argument('--seed',default=100) | ||||
@@ -49,8 +52,6 @@ if args.model == 'lattice': | |||||
fit_msg = ' '.join(fit_msg_list) | fit_msg = ' '.join(fit_msg_list) | ||||
fitlog.commit(__file__,fit_msg=fit_msg) | fitlog.commit(__file__,fit_msg=fit_msg) | ||||
fitlog.add_hyper(args) | |||||
device = torch.device(args.device) | device = torch.device(args.device) | ||||
for k,v in args.__dict__.items(): | for k,v in args.__dict__.items(): | ||||
print(k,v) | print(k,v) | ||||
@@ -78,6 +79,10 @@ elif args.dataset == 'weibo': | |||||
_refresh=refresh_data,index_token=False, | _refresh=refresh_data,index_token=False, | ||||
) | ) | ||||
elif args.dataset == 'weibo_old': | |||||
datasets,vocabs,embeddings = load_weibo_ner_old(weibo_ner_old_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||||
_refresh=refresh_data,index_token=False, | |||||
) | |||||
if args.dataset == 'ontonote': | if args.dataset == 'ontonote': | ||||
args.batch = 10 | args.batch = 10 | ||||
args.lr = 0.045 | args.lr = 0.045 | ||||
@@ -85,9 +90,18 @@ elif args.dataset == 'resume': | |||||
args.batch = 1 | args.batch = 1 | ||||
args.lr = 0.015 | args.lr = 0.015 | ||||
elif args.dataset == 'weibo': | elif args.dataset == 'weibo': | ||||
args.batch = 10 | |||||
args.gaz_dropout = 0.1 | |||||
args.embed_dropout = 0.1 | args.embed_dropout = 0.1 | ||||
args.output_dropout = 0.1 | args.output_dropout = 0.1 | ||||
elif args.dataset == 'weibo_old': | |||||
args.embed_dropout = 0.1 | |||||
args.output_dropout = 0.1 | |||||
if args.gaz_dropout < 0: | |||||
args.gaz_dropout = args.embed_dropout | |||||
fitlog.add_hyper(args) | |||||
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | ||||
_refresh=refresh_data) | _refresh=refresh_data) | ||||
@@ -145,7 +159,8 @@ if args.model =='lattice': | |||||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | ||||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | ||||
skip_batch_first=True,bidirectional=args.bi,debug=args.debug, | skip_batch_first=True,bidirectional=args.bi,debug=args.debug, | ||||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram | |||||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram, | |||||
gaz_dropout=args.gaz_dropout | |||||
) | ) | ||||
elif args.model == 'lstm': | elif args.model == 'lstm': | ||||
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], | model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], | ||||
@@ -156,11 +171,12 @@ elif args.model == 'lstm': | |||||
loss = LossInForward() | loss = LossInForward() | ||||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') | |||||
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') | |||||
encoding_type = 'bmeso' | |||||
if args.dataset == 'weibo': | |||||
encoding_type = 'bio' | |||||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type=encoding_type) | |||||
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') | acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') | ||||
metrics = [f1_metric,f1_metric_yj,acc_metric] | |||||
metrics = [f1_metric,acc_metric] | |||||
if args.optim == 'adam': | if args.optim == 'adam': | ||||
optimizer = optim.Adam(model.parameters(),lr=args.lr) | optimizer = optim.Adam(model.parameters(),lr=args.lr) | ||||
@@ -174,7 +190,7 @@ callbacks = [ | |||||
FitlogCallback({'test':datasets['test'],'train':datasets['train']}), | FitlogCallback({'test':datasets['test'],'train':datasets['train']}), | ||||
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep)) | LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep)) | ||||
] | ] | ||||
print('label_vocab:{}\n{}'.format(len(vocabs['label']),vocabs['label'].idx2word)) | |||||
trainer = Trainer(datasets['train'],model, | trainer = Trainer(datasets['train'],model, | ||||
optimizer=optimizer, | optimizer=optimizer, | ||||
loss=loss, | loss=loss, | ||||
@@ -3,7 +3,7 @@ from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP.modules import LSTM, ConditionalRandomField | from fastNLP.modules import LSTM, ConditionalRandomField | ||||
import torch | import torch | ||||
from fastNLP import seq_len_to_mask | from fastNLP import seq_len_to_mask | ||||
from utils import better_init_rnn | |||||
from utils import better_init_rnn,print_info | |||||
class LatticeLSTM_SeqLabel(nn.Module): | class LatticeLSTM_SeqLabel(nn.Module): | ||||
@@ -120,7 +120,7 @@ class LatticeLSTM_SeqLabel(nn.Module): | |||||
class LatticeLSTM_SeqLabel_V1(nn.Module): | class LatticeLSTM_SeqLabel_V1(nn.Module): | ||||
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False, | def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False, | ||||
device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False, | device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False, | ||||
skip_before_head=False,use_bigram=True,vocabs=None): | |||||
skip_before_head=False,use_bigram=True,vocabs=None,gaz_dropout=0): | |||||
if device is None: | if device is None: | ||||
self.device = torch.device('cpu') | self.device = torch.device('cpu') | ||||
else: | else: | ||||
@@ -173,6 +173,7 @@ class LatticeLSTM_SeqLabel_V1(nn.Module): | |||||
self.loss_func = nn.CrossEntropyLoss() | self.loss_func = nn.CrossEntropyLoss() | ||||
self.embed_dropout = nn.Dropout(embed_dropout) | self.embed_dropout = nn.Dropout(embed_dropout) | ||||
self.gaz_dropout = nn.Dropout(gaz_dropout) | |||||
self.output_dropout = nn.Dropout(output_dropout) | self.output_dropout = nn.Dropout(output_dropout) | ||||
def forward(self, chars, bigrams, seq_len, target, | def forward(self, chars, bigrams, seq_len, target, | ||||
@@ -257,15 +258,22 @@ class LSTM_SeqLabel(nn.Module): | |||||
better_init_rnn(self.encoder.lstm) | better_init_rnn(self.encoder.lstm) | ||||
self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) | self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) | ||||
self.debug = False | |||||
self.debug = True | |||||
self.loss_func = nn.CrossEntropyLoss() | self.loss_func = nn.CrossEntropyLoss() | ||||
self.embed_dropout = nn.Dropout(embed_dropout) | self.embed_dropout = nn.Dropout(embed_dropout) | ||||
self.output_dropout = nn.Dropout(output_dropout) | self.output_dropout = nn.Dropout(output_dropout) | ||||
self.crf = ConditionalRandomField(label_size, True) | self.crf = ConditionalRandomField(label_size, True) | ||||
def forward(self, chars, bigrams, seq_len, target): | def forward(self, chars, bigrams, seq_len, target): | ||||
if self.debug: | |||||
print_info('chars:{}'.format(chars.size())) | |||||
print_info('bigrams:{}'.format(bigrams.size())) | |||||
print_info('seq_len:{}'.format(seq_len.size())) | |||||
print_info('target:{}'.format(target.size())) | |||||
embed_char = self.char_embed(chars) | embed_char = self.char_embed(chars) | ||||
if self.use_bigram: | if self.use_bigram: | ||||
@@ -291,6 +299,9 @@ class LSTM_SeqLabel(nn.Module): | |||||
# batch_size, sent_len = pred.shape[0], pred.shape[1] | # batch_size, sent_len = pred.shape[0], pred.shape[1] | ||||
# loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) | # loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) | ||||
if self.debug: | |||||
print('debug mode:finish') | |||||
exit(1208) | |||||
if self.training: | if self.training: | ||||
loss = self.crf(pred, target, mask) | loss = self.crf(pred, target, mask) | ||||
return {'loss': loss} | return {'loss': loss} | ||||
@@ -326,7 +326,7 @@ class MultiInputLSTMCell_V1(nn.Module): | |||||
alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) | alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) | ||||
skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]) | |||||
skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]).float() | |||||
skip_mask = 1 - skip_mask | skip_mask = 1 - skip_mask | ||||
@@ -622,8 +622,8 @@ class LatticeLSTMLayer_sup_back_V1(nn.Module): | |||||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | ||||
h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0) | |||||
c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0) | |||||
h_1_mask = h_1.masked_fill(~ mask_for_seq_len[:,i].unsqueeze(-1),0) | |||||
c_1_mask = c_1.masked_fill(~ mask_for_seq_len[:, i].unsqueeze(-1), 0) | |||||
h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1) | h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1) | ||||
@@ -20,4 +20,5 @@ sst2_path = '/remote-home/xnli/data/corpus/text_classification/SST-2/' | |||||
ontonote4ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/OntoNote4NER' | ontonote4ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/OntoNote4NER' | ||||
msra_ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/MSRANER' | msra_ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/MSRANER' | ||||
resume_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/ResumeNER' | resume_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/ResumeNER' | ||||
weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER' | |||||
weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER' | |||||
weibo_ner_old_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER_old' |