From 4cc8a29926b6b5cb34fcfe9a3054da5a75ef0500 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 4 Jun 2022 08:27:42 +0000 Subject: [PATCH] =?UTF-8?q?1.=E4=B8=BAelement=E6=B5=8B=E8=AF=95=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0torch=E6=A0=87=E7=AD=BE=202.=E8=A7=A3=E5=86=B3torch?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E5=AF=BC=E8=87=B4=E7=9A=84int=E5=BC=A0?= =?UTF-8?q?=E9=87=8F=E6=97=A0=E6=B3=95=E5=AF=B9=E6=95=B4=E6=95=B0=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E9=99=A4=E6=B3=95=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/torch/generator/seq2seq_generator.py | 4 ++-- tests/core/metrics/test_element_cal_element.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fastNLP/modules/torch/generator/seq2seq_generator.py b/fastNLP/modules/torch/generator/seq2seq_generator.py index b54eea28..1d9a0b65 100755 --- a/fastNLP/modules/torch/generator/seq2seq_generator.py +++ b/fastNLP/modules/torch/generator/seq2seq_generator.py @@ -368,13 +368,13 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) - from_which_beam = torch.floor(ids / (num_beams + 1)).long() # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long() # (batch_size, 2*num_beams) else: scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) - from_which_beam = torch.floor(ids / vocab_size).long() # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids.float() / vocab_size).long() # (batch_size, 2*num_beams) next_tokens = ids % vocab_size # (batch_size, 2*num_beams) # 接下来需要组装下一个batch的结果。 diff --git a/tests/core/metrics/test_element_cal_element.py b/tests/core/metrics/test_element_cal_element.py index 340e2a43..9ecd4ddf 100644 --- a/tests/core/metrics/test_element_cal_element.py +++ b/tests/core/metrics/test_element_cal_element.py @@ -27,6 +27,7 @@ class MyMetric(Metric): class TestElemnt: + @pytest.mark.torch def test_case_v1(self): pred = torch.tensor([1, 1, 1, 1]) metric = MyMetric()