import argparse import os.path as osp import numpy as np import torch from torch import nn from abl.bridge import SimpleBridge from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.learning import ABLModel, BasicNN from abl.reasoning import GroundKB, KBBase, Reasoner from abl.utils import ABLLogger, print_log from datasets import get_dataset from models.nn import SymbolNet class HwfKB(KBBase): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], max_err=1e-10, ): super().__init__(pseudo_label_list, max_err) def _valid_candidate(self, formula): if len(formula) % 2 == 0: return False for i in range(len(formula)): if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: return False if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: return False return True # Implement the deduction function def logic_forward(self, formula): if not self._valid_candidate(formula): return np.inf return eval("".join(formula)) class HwfGroundKB(GroundKB): def __init__( self, pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], GKB_len_list=[1, 3, 5, 7], max_err=1e-10, ): super().__init__(pseudo_label_list, GKB_len_list, max_err) def _valid_candidate(self, formula): if len(formula) % 2 == 0: return False for i in range(len(formula)): if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: return False if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: return False return True # Implement the deduction function def logic_forward(self, formula): if not self._valid_candidate(formula): return np.inf return eval("".join(formula)) def main(): parser = argparse.ArgumentParser(description="Handwritten Formula example") parser.add_argument( "--no-cuda", action="store_true", default=False, help="disables CUDA training" ) parser.add_argument( "--epochs", type=int, default=3, help="number of epochs in each learning loop iteration (default : 3)", ) parser.add_argument( "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" ) parser.add_argument( "--batch-size", type=int, default=128, help="base model batch size (default : 128)" ) parser.add_argument( "--loops", type=int, default=5, help="number of loop iterations (default : 5)" ) parser.add_argument( "--segment_size", type=int, default=1000, help="segment size (default : 1000)" ) parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") parser.add_argument( "--max-revision", type=int, default=-1, help="maximum revision in reasoner (default : -1)", ) parser.add_argument( "--require-more-revision", type=int, default=0, help="require more revision in reasoner (default : 0)", ) parser.add_argument( "--ground", action="store_true", default=False, help="use GroundKB (default: False)" ) parser.add_argument( "--max-err", type=float, default=1e-10, help="max tolerance during abductive reasoning (default : 1e-10)", ) args = parser.parse_args() ### Working with Data train_data = get_dataset(train=True, get_pseudo_label=True) test_data = get_dataset(train=False, get_pseudo_label=True) ### Building the Learning Part # Build necessary components for BasicNN cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # Build BasicNN base_model = BasicNN( cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs, ) # Build ABLModel model = ABLModel(base_model) ### Building the Reasoning Part # Build knowledge base if args.ground: kb = HwfGroundKB() else: kb = HwfKB() # Create reasoner reasoner = Reasoner( kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision ) ### Building Evaluation Metrics metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")] ### Bridge Learning and Reasoning bridge = SimpleBridge(model, reasoner, metric_list) # Build logger print_log("Abductive Learning on the HWF example.", logger="current") # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") # Train and Test bridge.train( train_data, loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir, ) bridge.test(test_data) if __name__ == "__main__": main()