|
- # Copyright (c) OpenMMLab. All rights reserved.
- import pytest
- import torch
-
- from mmdet.models import Accuracy, build_loss
-
-
- def test_ce_loss():
- # use_mask and use_sigmoid cannot be true at the same time
- with pytest.raises(AssertionError):
- loss_cfg = dict(
- type='CrossEntropyLoss',
- use_mask=True,
- use_sigmoid=True,
- loss_weight=1.0)
- build_loss(loss_cfg)
-
- # test loss with class weights
- loss_cls_cfg = dict(
- type='CrossEntropyLoss',
- use_sigmoid=False,
- class_weight=[0.8, 0.2],
- loss_weight=1.0)
- loss_cls = build_loss(loss_cls_cfg)
- fake_pred = torch.Tensor([[100, -100]])
- fake_label = torch.Tensor([1]).long()
- assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
-
- loss_cls_cfg = dict(
- type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
- loss_cls = build_loss(loss_cls_cfg)
- assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
-
-
- def test_varifocal_loss():
- # only sigmoid version of VarifocalLoss is implemented
- with pytest.raises(AssertionError):
- loss_cfg = dict(
- type='VarifocalLoss', use_sigmoid=False, loss_weight=1.0)
- build_loss(loss_cfg)
-
- # test that alpha should be greater than 0
- with pytest.raises(AssertionError):
- loss_cfg = dict(
- type='VarifocalLoss',
- alpha=-0.75,
- gamma=2.0,
- use_sigmoid=True,
- loss_weight=1.0)
- build_loss(loss_cfg)
-
- # test that pred and target should be of the same size
- loss_cls_cfg = dict(
- type='VarifocalLoss',
- use_sigmoid=True,
- alpha=0.75,
- gamma=2.0,
- iou_weighted=True,
- reduction='mean',
- loss_weight=1.0)
- loss_cls = build_loss(loss_cls_cfg)
- with pytest.raises(AssertionError):
- fake_pred = torch.Tensor([[100.0, -100.0]])
- fake_target = torch.Tensor([[1.0]])
- loss_cls(fake_pred, fake_target)
-
- # test the calculation
- loss_cls = build_loss(loss_cls_cfg)
- fake_pred = torch.Tensor([[100.0, -100.0]])
- fake_target = torch.Tensor([[1.0, 0.0]])
- assert torch.allclose(loss_cls(fake_pred, fake_target), torch.tensor(0.0))
-
- # test the loss with weights
- loss_cls = build_loss(loss_cls_cfg)
- fake_pred = torch.Tensor([[0.0, 100.0]])
- fake_target = torch.Tensor([[1.0, 1.0]])
- fake_weight = torch.Tensor([0.0, 1.0])
- assert torch.allclose(
- loss_cls(fake_pred, fake_target, fake_weight), torch.tensor(0.0))
-
-
- def test_kd_loss():
- # test that temperature should be greater than 1
- with pytest.raises(AssertionError):
- loss_cfg = dict(
- type='KnowledgeDistillationKLDivLoss', loss_weight=1.0, T=0.5)
- build_loss(loss_cfg)
-
- # test that pred and target should be of the same size
- loss_cls_cfg = dict(
- type='KnowledgeDistillationKLDivLoss', loss_weight=1.0, T=1)
- loss_cls = build_loss(loss_cls_cfg)
- with pytest.raises(AssertionError):
- fake_pred = torch.Tensor([[100, -100]])
- fake_label = torch.Tensor([1]).long()
- loss_cls(fake_pred, fake_label)
-
- # test the calculation
- loss_cls = build_loss(loss_cls_cfg)
- fake_pred = torch.Tensor([[100.0, 100.0]])
- fake_target = torch.Tensor([[1.0, 1.0]])
- assert torch.allclose(loss_cls(fake_pred, fake_target), torch.tensor(0.0))
-
- # test the loss with weights
- loss_cls = build_loss(loss_cls_cfg)
- fake_pred = torch.Tensor([[100.0, -100.0], [100.0, 100.0]])
- fake_target = torch.Tensor([[1.0, 0.0], [1.0, 1.0]])
- fake_weight = torch.Tensor([0.0, 1.0])
- assert torch.allclose(
- loss_cls(fake_pred, fake_target, fake_weight), torch.tensor(0.0))
-
-
- def test_seesaw_loss():
- # only softmax version of Seesaw Loss is implemented
- with pytest.raises(AssertionError):
- loss_cfg = dict(type='SeesawLoss', use_sigmoid=True, loss_weight=1.0)
- build_loss(loss_cfg)
-
- # test that cls_score.size(-1) == num_classes + 2
- loss_cls_cfg = dict(
- type='SeesawLoss', p=0.0, q=0.0, loss_weight=1.0, num_classes=2)
- loss_cls = build_loss(loss_cls_cfg)
- # the length of fake_pred should be num_classes + 2 = 4
- with pytest.raises(AssertionError):
- fake_pred = torch.Tensor([[-100, 100]])
- fake_label = torch.Tensor([1]).long()
- loss_cls(fake_pred, fake_label)
- # the length of fake_pred should be num_classes + 2 = 4
- with pytest.raises(AssertionError):
- fake_pred = torch.Tensor([[-100, 100, -100]])
- fake_label = torch.Tensor([1]).long()
- loss_cls(fake_pred, fake_label)
-
- # test the calculation without p and q
- loss_cls_cfg = dict(
- type='SeesawLoss', p=0.0, q=0.0, loss_weight=1.0, num_classes=2)
- loss_cls = build_loss(loss_cls_cfg)
- fake_pred = torch.Tensor([[-100, 100, -100, 100]])
- fake_label = torch.Tensor([1]).long()
- loss = loss_cls(fake_pred, fake_label)
- assert torch.allclose(loss['loss_cls_objectness'], torch.tensor(200.))
- assert torch.allclose(loss['loss_cls_classes'], torch.tensor(0.))
-
- # test the calculation with p and without q
- loss_cls_cfg = dict(
- type='SeesawLoss', p=1.0, q=0.0, loss_weight=1.0, num_classes=2)
- loss_cls = build_loss(loss_cls_cfg)
- fake_pred = torch.Tensor([[-100, 100, -100, 100]])
- fake_label = torch.Tensor([0]).long()
- loss_cls.cum_samples[0] = torch.exp(torch.Tensor([20]))
- loss = loss_cls(fake_pred, fake_label)
- assert torch.allclose(loss['loss_cls_objectness'], torch.tensor(200.))
- assert torch.allclose(loss['loss_cls_classes'], torch.tensor(180.))
-
- # test the calculation with q and without p
- loss_cls_cfg = dict(
- type='SeesawLoss', p=0.0, q=1.0, loss_weight=1.0, num_classes=2)
- loss_cls = build_loss(loss_cls_cfg)
- fake_pred = torch.Tensor([[-100, 100, -100, 100]])
- fake_label = torch.Tensor([0]).long()
- loss = loss_cls(fake_pred, fake_label)
- assert torch.allclose(loss['loss_cls_objectness'], torch.tensor(200.))
- assert torch.allclose(loss['loss_cls_classes'],
- torch.tensor(200.) + torch.tensor(100.).log())
-
- # test the others
- loss_cls_cfg = dict(
- type='SeesawLoss',
- p=0.0,
- q=1.0,
- loss_weight=1.0,
- num_classes=2,
- return_dict=False)
- loss_cls = build_loss(loss_cls_cfg)
- fake_pred = torch.Tensor([[100, -100, 100, -100]])
- fake_label = torch.Tensor([0]).long()
- loss = loss_cls(fake_pred, fake_label)
- acc = loss_cls.get_accuracy(fake_pred, fake_label)
- act = loss_cls.get_activation(fake_pred)
- assert torch.allclose(loss, torch.tensor(0.))
- assert torch.allclose(acc['acc_objectness'], torch.tensor(100.))
- assert torch.allclose(acc['acc_classes'], torch.tensor(100.))
- assert torch.allclose(act, torch.tensor([1., 0., 0.]))
-
-
- def test_accuracy():
- # test for empty pred
- pred = torch.empty(0, 4)
- label = torch.empty(0)
- accuracy = Accuracy(topk=1)
- acc = accuracy(pred, label)
- assert acc.item() == 0
-
- pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
- [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
- [0.0, 0.0, 0.99, 0]])
- # test for top1
- true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
- accuracy = Accuracy(topk=1)
- acc = accuracy(pred, true_label)
- assert acc.item() == 100
-
- # test for top1 with score thresh=0.8
- true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
- accuracy = Accuracy(topk=1, thresh=0.8)
- acc = accuracy(pred, true_label)
- assert acc.item() == 40
-
- # test for top2
- accuracy = Accuracy(topk=2)
- label = torch.Tensor([3, 2, 0, 0, 2]).long()
- acc = accuracy(pred, label)
- assert acc.item() == 100
-
- # test for both top1 and top2
- accuracy = Accuracy(topk=(1, 2))
- true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
- acc = accuracy(pred, true_label)
- for a in acc:
- assert a.item() == 100
-
- # topk is larger than pred class number
- with pytest.raises(AssertionError):
- accuracy = Accuracy(topk=5)
- accuracy(pred, true_label)
-
- # wrong topk type
- with pytest.raises(AssertionError):
- accuracy = Accuracy(topk='wrong type')
- accuracy(pred, true_label)
-
- # label size is larger than required
- with pytest.raises(AssertionError):
- label = torch.Tensor([2, 3, 0, 1, 2, 0]).long() # size mismatch
- accuracy = Accuracy()
- accuracy(pred, label)
-
- # wrong pred dimension
- with pytest.raises(AssertionError):
- accuracy = Accuracy()
- accuracy(pred[:, :, None], true_label)
|