From 9cbcd74c58aa3dccd48382c1a753f3e67b672c03 Mon Sep 17 00:00:00 2001 From: LeeSureman <1349342500@QQ.com> Date: Tue, 1 Oct 2019 16:34:44 +0800 Subject: [PATCH] batch-support LatticeLSTM --- .../chinese_ner/LatticeLSTM/check_output.py | 252 ++++++ .../chinese_ner/LatticeLSTM/load_data.py | 772 ++++++++++++++++++ .../chinese_ner/LatticeLSTM/main.py | 189 +++++ .../chinese_ner/LatticeLSTM/models.py | 299 +++++++ .../chinese_ner/LatticeLSTM/modules.py | 638 +++++++++++++++ .../chinese_ner/LatticeLSTM/pathes.py | 23 + .../chinese_ner/LatticeLSTM/small.py | 126 +++ .../chinese_ner/LatticeLSTM/utils.py | 361 ++++++++ .../chinese_ner/LatticeLSTM/utils_.py | 405 +++++++++ 9 files changed, 3065 insertions(+) create mode 100644 reproduction/seqence_labelling/chinese_ner/LatticeLSTM/check_output.py create mode 100644 reproduction/seqence_labelling/chinese_ner/LatticeLSTM/load_data.py create mode 100644 reproduction/seqence_labelling/chinese_ner/LatticeLSTM/main.py create mode 100644 reproduction/seqence_labelling/chinese_ner/LatticeLSTM/models.py create mode 100644 reproduction/seqence_labelling/chinese_ner/LatticeLSTM/modules.py create mode 100644 reproduction/seqence_labelling/chinese_ner/LatticeLSTM/pathes.py create mode 100644 reproduction/seqence_labelling/chinese_ner/LatticeLSTM/small.py create mode 100644 reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils.py create mode 100644 reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils_.py diff --git a/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/check_output.py b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/check_output.py new file mode 100644 index 00000000..fa8aeae3 --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/check_output.py @@ -0,0 +1,252 @@ +import torch.nn as nn +from pathes import * +from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner +from fastNLP.embeddings import StaticEmbedding +from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 +from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward +import torch.optim as optim +import argparse +import torch +import sys +from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ +from fastNLP import Tester +import fitlog +from fastNLP.core.callback import FitlogCallback +from utils import set_seed +import os +from fastNLP import LRScheduler +from torch.optim.lr_scheduler import LambdaLR + + + +# sys.path.append('.') +# sys.path.append('..') +# for p in sys.path: +# print(p) +# fitlog.add_hyper_in_file (__file__) # record your hyperparameters +########hyper + +########hyper + +parser = argparse.ArgumentParser() +parser.add_argument('--device',default='cpu') +parser.add_argument('--debug',default=True) + +parser.add_argument('--batch',default=1) +parser.add_argument('--test_batch',default=1024) +parser.add_argument('--optim',default='sgd',help='adam|sgd') +parser.add_argument('--lr',default=0.015) +parser.add_argument('--model',default='lattice',help='lattice|lstm') +parser.add_argument('--skip_before_head',default=False)#in paper it's false +parser.add_argument('--hidden',default=100) +parser.add_argument('--momentum',default=0) +parser.add_argument('--bi',default=True) +parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') +parser.add_argument('--use_bigram',default=False) + +parser.add_argument('--embed_dropout',default=0) +parser.add_argument('--output_dropout',default=0) +parser.add_argument('--epoch',default=100) +parser.add_argument('--seed',default=100) + +args = parser.parse_args() + +set_seed(args.seed) + +fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)] +if args.model == 'lattice': + fit_msg_list.append(str(args.skip_before_head)) +fit_msg = ' '.join(fit_msg_list) +# fitlog.commit(__file__,fit_msg=fit_msg) + + +# fitlog.add_hyper(args) +device = torch.device(args.device) +for k,v in args.__dict__.items(): + print(k,v) + +refresh_data = False +if args.dataset == 'ontonote': + datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, + _refresh=refresh_data,index_token=False) +elif args.dataset == 'resume': + datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, + _refresh=refresh_data,index_token=False) +# exit() +w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, + _refresh=refresh_data) + + + +cache_name = os.path.join('cache',args.dataset+'_lattice') +datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, + _refresh=refresh_data,_cache_fp=cache_name) + +print('中:embedding:{}'.format(embeddings['char'](24))) +print('embed lookup dropout:{}'.format(embeddings['word'].word_dropout)) + +# for k, v in datasets.items(): +# # v.apply_field(lambda x: list(map(len, x)), 'skips_l2r_word', 'lexicon_count') +# # v.apply_field(lambda x: +# # list(map(lambda y: +# # list(map(lambda z: vocabs['word'].to_index(z), y)), x)), +# # 'skips_l2r_word') + +print(datasets['train'][0]) +print('vocab info:') +for k,v in vocabs.items(): + print('{}:{}'.format(k,len(v))) +# print(datasets['dev'][0]) +# print(datasets['test'][0]) +# print(datasets['train'].get_all_fields().keys()) +for k,v in datasets.items(): + if args.model == 'lattice': + v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source') + if args.skip_before_head: + v.set_padder('skips_l2r_word',LatticeLexiconPadder()) + v.set_padder('skips_l2r_source',LatticeLexiconPadder()) + v.set_padder('skips_r2l_word',LatticeLexiconPadder()) + v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True)) + else: + v.set_padder('skips_l2r_word',LatticeLexiconPadder()) + v.set_padder('skips_r2l_word', LatticeLexiconPadder()) + v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1)) + v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1)) + if args.bi: + v.set_input('chars','bigrams','seq_len', + 'skips_l2r_word','skips_l2r_source','lexicon_count', + 'skips_r2l_word', 'skips_r2l_source','lexicon_count_back', + 'target', + use_1st_ins_infer_dim_type=True) + else: + v.set_input('chars','bigrams','seq_len', + 'skips_l2r_word','skips_l2r_source','lexicon_count', + 'target', + use_1st_ins_infer_dim_type=True) + v.set_target('target','seq_len') + + v['target'].set_pad_val(0) + elif args.model == 'lstm': + v.set_ignore_type('skips_l2r_word','skips_l2r_source') + v.set_padder('skips_l2r_word',LatticeLexiconPadder()) + v.set_padder('skips_l2r_source',LatticeLexiconPadder()) + v.set_input('chars','bigrams','seq_len','target', + use_1st_ins_infer_dim_type=True) + v.set_target('target','seq_len') + + v['target'].set_pad_val(0) + +print(datasets['dev']['skips_l2r_word'][100]) + + +if args.model =='lattice': + model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'], + hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, + embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, + skip_batch_first=True,bidirectional=args.bi,debug=args.debug, + skip_before_head=args.skip_before_head,use_bigram=args.use_bigram, + vocabs=vocabs + ) +elif args.model == 'lstm': + model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], + hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, + bidirectional=args.bi, + embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, + use_bigram=args.use_bigram) + +for k,v in model.state_dict().items(): + print('{}:{}'.format(k,v.size())) + + + +# exit() +weight_dict = torch.load(open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb')) +# print(weight_dict.keys()) +# for k,v in weight_dict.items(): +# print('{}:{}'.format(k,v.size())) +def state_dict_param(model): + param_list = list(model.named_parameters()) + print(len(param_list)) + param_dict = {} + for i in range(len(param_list)): + param_dict[param_list[i][0]] = param_list[i][1] + + return param_dict + + +def copy_yangjie_lattice_weight(target,source_dict): + t = state_dict_param(target) + with torch.no_grad(): + t['encoder.char_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.weight_ih']) + t['encoder.char_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.weight_hh']) + t['encoder.char_cell.alpha_weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_ih']) + t['encoder.char_cell.alpha_weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_hh']) + t['encoder.char_cell.bias'].set_(source_dict['lstm.forward_lstm.rnn.bias']) + t['encoder.char_cell.alpha_bias'].set_(source_dict['lstm.forward_lstm.rnn.alpha_bias']) + t['encoder.word_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_ih']) + t['encoder.word_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_hh']) + t['encoder.word_cell.bias'].set_(source_dict['lstm.forward_lstm.word_rnn.bias']) + + t['encoder_back.char_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.weight_ih']) + t['encoder_back.char_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.weight_hh']) + t['encoder_back.char_cell.alpha_weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_ih']) + t['encoder_back.char_cell.alpha_weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_hh']) + t['encoder_back.char_cell.bias'].set_(source_dict['lstm.backward_lstm.rnn.bias']) + t['encoder_back.char_cell.alpha_bias'].set_(source_dict['lstm.backward_lstm.rnn.alpha_bias']) + t['encoder_back.word_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_ih']) + t['encoder_back.word_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_hh']) + t['encoder_back.word_cell.bias'].set_(source_dict['lstm.backward_lstm.word_rnn.bias']) + + for k,v in t.items(): + print('{}:{}'.format(k,v)) + +copy_yangjie_lattice_weight(model,weight_dict) + +# print(vocabs['label'].word2idx.keys()) + + + + + + + + +loss = LossInForward() + +f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') +f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') +acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') +metrics = [f1_metric,f1_metric_yj,acc_metric] + +if args.optim == 'adam': + optimizer = optim.Adam(model.parameters(),lr=args.lr) +elif args.optim == 'sgd': + optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) + + + +# tester = Tester(datasets['dev'],model,metrics=metrics,batch_size=args.test_batch,device=device) +# test_result = tester.test() +# print(test_result) +callbacks = [ + LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05)**ep)) +] +print(datasets['train'][:2]) +print(vocabs['char'].to_index(':')) +# StaticEmbedding +# datasets['train'] = datasets['train'][1:] +from fastNLP import SequentialSampler +trainer = Trainer(datasets['train'],model, + optimizer=optimizer, + loss=loss, + metrics=metrics, + dev_data=datasets['dev'], + device=device, + batch_size=args.batch, + n_epochs=args.epoch, + dev_batch_size=args.test_batch, + callbacks=callbacks, + check_code_level=-1, + sampler=SequentialSampler()) + +trainer.train() \ No newline at end of file diff --git a/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/load_data.py b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/load_data.py new file mode 100644 index 00000000..919f4e61 --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/load_data.py @@ -0,0 +1,772 @@ +from fastNLP.io import CSVLoader +from fastNLP import Vocabulary +from fastNLP import Const +import numpy as np +import fitlog +import pickle +import os +from fastNLP.embeddings import StaticEmbedding +from fastNLP import cache_results + + +@cache_results(_cache_fp='mtl16', _refresh=False) +def load_16_task(dict_path): + ''' + + :param dict_path: /remote-home/txsun/fnlp/MTL-LT/data + :return: + ''' + task_path = os.path.join(dict_path,'data.pkl') + embedding_path = os.path.join(dict_path,'word_embedding.npy') + + embedding = np.load(embedding_path).astype(np.float32) + + task_list = pickle.load(open(task_path, 'rb'))['task_lst'] + + for t in task_list: + t.train_set.rename_field('words_idx', 'words') + t.dev_set.rename_field('words_idx', 'words') + t.test_set.rename_field('words_idx', 'words') + + t.train_set.rename_field('label', 'target') + t.dev_set.rename_field('label', 'target') + t.test_set.rename_field('label', 'target') + + t.train_set.add_seq_len('words') + t.dev_set.add_seq_len('words') + t.test_set.add_seq_len('words') + + t.train_set.set_input(Const.INPUT, Const.INPUT_LEN) + t.dev_set.set_input(Const.INPUT, Const.INPUT_LEN) + t.test_set.set_input(Const.INPUT, Const.INPUT_LEN) + + return task_list,embedding + + +@cache_results(_cache_fp='SST2', _refresh=False) +def load_sst2(dict_path,embedding_path=None): + ''' + + :param dict_path: /remote-home/xnli/data/corpus/text_classification/SST-2/ + :param embedding_path: glove 300d txt + :return: + ''' + train_path = os.path.join(dict_path,'train.tsv') + dev_path = os.path.join(dict_path,'dev.tsv') + + loader = CSVLoader(headers=('words', 'target'), sep='\t') + train_data = loader.load(train_path).datasets['train'] + dev_data = loader.load(dev_path).datasets['train'] + + train_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words') + dev_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words') + + train_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len') + dev_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len') + + vocab = Vocabulary(min_freq=2) + vocab.from_dataset(train_data, field_name='words') + vocab.from_dataset(dev_data, field_name='words') + + # pretrained_embedding = load_word_emb(embedding_path, 300, vocab) + + label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(train_data, field_name='target') + + label_vocab.index_dataset(train_data, field_name='target') + label_vocab.index_dataset(dev_data, field_name='target') + + vocab.index_dataset(train_data, field_name='words', new_field_name='words') + vocab.index_dataset(dev_data, field_name='words', new_field_name='words') + + train_data.set_input(Const.INPUT, Const.INPUT_LEN) + train_data.set_target(Const.TARGET) + + dev_data.set_input(Const.INPUT, Const.INPUT_LEN) + dev_data.set_target(Const.TARGET) + + if embedding_path is not None: + pretrained_embedding = load_word_emb(embedding_path, 300, vocab) + return (train_data,dev_data),(vocab,label_vocab),pretrained_embedding + + else: + return (train_data,dev_data),(vocab,label_vocab) + +@cache_results(_cache_fp='OntonotesPOS', _refresh=False) +def load_conllized_ontonote_POS(path,embedding_path=None): + from fastNLP.io.data_loader import ConllLoader + header2index = {'words':3,'POS':4,'NER':10} + headers = ['words','POS'] + + if 'NER' in headers: + print('警告!通过 load_conllized_ontonote 函数读出来的NER标签不是BIOS,是纯粹的conll格式,是错误的!') + indexes = list(map(lambda x:header2index[x],headers)) + + loader = ConllLoader(headers,indexes) + + bundle = loader.load(path) + + # print(bundle.datasets) + + train_set = bundle.datasets['train'] + dev_set = bundle.datasets['dev'] + test_set = bundle.datasets['test'] + + + + + # train_set = loader.load(os.path.join(path,'train.txt')) + # dev_set = loader.load(os.path.join(path, 'dev.txt')) + # test_set = loader.load(os.path.join(path, 'test.txt')) + + # print(len(train_set)) + + train_set.add_seq_len('words','seq_len') + dev_set.add_seq_len('words','seq_len') + test_set.add_seq_len('words','seq_len') + + + + # print(dataset['POS']) + + vocab = Vocabulary(min_freq=1) + vocab.from_dataset(train_set,field_name='words') + vocab.from_dataset(dev_set, field_name='words') + vocab.from_dataset(test_set, field_name='words') + + vocab.index_dataset(train_set,field_name='words') + vocab.index_dataset(dev_set, field_name='words') + vocab.index_dataset(test_set, field_name='words') + + + + + label_vocab_dict = {} + + for i,h in enumerate(headers): + if h == 'words': + continue + label_vocab_dict[h] = Vocabulary(min_freq=1,padding=None,unknown=None) + label_vocab_dict[h].from_dataset(train_set,field_name=h) + + label_vocab_dict[h].index_dataset(train_set,field_name=h) + label_vocab_dict[h].index_dataset(dev_set,field_name=h) + label_vocab_dict[h].index_dataset(test_set,field_name=h) + + train_set.set_input(Const.INPUT, Const.INPUT_LEN) + train_set.set_target(headers[1]) + + dev_set.set_input(Const.INPUT, Const.INPUT_LEN) + dev_set.set_target(headers[1]) + + test_set.set_input(Const.INPUT, Const.INPUT_LEN) + test_set.set_target(headers[1]) + + if len(headers) > 2: + print('警告:由于任务数量大于1,所以需要每次手动设置target!') + + + print('train:',len(train_set),'dev:',len(dev_set),'test:',len(test_set)) + + if embedding_path is not None: + pretrained_embedding = load_word_emb(embedding_path, 300, vocab) + return (train_set,dev_set,test_set),(vocab,label_vocab_dict),pretrained_embedding + else: + return (train_set, dev_set, test_set), (vocab, label_vocab_dict) + + +@cache_results(_cache_fp='OntonotesNER', _refresh=False) +def load_conllized_ontonote_NER(path,embedding_path=None): + from fastNLP.io.pipe.conll import OntoNotesNERPipe + ontoNotesNERPipe = OntoNotesNERPipe(lower=True,target_pad_val=-100) + bundle_NER = ontoNotesNERPipe.process_from_file(path) + + train_set_NER = bundle_NER.datasets['train'] + dev_set_NER = bundle_NER.datasets['dev'] + test_set_NER = bundle_NER.datasets['test'] + + train_set_NER.add_seq_len('words','seq_len') + dev_set_NER.add_seq_len('words','seq_len') + test_set_NER.add_seq_len('words','seq_len') + + + NER_vocab = bundle_NER.get_vocab('target') + word_vocab = bundle_NER.get_vocab('words') + + if embedding_path is not None: + + embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, + dropout=0.5,lower=True) + + + # pretrained_embedding = load_word_emb(embedding_path, 300, word_vocab) + return (train_set_NER,dev_set_NER,test_set_NER),\ + (word_vocab,NER_vocab),embed + else: + return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab) + +@cache_results(_cache_fp='OntonotesPOSNER', _refresh=False) + +def load_conllized_ontonote_NER_POS(path,embedding_path=None): + from fastNLP.io.pipe.conll import OntoNotesNERPipe + ontoNotesNERPipe = OntoNotesNERPipe(lower=True) + bundle_NER = ontoNotesNERPipe.process_from_file(path) + + train_set_NER = bundle_NER.datasets['train'] + dev_set_NER = bundle_NER.datasets['dev'] + test_set_NER = bundle_NER.datasets['test'] + + NER_vocab = bundle_NER.get_vocab('target') + word_vocab = bundle_NER.get_vocab('words') + + (train_set_POS,dev_set_POS,test_set_POS),(_,POS_vocab) = load_conllized_ontonote_POS(path) + POS_vocab = POS_vocab['POS'] + + train_set_NER.add_field('pos',train_set_POS['POS'],is_target=True) + dev_set_NER.add_field('pos', dev_set_POS['POS'], is_target=True) + test_set_NER.add_field('pos', test_set_POS['POS'], is_target=True) + + if train_set_NER.has_field('target'): + train_set_NER.rename_field('target','ner') + + if dev_set_NER.has_field('target'): + dev_set_NER.rename_field('target','ner') + + if test_set_NER.has_field('target'): + test_set_NER.rename_field('target','ner') + + + + if train_set_NER.has_field('pos'): + train_set_NER.rename_field('pos','posid') + if dev_set_NER.has_field('pos'): + dev_set_NER.rename_field('pos','posid') + if test_set_NER.has_field('pos'): + test_set_NER.rename_field('pos','posid') + + if train_set_NER.has_field('ner'): + train_set_NER.rename_field('ner','nerid') + if dev_set_NER.has_field('ner'): + dev_set_NER.rename_field('ner','nerid') + if test_set_NER.has_field('ner'): + test_set_NER.rename_field('ner','nerid') + + if embedding_path is not None: + + embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, + dropout=0.5,lower=True) + + return (train_set_NER,dev_set_NER,test_set_NER),\ + (word_vocab,POS_vocab,NER_vocab),embed + else: + return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab) + +@cache_results(_cache_fp='Ontonotes3', _refresh=True) +def load_conllized_ontonote_pkl(path,embedding_path=None): + + data_bundle = pickle.load(open(path,'rb')) + train_set = data_bundle.datasets['train'] + dev_set = data_bundle.datasets['dev'] + test_set = data_bundle.datasets['test'] + + train_set.rename_field('pos','posid') + train_set.rename_field('ner','nerid') + train_set.rename_field('chunk','chunkid') + + dev_set.rename_field('pos','posid') + dev_set.rename_field('ner','nerid') + dev_set.rename_field('chunk','chunkid') + + test_set.rename_field('pos','posid') + test_set.rename_field('ner','nerid') + test_set.rename_field('chunk','chunkid') + + + word_vocab = data_bundle.vocabs['words'] + pos_vocab = data_bundle.vocabs['pos'] + ner_vocab = data_bundle.vocabs['ner'] + chunk_vocab = data_bundle.vocabs['chunk'] + + + if embedding_path is not None: + + embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, + dropout=0.5,lower=True) + + return (train_set,dev_set,test_set),\ + (word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed + else: + return (train_set, dev_set, test_set), (word_vocab,ner_vocab) + # print(data_bundle) + + + + + + + + + + +# @cache_results(_cache_fp='Conll2003', _refresh=False) +# def load_conll_2003(path,embedding_path=None): +# f = open(path, 'rb') +# data_pkl = pickle.load(f) +# +# task_lst = data_pkl['task_lst'] +# vocabs = data_pkl['vocabs'] +# # word_vocab = vocabs['words'] +# # pos_vocab = vocabs['pos'] +# # chunk_vocab = vocabs['chunk'] +# # ner_vocab = vocabs['ner'] +# +# if embedding_path is not None: +# embed = StaticEmbedding(vocab=vocabs['words'], model_dir_or_name=embedding_path, word_dropout=0.01, +# dropout=0.5) +# return task_lst,vocabs,embed +# else: +# return task_lst,vocabs + +# @cache_results(_cache_fp='Conll2003_mine', _refresh=False) +@cache_results(_cache_fp='Conll2003_mine_embed_100', _refresh=True) +def load_conll_2003_mine(path,embedding_path=None,pad_val=-100): + f = open(path, 'rb') + + data_pkl = pickle.load(f) + # print(data_pkl) + # print(data_pkl) + train_set = data_pkl[0]['train'] + dev_set = data_pkl[0]['dev'] + test_set = data_pkl[0]['test'] + + train_set.set_pad_val('posid',pad_val) + train_set.set_pad_val('nerid', pad_val) + train_set.set_pad_val('chunkid', pad_val) + + dev_set.set_pad_val('posid',pad_val) + dev_set.set_pad_val('nerid', pad_val) + dev_set.set_pad_val('chunkid', pad_val) + + test_set.set_pad_val('posid',pad_val) + test_set.set_pad_val('nerid', pad_val) + test_set.set_pad_val('chunkid', pad_val) + + if train_set.has_field('task_id'): + + train_set.delete_field('task_id') + + if dev_set.has_field('task_id'): + dev_set.delete_field('task_id') + + if test_set.has_field('task_id'): + test_set.delete_field('task_id') + + if train_set.has_field('words_idx'): + train_set.rename_field('words_idx','words') + + if dev_set.has_field('words_idx'): + dev_set.rename_field('words_idx','words') + + if test_set.has_field('words_idx'): + test_set.rename_field('words_idx','words') + + + + word_vocab = data_pkl[1]['words'] + pos_vocab = data_pkl[1]['pos'] + ner_vocab = data_pkl[1]['ner'] + chunk_vocab = data_pkl[1]['chunk'] + + if embedding_path is not None: + embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, + dropout=0.5,lower=True) + return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed + else: + return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab) + + +def load_conllized_ontonote_pkl_yf(path): + def init_task(task): + task_name = task.task_name + for ds in [task.train_set, task.dev_set, task.test_set]: + if ds.has_field('words'): + ds.rename_field('words', 'x') + else: + ds.rename_field('words_idx', 'x') + if ds.has_field('label'): + ds.rename_field('label', 'y') + else: + ds.rename_field(task_name, 'y') + ds.set_input('x', 'y', 'task_id') + ds.set_target('y') + + if task_name in ['ner', 'chunk'] or 'pos' in task_name: + ds.set_input('seq_len') + ds.set_target('seq_len') + return task + #/remote-home/yfshao/workdir/datasets/conll03/data.pkl + def pload(fn): + with open(fn, 'rb') as f: + return pickle.load(f) + + DB = pload(path) + task_lst = DB['task_lst'] + vocabs = DB['vocabs'] + task_lst = [init_task(task) for task in task_lst] + + return task_lst, vocabs + + +@cache_results(_cache_fp='weiboNER uni+bi', _refresh=False) +def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True, + normlize={'char':True,'bigram':True,'word':False}): + from fastNLP.io.data_loader import ConllLoader + from utils import get_bigrams + + loader = ConllLoader(['chars','target']) + bundle = loader.load(path) + + datasets = bundle.datasets + for k,v in datasets.items(): + print('{}:{}'.format(k,len(v))) + # print(*list(datasets.keys())) + vocabs = {} + word_vocab = Vocabulary() + bigram_vocab = Vocabulary() + label_vocab = Vocabulary(padding=None,unknown=None) + + for k,v in datasets.items(): + # ignore the word segmentation tag + v.apply_field(lambda x: [w[0] for w in x],'chars','chars') + v.apply_field(get_bigrams,'chars','bigrams') + + + word_vocab.from_dataset(datasets['train'],field_name='chars',no_create_entry_dataset=[datasets['dev'],datasets['test']]) + label_vocab.from_dataset(datasets['train'],field_name='target') + print('label_vocab:{}\n{}'.format(len(label_vocab),label_vocab.idx2word)) + + + for k,v in datasets.items(): + # v.set_pad_val('target',-100) + v.add_seq_len('chars',new_field_name='seq_len') + + + vocabs['char'] = word_vocab + vocabs['label'] = label_vocab + + + bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',no_create_entry_dataset=[datasets['dev'],datasets['test']]) + if index_token: + word_vocab.index_dataset(*list(datasets.values()), field_name='raw_words', new_field_name='words') + bigram_vocab.index_dataset(*list(datasets.values()),field_name='raw_bigrams',new_field_name='bigrams') + label_vocab.index_dataset(*list(datasets.values()), field_name='raw_target', new_field_name='target') + + # for k,v in datasets.items(): + # v.set_input('chars','bigrams','seq_len','target') + # v.set_target('target','seq_len') + + vocabs['bigram'] = bigram_vocab + + embeddings = {} + + if unigram_embedding_path is not None: + unigram_embedding = StaticEmbedding(word_vocab, model_dir_or_name=unigram_embedding_path, + word_dropout=0.01,normalize=normlize['char']) + embeddings['char'] = unigram_embedding + + if bigram_embedding_path is not None: + bigram_embedding = StaticEmbedding(bigram_vocab, model_dir_or_name=bigram_embedding_path, + word_dropout=0.01,normalize=normlize['bigram']) + embeddings['bigram'] = bigram_embedding + + return datasets, vocabs, embeddings + + + +# datasets,vocabs = load_weibo_ner('/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo') +# +# print(datasets['train'][:5]) +# print(vocabs['word'].idx2word) +# print(vocabs['target'].idx2word) + + +@cache_results(_cache_fp='cache/ontonotes4ner',_refresh=False) +def load_ontonotes4ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True, + normalize={'char':True,'bigram':True,'word':False}): + from fastNLP.io.data_loader import ConllLoader + from utils import get_bigrams + + train_path = os.path.join(path,'train.char.bmes') + dev_path = os.path.join(path,'dev.char.bmes') + test_path = os.path.join(path,'test.char.bmes') + + loader = ConllLoader(['chars','target']) + train_bundle = loader.load(train_path) + dev_bundle = loader.load(dev_path) + test_bundle = loader.load(test_path) + + + datasets = dict() + datasets['train'] = train_bundle.datasets['train'] + datasets['dev'] = dev_bundle.datasets['train'] + datasets['test'] = test_bundle.datasets['train'] + + + datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams') + datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') + datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') + + datasets['train'].add_seq_len('chars') + datasets['dev'].add_seq_len('chars') + datasets['test'].add_seq_len('chars') + + + + char_vocab = Vocabulary() + bigram_vocab = Vocabulary() + label_vocab = Vocabulary(padding=None,unknown=None) + print(datasets.keys()) + print(len(datasets['dev'])) + print(len(datasets['test'])) + print(len(datasets['train'])) + char_vocab.from_dataset(datasets['train'],field_name='chars', + no_create_entry_dataset=[datasets['dev'],datasets['test']] ) + bigram_vocab.from_dataset(datasets['train'],field_name='bigrams', + no_create_entry_dataset=[datasets['dev'],datasets['test']]) + label_vocab.from_dataset(datasets['train'],field_name='target') + if index_token: + char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], + field_name='chars',new_field_name='chars') + bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], + field_name='bigrams',new_field_name='bigrams') + label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], + field_name='target',new_field_name='target') + + vocabs = {} + vocabs['char'] = char_vocab + vocabs['label'] = label_vocab + vocabs['bigram'] = bigram_vocab + vocabs['label'] = label_vocab + + embeddings = {} + if char_embedding_path is not None: + char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01, + normalize=normalize['char']) + embeddings['char'] = char_embedding + + if bigram_embedding_path is not None: + bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01, + normalize=normalize['bigram']) + embeddings['bigram'] = bigram_embedding + + return datasets,vocabs,embeddings + + + +@cache_results(_cache_fp='cache/resume_ner',_refresh=False) +def load_resume_ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True, + normalize={'char':True,'bigram':True,'word':False}): + from fastNLP.io.data_loader import ConllLoader + from utils import get_bigrams + + train_path = os.path.join(path,'train.char.bmes') + dev_path = os.path.join(path,'dev.char.bmes') + test_path = os.path.join(path,'test.char.bmes') + + loader = ConllLoader(['chars','target']) + train_bundle = loader.load(train_path) + dev_bundle = loader.load(dev_path) + test_bundle = loader.load(test_path) + + + datasets = dict() + datasets['train'] = train_bundle.datasets['train'] + datasets['dev'] = dev_bundle.datasets['train'] + datasets['test'] = test_bundle.datasets['train'] + + + datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams') + datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') + datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') + + datasets['train'].add_seq_len('chars') + datasets['dev'].add_seq_len('chars') + datasets['test'].add_seq_len('chars') + + + + char_vocab = Vocabulary() + bigram_vocab = Vocabulary() + label_vocab = Vocabulary(padding=None,unknown=None) + print(datasets.keys()) + print(len(datasets['dev'])) + print(len(datasets['test'])) + print(len(datasets['train'])) + char_vocab.from_dataset(datasets['train'],field_name='chars', + no_create_entry_dataset=[datasets['dev'],datasets['test']] ) + bigram_vocab.from_dataset(datasets['train'],field_name='bigrams', + no_create_entry_dataset=[datasets['dev'],datasets['test']]) + label_vocab.from_dataset(datasets['train'],field_name='target') + if index_token: + char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], + field_name='chars',new_field_name='chars') + bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], + field_name='bigrams',new_field_name='bigrams') + label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], + field_name='target',new_field_name='target') + + vocabs = {} + vocabs['char'] = char_vocab + vocabs['label'] = label_vocab + vocabs['bigram'] = bigram_vocab + + embeddings = {} + if char_embedding_path is not None: + char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01,normalize=normalize['char']) + embeddings['char'] = char_embedding + + if bigram_embedding_path is not None: + bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01,normalize=normalize['bigram']) + embeddings['bigram'] = bigram_embedding + + return datasets,vocabs,embeddings + + +@cache_results(_cache_fp='need_to_defined_fp',_refresh=False) +def equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,word_embedding_path=None, + normalize={'char':True,'bigram':True,'word':False}): + from utils_ import Trie,get_skip_path + from functools import partial + w_trie = Trie() + for w in w_list: + w_trie.insert(w) + + # for k,v in datasets.items(): + # v.apply_field(partial(get_skip_path,w_trie=w_trie),'chars','skips') + + def skips2skips_l2r(chars,w_trie): + ''' + + :param lexicons: list[[int,int,str]] + :return: skips_l2r + ''' + # print(lexicons) + # print('******') + + lexicons = get_skip_path(chars,w_trie=w_trie) + + + # max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0 + + result = [[] for _ in range(len(chars))] + + for lex in lexicons: + s = lex[0] + e = lex[1] + w = lex[2] + + result[e].append([s,w]) + + return result + + def skips2skips_r2l(chars,w_trie): + ''' + + :param lexicons: list[[int,int,str]] + :return: skips_l2r + ''' + # print(lexicons) + # print('******') + + lexicons = get_skip_path(chars,w_trie=w_trie) + + + # max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0 + + result = [[] for _ in range(len(chars))] + + for lex in lexicons: + s = lex[0] + e = lex[1] + w = lex[2] + + result[s].append([e,w]) + + return result + + for k,v in datasets.items(): + v.apply_field(partial(skips2skips_l2r,w_trie=w_trie),'chars','skips_l2r') + + for k,v in datasets.items(): + v.apply_field(partial(skips2skips_r2l,w_trie=w_trie),'chars','skips_r2l') + + # print(v['skips_l2r'][0]) + word_vocab = Vocabulary() + word_vocab.add_word_lst(w_list) + vocabs['word'] = word_vocab + for k,v in datasets.items(): + v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_l2r','skips_l2r_source') + v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_l2r', 'skips_l2r_word') + + for k,v in datasets.items(): + v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_r2l','skips_r2l_source') + v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_r2l', 'skips_r2l_word') + + for k,v in datasets.items(): + v.apply_field(lambda x:list(map(len,x)), 'skips_l2r_word', 'lexicon_count') + v.apply_field(lambda x: + list(map(lambda y: + list(map(lambda z:word_vocab.to_index(z),y)),x)), + 'skips_l2r_word',new_field_name='skips_l2r_word') + + v.apply_field(lambda x:list(map(len,x)), 'skips_r2l_word', 'lexicon_count_back') + + v.apply_field(lambda x: + list(map(lambda y: + list(map(lambda z:word_vocab.to_index(z),y)),x)), + 'skips_r2l_word',new_field_name='skips_r2l_word') + + + + + + if word_embedding_path is not None: + word_embedding = StaticEmbedding(word_vocab,word_embedding_path,word_dropout=0,normalize=normalize['word']) + embeddings['word'] = word_embedding + + vocabs['char'].index_dataset(datasets['train'], datasets['dev'], datasets['test'], + field_name='chars', new_field_name='chars') + vocabs['bigram'].index_dataset(datasets['train'], datasets['dev'], datasets['test'], + field_name='bigrams', new_field_name='bigrams') + vocabs['label'].index_dataset(datasets['train'], datasets['dev'], datasets['test'], + field_name='target', new_field_name='target') + + return datasets,vocabs,embeddings + + + +@cache_results(_cache_fp='cache/load_yangjie_rich_pretrain_word_list',_refresh=False) +def load_yangjie_rich_pretrain_word_list(embedding_path,drop_characters=True): + f = open(embedding_path,'r') + lines = f.readlines() + w_list = [] + for line in lines: + splited = line.strip().split(' ') + w = splited[0] + w_list.append(w) + + if drop_characters: + w_list = list(filter(lambda x:len(x) != 1, w_list)) + + return w_list + + + +# from pathes import * +# +# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path, +# yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path) +# print(datasets.keys()) +# print(vocabs.keys()) +# print(embeddings) +# yangjie_rich_pretrain_word_path +# datasets['train'].set_pad_val \ No newline at end of file diff --git a/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/main.py b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/main.py new file mode 100644 index 00000000..f5006bde --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/main.py @@ -0,0 +1,189 @@ +import torch.nn as nn +# from pathes import * +from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner,load_weibo_ner +from fastNLP.embeddings import StaticEmbedding +from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 +from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward +import torch.optim as optim +import argparse +import torch +import sys +from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ +from fastNLP import Tester +import fitlog +from fastNLP.core.callback import FitlogCallback +from utils import set_seed +import os +from fastNLP import LRScheduler +from torch.optim.lr_scheduler import LambdaLR + +parser = argparse.ArgumentParser() +parser.add_argument('--device',default='cuda:4') +parser.add_argument('--debug',default=False) + +parser.add_argument('--norm_embed',default=True) +parser.add_argument('--batch',default=10) +parser.add_argument('--test_batch',default=1024) +parser.add_argument('--optim',default='sgd',help='adam|sgd') +parser.add_argument('--lr',default=0.045) +parser.add_argument('--model',default='lattice',help='lattice|lstm') +parser.add_argument('--skip_before_head',default=False)#in paper it's false +parser.add_argument('--hidden',default=100) +parser.add_argument('--momentum',default=0) +parser.add_argument('--bi',default=True) +parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') +parser.add_argument('--use_bigram',default=True) + +parser.add_argument('--embed_dropout',default=0.5) +parser.add_argument('--output_dropout',default=0.5) +parser.add_argument('--epoch',default=100) +parser.add_argument('--seed',default=100) + +args = parser.parse_args() + +set_seed(args.seed) + +fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)] +if args.model == 'lattice': + fit_msg_list.append(str(args.skip_before_head)) +fit_msg = ' '.join(fit_msg_list) +fitlog.commit(__file__,fit_msg=fit_msg) + + +fitlog.add_hyper(args) +device = torch.device(args.device) +for k,v in args.__dict__.items(): + print(k,v) + +refresh_data = False + + +from pathes import * +# ontonote4ner_cn_path = 0 +# yangjie_rich_pretrain_unigram_path = 0 +# yangjie_rich_pretrain_bigram_path = 0 +# resume_ner_path = 0 +# weibo_ner_path = 0 + +if args.dataset == 'ontonote': + datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, + _refresh=refresh_data,index_token=False, + ) +elif args.dataset == 'resume': + datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, + _refresh=refresh_data,index_token=False, + ) +elif args.dataset == 'weibo': + datasets,vocabs,embeddings = load_weibo_ner(weibo_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, + _refresh=refresh_data,index_token=False, + ) + +if args.dataset == 'ontonote': + args.batch = 10 + args.lr = 0.045 +elif args.dataset == 'resume': + args.batch = 1 + args.lr = 0.015 +elif args.dataset == 'weibo': + args.embed_dropout = 0.1 + args.output_dropout = 0.1 + +w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, + _refresh=refresh_data) + +cache_name = os.path.join('cache',args.dataset+'_lattice') +datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, + _refresh=refresh_data,_cache_fp=cache_name) + +print(datasets['train'][0]) +print('vocab info:') +for k,v in vocabs.items(): + print('{}:{}'.format(k,len(v))) + +for k,v in datasets.items(): + if args.model == 'lattice': + v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source') + if args.skip_before_head: + v.set_padder('skips_l2r_word',LatticeLexiconPadder()) + v.set_padder('skips_l2r_source',LatticeLexiconPadder()) + v.set_padder('skips_r2l_word',LatticeLexiconPadder()) + v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True)) + else: + v.set_padder('skips_l2r_word',LatticeLexiconPadder()) + v.set_padder('skips_r2l_word', LatticeLexiconPadder()) + v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1)) + v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1)) + if args.bi: + v.set_input('chars','bigrams','seq_len', + 'skips_l2r_word','skips_l2r_source','lexicon_count', + 'skips_r2l_word', 'skips_r2l_source','lexicon_count_back', + 'target', + use_1st_ins_infer_dim_type=True) + else: + v.set_input('chars','bigrams','seq_len', + 'skips_l2r_word','skips_l2r_source','lexicon_count', + 'target', + use_1st_ins_infer_dim_type=True) + v.set_target('target','seq_len') + + v['target'].set_pad_val(0) + elif args.model == 'lstm': + v.set_ignore_type('skips_l2r_word','skips_l2r_source') + v.set_padder('skips_l2r_word',LatticeLexiconPadder()) + v.set_padder('skips_l2r_source',LatticeLexiconPadder()) + v.set_input('chars','bigrams','seq_len','target', + use_1st_ins_infer_dim_type=True) + v.set_target('target','seq_len') + + v['target'].set_pad_val(0) + +print(datasets['dev']['skips_l2r_word'][100]) + + +if args.model =='lattice': + model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'], + hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, + embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, + skip_batch_first=True,bidirectional=args.bi,debug=args.debug, + skip_before_head=args.skip_before_head,use_bigram=args.use_bigram + ) +elif args.model == 'lstm': + model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], + hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, + bidirectional=args.bi, + embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, + use_bigram=args.use_bigram) + + +loss = LossInForward() + +f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') +f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') +acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') +metrics = [f1_metric,f1_metric_yj,acc_metric] + +if args.optim == 'adam': + optimizer = optim.Adam(model.parameters(),lr=args.lr) +elif args.optim == 'sgd': + optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) + + + + +callbacks = [ + FitlogCallback({'test':datasets['test'],'train':datasets['train']}), + LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep)) +] + +trainer = Trainer(datasets['train'],model, + optimizer=optimizer, + loss=loss, + metrics=metrics, + dev_data=datasets['dev'], + device=device, + batch_size=args.batch, + n_epochs=args.epoch, + dev_batch_size=args.test_batch, + callbacks=callbacks) + +trainer.train() \ No newline at end of file diff --git a/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/models.py b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/models.py new file mode 100644 index 00000000..f0f912d9 --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/models.py @@ -0,0 +1,299 @@ +import torch.nn as nn +from fastNLP.embeddings import StaticEmbedding +from fastNLP.modules import LSTM, ConditionalRandomField +import torch +from fastNLP import seq_len_to_mask +from utils import better_init_rnn + + +class LatticeLSTM_SeqLabel(nn.Module): + def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False, + device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False, + skip_before_head=False,use_bigram=True,vocabs=None): + if device is None: + self.device = torch.device('cpu') + else: + self.device = torch.device(device) + from modules import LatticeLSTMLayer_sup_back_V0 + super().__init__() + self.debug = debug + self.skip_batch_first = skip_batch_first + self.char_embed_size = char_embed.embedding.weight.size(1) + self.bigram_embed_size = bigram_embed.embedding.weight.size(1) + self.word_embed_size = word_embed.embedding.weight.size(1) + self.hidden_size = hidden_size + self.label_size = label_size + self.bidirectional = bidirectional + self.use_bigram = use_bigram + self.vocabs = vocabs + + if self.use_bigram: + self.input_size = self.char_embed_size + self.bigram_embed_size + else: + self.input_size = self.char_embed_size + + self.char_embed = char_embed + self.bigram_embed = bigram_embed + self.word_embed = word_embed + self.encoder = LatticeLSTMLayer_sup_back_V0(self.input_size,self.word_embed_size, + self.hidden_size, + left2right=True, + bias=bias, + device=self.device, + debug=self.debug, + skip_before_head=skip_before_head) + if self.bidirectional: + self.encoder_back = LatticeLSTMLayer_sup_back_V0(self.input_size, + self.word_embed_size, self.hidden_size, + left2right=False, + bias=bias, + device=self.device, + debug=self.debug, + skip_before_head=skip_before_head) + + self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) + self.crf = ConditionalRandomField(label_size, True) + + self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True)) + if self.crf.include_start_end_trans: + self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) + self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) + + self.loss_func = nn.CrossEntropyLoss() + self.embed_dropout = nn.Dropout(embed_dropout) + self.output_dropout = nn.Dropout(output_dropout) + + def forward(self, chars, bigrams, seq_len, target, + skips_l2r_source, skips_l2r_word, lexicon_count, + skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None): + # print('skips_l2r_word_id:{}'.format(skips_l2r_word.size())) + batch = chars.size(0) + max_seq_len = chars.size(1) + # max_lexicon_count = skips_l2r_word.size(2) + + + embed_char = self.char_embed(chars) + if self.use_bigram: + + embed_bigram = self.bigram_embed(bigrams) + + embedding = torch.cat([embed_char, embed_bigram], dim=-1) + else: + + embedding = embed_char + + + embed_nonword = self.embed_dropout(embedding) + + # skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1]) + embed_word = self.word_embed(skips_l2r_word) + embed_word = self.embed_dropout(embed_word) + # embed_word = torch.reshape(embed_word,shape=[batch,max_seq_len,max_lexicon_count,-1]) + + + encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count) + + if self.bidirectional: + embed_word_back = self.word_embed(skips_r2l_word) + embed_word_back = self.embed_dropout(embed_word_back) + encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source, + embed_word_back, lexicon_count_back) + encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1) + + encoded_h = self.output_dropout(encoded_h) + + pred = self.output(encoded_h) + + mask = seq_len_to_mask(seq_len) + + if self.training: + loss = self.crf(pred, target, mask) + return {'loss': loss} + else: + pred, path = self.crf.viterbi_decode(pred, mask) + return {'pred': pred} + + # batch_size, sent_len = pred.shape[0], pred.shape[1] + # loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) + # return {'pred':pred,'loss':loss} + +class LatticeLSTM_SeqLabel_V1(nn.Module): + def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False, + device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False, + skip_before_head=False,use_bigram=True,vocabs=None): + if device is None: + self.device = torch.device('cpu') + else: + self.device = torch.device(device) + from modules import LatticeLSTMLayer_sup_back_V1 + super().__init__() + self.count = 0 + self.debug = debug + self.skip_batch_first = skip_batch_first + self.char_embed_size = char_embed.embedding.weight.size(1) + self.bigram_embed_size = bigram_embed.embedding.weight.size(1) + self.word_embed_size = word_embed.embedding.weight.size(1) + self.hidden_size = hidden_size + self.label_size = label_size + self.bidirectional = bidirectional + self.use_bigram = use_bigram + self.vocabs = vocabs + + if self.use_bigram: + self.input_size = self.char_embed_size + self.bigram_embed_size + else: + self.input_size = self.char_embed_size + + self.char_embed = char_embed + self.bigram_embed = bigram_embed + self.word_embed = word_embed + self.encoder = LatticeLSTMLayer_sup_back_V1(self.input_size,self.word_embed_size, + self.hidden_size, + left2right=True, + bias=bias, + device=self.device, + debug=self.debug, + skip_before_head=skip_before_head) + if self.bidirectional: + self.encoder_back = LatticeLSTMLayer_sup_back_V1(self.input_size, + self.word_embed_size, self.hidden_size, + left2right=False, + bias=bias, + device=self.device, + debug=self.debug, + skip_before_head=skip_before_head) + + self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) + self.crf = ConditionalRandomField(label_size, True) + + self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True)) + if self.crf.include_start_end_trans: + self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) + self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) + + self.loss_func = nn.CrossEntropyLoss() + self.embed_dropout = nn.Dropout(embed_dropout) + self.output_dropout = nn.Dropout(output_dropout) + + def forward(self, chars, bigrams, seq_len, target, + skips_l2r_source, skips_l2r_word, lexicon_count, + skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None): + + batch = chars.size(0) + max_seq_len = chars.size(1) + + + + embed_char = self.char_embed(chars) + if self.use_bigram: + + embed_bigram = self.bigram_embed(bigrams) + + embedding = torch.cat([embed_char, embed_bigram], dim=-1) + else: + + embedding = embed_char + + + embed_nonword = self.embed_dropout(embedding) + + # skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1]) + embed_word = self.word_embed(skips_l2r_word) + embed_word = self.embed_dropout(embed_word) + + + + encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count) + + if self.bidirectional: + embed_word_back = self.word_embed(skips_r2l_word) + embed_word_back = self.embed_dropout(embed_word_back) + encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source, + embed_word_back, lexicon_count_back) + encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1) + + encoded_h = self.output_dropout(encoded_h) + + pred = self.output(encoded_h) + + mask = seq_len_to_mask(seq_len) + + if self.training: + loss = self.crf(pred, target, mask) + return {'loss': loss} + else: + pred, path = self.crf.viterbi_decode(pred, mask) + return {'pred': pred} + + +class LSTM_SeqLabel(nn.Module): + def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, + bidirectional=False, device=None, embed_dropout=0, output_dropout=0,use_bigram=True): + + if device is None: + self.device = torch.device('cpu') + else: + self.device = torch.device(device) + super().__init__() + self.char_embed_size = char_embed.embedding.weight.size(1) + self.bigram_embed_size = bigram_embed.embedding.weight.size(1) + self.word_embed_size = word_embed.embedding.weight.size(1) + self.hidden_size = hidden_size + self.label_size = label_size + self.bidirectional = bidirectional + self.use_bigram = use_bigram + + self.char_embed = char_embed + self.bigram_embed = bigram_embed + self.word_embed = word_embed + + if self.use_bigram: + self.input_size = self.char_embed_size + self.bigram_embed_size + else: + self.input_size = self.char_embed_size + + self.encoder = LSTM(self.input_size, self.hidden_size, + bidirectional=self.bidirectional) + + better_init_rnn(self.encoder.lstm) + + self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) + + self.debug = False + self.loss_func = nn.CrossEntropyLoss() + self.embed_dropout = nn.Dropout(embed_dropout) + self.output_dropout = nn.Dropout(output_dropout) + self.crf = ConditionalRandomField(label_size, True) + + def forward(self, chars, bigrams, seq_len, target): + embed_char = self.char_embed(chars) + + if self.use_bigram: + + embed_bigram = self.bigram_embed(bigrams) + + embedding = torch.cat([embed_char, embed_bigram], dim=-1) + else: + + embedding = embed_char + + embedding = self.embed_dropout(embedding) + + encoded_h, encoded_c = self.encoder(embedding, seq_len) + + encoded_h = self.output_dropout(encoded_h) + + pred = self.output(encoded_h) + + mask = seq_len_to_mask(seq_len) + + # pred = self.crf(pred) + + # batch_size, sent_len = pred.shape[0], pred.shape[1] + # loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) + if self.training: + loss = self.crf(pred, target, mask) + return {'loss': loss} + else: + pred, path = self.crf.viterbi_decode(pred, mask) + return {'pred': pred} diff --git a/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/modules.py b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/modules.py new file mode 100644 index 00000000..84e21dc5 --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/modules.py @@ -0,0 +1,638 @@ +import torch.nn as nn +import torch +from fastNLP.core.utils import seq_len_to_mask +from utils import better_init_rnn +import numpy as np + + +class WordLSTMCell_yangjie(nn.Module): + + """A basic LSTM cell.""" + + def __init__(self, input_size, hidden_size, use_bias=True,debug=False, left2right=True): + """ + Most parts are copied from torch.nn.LSTMCell. + """ + + super().__init__() + self.left2right = left2right + self.debug = debug + self.input_size = input_size + self.hidden_size = hidden_size + self.use_bias = use_bias + self.weight_ih = nn.Parameter( + torch.FloatTensor(input_size, 3 * hidden_size)) + self.weight_hh = nn.Parameter( + torch.FloatTensor(hidden_size, 3 * hidden_size)) + if use_bias: + self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + """ + Initialize parameters following the way proposed in the paper. + """ + nn.init.orthogonal(self.weight_ih.data) + weight_hh_data = torch.eye(self.hidden_size) + weight_hh_data = weight_hh_data.repeat(1, 3) + with torch.no_grad(): + self.weight_hh.set_(weight_hh_data) + # The bias is just set to zero vectors. + if self.use_bias: + nn.init.constant(self.bias.data, val=0) + + def forward(self, input_, hx): + """ + Args: + input_: A (batch, input_size) tensor containing input + features. + hx: A tuple (h_0, c_0), which contains the initial hidden + and cell state, where the size of both states is + (batch, hidden_size). + Returns: + h_1, c_1: Tensors containing the next hidden and cell state. + """ + + h_0, c_0 = hx + + + + batch_size = h_0.size(0) + bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) + wh_b = torch.addmm(bias_batch, h_0, self.weight_hh) + wi = torch.mm(input_, self.weight_ih) + f, i, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1) + c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) + + return c_1 + + def __repr__(self): + s = '{name}({input_size}, {hidden_size})' + return s.format(name=self.__class__.__name__, **self.__dict__) + + +class MultiInputLSTMCell_V0(nn.Module): + def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False): + super().__init__() + self.char_input_size = char_input_size + self.hidden_size = hidden_size + self.use_bias = use_bias + + self.weight_ih = nn.Parameter( + torch.FloatTensor(char_input_size, 3 * hidden_size) + ) + + self.weight_hh = nn.Parameter( + torch.FloatTensor(hidden_size, 3 * hidden_size) + ) + + self.alpha_weight_ih = nn.Parameter( + torch.FloatTensor(char_input_size, hidden_size) + ) + + self.alpha_weight_hh = nn.Parameter( + torch.FloatTensor(hidden_size, hidden_size) + ) + + if self.use_bias: + self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size)) + self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size)) + else: + self.register_parameter('bias', None) + self.register_parameter('alpha_bias', None) + + self.debug = debug + self.reset_parameters() + + def reset_parameters(self): + """ + Initialize parameters following the way proposed in the paper. + """ + nn.init.orthogonal(self.weight_ih.data) + nn.init.orthogonal(self.alpha_weight_ih.data) + + weight_hh_data = torch.eye(self.hidden_size) + weight_hh_data = weight_hh_data.repeat(1, 3) + with torch.no_grad(): + self.weight_hh.set_(weight_hh_data) + + alpha_weight_hh_data = torch.eye(self.hidden_size) + alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1) + with torch.no_grad(): + self.alpha_weight_hh.set_(alpha_weight_hh_data) + + # The bias is just set to zero vectors. + if self.use_bias: + nn.init.constant_(self.bias.data, val=0) + nn.init.constant_(self.alpha_bias.data, val=0) + + def forward(self, inp, skip_c, skip_count, hx): + ''' + + :param inp: chars B * hidden + :param skip_c: 由跳边得到的c, B * X * hidden + :param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask + :param hx: + :return: + ''' + max_skip_count = torch.max(skip_count).item() + + + + if True: + h_0, c_0 = hx + batch_size = h_0.size(0) + + bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) + + wi = torch.matmul(inp, self.weight_ih) + wh = torch.matmul(h_0, self.weight_hh) + + + + i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) + + i = torch.sigmoid(i).unsqueeze(1) + o = torch.sigmoid(o).unsqueeze(1) + g = torch.tanh(g).unsqueeze(1) + + + + alpha_wi = torch.matmul(inp, self.alpha_weight_ih) + alpha_wi.unsqueeze_(1) + + # alpha_wi = alpha_wi.expand(1,skip_count,self.hidden_size) + alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh) + + alpha_bias_batch = self.alpha_bias.unsqueeze(0) + + alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) + + skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]) + + skip_mask = 1 - skip_mask + + + skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size) + + skip_mask = (skip_mask).float()*1e20 + + alpha = alpha - skip_mask + + alpha = torch.exp(torch.cat([i, alpha], dim=1)) + + + + alpha_sum = torch.sum(alpha, dim=1, keepdim=True) + + alpha = torch.div(alpha, alpha_sum) + + merge_i_c = torch.cat([g, skip_c], dim=1) + + c_1 = merge_i_c * alpha + + c_1 = c_1.sum(1, keepdim=True) + # h_1 = o * c_1 + h_1 = o * torch.tanh(c_1) + + return h_1.squeeze(1), c_1.squeeze(1) + + else: + + h_0, c_0 = hx + batch_size = h_0.size(0) + + bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) + + wi = torch.matmul(inp, self.weight_ih) + wh = torch.matmul(h_0, self.weight_hh) + + i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) + + i = torch.sigmoid(i).unsqueeze(1) + o = torch.sigmoid(o).unsqueeze(1) + g = torch.tanh(g).unsqueeze(1) + + c_1 = g + h_1 = o * c_1 + + return h_1,c_1 + +class MultiInputLSTMCell_V1(nn.Module): + def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False): + super().__init__() + self.char_input_size = char_input_size + self.hidden_size = hidden_size + self.use_bias = use_bias + + self.weight_ih = nn.Parameter( + torch.FloatTensor(char_input_size, 3 * hidden_size) + ) + + self.weight_hh = nn.Parameter( + torch.FloatTensor(hidden_size, 3 * hidden_size) + ) + + self.alpha_weight_ih = nn.Parameter( + torch.FloatTensor(char_input_size, hidden_size) + ) + + self.alpha_weight_hh = nn.Parameter( + torch.FloatTensor(hidden_size, hidden_size) + ) + + if self.use_bias: + self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size)) + self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size)) + else: + self.register_parameter('bias', None) + self.register_parameter('alpha_bias', None) + + self.debug = debug + self.reset_parameters() + + def reset_parameters(self): + """ + Initialize parameters following the way proposed in the paper. + """ + nn.init.orthogonal(self.weight_ih.data) + nn.init.orthogonal(self.alpha_weight_ih.data) + + weight_hh_data = torch.eye(self.hidden_size) + weight_hh_data = weight_hh_data.repeat(1, 3) + with torch.no_grad(): + self.weight_hh.set_(weight_hh_data) + + alpha_weight_hh_data = torch.eye(self.hidden_size) + alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1) + with torch.no_grad(): + self.alpha_weight_hh.set_(alpha_weight_hh_data) + + # The bias is just set to zero vectors. + if self.use_bias: + nn.init.constant_(self.bias.data, val=0) + nn.init.constant_(self.alpha_bias.data, val=0) + + def forward(self, inp, skip_c, skip_count, hx): + ''' + + :param inp: chars B * hidden + :param skip_c: 由跳边得到的c, B * X * hidden + :param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask + :param hx: + :return: + ''' + max_skip_count = torch.max(skip_count).item() + + + + if True: + h_0, c_0 = hx + batch_size = h_0.size(0) + + bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) + + wi = torch.matmul(inp, self.weight_ih) + wh = torch.matmul(h_0, self.weight_hh) + + + i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) + + i = torch.sigmoid(i).unsqueeze(1) + o = torch.sigmoid(o).unsqueeze(1) + g = torch.tanh(g).unsqueeze(1) + + + + ##basic lstm start + + f = 1 - i + c_1_basic = f*c_0.unsqueeze(1) + i*g + c_1_basic = c_1_basic.squeeze(1) + + + + + + alpha_wi = torch.matmul(inp, self.alpha_weight_ih) + alpha_wi.unsqueeze_(1) + + + alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh) + + alpha_bias_batch = self.alpha_bias.unsqueeze(0) + + alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) + + skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]) + + skip_mask = 1 - skip_mask + + + skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size) + + skip_mask = (skip_mask).float()*1e20 + + alpha = alpha - skip_mask + + alpha = torch.exp(torch.cat([i, alpha], dim=1)) + + + + alpha_sum = torch.sum(alpha, dim=1, keepdim=True) + + alpha = torch.div(alpha, alpha_sum) + + merge_i_c = torch.cat([g, skip_c], dim=1) + + c_1 = merge_i_c * alpha + + c_1 = c_1.sum(1, keepdim=True) + # h_1 = o * c_1 + c_1 = c_1.squeeze(1) + count_select = (skip_count != 0).float().unsqueeze(-1) + + + + + c_1 = c_1*count_select + c_1_basic*(1-count_select) + + + o = o.squeeze(1) + h_1 = o * torch.tanh(c_1) + + return h_1, c_1 + +class LatticeLSTMLayer_sup_back_V0(nn.Module): + def __init__(self, char_input_size, word_input_size, hidden_size, left2right, + bias=True,device=None,debug=False,skip_before_head=False): + super().__init__() + + self.skip_before_head = skip_before_head + + self.hidden_size = hidden_size + + self.char_cell = MultiInputLSTMCell_V0(char_input_size, hidden_size, bias,debug) + + self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug) + + self.word_input_size = word_input_size + self.left2right = left2right + self.bias = bias + self.device = device + self.debug = debug + + def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None): + ''' + + :param inp: batch * seq_len * embedding, chars + :param seq_len: batch, length of chars + :param skip_sources: batch * seq_len * X, 跳边的起点 + :param skip_words: batch * seq_len * X * embedding, 跳边的词 + :param lexicon_count: batch * seq_len, count of lexicon per example per position + :param init_state: the hx of rnn + :return: + ''' + + + if self.left2right: + + max_seq_len = max(seq_len) + batch_size = inp.size(0) + c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) + h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) + + for i in range(max_seq_len): + max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) + h_0, c_0 = h_[:, i, :], c_[:, i, :] + + skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() + + skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) + skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) + + + index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) + index_1 = skip_source_flat + + if not self.skip_before_head: + c_x = c_[[index_0, index_1+1]] + h_x = h_[[index_0, index_1+1]] + else: + c_x = c_[[index_0,index_1]] + h_x = h_[[index_0,index_1]] + + c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) + h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) + + + + + c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) + + c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) + + h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) + + + h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1) + c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1) + + return h_[:,1:],c_[:,1:] + else: + mask_for_seq_len = seq_len_to_mask(seq_len) + + max_seq_len = max(seq_len) + batch_size = inp.size(0) + c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) + h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) + + for i in reversed(range(max_seq_len)): + max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) + + + + h_0, c_0 = h_[:, 0, :], c_[:, 0, :] + + skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() + + skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) + skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) + + + index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) + index_1 = skip_source_flat-i + + if not self.skip_before_head: + c_x = c_[[index_0, index_1-1]] + h_x = h_[[index_0, index_1-1]] + else: + c_x = c_[[index_0,index_1]] + h_x = h_[[index_0,index_1]] + + c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) + h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) + + + + + c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) + + c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) + + h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) + + + h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0) + c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0) + + + h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1) + c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1) + + return h_[:,:-1],c_[:,:-1] + +class LatticeLSTMLayer_sup_back_V1(nn.Module): + # V1与V0的不同在于,V1在当前位置完全无lexicon匹配时,会采用普通的lstm计算公式, + # 普通的lstm计算公式与杨杰实现的lattice lstm在lexicon数量为0时不同 + def __init__(self, char_input_size, word_input_size, hidden_size, left2right, + bias=True,device=None,debug=False,skip_before_head=False): + super().__init__() + + self.debug = debug + + self.skip_before_head = skip_before_head + + self.hidden_size = hidden_size + + self.char_cell = MultiInputLSTMCell_V1(char_input_size, hidden_size, bias,debug) + + self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug) + + self.word_input_size = word_input_size + self.left2right = left2right + self.bias = bias + self.device = device + + def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None): + ''' + + :param inp: batch * seq_len * embedding, chars + :param seq_len: batch, length of chars + :param skip_sources: batch * seq_len * X, 跳边的起点 + :param skip_words: batch * seq_len * X * embedding_size, 跳边的词 + :param lexicon_count: batch * seq_len, + lexicon_count[i,j]为第i个例子以第j个位子为结尾匹配到的词的数量 + :param init_state: the hx of rnn + :return: + ''' + + + if self.left2right: + + max_seq_len = max(seq_len) + batch_size = inp.size(0) + c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) + h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) + + for i in range(max_seq_len): + max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) + h_0, c_0 = h_[:, i, :], c_[:, i, :] + + #为了使rnn能够计算B*lexicon_count*embedding_size的张量,需要将其reshape成二维张量 + #为了匹配pytorch的[]取址方式,需要将reshape成二维张量 + + skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() + + skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) + skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) + + + index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) + index_1 = skip_source_flat + + + if not self.skip_before_head: + c_x = c_[[index_0, index_1+1]] + h_x = h_[[index_0, index_1+1]] + else: + c_x = c_[[index_0,index_1]] + h_x = h_[[index_0,index_1]] + + c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) + h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) + + + + c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) + + c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) + + h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) + + + h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1) + c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1) + + return h_[:,1:],c_[:,1:] + else: + mask_for_seq_len = seq_len_to_mask(seq_len) + + max_seq_len = max(seq_len) + batch_size = inp.size(0) + c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) + h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) + + for i in reversed(range(max_seq_len)): + max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) + + + h_0, c_0 = h_[:, 0, :], c_[:, 0, :] + + skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() + + skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) + skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) + + + index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) + index_1 = skip_source_flat-i + + if not self.skip_before_head: + c_x = c_[[index_0, index_1-1]] + h_x = h_[[index_0, index_1-1]] + else: + c_x = c_[[index_0,index_1]] + h_x = h_[[index_0,index_1]] + + c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) + h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) + + + + + c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) + + + + c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) + + h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) + + + h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0) + c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0) + + + h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1) + c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1) + + + + return h_[:,:-1],c_[:,:-1] + + + + diff --git a/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/pathes.py b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/pathes.py new file mode 100644 index 00000000..af1efaf7 --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/pathes.py @@ -0,0 +1,23 @@ + + +glove_100_path = 'en-glove-6b-100d' +glove_50_path = 'en-glove-6b-50d' +glove_200_path = '' +glove_300_path = 'en-glove-840b-300' +fasttext_path = 'en-fasttext' #300 +tencent_chinese_word_path = 'cn' # tencent 200 +fasttext_cn_path = 'cn-fasttext' # 300 +yangjie_rich_pretrain_unigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.uni.ite50.vec' +yangjie_rich_pretrain_bigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.bi.ite50.vec' +yangjie_rich_pretrain_word_path = '/remote-home/xnli/data/pretrain/chinese/ctb.50d.vec' + + +conll_2003_path = '/remote-home/xnli/data/corpus/multi_task/conll_2013/data_mine.pkl' +conllized_ontonote_path = '/remote-home/txsun/data/OntoNotes-5.0-NER-master/v12/english' +conllized_ontonote_pkl_path = '/remote-home/txsun/data/ontonotes5.pkl' +sst2_path = '/remote-home/xnli/data/corpus/text_classification/SST-2/' +# weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo' +ontonote4ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/OntoNote4NER' +msra_ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/MSRANER' +resume_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/ResumeNER' +weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER' \ No newline at end of file diff --git a/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/small.py b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/small.py new file mode 100644 index 00000000..c877d96f --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/small.py @@ -0,0 +1,126 @@ +from utils_ import get_skip_path_trivial, Trie, get_skip_path +from load_data import load_yangjie_rich_pretrain_word_list, load_ontonotes4ner, equip_chinese_ner_with_skip +from pathes import * +from functools import partial +from fastNLP import cache_results +from fastNLP.embeddings.static_embedding import StaticEmbedding +import torch +import torch.nn as nn +import torch.nn.functional as F +from fastNLP.core.metrics import _bmes_tag_to_spans,_bmeso_tag_to_spans +from load_data import load_resume_ner + + +# embed = StaticEmbedding(None,embedding_dim=2) +# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, +# _refresh=True,index_token=False) +# +# w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, +# _refresh=False) +# +# datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, +# _refresh=True) +# + +def reverse_style(input_string): + target_position = input_string.index('[') + input_len = len(input_string) + output_string = input_string[target_position:input_len] + input_string[0:target_position] + # print('in:{}.out:{}'.format(input_string, output_string)) + return output_string + + + + + +def get_yangjie_bmeso(label_list): + def get_ner_BMESO_yj(label_list): + # list_len = len(word_list) + # assert(list_len == len(label_list)), "word list size unmatch with label list" + list_len = len(label_list) + begin_label = 'b-' + end_label = 'e-' + single_label = 's-' + whole_tag = '' + index_tag = '' + tag_list = [] + stand_matrix = [] + for i in range(0, list_len): + # wordlabel = word_list[i] + current_label = label_list[i].lower() + if begin_label in current_label: + if index_tag != '': + tag_list.append(whole_tag + ',' + str(i - 1)) + whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i) + index_tag = current_label.replace(begin_label, "", 1) + + elif single_label in current_label: + if index_tag != '': + tag_list.append(whole_tag + ',' + str(i - 1)) + whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i) + tag_list.append(whole_tag) + whole_tag = "" + index_tag = "" + elif end_label in current_label: + if index_tag != '': + tag_list.append(whole_tag + ',' + str(i)) + whole_tag = '' + index_tag = '' + else: + continue + if (whole_tag != '') & (index_tag != ''): + tag_list.append(whole_tag) + tag_list_len = len(tag_list) + + for i in range(0, tag_list_len): + if len(tag_list[i]) > 0: + tag_list[i] = tag_list[i] + ']' + insert_list = reverse_style(tag_list[i]) + stand_matrix.append(insert_list) + # print stand_matrix + return stand_matrix + + def transform_YJ_to_fastNLP(span): + span = span[1:] + span_split = span.split(']') + # print('span_list:{}'.format(span_split)) + span_type = span_split[1] + # print('span_split[0].split(','):{}'.format(span_split[0].split(','))) + if ',' in span_split[0]: + b, e = span_split[0].split(',') + else: + b = span_split[0] + e = b + + b = int(b) + e = int(e) + + e += 1 + + return (span_type, (b, e)) + yj_form = get_ner_BMESO_yj(label_list) + # print('label_list:{}'.format(label_list)) + # print('yj_from:{}'.format(yj_form)) + fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form)) + return fastNLP_form + + +# tag_list = ['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O'] +# span_list = get_ner_BMES(tag_list) +# print(span_list) +# yangjie_label_list = ['B-NAME', 'E-NAME', 'O', 'B-CONT', 'M-CONT', 'E-CONT', 'B-RACE', 'E-RACE', 'B-TITLE', 'M-TITLE', 'E-TITLE', 'B-EDU', 'M-EDU', 'E-EDU', 'B-ORG', 'M-ORG', 'E-ORG', 'M-NAME', 'B-PRO', 'M-PRO', 'E-PRO', 'S-RACE', 'S-NAME', 'B-LOC', 'M-LOC', 'E-LOC', 'M-RACE', 'S-ORG'] +# my_label_list = ['O', 'M-ORG', 'M-TITLE', 'B-TITLE', 'E-TITLE', 'B-ORG', 'E-ORG', 'M-EDU', 'B-NAME', 'E-NAME', 'B-EDU', 'E-EDU', 'M-NAME', 'M-PRO', 'M-CONT', 'B-PRO', 'E-PRO', 'B-CONT', 'E-CONT', 'M-LOC', 'B-RACE', 'E-RACE', 'S-NAME', 'B-LOC', 'E-LOC', 'M-RACE', 'S-RACE', 'S-ORG'] +# yangjie_label = set(yangjie_label_list) +# my_label = set(my_label_list) + +a = torch.tensor([0,2,0,3]) +b = (a==0) +print(b) +print(b.float()) +from fastNLP import RandomSampler + +# f = open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb') +# weight_dict = torch.load(f) +# print(weight_dict.keys()) +# for k,v in weight_dict.items(): +# print("{}:{}".format(k,v.size())) \ No newline at end of file diff --git a/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils.py b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils.py new file mode 100644 index 00000000..8c64c43c --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils.py @@ -0,0 +1,361 @@ +import torch.nn.functional as F +import torch +import random +import numpy as np +from fastNLP import Const +from fastNLP import CrossEntropyLoss +from fastNLP import AccuracyMetric +from fastNLP import Tester +import os +from fastNLP import logger +def should_mask(name, t=''): + if 'bias' in name: + return False + if 'embedding' in name: + splited = name.split('.') + if splited[-1]!='weight': + return False + if 'embedding' in splited[-2]: + return False + if 'c0' in name: + return False + if 'h0' in name: + return False + + if 'output' in name and t not in name: + return False + + return True +def get_init_mask(model): + init_masks = {} + for name, param in model.named_parameters(): + if should_mask(name): + init_masks[name+'.mask'] = torch.ones_like(param) + # logger.info(init_masks[name+'.mask'].requires_grad) + + return init_masks + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed+100) + torch.manual_seed(seed+200) + torch.cuda.manual_seed_all(seed+300) + +def get_parameters_size(model): + result = {} + for name,p in model.state_dict().items(): + result[name] = p.size() + + return result + +def prune_by_proportion_model(model,proportion,task): + # print('this time prune to ',proportion*100,'%') + for name, p in model.named_parameters(): + # print(name) + if not should_mask(name,task): + continue + + tensor = p.data.cpu().numpy() + index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy()) + # print(name,'alive count',len(index[0])) + alive = tensor[index] + # print('p and mask size:',p.size(),print(model.mask[task][name+'.mask'].size())) + percentile_value = np.percentile(abs(alive), (1 - proportion) * 100) + # tensor = p + # index = torch.nonzero(model.mask[task][name+'.mask']) + # # print('nonzero len',index) + # alive = tensor[index] + # print('alive size:',alive.shape) + # prune_by_proportion_model() + + # percentile_value = torch.topk(abs(alive), int((1-proportion)*len(index[0]))).values + # print('the',(1-proportion)*len(index[0]),'th big') + # print('threshold:',percentile_value) + + prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value) + # for + +def prune_by_proportion_model_global(model,proportion,task): + # print('this time prune to ',proportion*100,'%') + alive = None + for name, p in model.named_parameters(): + # print(name) + if not should_mask(name,task): + continue + + tensor = p.data.cpu().numpy() + index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy()) + # print(name,'alive count',len(index[0])) + if alive is None: + alive = tensor[index] + else: + alive = np.concatenate([alive,tensor[index]],axis=0) + + percentile_value = np.percentile(abs(alive), (1 - proportion) * 100) + + for name, p in model.named_parameters(): + if should_mask(name,task): + prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value) + + +def prune_by_threshold_parameter(p, mask, threshold): + p_abs = torch.abs(p) + + new_mask = (p_abs > threshold).float() + # print(mask) + mask[:]*=new_mask + + +def one_time_train_and_prune_single_task(trainer,PRUNE_PER, + optimizer_init_state_dict=None, + model_init_state_dict=None, + is_global=None, + ): + + + from fastNLP import Trainer + + + trainer.optimizer.load_state_dict(optimizer_init_state_dict) + trainer.model.load_state_dict(model_init_state_dict) + # print('metrics:',metrics.__dict__) + # print('loss:',loss.__dict__) + # print('trainer input:',task.train_set.get_input_name()) + # trainer = Trainer(model=model, train_data=task.train_set, dev_data=task.dev_set, loss=loss, metrics=metrics, + # optimizer=optimizer, n_epochs=EPOCH, batch_size=BATCH, device=device,callbacks=callbacks) + + + trainer.train(load_best_model=True) + # tester = Tester(task.train_set, model, metrics, BATCH, device=device, verbose=1,use_tqdm=False) + # print('FOR DEBUG: test train_set:',tester.test()) + # print('**'*20) + # if task.test_set: + # tester = Tester(task.test_set, model, metrics, BATCH, device=device, verbose=1) + # tester.test() + if is_global: + + prune_by_proportion_model_global(trainer.model, PRUNE_PER, trainer.model.now_task) + + else: + prune_by_proportion_model(trainer.model, PRUNE_PER, trainer.model.now_task) + + + +# def iterative_train_and_prune_single_task(get_trainer,ITER,PRUNE,is_global=False,save_path=None): +def iterative_train_and_prune_single_task(get_trainer,args,model,train_set,dev_set,test_set,device,save_path=None): + + ''' + + :param trainer: + :param ITER: + :param PRUNE: + :param is_global: + :param save_path: should be a dictionary which will be filled with mask and state dict + :return: + ''' + + + + from fastNLP import Trainer + import torch + import math + import copy + PRUNE = args.prune + ITER = args.iter + trainer = get_trainer(args,model,train_set,dev_set,test_set,device) + optimizer_init_state_dict = copy.deepcopy(trainer.optimizer.state_dict()) + model_init_state_dict = copy.deepcopy(trainer.model.state_dict()) + if save_path is not None: + if not os.path.exists(save_path): + os.makedirs(save_path) + # if not os.path.exists(os.path.join(save_path, 'model_init.pkl')): + # f = open(os.path.join(save_path, 'model_init.pkl'), 'wb') + # torch.save(trainer.model.state_dict(),f) + + + mask_count = 0 + model = trainer.model + task = trainer.model.now_task + for name, p in model.mask[task].items(): + mask_count += torch.sum(p).item() + init_mask_count = mask_count + logger.info('init mask count:{}'.format(mask_count)) + # logger.info('{}th traning mask count: {} / {} = {}%'.format(i, mask_count, init_mask_count, + # mask_count / init_mask_count * 100)) + + prune_per_iter = math.pow(PRUNE, 1 / ITER) + + + for i in range(ITER): + trainer = get_trainer(args,model,train_set,dev_set,test_set,device) + one_time_train_and_prune_single_task(trainer,prune_per_iter,optimizer_init_state_dict,model_init_state_dict) + if save_path is not None: + f = open(os.path.join(save_path,task+'_mask_'+str(i)+'.pkl'),'wb') + torch.save(model.mask[task],f) + + mask_count = 0 + for name, p in model.mask[task].items(): + mask_count += torch.sum(p).item() + logger.info('{}th traning mask count: {} / {} = {}%'.format(i,mask_count,init_mask_count,mask_count/init_mask_count*100)) + + +def get_appropriate_cuda(task_scale='s'): + if task_scale not in {'s','m','l'}: + logger.info('task scale wrong!') + exit(2) + import pynvml + pynvml.nvmlInit() + total_cuda_num = pynvml.nvmlDeviceGetCount() + for i in range(total_cuda_num): + logger.info(i) + handle = pynvml.nvmlDeviceGetHandleByIndex(i) # 这里的0是GPU id + memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle) + utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle) + logger.info(i, 'mem:', memInfo.used / memInfo.total, 'util:',utilizationInfo.gpu) + if memInfo.used / memInfo.total < 0.15 and utilizationInfo.gpu <0.2: + logger.info(i,memInfo.used / memInfo.total) + return 'cuda:'+str(i) + + if task_scale=='s': + max_memory=2000 + elif task_scale=='m': + max_memory=6000 + else: + max_memory = 9000 + + max_id = -1 + for i in range(total_cuda_num): + handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 这里的0是GPU id + memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle) + utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle) + if max_memory < memInfo.free: + max_memory = memInfo.free + max_id = i + + if id == -1: + logger.info('no appropriate gpu, wait!') + exit(2) + + return 'cuda:'+str(max_id) + + # if memInfo.used / memInfo.total < 0.5: + # return + +def print_mask(mask_dict): + def seq_mul(*X): + res = 1 + for x in X: + res*=x + return res + + for name,p in mask_dict.items(): + total_size = seq_mul(*p.size()) + unmasked_size = len(np.nonzero(p)) + + print(name,':',unmasked_size,'/',total_size,'=',unmasked_size/total_size*100,'%') + + + print() + + +def check_words_same(dataset_1,dataset_2,field_1,field_2): + if len(dataset_1[field_1]) != len(dataset_2[field_2]): + logger.info('CHECK: example num not same!') + return False + + for i, words in enumerate(dataset_1[field_1]): + if len(dataset_1[field_1][i]) != len(dataset_2[field_2][i]): + logger.info('CHECK {} th example length not same'.format(i)) + logger.info('1:{}'.format(dataset_1[field_1][i])) + logger.info('2:'.format(dataset_2[field_2][i])) + return False + + # for j,w in enumerate(words): + # if dataset_1[field_1][i][j] != dataset_2[field_2][i][j]: + # print('CHECK', i, 'th example has words different!') + # print('1:',dataset_1[field_1][i]) + # print('2:',dataset_2[field_2][i]) + # return False + + logger.info('CHECK: totally same!') + + return True + +def get_now_time(): + import time + from datetime import datetime, timezone, timedelta + dt = datetime.utcnow() + # print(dt) + tzutc_8 = timezone(timedelta(hours=8)) + local_dt = dt.astimezone(tzutc_8) + result = ("_{}_{}_{}__{}_{}_{}".format(local_dt.year, local_dt.month, local_dt.day, local_dt.hour, local_dt.minute, + local_dt.second)) + + return result + + +def get_bigrams(words): + result = [] + for i,w in enumerate(words): + if i!=len(words)-1: + result.append(words[i]+words[i+1]) + else: + result.append(words[i]+'') + + return result + +def print_info(*inp,islog=False,sep=' '): + from fastNLP import logger + if islog: + print(*inp,sep=sep) + else: + inp = sep.join(map(str,inp)) + logger.info(inp) + +def better_init_rnn(rnn,coupled=False): + import torch.nn as nn + if coupled: + repeat_size = 3 + else: + repeat_size = 4 + # print(list(rnn.named_parameters())) + if hasattr(rnn,'num_layers'): + for i in range(rnn.num_layers): + nn.init.orthogonal(getattr(rnn,'weight_ih_l'+str(i)).data) + weight_hh_data = torch.eye(rnn.hidden_size) + weight_hh_data = weight_hh_data.repeat(1, repeat_size) + with torch.no_grad(): + getattr(rnn,'weight_hh_l'+str(i)).set_(weight_hh_data) + nn.init.constant(getattr(rnn,'bias_ih_l'+str(i)).data, val=0) + nn.init.constant(getattr(rnn,'bias_hh_l'+str(i)).data, val=0) + + if rnn.bidirectional: + for i in range(rnn.num_layers): + nn.init.orthogonal(getattr(rnn, 'weight_ih_l' + str(i)+'_reverse').data) + weight_hh_data = torch.eye(rnn.hidden_size) + weight_hh_data = weight_hh_data.repeat(1, repeat_size) + with torch.no_grad(): + getattr(rnn, 'weight_hh_l' + str(i)+'_reverse').set_(weight_hh_data) + nn.init.constant(getattr(rnn, 'bias_ih_l' + str(i)+'_reverse').data, val=0) + nn.init.constant(getattr(rnn, 'bias_hh_l' + str(i)+'_reverse').data, val=0) + + + else: + nn.init.orthogonal(rnn.weight_ih.data) + weight_hh_data = torch.eye(rnn.hidden_size) + weight_hh_data = weight_hh_data.repeat(repeat_size,1) + with torch.no_grad(): + rnn.weight_hh.set_(weight_hh_data) + # The bias is just set to zero vectors. + print('rnn param size:{},{}'.format(rnn.weight_hh.size(),type(rnn))) + if rnn.bias: + nn.init.constant(rnn.bias_ih.data, val=0) + nn.init.constant(rnn.bias_hh.data, val=0) + + # print(list(rnn.named_parameters())) + + + + + + diff --git a/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils_.py b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils_.py new file mode 100644 index 00000000..dfc05486 --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils_.py @@ -0,0 +1,405 @@ +import collections +from fastNLP import cache_results +def get_skip_path(chars,w_trie): + sentence = ''.join(chars) + result = w_trie.get_lexicon(sentence) + + return result + +# @cache_results(_cache_fp='cache/get_skip_path_trivial',_refresh=True) +def get_skip_path_trivial(chars,w_list): + chars = ''.join(chars) + w_set = set(w_list) + result = [] + # for i in range(len(chars)): + # result.append([]) + for i in range(len(chars)-1): + for j in range(i+2,len(chars)+1): + if chars[i:j] in w_set: + result.append([i,j-1,chars[i:j]]) + + return result + + +class TrieNode: + def __init__(self): + self.children = collections.defaultdict(TrieNode) + self.is_w = False + +class Trie: + def __init__(self): + self.root = TrieNode() + + def insert(self,w): + + current = self.root + for c in w: + current = current.children[c] + + current.is_w = True + + def search(self,w): + ''' + + :param w: + :return: + -1:not w route + 0:subroute but not word + 1:subroute and word + ''' + current = self.root + + for c in w: + current = current.children.get(c) + + if current is None: + return -1 + + if current.is_w: + return 1 + else: + return 0 + + def get_lexicon(self,sentence): + result = [] + for i in range(len(sentence)): + current = self.root + for j in range(i, len(sentence)): + current = current.children.get(sentence[j]) + if current is None: + break + + if current.is_w: + result.append([i,j,sentence[i:j+1]]) + + return result + +from fastNLP.core.field import Padder +import numpy as np +import torch +from collections import defaultdict +class LatticeLexiconPadder(Padder): + + def __init__(self, pad_val=0, pad_val_dynamic=False,dynamic_offset=0, **kwargs): + ''' + + :param pad_val: + :param pad_val_dynamic: if True, pad_val is the seq_len + :param kwargs: + ''' + self.pad_val = pad_val + self.pad_val_dynamic = pad_val_dynamic + self.dynamic_offset = dynamic_offset + + def __call__(self, contents, field_name, field_ele_dtype, dim: int): + # 与autoPadder中 dim=2 的情况一样 + max_len = max(map(len, contents)) + + max_len = max(max_len,1)#avoid 0 size dim which causes cuda wrong + + max_word_len = max([max([len(content_ii) for content_ii in content_i]) for + content_i in contents]) + + max_word_len = max(max_word_len,1) + if self.pad_val_dynamic: + # print('pad_val_dynamic:{}'.format(max_len-1)) + + array = np.full((len(contents), max_len, max_word_len), max_len-1+self.dynamic_offset, + dtype=field_ele_dtype) + + else: + array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + for j, content_ii in enumerate(content_i): + array[i, j, :len(content_ii)] = content_ii + array = torch.tensor(array) + + return array + +from fastNLP.core.metrics import MetricBase + +def get_yangjie_bmeso(label_list,ignore_labels=None): + def get_ner_BMESO_yj(label_list): + def reverse_style(input_string): + target_position = input_string.index('[') + input_len = len(input_string) + output_string = input_string[target_position:input_len] + input_string[0:target_position] + # print('in:{}.out:{}'.format(input_string, output_string)) + return output_string + + # list_len = len(word_list) + # assert(list_len == len(label_list)), "word list size unmatch with label list" + list_len = len(label_list) + begin_label = 'b-' + end_label = 'e-' + single_label = 's-' + whole_tag = '' + index_tag = '' + tag_list = [] + stand_matrix = [] + for i in range(0, list_len): + # wordlabel = word_list[i] + current_label = label_list[i].lower() + if begin_label in current_label: + if index_tag != '': + tag_list.append(whole_tag + ',' + str(i - 1)) + whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i) + index_tag = current_label.replace(begin_label, "", 1) + + elif single_label in current_label: + if index_tag != '': + tag_list.append(whole_tag + ',' + str(i - 1)) + whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i) + tag_list.append(whole_tag) + whole_tag = "" + index_tag = "" + elif end_label in current_label: + if index_tag != '': + tag_list.append(whole_tag + ',' + str(i)) + whole_tag = '' + index_tag = '' + else: + continue + if (whole_tag != '') & (index_tag != ''): + tag_list.append(whole_tag) + tag_list_len = len(tag_list) + + for i in range(0, tag_list_len): + if len(tag_list[i]) > 0: + tag_list[i] = tag_list[i] + ']' + insert_list = reverse_style(tag_list[i]) + stand_matrix.append(insert_list) + # print stand_matrix + return stand_matrix + + def transform_YJ_to_fastNLP(span): + span = span[1:] + span_split = span.split(']') + # print('span_list:{}'.format(span_split)) + span_type = span_split[1] + # print('span_split[0].split(','):{}'.format(span_split[0].split(','))) + if ',' in span_split[0]: + b, e = span_split[0].split(',') + else: + b = span_split[0] + e = b + + b = int(b) + e = int(e) + + e += 1 + + return (span_type, (b, e)) + yj_form = get_ner_BMESO_yj(label_list) + # print('label_list:{}'.format(label_list)) + # print('yj_from:{}'.format(yj_form)) + fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form)) + return fastNLP_form +class SpanFPreRecMetric_YJ(MetricBase): + r""" + 别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` + + 在序列标注问题中,以span的方式计算F, pre, rec. + 比如中文Part of speech中,会以character的方式进行标注,句子 `中国在亚洲` 对应的POS可能为(以BMES为例) + ['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 + 最后得到的metric结果为:: + + { + 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 + 'pre': xxx, + 'rec':xxx + } + + 若only_gross=False, 即还会返回各个label的metric统计值:: + + { + 'f': xxx, + 'pre': xxx, + 'rec':xxx, + 'f-label': xxx, + 'pre-label': xxx, + 'rec-label':xxx, + ... + } + + :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), + 在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. + :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 + :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 + :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 + :param str encoding_type: 目前支持bio, bmes, bmeso, bioes + :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 + 个label + :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 + label的f1, pre, rec + :param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : + 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) + :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . + 常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 + """ + def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, + only_gross=True, f_type='micro', beta=1): + from fastNLP.core import Vocabulary + from fastNLP.core.metrics import _bmes_tag_to_spans,_bio_tag_to_spans,\ + _bioes_tag_to_spans,_bmeso_tag_to_spans + from collections import defaultdict + + encoding_type = encoding_type.lower() + + if not isinstance(tag_vocab, Vocabulary): + raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) + if f_type not in ('micro', 'macro'): + raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) + + self.encoding_type = encoding_type + # print('encoding_type:{}'self.encoding_type) + if self.encoding_type == 'bmes': + self.tag_to_span_func = _bmes_tag_to_spans + elif self.encoding_type == 'bio': + self.tag_to_span_func = _bio_tag_to_spans + elif self.encoding_type == 'bmeso': + self.tag_to_span_func = _bmeso_tag_to_spans + elif self.encoding_type == 'bioes': + self.tag_to_span_func = _bioes_tag_to_spans + elif self.encoding_type == 'bmesoyj': + self.tag_to_span_func = get_yangjie_bmeso + # self.tag_to_span_func = + else: + raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") + + self.ignore_labels = ignore_labels + self.f_type = f_type + self.beta = beta + self.beta_square = self.beta ** 2 + self.only_gross = only_gross + + super().__init__() + self._init_param_map(pred=pred, target=target, seq_len=seq_len) + + self.tag_vocab = tag_vocab + + self._true_positives = defaultdict(int) + self._false_positives = defaultdict(int) + self._false_negatives = defaultdict(int) + + def evaluate(self, pred, target, seq_len): + from fastNLP.core.utils import _get_func_signature + """evaluate函数将针对一个批次的预测结果做评价指标的累计 + + :param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 + :param target: [batch, seq_len], 真实值 + :param seq_len: [batch] 文本长度标记 + :return: + """ + if not isinstance(pred, torch.Tensor): + raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(pred)}.") + if not isinstance(target, torch.Tensor): + raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(target)}.") + + if not isinstance(seq_len, torch.Tensor): + raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(seq_len)}.") + + if pred.size() == target.size() and len(target.size()) == 2: + pass + elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: + num_classes = pred.size(-1) + pred = pred.argmax(dim=-1) + if (target >= num_classes).any(): + raise ValueError("A gold label passed to SpanBasedF1Metric contains an " + "id >= {}, the number of classes.".format(num_classes)) + else: + raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " + f"size:{pred.size()}, target should have size: {pred.size()} or " + f"{pred.size()[:-1]}, got {target.size()}.") + + batch_size = pred.size(0) + pred = pred.tolist() + target = target.tolist() + for i in range(batch_size): + pred_tags = pred[i][:int(seq_len[i])] + gold_tags = target[i][:int(seq_len[i])] + + pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] + gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] + + pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) + gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) + + for span in pred_spans: + if span in gold_spans: + self._true_positives[span[0]] += 1 + gold_spans.remove(span) + else: + self._false_positives[span[0]] += 1 + for span in gold_spans: + self._false_negatives[span[0]] += 1 + + def get_metric(self, reset=True): + """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" + evaluate_result = {} + if not self.only_gross or self.f_type == 'macro': + tags = set(self._false_negatives.keys()) + tags.update(set(self._false_positives.keys())) + tags.update(set(self._true_positives.keys())) + f_sum = 0 + pre_sum = 0 + rec_sum = 0 + for tag in tags: + tp = self._true_positives[tag] + fn = self._false_negatives[tag] + fp = self._false_positives[tag] + f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) + f_sum += f + pre_sum += pre + rec_sum += rec + if not self.only_gross and tag != '': # tag!=''防止无tag的情况 + f_key = 'f-{}'.format(tag) + pre_key = 'pre-{}'.format(tag) + rec_key = 'rec-{}'.format(tag) + evaluate_result[f_key] = f + evaluate_result[pre_key] = pre + evaluate_result[rec_key] = rec + + if self.f_type == 'macro': + evaluate_result['f'] = f_sum / len(tags) + evaluate_result['pre'] = pre_sum / len(tags) + evaluate_result['rec'] = rec_sum / len(tags) + + if self.f_type == 'micro': + f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), + sum(self._false_negatives.values()), + sum(self._false_positives.values())) + evaluate_result['f'] = f + evaluate_result['pre'] = pre + evaluate_result['rec'] = rec + + if reset: + self._true_positives = defaultdict(int) + self._false_positives = defaultdict(int) + self._false_negatives = defaultdict(int) + + for key, value in evaluate_result.items(): + evaluate_result[key] = round(value, 6) + + return evaluate_result + + def _compute_f_pre_rec(self, tp, fn, fp): + """ + + :param tp: int, true positive + :param fn: int, false negative + :param fp: int, false positive + :return: (f, pre, rec) + """ + pre = tp / (fp + tp + 1e-13) + rec = tp / (fn + tp + 1e-13) + f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) + + return f, pre, rec + + + +