Browse Source

[FIX] update test cases

pull/1/head
Tony-HYX 1 year ago
parent
commit
51017bb42d
3 changed files with 14 additions and 13 deletions
  1. +2
    -2
      tests/test_abl_model.py
  2. +4
    -5
      tests/test_basic_nn.py
  3. +8
    -6
      tests/test_reasoning.py

+ 2
- 2
tests/test_abl_model.py View File

@@ -52,8 +52,8 @@ class TestABLModel(object):
"""Test the train method of the ABLModel class."""
model = ABLModel(base_model_instance)
list_data_instance.abduced_idx = [[1, 2], [3, 4], [5, 6]]
loss = model.train(list_data_instance)
assert isinstance(loss, float), "Training should return a float value indicating the loss."
ins = model.train(list_data_instance)
assert ins == model.base_model, "Training should return the base model instance."

def test_ablmodel_save_load(self, base_model_instance, tmp_path):
"""Test the save method of the ABLModel class."""


+ 4
- 5
tests/test_basic_nn.py View File

@@ -5,7 +5,6 @@ import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset


class TestBasicNN(object):
@pytest.fixture
def sample_data(self):
@@ -40,12 +39,12 @@ class TestBasicNN(object):

# Test fit with direct data
X, y = sample_data
loss = basic_nn_instance.fit(X=list(X), y=list(y))
assert isinstance(loss, float)
ins = basic_nn_instance.fit(X=list(X), y=list(y))
assert ins == basic_nn_instance

# Test fit with DataLoader
loss = basic_nn_instance.fit(data_loader=sample_data_loader_with_label)
assert isinstance(loss, float)
ins = basic_nn_instance.fit(data_loader=sample_data_loader_with_label)
assert ins == basic_nn_instance

# Test invalid fit method input
with pytest.raises(ValueError):


+ 8
- 6
tests/test_reasoning.py View File

@@ -14,19 +14,21 @@ class TestKBBase(object):
def test_logic_forward(self, kb_add):
result = kb_add.logic_forward([1, 2])
assert result == 3
with pytest.raises(TypeError):
kb_add.logic_forward([1, 2], [0.1, -0.2, 0.2, -0.3])

def test_revise_at_idx(self, kb_add):
result = kb_add.revise_at_idx([0, 2], 2, [])
result = kb_add.revise_at_idx([0, 2], 2, [0.1, -0.2, 0.2, -0.3], [])
assert result == [[0, 2]]
result = kb_add.revise_at_idx([1, 2], 2, [])
result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [])
assert result == []
result = kb_add.revise_at_idx([1, 2], 2, [0, 1])
result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0, 1])
assert result == [[0, 2], [1, 1], [2, 0]]

def test_abduce_candidates(self, kb_add):
result = kb_add.abduce_candidates([0, 1], 1, max_revision_num=2, require_more_revision=0)
result = kb_add.abduce_candidates([0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0)
assert result == [[0, 1]]
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, require_more_revision=0)
result = kb_add.abduce_candidates([1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0)
assert result == [[1, 0]]


@@ -42,7 +44,7 @@ class TestGroundKB(object):

def test_abduce_candidates_ground(self, kb_add_ground):
result = kb_add_ground.abduce_candidates(
[1, 2], 1, max_revision_num=2, require_more_revision=0
[1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0
)
assert result == [(1, 0)]



Loading…
Cancel
Save