Browse Source

1.为element测试添加torch标签 2.解决torch版本导致的int张量无法对整数使用除法的问题

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
4cc8a29926
2 changed files with 3 additions and 2 deletions
  1. +2
    -2
      fastNLP/modules/torch/generator/seq2seq_generator.py
  2. +1
    -0
      tests/core/metrics/test_element_cal_element.py

+ 2
- 2
fastNLP/modules/torch/generator/seq2seq_generator.py View File

@@ -368,13 +368,13 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True)
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1))
next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams)
from_which_beam = torch.floor(ids / (num_beams + 1)).long() # (batch_size, 2*num_beams)
from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long() # (batch_size, 2*num_beams)
else: else:
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size)
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size)
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams)
from_which_beam = torch.floor(ids / vocab_size).long() # (batch_size, 2*num_beams)
from_which_beam = torch.floor(ids.float() / vocab_size).long() # (batch_size, 2*num_beams)
next_tokens = ids % vocab_size # (batch_size, 2*num_beams) next_tokens = ids % vocab_size # (batch_size, 2*num_beams)


# 接下来需要组装下一个batch的结果。 # 接下来需要组装下一个batch的结果。


+ 1
- 0
tests/core/metrics/test_element_cal_element.py View File

@@ -27,6 +27,7 @@ class MyMetric(Metric):


class TestElemnt: class TestElemnt:


@pytest.mark.torch
def test_case_v1(self): def test_case_v1(self):
pred = torch.tensor([1, 1, 1, 1]) pred = torch.tensor([1, 1, 1, 1])
metric = MyMetric() metric = MyMetric()


Loading…
Cancel
Save