|
|
@@ -311,9 +311,17 @@ class TestLosserError(unittest.TestCase): |
|
|
|
print(los(pred_dict=pred_dict, target_dict=target_dict)) |
|
|
|
|
|
|
|
# |
|
|
|
def test_AccuracyMetric2(self): |
|
|
|
def test_losser2(self): |
|
|
|
# (2) with corrupted size |
|
|
|
pred_dict = {"pred": torch.zeros(16, 3, 4)} |
|
|
|
pred_dict = {"pred": torch.zeros(16, 3)} |
|
|
|
target_dict = {'target': torch.zeros(16, 3).long()} |
|
|
|
los = loss.CrossEntropyLoss() |
|
|
|
|
|
|
|
print(los(pred_dict=pred_dict, target_dict=target_dict)) |
|
|
|
|
|
|
|
def test_losser3(self): |
|
|
|
# (2) with corrupted size |
|
|
|
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param':0} |
|
|
|
target_dict = {'target': torch.zeros(16, 3).long()} |
|
|
|
los = loss.CrossEntropyLoss() |
|
|
|
|
|
|
|