From c707cd4336ce8f02ddf996271e2acd6e562641f1 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 3 Nov 2020 16:41:22 +0800 Subject: [PATCH 1/2] =?UTF-8?q?CRF=E7=BB=B4=E7=89=B9=E6=AF=94=E8=A7=A3?= =?UTF-8?q?=E7=A0=81BUG=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/io/pipe/cws.py | 2 +- fastNLP/modules/decoder/crf.py | 51 +++-- test/data_for_tests/modules/decoder/crf.json | 1 + test/modules/decoder/test_CRF.py | 202 +++++++++++++------ 4 files changed, 179 insertions(+), 77 deletions(-) create mode 100644 test/data_for_tests/modules/decoder/crf.json diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py index c3aab4e6..3849a34b 100644 --- a/fastNLP/io/pipe/cws.py +++ b/fastNLP/io/pipe/cws.py @@ -122,7 +122,7 @@ def _find_and_replace_digit_spans(line): otherwise unkdgt """ new_line = '' - pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%%,。!<-“])' + pattern = r'\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%%,。!<-“])' prev_end = 0 for match in re.finditer(pattern, line): start, end = match.span() diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index b5ffa35d..0a05f6f4 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -198,8 +198,18 @@ class ConditionalRandomField(nn.Module): constrain = torch.zeros(num_tags + 2, num_tags + 2) else: constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float) + has_start = False + has_end = False for from_tag_id, to_tag_id in allowed_transitions: constrain[from_tag_id, to_tag_id] = 0 + if from_tag_id==num_tags: + has_start = True + if to_tag_id==num_tags+1: + has_end = True + if not has_start: + constrain[num_tags, :].fill_(0) + if not has_end: + constrain[:, num_tags+1].fill_(0) self._constrain = nn.Parameter(constrain, requires_grad=False) initial_parameter(self, initial_method) @@ -290,14 +300,15 @@ class ConditionalRandomField(nn.Module): scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 """ - batch_size, seq_len, n_tags = logits.size() + batch_size, max_len, n_tags = logits.size() + seq_len = mask.long().sum(1) logits = logits.transpose(0, 1).data # L, B, H mask = mask.transpose(0, 1).data.eq(True) # L, B flip_mask = mask.eq(False) # dp - vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) - vscore = logits[0] + vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long) + vscore = logits[0] # bsz x n_tags transitions = self._constrain.data.clone() transitions[:n_tags, :n_tags] += self.trans_m.data if self.include_start_end_trans: @@ -305,36 +316,44 @@ class ConditionalRandomField(nn.Module): transitions[:n_tags, n_tags + 1] += self.end_scores.data vscore += transitions[n_tags, :n_tags] + trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data - for i in range(1, seq_len): + end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags + + # 针对长度为1的句子 + vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \ + .masked_fill(seq_len.ne(1).view(-1, 1), 0) + for i in range(1, max_len): prev_score = vscore.view(batch_size, n_tags, 1) cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score - score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) + score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) # bsz x n_tag x n_tag + # 需要考虑当前位置是该序列的最后一个 + score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0) + best_score, best_dst = score.max(1) vpath[i] = best_dst - vscore = best_score - - if self.include_start_end_trans: - vscore += transitions[:n_tags, n_tags + 1].view(1, -1) + # 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况 + vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ + vscore.masked_fill(mask[i].view(batch_size, 1), 0) # backtrace batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) - seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) - lens = (mask.long().sum(0) - 1) + seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device) + lens = (seq_len - 1) # idxes [L, B], batched idx from seq_len-1 to 0 - idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len + idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len - ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) + ans = logits.new_empty((max_len, batch_size), dtype=torch.long) ans_score, last_tags = vscore.max(1) ans[idxes[0], batch_idx] = last_tags - for i in range(seq_len - 1): + for i in range(max_len - 1): last_tags = vpath[idxes[i], batch_idx, last_tags] ans[idxes[i + 1], batch_idx] = last_tags ans = ans.transpose(0, 1) if unpad: paths = [] - for idx, seq_len in enumerate(lens): - paths.append(ans[idx, :seq_len + 1].tolist()) + for idx, max_len in enumerate(lens): + paths.append(ans[idx, :max_len + 1].tolist()) else: paths = ans return paths, ans_score diff --git a/test/data_for_tests/modules/decoder/crf.json b/test/data_for_tests/modules/decoder/crf.json new file mode 100644 index 00000000..ff2d6689 --- /dev/null +++ b/test/data_for_tests/modules/decoder/crf.json @@ -0,0 +1 @@ +{"bio_logits": [[[-1.8154915571212769, -1.3753865957260132, -10001.513671875, -1.619813084602356, -10001.79296875], [-1.742034673690796, -1.5048011541366577, -2.042131185531616, -1.2594754695892334, -1.6648437976837158], [-1.5522804260253906, -1.2926381826400757, -1.8607124090194702, -1.6692707538604736, -1.7734650373458862], [-1.6101375818252563, -1.3285458087921143, -1.7735439538955688, -1.5734118223190308, -1.8438279628753662], [-1.6522153615951538, -1.2640260457992554, -1.9092718362808228, -1.6192445755004883, -1.7168875932693481], [-1.4932769536972046, -1.4628725051879883, -1.9623159170150757, -1.497014045715332, -1.7177777290344238], [-1.8419824838638306, -2.1428799629211426, -1.4285861253738403, -1.2972710132598877, -1.5546820163726807], [-1.671349048614502, -1.4115079641342163, -1.624293565750122, -1.537371277809143, -1.8563929796218872], [-1.5080815553665161, -1.3281997442245483, -1.7912147045135498, -1.5656323432922363, -1.980512022972107], [-2.0562098026275635, -1.4711416959762573, -1.5297126770019531, -1.7554184198379517, -1.3744999170303345]], [[-1.3193378448486328, -1.997290849685669, -10002.0751953125, -1.3334847688674927, -10001.5712890625], [-1.229069471359253, -1.2702847719192505, -2.0717740058898926, -1.9828989505767822, -1.8136863708496094], [-1.8161871433258057, -1.4339262247085571, -1.4476666450500488, -1.8693819046020508, -1.562330722808838], [-1.897119402885437, -1.5767627954483032, -1.54145348072052, -1.6185026168823242, -1.4649395942687988], [-1.8498220443725586, -1.264282464981079, -1.7192784547805786, -1.8041315078735352, -1.530255913734436], [-1.1517643928527832, -1.6473538875579834, -1.5833101272583008, -1.9973593950271606, -1.894622802734375], [-1.7796387672424316, -1.8036197423934937, -1.2666513919830322, -1.4641741514205933, -1.8736846446990967], [-1.555580496788025, -1.5448863506317139, -1.609066128730774, -1.5487936735153198, -1.8138916492462158], [-1.8701002597808838, -2.0567376613616943, -1.6318782567977905, -1.2336504459381104, -1.4643338918685913], [-1.6615228652954102, -1.9764257669448853, -1.277781367301941, -1.3614437580108643, -1.990394949913025]], [[-1.74202299118042, -1.659791111946106, -10001.9951171875, -1.0417697429656982, -10001.9248046875], [-1.2423228025436401, -1.7404581308364868, -1.7569608688354492, -1.5077661275863647, -1.9528108835220337], [-1.7840592861175537, -1.50230872631073, -1.4460601806640625, -1.9473626613616943, -1.4641118049621582], [-1.6109998226165771, -2.0336639881134033, -1.3807575702667236, -1.221280574798584, -2.0938124656677246], [-1.8956525325775146, -1.6966334581375122, -1.8089725971221924, -1.9510140419006348, -1.020185947418213], [-1.7131900787353516, -1.7260419130325317, -2.161870241165161, -1.2767468690872192, -1.3956587314605713], [-1.7567639350891113, -1.1352611780166626, -1.7109652757644653, -1.8825695514678955, -1.7534843683242798], [-1.826012372970581, -1.9964908361434937, -1.7898284196853638, -1.2279980182647705, -1.413594365119934], [-1.522060513496399, -1.56121826171875, -1.5711766481399536, -1.4620665311813354, -2.0226776599884033], [-1.3122025728225708, -2.0931777954101562, -1.8858696222305298, -1.831908106803894, -1.2184979915618896]], [[-1.3956559896469116, -1.8315693140029907, -10001.48046875, -1.844576358795166, -10001.5771484375], [-1.562046766281128, -1.7216087579727173, -1.5044764280319214, -1.4362742900848389, -1.8867106437683105], [-1.5304349660873413, -1.5527287721633911, -1.5590341091156006, -1.6369349956512451, -1.7899152040481567], [-1.6007282733917236, -2.054649829864502, -1.9757367372512817, -1.4219664335250854, -1.2371348142623901], [-1.841418981552124, -1.8178046941757202, -1.5939710140228271, -1.2179311513900757, -1.7144266366958618], [-1.6715152263641357, -1.5060933828353882, -1.6629694700241089, -1.633326530456543, -1.5827515125274658], [-1.9413940906524658, -1.853175163269043, -1.6390701532363892, -1.2217824459075928, -1.5564061403274536], [-1.746218204498291, -1.7089520692825317, -1.6738371849060059, -1.627657175064087, -1.344780445098877], [-1.1776174306869507, -1.629957675933838, -1.79096519947052, -1.7566864490509033, -1.853833556175232], [-1.4880272150039673, -1.4722591638565063, -1.631064534187317, -1.9562634229660034, -1.5718109607696533]]], "bio_scores": [-1.3754, -4.5403, -8.7047, -12.8693], "bio_path": [[1], [3, 0, 1, 1], [3, 0, 1, 3, 4, 3, 1, 3], [0, 1, 1, 0, 3, 0, 3, 0, 1, 0]], "bio_trans_m": [[-0.095858134329319, 0.01011368352919817, -0.33539193868637085, -0.20200660824775696, 0.136741504073143], [0.5436117649078369, 0.37222158908843994, -0.15174923837184906, 0.10455792397260666, -0.35702475905418396], [0.3681447505950928, -0.6996435523033142, -0.002348324516788125, 0.5087339282035828, -0.08750446885824203], [0.6505969762802124, 0.0064192176796495914, -0.10901711881160736, -0.24849674105644226, -0.1375938355922699], [-0.019853945821523666, -0.9098508954048157, 0.06740495562553406, 0.2244909256696701, -0.29204151034355164]], "bio_seq_lens": [1, 4, 8, 10], "bmes_logits": [[[-10002.5830078125, -20002.54296875, -10001.9765625, -2.033155679702759, -10001.712890625, -20001.68359375, -10002.4130859375, -2.1159744262695312], [-1.870416283607483, -2.2075278759002686, -1.9922529458999634, -2.1696650981903076, -2.4956214427948, -2.1040704250335693, -2.065218925476074, -1.869700312614441], [-1.8947919607162476, -2.398089647293091, -2.1316606998443604, -1.6458176374435425, -2.001098871231079, -2.362668514251709, -2.513232707977295, -1.9884836673736572], [-1.5058399438858032, -2.3359181880950928, -2.382275342941284, -2.4573683738708496, -1.7870502471923828, -2.342841148376465, -2.1982951164245605, -2.0483522415161133], [-2.0845396518707275, -2.0447516441345215, -1.7635326385498047, -1.9375617504119873, -2.530120611190796, -1.8380637168884277, -2.099860906600952, -2.666682481765747], [-2.299673557281494, -2.3165550231933594, -1.9403637647628784, -1.8729832172393799, -1.8798956871032715, -1.8799573183059692, -2.2314014434814453, -2.39471173286438], [-1.9613308906555176, -2.136000633239746, -2.1178860664367676, -2.1553683280944824, -1.7840471267700195, -2.4148807525634766, -2.4621479511260986, -1.817263126373291], [-2.056917428970337, -2.5026133060455322, -1.9233015775680542, -2.0078444480895996, -2.064028024673462, -1.776533842086792, -2.3748488426208496, -2.114560127258301], [-2.3671767711639404, -1.7896978855133057, -2.416537284851074, -2.26574444770813, -2.2460145950317383, -1.7739624977111816, -1.9555294513702393, -2.045677661895752], [-2.3571174144744873, -1.820650577545166, -2.2781612873077393, -1.9325084686279297, -1.863953948020935, -2.2260994911193848, -2.5020244121551514, -1.8891260623931885]], [[-2.0461926460266113, -10002.0625, -10001.712890625, -2.251368761062622, -2.2985825538635254, -10002.146484375, -10002.0185546875, -2.225799560546875], [-1.9879356622695923, -2.4706358909606934, -2.3151662349700928, -1.5818747282028198, -2.329188346862793, -2.1170380115509033, -2.159011125564575, -1.9593485593795776], [-2.2397706508636475, -2.2388737201690674, -1.826286792755127, -2.444268226623535, -1.7793290615081787, -2.402519941329956, -1.8540253639221191, -2.09319806098938], [-1.7938345670700073, -2.525993585586548, -1.9962739944458008, -1.9414381980895996, -2.5183513164520264, -2.5057737827301025, -1.7933388948440552, -1.925837755203247], [-2.2330663204193115, -2.098536491394043, -1.9872602224349976, -1.7660422325134277, -2.5269722938537598, -1.9648237228393555, -1.80750572681427, -2.551790475845337], [-1.802718162536621, -2.4936702251434326, -1.846991777420044, -2.6299049854278564, -1.8180453777313232, -2.010246992111206, -1.9285591840744019, -2.5121750831604004], [-1.7665618658065796, -2.2445054054260254, -1.822519063949585, -2.5471863746643066, -2.719733715057373, -1.9708809852600098, -1.7871110439300537, -2.2026400566101074], [-2.2046854496002197, -2.375577926635742, -1.9162014722824097, -2.397550344467163, -1.9547137022018433, -1.759222149848938, -1.818831443786621, -2.4931435585021973], [-1.9187703132629395, -2.5046753883361816, -1.871201515197754, -2.3421711921691895, -2.372368335723877, -1.883248209953308, -1.8868682384490967, -2.0830271244049072], [-2.406679630279541, -1.7564219236373901, -2.340674877166748, -1.8392919301986694, -2.3711328506469727, -1.913435935974121, -2.221808433532715, -2.019878625869751]], [[-10001.7607421875, -20002.30078125, -10001.9677734375, -1.7931804656982422, -10002.2451171875, -20002.15234375, -10002.208984375, -2.4127495288848877], [-2.162931442260742, -2.121459484100342, -2.4020097255706787, -2.5620131492614746, -1.7713403701782227, -2.1945695877075195, -1.8392865657806396, -1.8513271808624268], [-2.2151875495910645, -1.9279260635375977, -2.24403977394104, -2.1955597400665283, -2.2283377647399902, -1.7366830110549927, -2.634793519973755, -1.757084608078003], [-1.813708782196045, -1.93169105052948, -2.2419192790985107, -2.307635545730591, -2.19914174079895, -2.070988178253174, -2.0030927658081055, -2.1678688526153564], [-2.118651866912842, -1.867727518081665, -2.312565326690674, -2.274792194366455, -1.9973562955856323, -2.000102996826172, -1.8425841331481934, -2.3635623455047607], [-2.435579538345337, -1.7167878150939941, -2.3040761947631836, -1.657408595085144, -2.462364912033081, -2.2767324447631836, -1.7957141399383545, -2.425132989883423], [-1.806656837463379, -1.7759110927581787, -2.5295629501342773, -1.9216285943984985, -2.2615668773651123, -1.8556532859802246, -2.4842538833618164, -2.3384106159210205], [-1.9859262704849243, -1.6575560569763184, -2.2854154109954834, -1.9267034530639648, -2.5214226245880127, -2.0166244506835938, -2.479127883911133, -2.0595011711120605], [-2.0371243953704834, -2.2420313358306885, -2.0946967601776123, -2.2463889122009277, -1.8954271078109741, -1.942257285118103, -2.0445871353149414, -2.1946396827697754], [-2.0210611820220947, -2.362877130508423, -1.9862446784973145, -1.8275481462478638, -2.140009880065918, -1.869648814201355, -2.6818318367004395, -2.0021097660064697]], [[-1.986312985420227, -10002.50390625, -10002.0361328125, -1.908732295036316, -2.21740984916687, -10002.1318359375, -10002.1044921875, -1.87873113155365], [-1.9292036294937134, -2.163956880569458, -2.3703503608703613, -1.939669132232666, -1.8776776790618896, -2.4469380378723145, -2.423905611038208, -1.7453217506408691], [-2.0289347171783447, -2.520860195159912, -2.5013701915740967, -2.078547477722168, -1.9699862003326416, -1.8206181526184082, -1.7796630859375, -2.1984922885894775], [-1.8523262739181519, -1.978093147277832, -2.558772087097168, -2.498471260070801, -1.9756053686141968, -1.8080697059631348, -1.9115748405456543, -2.357147216796875], [-2.314960479736328, -2.2433876991271973, -1.6113512516021729, -2.19716477394104, -1.78402578830719, -2.343987226486206, -2.3425848484039307, -2.084155797958374], [-2.002289056777954, -2.2630276679992676, -1.887984275817871, -2.044983386993408, -2.217646360397339, -1.9103771448135376, -2.154231548309326, -2.2321436405181885], [-2.199540853500366, -2.063075065612793, -1.813851237297058, -2.3199379444122314, -1.7984188795089722, -2.4952447414398193, -2.4516515731811523, -1.7922154664993286], [-2.509786367416382, -1.79443359375, -1.8561275005340576, -2.2977330684661865, -2.2080044746398926, -1.7294546365737915, -2.4617154598236084, -2.0944302082061768], [-2.491340160369873, -2.403804063796997, -1.8452543020248413, -1.6882175207138062, -2.5513625144958496, -2.294516086578369, -1.9522627592086792, -1.8124374151229858], [-2.1524035930633545, -2.2049806118011475, -2.3353655338287354, -2.317572832107544, -2.2914233207702637, -1.8211665153503418, -1.69517982006073, -2.0270023345947266]]], "bmes_scores": [-2.0332, -6.1623, -1.7932, -16.7561], "bmes_path": [[3], [7, 3, 4, 6], [3], [3, 4, 5, 6, 7, 3, 4, 5, 6, 7]], "bmes_trans_m": [[0.47934335470199585, -0.2151593416929245, -0.12467780709266663, -0.44244644045829773, 0.16480575501918793, -0.006573359947651625, -1.187401294708252, -0.17424514889717102], [-0.03494556248188019, -0.8173441290855408, -0.2682552933692932, 0.18933893740177155, 0.2203899323940277, 0.3905894160270691, -0.007638207171112299, 0.19527725875377655], [-0.2779119908809662, -0.37053248286247253, 0.34394705295562744, -0.26433902978897095, -0.0001995275670196861, -0.39156094193458557, -0.035449881106615067, 0.02454843744635582], [-0.01391045656055212, 0.3419516384601593, -0.48559853434562683, -0.5893992781639099, 0.9119477272033691, 0.1731061041355133, -0.15039317309856415, 0.1523006409406662], [0.4866299033164978, 0.28264448046684265, -0.25895795226097107, 0.0404033362865448, -0.060920555144548416, 0.12364576756954193, 0.1294233351945877, 0.2434755265712738], [-0.04159824922680855, 0.25353407859802246, 0.12913571298122406, -0.036356933414936066, -0.18522876501083374, -0.5329958200454712, 0.2505933344364166, 0.26512718200683594], [-0.2509276270866394, 0.3572998046875, 0.01873799040913582, -0.30620086193084717, -0.09893298894166946, -0.37399813532829285, -0.6530448198318481, -0.17514197528362274], [-0.29702028632164, 0.680363118648529, -0.6010262370109558, 0.17669369280338287, 0.45010149478912354, -0.1026386097073555, 0.34120017290115356, -0.04910941794514656]], "bmes_seq_lens": [1, 4, 1, 10]} \ No newline at end of file diff --git a/test/modules/decoder/test_CRF.py b/test/modules/decoder/test_CRF.py index 85173669..55548a41 100644 --- a/test/modules/decoder/test_CRF.py +++ b/test/modules/decoder/test_CRF.py @@ -132,68 +132,150 @@ class TestCRF(unittest.TestCase): self.assertSetEqual(expected_res, set( allowed_transitions(vocab, include_start_end=True))) + # def test_case2(self): + # # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 + # pass + # import torch + # from fastNLP import seq_len_to_mask + # + # labels = ['O'] + # for label in ['X', 'Y']: + # for tag in 'BI': + # labels.append('{}-{}'.format(tag, label)) + # id2label = {idx: label for idx, label in enumerate(labels)} + # num_tags = len(id2label) + # max_len = 10 + # batch_size = 4 + # bio_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() + # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions + # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), + # include_start_end_transitions=False) + # bio_trans_m = allen_CRF.transitions + # bio_seq_lens = torch.randint(1, max_len, size=(batch_size,)) + # bio_seq_lens[0] = 1 + # bio_seq_lens[-1] = max_len + # mask = seq_len_to_mask(bio_seq_lens) + # allen_res = allen_CRF.viterbi_tags(bio_logits, mask) + # + # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions + # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, + # include_start_end=True)) + # fast_CRF.trans_m = bio_trans_m + # fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) + # bio_scores = [round(score, 4) for _, score in allen_res] + # # score equal + # self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) + # # seq equal + # bio_path = [_ for _, score in allen_res] + # self.assertListEqual(bio_path, fast_res[0]) + # + # labels = [] + # for label in ['X', 'Y']: + # for tag in 'BMES': + # labels.append('{}-{}'.format(tag, label)) + # id2label = {idx: label for idx, label in enumerate(labels)} + # num_tags = len(id2label) + # + # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions + # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), + # include_start_end_transitions=False) + # bmes_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() + # bmes_trans_m = allen_CRF.transitions + # bmes_seq_lens = torch.randint(1, max_len, size=(batch_size,)) + # bmes_seq_lens[0] = 1 + # bmes_seq_lens[-1] = max_len + # mask = seq_len_to_mask(bmes_seq_lens) + # allen_res = allen_CRF.viterbi_tags(bmes_logits, mask) + # + # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions + # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, + # encoding_type='BMES', + # include_start_end=True)) + # fast_CRF.trans_m = bmes_trans_m + # fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) + # # score equal + # bmes_scores = [round(score, 4) for _, score in allen_res] + # self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) + # # seq equal + # bmes_path = [_ for _, score in allen_res] + # self.assertListEqual(bmes_path, fast_res[0]) + # + # data = { + # 'bio_logits': bio_logits.tolist(), + # 'bio_scores': bio_scores, + # 'bio_path': bio_path, + # 'bio_trans_m': bio_trans_m.tolist(), + # 'bio_seq_lens': bio_seq_lens.tolist(), + # 'bmes_logits': bmes_logits.tolist(), + # 'bmes_scores': bmes_scores, + # 'bmes_path': bmes_path, + # 'bmes_trans_m': bmes_trans_m.tolist(), + # 'bmes_seq_lens': bmes_seq_lens.tolist(), + # } + # + # with open('weights.json', 'w') as f: + # import json + # json.dump(data, f) def test_case2(self): - # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 - pass - # import torch - # from fastNLP.modules.decoder.crf import seq_len_to_byte_mask - # - # labels = ['O'] - # for label in ['X', 'Y']: - # for tag in 'BI': - # labels.append('{}-{}'.format(tag, label)) - # id2label = {idx: label for idx, label in enumerate(labels)} - # num_tags = len(id2label) - # - # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions - # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), - # include_start_end_transitions=False) - # batch_size = 3 - # logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() - # trans_m = allen_CRF.transitions - # seq_lens = torch.randint(1, 20, size=(batch_size,)) - # seq_lens[-1] = 20 - # mask = seq_len_to_byte_mask(seq_lens) - # allen_res = allen_CRF.viterbi_tags(logits, mask) - # - # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions - # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) - # fast_CRF.trans_m = trans_m - # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) - # # score equal - # self.assertListEqual([score for _, score in allen_res], fast_res[1]) - # # seq equal - # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) - # - # - # labels = [] - # for label in ['X', 'Y']: - # for tag in 'BMES': - # labels.append('{}-{}'.format(tag, label)) - # id2label = {idx: label for idx, label in enumerate(labels)} - # num_tags = len(id2label) - # - # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions - # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), - # include_start_end_transitions=False) - # batch_size = 3 - # logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() - # trans_m = allen_CRF.transitions - # seq_lens = torch.randint(1, 20, size=(batch_size,)) - # seq_lens[-1] = 20 - # mask = seq_len_to_byte_mask(seq_lens) - # allen_res = allen_CRF.viterbi_tags(logits, mask) - # - # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions - # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, - # encoding_type='BMES')) - # fast_CRF.trans_m = trans_m - # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) - # # score equal - # self.assertListEqual([score for _, score in allen_res], fast_res[1]) - # # seq equal - # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) + # 测试CRF是否正常work。 + import json + import torch + from fastNLP import seq_len_to_mask + + with open('test/data_for_tests/modules/decoder/crf.json', 'r') as f: + data = json.load(f) + + bio_logits = torch.FloatTensor(data['bio_logits']) + bio_scores = data['bio_scores'] + bio_path = data['bio_path'] + bio_trans_m = torch.FloatTensor(data['bio_trans_m']) + bio_seq_lens = torch.LongTensor(data['bio_seq_lens']) + + bmes_logits = torch.FloatTensor(data['bmes_logits']) + bmes_scores = data['bmes_scores'] + bmes_path = data['bmes_path'] + bmes_trans_m = torch.FloatTensor(data['bmes_trans_m']) + bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens']) + + labels = ['O'] + for label in ['X', 'Y']: + for tag in 'BI': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + num_tags = len(id2label) + + mask = seq_len_to_mask(bio_seq_lens) + + from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions + fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, + include_start_end=True)) + fast_CRF.trans_m.data = bio_trans_m + fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) + # score equal + self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) + # seq equal + self.assertListEqual(bio_path, fast_res[0]) + + labels = [] + for label in ['X', 'Y']: + for tag in 'BMES': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + num_tags = len(id2label) + + mask = seq_len_to_mask(bmes_seq_lens) + + from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions + fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, + encoding_type='BMES', + include_start_end=True)) + fast_CRF.trans_m.data = bmes_trans_m + fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) + # score equal + self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) + # seq equal + self.assertListEqual(bmes_path, fast_res[0]) def test_case3(self): # 测试crf的loss不会出现负数 From 850561728b40fede38713952c5f06e605bc19146 Mon Sep 17 00:00:00 2001 From: willqvq Date: Fri, 6 Nov 2020 17:46:53 +0800 Subject: [PATCH 2/2] update the .Jenkinsfile --- .Jenkinsfile | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.Jenkinsfile b/.Jenkinsfile index b7dd29a0..7c0a64fd 100644 --- a/.Jenkinsfile +++ b/.Jenkinsfile @@ -36,10 +36,13 @@ pipeline { } } post { - always { - sh 'post' + failure { + sh 'post 1' + } + success { + sh 'post 0' + sh 'post github' } - } } \ No newline at end of file