From ace8aafdb2582474967ab2cfe944d600d0eddff1 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 6 Apr 2020 01:40:26 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBertEmbedding=E5=9C=A8include?= =?UTF-8?q?=5Fcls=5Fsep=3DTrue=E6=97=B6=EF=BC=8CSEP=E4=BD=8D=E7=BD=AE?= =?UTF-8?q?=E7=9A=84=E5=80=BC=E4=B8=8D=E6=AD=A3=E7=A1=AE=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/bert_embedding.py | 8 ++--- fastNLP/modules/encoder/bert.py | 2 +- .../embedding/small_bert/config.json | 2 +- .../small_bert/small_pytorch_model.bin | Bin 37965 -> 38083 bytes .../embedding/small_bert/vocab.txt | 1 + test/embeddings/test_bert_embedding.py | 33 +++++++++++++++++- 6 files changed, 39 insertions(+), 7 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 3bd448aa..660e803e 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -224,9 +224,9 @@ class BertWordPieceEncoder(nn.Module): 第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。 :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) """ - with torch.no_grad(): - sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len - if token_type_ids is None: + if token_type_ids is None: + with torch.no_grad(): + sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) token_type_ids = sep_mask_cumsum.fmod(2) if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 @@ -462,7 +462,7 @@ class _WordBertModel(nn.Module): outputs[l_index, :, 0] = pooled_cls else: outputs[l_index, :, 0] = output_layer[:, 0] - outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift] + outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, word_pieces_lengths + s_shift] # 3. 最终的embedding结果 return outputs diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 3496c5f6..4523163b 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -1011,7 +1011,7 @@ class _WordPieceBertModel(nn.Module): if word_pieces[0] != self._cls_index: word_pieces.insert(0, self._cls_index) if word_pieces[-1] != self._sep_index: - word_pieces.insert(-1, self._sep_index) + word_pieces.append(self._sep_index) return word_pieces for index, dataset in enumerate(datasets): diff --git a/test/data_for_tests/embedding/small_bert/config.json b/test/data_for_tests/embedding/small_bert/config.json index 3e516872..da4cda35 100644 --- a/test/data_for_tests/embedding/small_bert/config.json +++ b/test/data_for_tests/embedding/small_bert/config.json @@ -9,5 +9,5 @@ "num_attention_heads": 4, "num_hidden_layers": 2, "type_vocab_size": 2, - "vocab_size": 20 + "vocab_size": 21 } \ No newline at end of file diff --git a/test/data_for_tests/embedding/small_bert/small_pytorch_model.bin b/test/data_for_tests/embedding/small_bert/small_pytorch_model.bin index fe968fb5d64a87b224d0ed9d793e6bf3aeb70971..a0811def26d6ae1820f4b9ba03914e233d163047 100644 GIT binary patch delta 6678 zcmaJ_cU%<59;Zn!2UrjrHpBwn9e2kOy&Z@p3KrZ1yx0*@P!VTbH5&Cmg|$arP3*mQ zLo|tcJH{5fF&bkROKk7uX<|vD(f5A4v+xO*-1Cp$e)r4x{eEX|c4ubKZ1Ot1*=z2F zp+fG6QH7-c7-Jyzii@-Rd3bn?5(-K7kX-DY10W|au;cK;Y~g)I24Ww37_>2%^ycIw zlfht2G@H#vT+QwcY|hBh*tctjXEy=+^+@+fuaT}R#s2BKbk9;8AocgG6B5{*$3X%P z=5dI<#k_fu(nilZo}u$@)fzA|n`9T{3b`Z@hvh+(Jny`}SsAQm3)bcNcJ=TAi*Usv zMzM%s7Ln8Q6#eP~uEFD))6|<LW9s#pdIJH+@VFsUZzSTzJZ@r-QPg*Ng=aMVXLWsIk|oIk>an1%1NCN#x~>Gr zrF)j(=1jSTfLrpo6;qBEaRQH9OFHi+8G1LOR)fWAG@7tM6E$k0iHfL+iJAp$;c*fZ zwTd{I$L~nzgQ6sTwQy;hcSvTNe-<|w3|6Z-$%5M|!>#};4VDw z%JjR5xI2$CK;M+v!@a1%s<&dkBKiB3$v4FZB!c?WvACacg1HCmHp4l-2=3 z>%eHG17n!_SOI^;<8e%Vyof*M@dUe33iD|so%F4qInlke!3@%qD$+YD(-*Py#Y}&) zfT!?yD$_3!aVd{QZGeKoA%yi{ntPi>lSz-KYdx5u^q`!K|mD$g?sg+WQ!H=f%GXL zzE>45!Cxxrzhc5)3wR%o_cP%GB6jjvw#UKfL`fq9L!_0FzIIIA>nFq85+AHc|BX64 z=_#f75Oe}olEcj8h=7ms_!u+!R>a47e8L{9nM7oqbgOUHo2+_#N*n#t+UTEA@}FhW z=LGy6kIysd??rrp#}{E>jQkJoMGY2<318Agf7C=TE23AJ=v4t< zLd)ydzf1PGT+*$O6tFjDxqBz0GI?%mPIok?yNc62=Jb<*@ALQpb9yM^M?8KEPnIR~ zU+zUs`a~oCcf~XNGaGZ5G5=vR<_S}ND&Svu{3}y_CgSHjej$~IR(aRI{68jUiyr@` zCI4MZ{)dwM71Mt$;6Hi%hUvc*@qc;ZA>9aVwct%)SL6k+f9V zR#I_T8%I1~x0V)tvPDl4lqOoYTLZ4_wjOR+w+(=puMvwAw*jbhyDeOqy$z5+IH){9P?>!SAd#f1?7L9y zmMT-AO+>KNJv2~yt&1Ab4*u*uOQI2`L3^bM%hCa$mZc*=CChtoWm!4_5=dvLJV}Gf zvZMnNNf(vvDiwv>9NhrBw=|hdR?=N*!YnfYYL-0!DwaH4ndSR{1o8n?o@7F0mRW#A z(o<#aP%9@LdjkFw8Z14Hh;(Rgf)AmE`#5GxqLK7cny_V%^aiZtQC33I2X5Gm=nIHt z>IX<5{h{(?092MK8<0o_s%#Eadn;*SNEGCVatwl7_W>+PW+TZ}ny4!!84OVMFCjw! zD&vs{SLU4$NFYO@@?;oP<~=n_)*Qml37a1i0BZANB0yz4Cc%|OZ~ziW5mcTOLuC;r0}{y;m7OZh4XG!Msu5l>9ZG<9 z>z<^CwTF}{O_;j~P;;LKP;s9QSLQwgkU(Za<;g6l%zZW>k<3xqPoUacr~@FGA|fG2 zO~+i|+y^iunM`D!(u4(=4^Rsr0aOAkfGZ2I5RgDVh02r9pt1mq0EuL=$}WLwkE)ms zBMS?MOMW$Vj-^1m&tOb8TFEk{iT0LU4wy{^cqUfB4SOb50%Ac{0TRe+s61H%l?7P~ zNF?i2c0E*k&5Eh8ux6xV12FD0m|*)wHY!cnIBWu_jl*Vu$~bI+E3@AUNFdvw@?<+y zX1@cFNOr31E~ujuioE(_k2v;e!oj{xD2Ge+9O(JiFBX+YPOqx0cB1okWz|!Al@Ag6 zieq=RzqYMBFZxP$oY23q$_6vcl~vCaI#*Ub3xro;9R{OG-NHvyp$$g1@s4z5)tz&7S61ES z)^KIjZEDTGvC6)(Rt4+OzD${wy&#UNpetBqqXr*@s{BYsv$Ac%(XOnrk->FbSaq*~ z>$W}LI>-g*EF<68^xV~#+Wv#>wTv=tw!ZEI_vX0<}uB`G_$BkTBW!(GV#xAt_ z@`KM*Rr{QzypeEIS5|pP;8<5yd1`U9+=Q~eGctUni5&u^Cn=m1U%Y7To;f~JZm;QO zSgI5CR!{3n{8hmig zkrmGUpR7eYx93xvki32|b=Mi8{{2Zxf02Lfc*)tld^SC|RYx1*Z@D^6YtpN|ICOH% zY}7KgKb<&iAJ_fJ2lT?DWO?0KUz+x;1??IB5^eJ=bxu!CM^9tla&IOzM8SFWQ)VP@ zM_of|JGlG&Dd*+o9(2~Jd^DqcpPYHP8oCoSfQsunI)~k;fqG`{puQ($?xS4+wAHVl zQeQ!&bGHO=10!P4^Nq97AF~eA8^@|spK0f8c>7~#^s=qgmXw8lv(Ba^4-U{E^XFWe zC4}Q*yi+E6zv8l5J(FA4oXcURD}XYa|& zf`+1;b+v7eC#{p`O!PtxGk%npc}(Zh6Q0x5d&dsw5-&In=MOrU{_LF+H?xG}zA5KA zPShd4=NCEMprds1moK?nhX&IuYhU`vV}V?7@-}y6ghVeLSV(nEzm=y9dc^(i$ZAO& zY^jYhkEPJ3xt=LWo}Fp^HK7y--L(zBeUsZfw-B|mo1BIIXVCPJo7~gJ(P+WR#oV1s z#q`#~RJ8xpUX**fIrrT!*W@qH9HbBbES4jCjH2HioIpw17xK{!IVk^du+#n|4fzil zPs>m5qpzPIM|Tas*w$Q%r;bf8%29MJ6MZ#oDSFeSHQlxR8tVA?7`odxjJw^u9oL{; z8rp37n(ivAkM8dHUJlt6iarq&>Cw>5R7OWR`FRYw;V~59J)7x+HxJPB9U0t#w$HhS z1vYfZ*NcWPHPVLnTR8bGH_-iyJiQv$9+kbVFMlwrVaj&tt%Ez3^O74rGKr3!7LC$& z70Vm_rqJ^Xdm(S@8M)DvW@xuHK)!v=%(c>8M!WpKql2%{MB5MDL6>i&B4_P1+Is2_ zT!-Dm=!<(!y5)RRG!EaECr?-}|0>;~hekD}rPGGmitb-Vi#qJ*ZWP7Rd*4T(kcLIH zUZ1bzALsL{(Bic-=>CH@(c&+zq4)!BQ2vEaupR%+4{dAr>>2~3J%DcGiMK5R;pXy&$(5$nMl#uFIRxc~iM|xa{ e$`VpP^O3r??JU86S$n3#UsV*U%-J%o#{U4=71BHa delta 6564 zcmZ`;dwdPY9#7)LNu)_mUIcj|;!)zfUpuD_^{ODOR1X!;grrI`V+pr7C|0Xs_n+uo|HrFu(AsnYiTc4w1DK8`SUV^qa4?5MhGqt2#%Agk=30iku*hCTLlr z2)5#IJd0aX^)@_i%i;ugH(5O?C~;WgGu7=DgTrAp+i_C9HZujplV$NBSsXIR?%PrE z_8jiO;#4Z$k;iE)?j)=XN*mhQuddDJFybzts{`GxvaW6sem*mB5$;AcyK^|5#XYE| zp2r3j8{J95s-VQ|@de|iVUu5Ji^F8aW>wmvN?T=V8;rKuw09tv=+Un_)8q_&Eh^(yf2UYvADlGT`*|dgM8R< z#rw;CbLbsLJ075V&s4n+l)Vq4@{l=;#Y3q4P#&`^9;V9g(&)0Y{VAIqdYciyQty6v zgZo#hdpDII!QqiCevQhH;&Bd(M+;HG#;h@Zg>6QI6^~Udk5es=mn~1Anz3(yVEP6AZp_(sN&Cit0&!X~w@pi zzo5bWTh#qRD!+)si&@N5`6WCqVewL^jdkcUf5JwG89MUss^#UXd?a zMl}T3M52-nE&MXD=t*vH~ps`~+tKV&G zsKr087XM$e{ZFa-1rA?i@n=;15|2M;@nxkfcSE&;HVT%YD}KezcDo*5RmHzh#lMur z|4qfe;_%lj{tp%ZhR4@f{H?pIYF#VL42#UV?pNKUH=6N(RrMRH`ggMW_f-8Rhi|d? z2daLX$3L?8e=u|6g*#zM!+-LtYcN>t_)h(_yi2PN!{ui>T<%fv`y8%g@h?>T0goTD z_*Z36wgmZBIm%1?k)Iqe@lCK${HA(;ta^VUd;guv|H0uuSrQN!5TIlrfq-}t1d%1p z1VeP9FeD<s|hV*zm_ z4kAlhLZmrc0pdx#!nTI!*28{jd>-M_^QKNH%oSeHhC;=ogfS7}o;L8iA1?;jrkL%d zt=vK_LIU81A|%2MEkY6?YA6{HM^YfNq#ZJVS*c44>KU@!vcsSR){RIL8LzHfOz6iSSQ54 zI!K^CS_@;f8V_OMH@{|VFzE=AThMyA0IKzP0iaxu7vV~MWB}qwPlzn(1(EuA2@p?u zD{LQ#Zfku#f{ zP`2|5T&bPmfH?9hM3%TAQad96@nodJz6Ng(n`ac@XJ(Xl^m=%2lN`APT{FpOz;a4? z*(77&hK`J}fN0urfH*Q9B1j@gz@SCqi_0tFMAk5v~&kN5{~aG6{Zrwi&D5 z?y&0hBwuc!tg)m3pyFi_c^#nq_>u`DLl zLx6gAdW*qofv*rUO>Uvqf&;9SA68;A9d4Axn9Kk~Jro1t$V`YVnFW!0_!}Ueys5CW zA-dBVJjk=e6&pKc4$yvm*kH3r=E^PTDDVJON5MRRavA2sm0DN;h$C-7WXVE^)WRY_ zJXx%;yigjRD0t$c%EFp^$P(cEW_G{^n3TvZXy&B=)y&HP%9;NTSDJY_Adakn$dZ*1 zY35QuJXxi%Wf0w6>mR;hYR-v^_XzNlKN~g!?Da{x+=ALz4N$eQ2B2)?ZMae!?*QV+ zT8J!J2a(!%7Z6X@D{KWscYM8#qQb&A@`SjSk)92}`}JUg4;iviZlS(1-vdks15Auf za6>1?Wmu6^+KGUJ})Q-2Xt1?Cp_{2MZ>v4;ImFG zf*O-BOXLj(Zq|fUHaVImq&%N+;{ytK;aKIt4JRS$@rXx09|2Y#cvxG{Qc@WPIHU7l-n(TJQrYddYC%i@KEInU`E)a1Xx(}7_Z=6co=JmUZ}d-J52+Na9dar9ACcj6_qep8hsoMO?LFgGe*?lGYbH>XJ# zI!qA9pFUW#IcB>QP~I1n&6$NZ+^CbLEr=6eznINbjn89Z5^}}0yT_xFvInSW)=$#h zqh2ZE;#D!{z;VfZ{c9$C?K{#FZ#S3e{@3FDNjGX{Up^#Oln?bS=(|iZE}hQQo*02X zp8f-4t3HqVO*KfZ&is>E;EDR5xT&l=I$WTW#zkJ04t>=LHNW=V(T>IK(D(Z0z76LF zOQBcJi&y7gmDX4*B<8cj&O0|wqxr3Ni$%X}XRZ}jFqhijl58KR`BsuMl$IK{*8K7V&}1vBTUbnOf=}qY351k87co@fX_a&67|lG^-a@vly*9H zI-iU?F3#LO0aacvN5!)aF^9YMKx1OGzUK{<(%mW1uIg`lJ116LLBF0mi>{6+lnTc* zcSY>!;XJ%+wUqlL!h`xPZR^q=+K*0tz@h@*4KX!sA2TdLY!%riE`(&m$;)T9RvS?M)OEm0HuH^-0m#{%z;6E>Ymono! zoI~R)L?)p!*H@JNgX8mnXz5JAF*Nwki_#2S!d%SGWA2aioN*Q&eIj}{)uP88z4n-o zMu}raE)Wwc3z%FfRGM$xD}M7Z+I535pu`nvD1AeYB;5=}I@1`&po5b#q z519_!ngvxZ@tU0_^QQ<$y0K9)vz%#%ii)FExx)o58sJ87b>T70xf>q}X; z4;7XFDE7L%P}=lLca$;s3KQQq6ZPqPP2$Iu*Dxr-x2dz4aiknVv3pXaqTusV%BXLp zm1i$9-xW6Z&027mS@H5Isjfq+7(IClZZ=pG-j-iBJv978o%TaAyTan#- zvnF_Bt$6gy6H?!#wdfx&dYI%{Tf_+6YbYu3mbi6Dito3a5>$KiSLv%0gP30iPZurS zK4e;qL#X-rBVukuxtO=`9<#YTdNk^vq;tjACZmE&L&R%!U6HrhxS<+vxu-{*@MThe z@0NmhG~TzShUfxy!lwiJd$aSRG=d@7sR;;}_o5(nDDfT`m8F5E@V_3;6