From 5d9e064ec27595a1e09d7ddcbb27280c685b0701 Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Mon, 8 Jul 2019 01:11:46 +0800 Subject: [PATCH 1/4] text_classfication --- .../text_classification/data/SSTLoader.py | 90 ++++++++++++++++++- .../text_classification/train_awdlstm.py | 41 +-------- .../text_classification/train_lstm.py | 43 ++------- .../text_classification/train_lstm_att.py | 41 +-------- 4 files changed, 102 insertions(+), 113 deletions(-) diff --git a/reproduction/text_classification/data/SSTLoader.py b/reproduction/text_classification/data/SSTLoader.py index b570994e..d8403b7a 100644 --- a/reproduction/text_classification/data/SSTLoader.py +++ b/reproduction/text_classification/data/SSTLoader.py @@ -5,7 +5,8 @@ from fastNLP.core.vocabulary import VocabularyOption, Vocabulary from fastNLP import DataSet from fastNLP import Instance from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader - +import csv +from typing import Union, Dict class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' @@ -97,3 +98,90 @@ class SSTLoader(DataSetLoader): return info +class sst2Loader(DataSetLoader): + ''' + 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', + ''' + def __init__(self): + super(sst2Loader, self).__init__() + + def _load(self, path: str) -> DataSet: + ds = DataSet() + all_count=0 + csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') + skip_row = 0 + for idx,row in enumerate(csv_reader): + if idx<=skip_row: + continue + target = row[1] + words = row[0].split() + ds.append(Instance(words=words,target=target)) + all_count+=1 + print("all count:", all_count) + return ds + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None, + char_level_op=False): + + paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars=[] + for word in words: + word=word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs={} + + # 就分隔为char形式 + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + return info + +if __name__=="__main__": + datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", + "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} + datainfo=sst2Loader().process(datapath,char_level_op=True) + #print(datainfo.datasets["train"]) + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/train_awdlstm.py b/reproduction/text_classification/train_awdlstm.py index ce3e52bc..e67bd25b 100644 --- a/reproduction/text_classification/train_awdlstm.py +++ b/reproduction/text_classification/train_awdlstm.py @@ -8,9 +8,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' import torch.nn as nn -from data.SSTLoader import SSTLoader from data.IMDBLoader import IMDBLoader -from data.yelpLoader import yelpLoader from fastNLP.modules.encoder.embedding import StaticEmbedding from model.awd_lstm import AWDLSTMSentiment @@ -41,18 +39,9 @@ opt=Config # load data -dataloaders = { - "IMDB":IMDBLoader(), - "YELP":yelpLoader(), - "SST-5":SSTLoader(subtree=True,fine_grained=True), - "SST-3":SSTLoader(subtree=True,fine_grained=False) -} - -if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: - raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") - -dataloader = dataloaders[opt.task_name] +dataloader=IMDBLoader() datainfo=dataloader.process(opt.datapath) + # print(datainfo.datasets["train"]) # print(datainfo) @@ -71,32 +60,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T def train(datainfo, model, optimizer, loss, metrics, opt): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, n_epochs=opt.train_epoch, save_path=opt.save_model_path) trainer.train() -def test(datainfo, metrics, opt): - # load model - model = ModelLoader.load_pytorch_model(opt.load_model_path) - print("model loaded!") - - # Tester - tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) - acc = tester.test() - print("acc=",acc) - - - -parser = argparse.ArgumentParser() -parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') - - -args = parser.parse_args() -if args.mode == 'train': +if __name__ == "__main__": train(datainfo, model, optimizer, loss, metrics, opt) -elif args.mode == 'test': - test(datainfo, metrics, opt) -else: - print('no mode specified for model!') - parser.print_help() diff --git a/reproduction/text_classification/train_lstm.py b/reproduction/text_classification/train_lstm.py index b320e79c..b89abc14 100644 --- a/reproduction/text_classification/train_lstm.py +++ b/reproduction/text_classification/train_lstm.py @@ -6,9 +6,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' import torch.nn as nn -from data.SSTLoader import SSTLoader from data.IMDBLoader import IMDBLoader -from data.yelpLoader import yelpLoader from fastNLP.modules.encoder.embedding import StaticEmbedding from model.lstm import BiLSTMSentiment @@ -38,18 +36,9 @@ opt=Config # load data -dataloaders = { - "IMDB":IMDBLoader(), - "YELP":yelpLoader(), - "SST-5":SSTLoader(subtree=True,fine_grained=True), - "SST-3":SSTLoader(subtree=True,fine_grained=False) -} - -if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: - raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") - -dataloader = dataloaders[opt.task_name] +dataloader=IMDBLoader() datainfo=dataloader.process(opt.datapath) + # print(datainfo.datasets["train"]) # print(datainfo) @@ -68,32 +57,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T def train(datainfo, model, optimizer, loss, metrics, opt): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, n_epochs=opt.train_epoch, save_path=opt.save_model_path) trainer.train() -def test(datainfo, metrics, opt): - # load model - model = ModelLoader.load_pytorch_model(opt.load_model_path) - print("model loaded!") - - # Tester - tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) - acc = tester.test() - print("acc=",acc) - - - -parser = argparse.ArgumentParser() -parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') - - -args = parser.parse_args() -if args.mode == 'train': - train(datainfo, model, optimizer, loss, metrics, opt) -elif args.mode == 'test': - test(datainfo, metrics, opt) -else: - print('no mode specified for model!') - parser.print_help() +if __name__ == "__main__": + train(datainfo, model, optimizer, loss, metrics, opt) \ No newline at end of file diff --git a/reproduction/text_classification/train_lstm_att.py b/reproduction/text_classification/train_lstm_att.py index 8db27d09..b4d37525 100644 --- a/reproduction/text_classification/train_lstm_att.py +++ b/reproduction/text_classification/train_lstm_att.py @@ -6,9 +6,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' import torch.nn as nn -from data.SSTLoader import SSTLoader from data.IMDBLoader import IMDBLoader -from data.yelpLoader import yelpLoader from fastNLP.modules.encoder.embedding import StaticEmbedding from model.lstm_self_attention import BiLSTM_SELF_ATTENTION @@ -40,18 +38,9 @@ opt=Config # load data -dataloaders = { - "IMDB":IMDBLoader(), - "YELP":yelpLoader(), - "SST-5":SSTLoader(subtree=True,fine_grained=True), - "SST-3":SSTLoader(subtree=True,fine_grained=False) -} - -if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: - raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") - -dataloader = dataloaders[opt.task_name] +dataloader=IMDBLoader() datainfo=dataloader.process(opt.datapath) + # print(datainfo.datasets["train"]) # print(datainfo) @@ -70,32 +59,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T def train(datainfo, model, optimizer, loss, metrics, opt): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, n_epochs=opt.train_epoch, save_path=opt.save_model_path) trainer.train() -def test(datainfo, metrics, opt): - # load model - model = ModelLoader.load_pytorch_model(opt.load_model_path) - print("model loaded!") - - # Tester - tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) - acc = tester.test() - print("acc=",acc) - - - -parser = argparse.ArgumentParser() -parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') - - -args = parser.parse_args() -if args.mode == 'train': +if __name__ == "__main__": train(datainfo, model, optimizer, loss, metrics, opt) -elif args.mode == 'test': - test(datainfo, metrics, opt) -else: - print('no mode specified for model!') - parser.print_help() From 46c82a7daac64c10ce425fe9c1201569dc02a75d Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Mon, 8 Jul 2019 01:29:42 +0800 Subject: [PATCH 2/4] text_classfication --- .../text_classification/results_LSTM.xlsx | Bin 9944 -> 0 bytes .../text_classification/train_awdlstm.py | 4 ++-- reproduction/text_classification/train_lstm.py | 4 ++-- .../text_classification/train_lstm_att.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) delete mode 100644 reproduction/text_classification/results_LSTM.xlsx diff --git a/reproduction/text_classification/results_LSTM.xlsx b/reproduction/text_classification/results_LSTM.xlsx deleted file mode 100644 index 0d7b841b12b43ee346c4d9db17032fade8c0b40a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9944 zcmeHNWmFtlw{0APHPE=by9Rd+1a}G2xVu{jmf!?;OK^8dkYEYHgIj>0!ChY`GxIXT zSV;*Y zQ45hC)h0rUXS_C5!#Y0QjUTPKibgeJC$)<0upZ5^AYS<_jgGkJsx8Zv&Qq#q$Zw*` zoeza>WwWJ%3(7R3YGYq==is64@57JRBRpxEHDR%)xTm1T;jVkcxzTx~v?X2|E+RqQ z&hV`S`f_?lXrN06T7ljT&xFtOxhgY({GwJ{n+gY~CRJs+3$X@E;&~K3s9L$X?ut}y zdvFYtV^D30&o4nLYE=gH){cV|g)q^$+P==7=ZFW(zITKjq(H%dru&($dEKO zat7PFu(JHv|2IAVgLCq4uU;0f__Ui1C3Ii*CanK_av>HduHYdi+e)tPA1F7EULTcD zMY!;emI$a$@Ek@apxyswaBe{`YP+BO>?>EL&<93FpCxe5>_EmD2`BSkZ!M{&Z5ah1=OUt z+V}F%sz%4$4_?0QnCA8Oe|(9eEV(5PcI2H)GQxcRheWt^HX>k znc2QSmP~(#b?eKll+mwB!G(FjF)Y(ZopTzXTgP?yG1IGu7pc3n|6m~aCF$EmNRRdJ zB;mXrhQ)#b05;(P0Cb3FJZ)J$9Gq=T92{(Z^kZcj1`g?LAiwOo`@kEe&XCHyB5Wkd z(g`<}R(n?2*ESi;*le|m2_@SBOWgEZ4Q;NZh_)=a(ZlWM`Yz55cGf#Vav!7wq%QiU zaBPa{yjrvJ(ua`+q?5)9nB|ytd8A^V;N03_5M|;nC^2; zOxZ=FJUAbF5`H0_!!*AcxhV<*4>Hga4rGfQc-GlVfBWO%JQmI;^%R?><^<9ZmXbSTw}hEt4w`G*JdfoV%`SrmhYGM z*wFP!2i#Ad@osrt4zAF5>#V_LIb3^f&UD16`S5tC5n4b+J0*Tdkh-zW<8EER{miqi z!Q5NdVC1D2IRsKPLFc;TOY4KuPVYQ_o!k<9LE}^xa29mv1{ad=xrwrcZ>V;%)R#_&?kIbnXhQS+ut?#FV^t`Dm%`tEvaH~hb?izQlIGI`!*`>c-O z4P3OcGci(}TF7k*-4s+zTT6p1t#E+{aP-k-nTMt0q>$|Rcm4zkCng_5T3R5T3;_TU z3X(s6^fQ0YqQCkaCkTYI%N}VAdwcVI2dVRL( z%)dfeCn{ASDd86K4Rs|8>z5%7ppSQ1&r5C9*k5rKzaFH2H+RJJ@nqk8ZJ=**3G$`? z9Wr<6;B66D0N^_s0DuR%;?I)nVhIMjy0HE{u>YvN8CrAE`Fub>BddF;%l77wd2Eg} z{Xl3$v~Y|ss+y3fF%wz*hnwgqRGtnhg)jG6l7(Q8~Fh^6_D<2cJ3-d0siYP3gF3qSfYi>@HKWOxyAN8cvh?29yIIEVV2q>jy zXiF>TftbSLp2y;AI5Ng4>Bvf-Pl|wxIUN_O6I8OW^4XGOEW^8lXl*IuU3#6tl2i)C z!I#OjhKcpMNrDu~DwqB&hL~~KWYw$mTG_AZs%y;dyg$7tXa)HOcuJ&clVbWK9BfP( zGqmB-uSS3h`RE7}i8(hqZc{m3EXU|hZ$+q9QHcjw^l1-RxG$RNz4E1P9C42WRBh+O zx7Op2IhPRDT79m)f*hU(G@ZwI?~2su?>;uuZmppbfL$x)jcJE+N_trmc=v5rx1oYr z%EN&P9AN_Np-ie6&IInKT|A0Kjt5JP#B-_KbbVn`gsI6Ajkwo?3sOK_=E9hra;vK< z*<32Dw?Ob#nk4;7!09hYk0?+lv#g#x=zHI642I@@lxr7I;D!|!jOeS2KZQ;>fp^Q6 zuL!BEU@kv=|v(d!SjJ;=0eVEj@lTkV%8!z>j;C(?^h`WivdJp!6{3tSUMXpo@g0&7Xdw zbO?1CCiHIma_YH?3ZsV-f)~yPYz&NUJmma=fhzQrvD@t?s*gD z{=Ca{)=GN;i9?(5v%g+wjk)G3rA9~2+ z@|}E9Vf(O`5WWEqksWAEW)DT|KtcqY<(Y5eKD7XPH;THR#b38jRvHd{;YMjgO`foW^M<}nO1}TRX7{OsBr7tWD(|V`sS9Bx9 z&s(sjyFA3j`gE5vmnZtjaz_xb!l~5#T*m!OxAnE*m{WXLm&o`)g|S558{68rvEs7S z^jbaqPQcgJfIHr>3g!xJ>3TC8MJLsgDtI1LWAWzn$Q@JU3IRc zL_q?9RFHuOK3fiP3JP*@*mxrM7KaHNnfK|yXM=g?=LN;Mzg}5 zf{Cl1UXpA#RRdqPTfcsqgMVma8>HsRW2zGjJ8yox%_Xt1%$qyevhiXPHg3XAH8F5b zpcX&%nFWMu*|%pYijEaj#Nqjxqe1^us7v zmc{x`C$jYx-Wf{yyuiQ~u;p&ey<*&b^nA<8;52opFI~w`^}VbMOUrQ1)ZXcH=+dAgg_ryOwM%dDJ%<8JU%K$diojHgHU|VCix5czr6A{z zg{yiBbYgsep3Zmiw%ac`5O~5x(#t(<95Dp)r)Pb^60}|Qo7Q241Ge_PICtK+@c&|W^+8@Jo0M##c1e)nno_q%;+pQ%>g zQ*i@p&VKi1VXe|rvkenJX*wYc#>ki^-Ab2~xFQKlbwZ^2^m*Gw9Cm#iJs23?;n*4{ z4w?=s3gL)RXq;2-`qPZ#uG0(wEJ*5x$k`OTh=RDoa1DRVPV|6rdqTD(0C{;;_x#YvqFGsoB^Go5b*_Lnz< zHw&F(@yz54w1+42_~T%Abg>Ayp){1`L`S}(;`PxCOup`UL>-*~T$&`oRKWw9_vSJhDNu{^@z4I_M$*7psjea&`zIcxv^%KfY-Q2nKqtCXsnf4;f|y8 z4%n0p%1u0bt#O5&N#6U z9z+V_d?3K6$N)?8?_sfSUa1I$GZ-@Bso#X|Yv_tKs%*iF1s{2bpLMcD-f7TYwR~|U*}_7xUh-F z=>Kl2q%OZoYb#8pR%BO3tW?N%o8RKmt_NWpDVpx4+c*?=9C zhiFOrGdEoHcCTSazf0>PQ56_aRoH*6!?}wf@{q%ITJ1`3{MzxgYdmM29go&3eUgQ# z)zNDA@FLye8{u43Hg9k~a*0imS@10-IwgaP23xcILHhBh)Mx616v$n#3y>Da2t=M{ z3rE=0uAkT{7K;JMcVH`{ip?wZHf$JV{R$Wuj0@FufiP3b8IEmR_8P)Dn=H~purPE+z z0^c%s?q;g*gUypgR~8wgb`)I2__Td@i<~N&#`*(Ym7Bj%MC4@lolXr;Ep_Ns#v}GN zW^z}BYVIy7veC73wLb2hQt3tv646{?K53h44Bp+!c;27e$bPm|qyE<*-`_o?&r zR#H1ptUg&;<;16ze>t;JhkK_0Mmez~zG*q9AaMr2@yB7I?Xf#ac!T9GUf6*#mj3|LB3Q3V1E$VceH}Sk!`p~sI$w` zSi9?zXNfcyeUp>Jtl7(GThwA7QCC9;1PB0ZFTahbW+~21D>T_P_<=Udf|Ja;hc5lu zJHrkcb*rKR<4+7LXI^kHO9rE_CN~`4R!4~u9@m7X!*^Kb>jaBre9BhDrPWz}(stQ{ zB*W!gH0rw|xpL@R-Ctz5OdLQY2zP$~tYgKc+X<^p>vqY#HuE{MK+p?DNQm~I!xQ!C z_ak`KjvKfCc&7h2)A0$-XEVXb45p*!ANVq|C3U!#Nra9%5A>Jzvq%(#)9`8NES%p= zD&^*4O9pR|#yfh*H-)`}+hYAK6YNOxVnKnZ5DbtZDCp0mvNU!Eo2j`vTiILuky{BW z3f*kLz9qPK68%0&6~W;J_%0H7e87zU_sJuMQ=;AxLmIn{yj^rY?p{y2?mf;|S*%aT z1PQ$n^g$U=UplHiOD?WgbOar72A-K_$h^IPjjKHQ9u*m7+2T{TU6#pqs8wA6h#mNPL-0RoLGHK<7#JPWQ*+wX2Q>GLWX>EHR^QZ{JjKr zKdAC?cIHr7VYf^m&0M)XE0>w&rXR*vB(D$iUp*}AF4R1Vaq6V!blR)fC=+R*NS^yN zxXfG2!Do0CfmneZiWoaqvrYiE-<{#7JVRs&KO$IK@V|Cd>^Mhl8g?_dN2eG!ntJzJ z?+Ms(D0D(v#33G}{KI=cl%78a?|(=?zlQKXg`bd^&d4q{kkmQMorI@P!j~5)5^B5@ zO&S|8CRelYbIGyipj)3>BJlh5b;sd#p7d+MVnTDYIHrjf=of3IxRzB1KC{g4#H#tW z=Rru!g0aTqEFEbuN=ot(eX-iu$ki`YsBZDP{d;DrqmhbZg}=?nMR>6dT)re6DJ!wI zJ2FD3V^yl(XIk^qdd@GTs&ls+?MEtTpa1T&?9p(D+~e41!g;cf`=!1g1)7NMc${U~ z_@TiT*?szWe1`P9qO7~gbaf4R^H&l80PPf>qSsmZ??nIUfh!<2ZOM z*;XG5ybnF%R(~2NSk5)BslYrjs%`V#P6NBhX+&c~Y=ND_8hh1658q9(-1E3nFX_vX zU|L2|jJEZg!#D3h?LMs6RE)K)PFXSn*}H@K&Ar1=OE0^M7J@_pP_V%NznybXa+E_v zPE!kTS`XQ8-3>5lgxm&+k$=lS3%cQxT96h~h+>EVQ4Gx-OqHD-99>vV9h|{`&K>_( zD1-!-Z@i*XHybGATy~!PNZ85W(yE;5HNXuH6H|=LJ?VB3aBS}}z!8s!=Tp|k!7PE9gsLatGrh1 zu_tp!x}#MxxStTOWuwTCm0X2#LYFo5JzySlfVlME(a3Lbf{_M^Mj(V{0skJ2#*U8v z4MvE|{@gO+JMHIw%oM_1p@r^6W!Te&ifXD9FnUi{0pz^3!2CE23c)=D16kl?V)$KD z_t`nu_aymsQAJ8~Y7SHZU#DbLzxvKim%T`NlR=AN zhZtmw5Rv(Dm9jcqSX2^*o)S6vktSs*{)XJt$XhP*nK;$UX}SU~`=&~U{FSktYj~og zaY2#QDk(?mL|S8qe)v9}m;2nz?RytvL)N&6+L14TWuuR*THxnw^uG=P`!I8cU z-b~8b&+p9q1V{SDr1B2%L>rv%M=VI9g0|NYoZf+0nHyk?UNUi<&?1}rnzmC6ItD~&&Pk?`}8~Drc$MFp$N&dEq;GyBexzt~#pCHQHgPGNb z#{Zm1`eh0Lpgj6%{C`XixxaI|7Yd?yZI-wznTA8$CVY}ATbCDCdiKl62=)n(g*NA DxVh?K diff --git a/reproduction/text_classification/train_awdlstm.py b/reproduction/text_classification/train_awdlstm.py index e67bd25b..007b2910 100644 --- a/reproduction/text_classification/train_awdlstm.py +++ b/reproduction/text_classification/train_awdlstm.py @@ -33,9 +33,9 @@ class Config(): task_name = "IMDB" datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} - load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" save_model_path="./result_IMDB_test/" -opt=Config + +opt=Config() # load data diff --git a/reproduction/text_classification/train_lstm.py b/reproduction/text_classification/train_lstm.py index b89abc14..4ecc61a1 100644 --- a/reproduction/text_classification/train_lstm.py +++ b/reproduction/text_classification/train_lstm.py @@ -30,9 +30,9 @@ class Config(): task_name = "IMDB" datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} - load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" save_model_path="./result_IMDB_test/" -opt=Config + +opt=Config() # load data diff --git a/reproduction/text_classification/train_lstm_att.py b/reproduction/text_classification/train_lstm_att.py index b4d37525..a6f0dd03 100644 --- a/reproduction/text_classification/train_lstm_att.py +++ b/reproduction/text_classification/train_lstm_att.py @@ -32,9 +32,9 @@ class Config(): task_name = "IMDB" datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} - load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" save_model_path="./result_IMDB_test/" -opt=Config + +opt=Config() # load data From 8156f3c69e8e79eb5050f20fea046092e9d3ad4f Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Mon, 8 Jul 2019 05:14:36 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- reproduction/text_classification/README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md index b058fbb2..4b8f44bd 100644 --- a/reproduction/text_classification/README.md +++ b/reproduction/text_classification/README.md @@ -3,20 +3,20 @@ char_cnn :论文链接[Character-level Convolutional Networks for Text Classification](https://arxiv.org/pdf/1509.01626v3.pdf) dpcnn:论文链接[Deep Pyramid Convolutional Neural Networks for TextCategorization](https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf) HAN:论文链接[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) +LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding]() +AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models]() #待补充 -awd_lstm: -lstm_self_attention(BCN?): -awd-sltm: # 数据集及复现结果汇总 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) -model name | yelp_p | sst-2|IMDB| -:---: | :---: | :---: | :---: -char_cnn | 93.80/95.12 | - |- | -dpcnn | 95.50/97.36 | - |- | -HAN |- | - |-| -BCN| - |- |-| -awd-lstm| - |- |-| +model name | yelp_p | yelp_f | sst-2|IMDB| +:---: | :---: | :---: | :---: |----- |:---: +char_cnn | 93.80/95.12 | - | - |- | +dpcnn | 95.50/97.36 | - | - |- | +HAN |- | - | - |-| +LSTM| 95.74/- |- |- |88.52/-| +AWD-LSTM| 95.96/- |- |- |88.91/-| +LSTM+self_attention| 96.34/- | - | - |89.53/-| From 8f78bf5250e7f183575a7dd6603aa1315668b217 Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Mon, 8 Jul 2019 05:17:50 +0800 Subject: [PATCH 4/4] =?UTF-8?q?readme=E6=A0=BC=E5=BC=8F=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- reproduction/text_classification/README.md | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md index 4b8f44bd..08c893b7 100644 --- a/reproduction/text_classification/README.md +++ b/reproduction/text_classification/README.md @@ -1,22 +1,28 @@ # text_classification任务模型复现 这里使用fastNLP复现以下模型: + char_cnn :论文链接[Character-level Convolutional Networks for Text Classification](https://arxiv.org/pdf/1509.01626v3.pdf) + dpcnn:论文链接[Deep Pyramid Convolutional Neural Networks for TextCategorization](https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf) + HAN:论文链接[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) + LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding]() + AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models]() + #待补充 # 数据集及复现结果汇总 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) -model name | yelp_p | yelp_f | sst-2|IMDB| -:---: | :---: | :---: | :---: |----- |:---: -char_cnn | 93.80/95.12 | - | - |- | -dpcnn | 95.50/97.36 | - | - |- | -HAN |- | - | - |-| -LSTM| 95.74/- |- |- |88.52/-| -AWD-LSTM| 95.96/- |- |- |88.91/-| -LSTM+self_attention| 96.34/- | - | - |89.53/-| +model name | yelp_p | yelp_f | sst-2|IMDB +:---: | :---: | :---: | :---: |----- +char_cnn | 93.80/95.12 | - | - |- +dpcnn | 95.50/97.36 | - | - |- +HAN |- | - | - |- +LSTM| 95.74/- |- |- |88.52/- +AWD-LSTM| 95.96/- |- |- |88.91/- +LSTM+self_attention| 96.34/- | - | - |89.53/-