import argparse import os.path as osp import torch import torch.nn as nn from ablkit.learning import ABLModel, BasicNN from ablkit.utils import ABLLogger, print_log from bridge import HedBridge from consistency_metric import ConsistencyMetric from datasets import get_dataset, split_equation from models.nn import SymbolNet from reasoning import HedKB, HedReasoner def main(): parser = argparse.ArgumentParser(description="Handwritten Equation Decipherment example") parser.add_argument( "--no-cuda", action="store_true", default=False, help="disables CUDA training" ) parser.add_argument( "--epochs", type=int, default=1, help="number of epochs in each learning loop iteration (default : 1)", ) parser.add_argument( "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" ) parser.add_argument( "--weight-decay", type=float, default=1e-4, help="weight decay (default : 0.0001)" ) parser.add_argument( "--batch-size", type=int, default=32, help="base model batch size (default : 32)" ) parser.add_argument( "--segment_size", type=int, default=1000, help="segment size (default : 1000)" ) parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") parser.add_argument( "--max-revision", type=int, default=10, help="maximum revision in reasoner (default : 10)", ) args = parser.parse_args() # Build logger print_log("Abductive Learning on the HED example.", logger="current") ### Working with Data print_log("Working with Data.", logger="current") total_train_data = get_dataset(train=True) train_data, val_data = split_equation(total_train_data, 3, 1) test_data = get_dataset(train=False) ### Building the Learning Part print_log("Building the Learning Part.", logger="current") # Build necessary components for BasicNN cls = SymbolNet(num_classes=4) loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, weight_decay=args.weight_decay) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # Build BasicNN base_model = BasicNN( cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs, stop_loss=None, ) # Build ABLModel model = ABLModel(base_model) ### Building the Reasoning Part print_log("Building the Reasoning Part.", logger="current") # Build knowledge base kb = HedKB() # Create reasoner reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision) ### Building Evaluation Metrics print_log("Building Evaluation Metrics.", logger="current") metric_list = [ConsistencyMetric(kb=kb)] ### Bridge Learning and Reasoning print_log("Bridge Learning and Reasoning.", logger="current") bridge = HedBridge(model, reasoner, metric_list) # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") bridge.pretrain(weights_dir) bridge.train(train_data, val_data, save_dir=weights_dir) bridge.test(test_data) if __name__ == "__main__": main()