@@ -14,7 +14,7 @@ | |||
</div> | |||
# Abductive Learning (ABL) Kit | |||
# ABL Kit: A Python Toolkit for Abductive Learning | |||
**ABL Kit** is an efficient Python toolkit for **Abductive Learning (ABL)**. | |||
ABL is a novel paradigm that integrates machine learning and | |||
@@ -1,3 +1,9 @@ | |||
""" | |||
This module contains the base class for the Bridge part. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
from abc import ABCMeta, abstractmethod | |||
from typing import Any, List, Optional, Tuple, Union | |||
@@ -31,11 +37,11 @@ class BaseBridge(metaclass=ABCMeta): | |||
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: | |||
if not isinstance(model, ABLModel): | |||
raise TypeError( | |||
"Expected an instance of ABLModel, but received type: {}".format(type(model)) | |||
f"Expected an instance of ABLModel, but received type: {type(model)}" | |||
) | |||
if not isinstance(reasoner, Reasoner): | |||
raise TypeError( | |||
"Expected an instance of Reasoner, but received type: {}".format(type(reasoner)) | |||
f"Expected an instance of Reasoner, but received type: {type(reasoner)}" | |||
) | |||
self.model = model | |||
@@ -1,3 +1,9 @@ | |||
""" | |||
This module contains a simple implementation of the Bridge part. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
import os.path as osp | |||
from typing import Any, List, Optional, Tuple, Union | |||
@@ -221,7 +227,7 @@ class SimpleBridge(BaseBridge): | |||
Labeled data should be in the same format as ``train_data``. The only difference is | |||
that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be | |||
utilized to train the model. Defaults to None. | |||
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 | |||
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 pylint: disable=line-too-long | |||
Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label`` | |||
and ``Y`` can be either None or not, which depends on the evaluation metircs in | |||
``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate | |||
@@ -327,10 +333,11 @@ class SimpleBridge(BaseBridge): | |||
Parameters | |||
---------- | |||
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 | |||
Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object | |||
with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be | |||
either None or not, which depends on the evaluation metircs in ``self.metric_list``. | |||
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long | |||
Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` | |||
object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` | |||
and ``Y`` can be either None or not, which depends on the evaluation metircs in | |||
``self.metric_list``. | |||
""" | |||
val_data_examples = self.data_preprocess("val", val_data) | |||
self._valid(val_data_examples) | |||
@@ -346,10 +353,11 @@ class SimpleBridge(BaseBridge): | |||
Parameters | |||
---------- | |||
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 | |||
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long | |||
Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object | |||
with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` | |||
can be either None or not, which depends on the evaluation metircs in ``self.metric_list``. | |||
can be either None or not, which depends on the evaluation metircs in | |||
``self.metric_list``. | |||
""" | |||
print_log("Test start:", logger="current") | |||
test_data_examples = self.data_preprocess("test", test_data) | |||
@@ -1,3 +1,9 @@ | |||
""" | |||
This module contains the base class used for evaluation. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
import logging | |||
from abc import ABCMeta, abstractmethod | |||
from typing import Any, List, Optional | |||
@@ -1,3 +1,10 @@ | |||
""" | |||
This module contains the ReasoningMetric, which is used for evaluating the model performance | |||
on tasks that need reasoning. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
from typing import Optional | |||
from ...reasoning import KBBase | |||
@@ -7,7 +14,7 @@ from .base_metric import BaseMetric | |||
class ReasoningMetric(BaseMetric): | |||
""" | |||
A metrics class for evaluating the model performance on tasks need reasoning. | |||
A metrics class for evaluating the model performance on tasks that need reasoning. | |||
This class is designed to calculate the accuracy of the reasoing results. Reasoning | |||
results are generated by first using the learning part to predict pseudo-labels | |||
@@ -34,6 +41,7 @@ class ReasoningMetric(BaseMetric): | |||
super().__init__(prefix) | |||
self.kb = kb | |||
# pylint: disable=protected-access | |||
def process(self, data_examples: ListData) -> None: | |||
""" | |||
Process a batch of data examples. | |||
@@ -1,4 +1,8 @@ | |||
from typing import Optional | |||
""" | |||
This module contains the class SymbolAccuracy, which is used for evaluating symbol-level accuracy. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
import numpy as np | |||
@@ -20,9 +24,6 @@ class SymbolAccuracy(BaseMetric): | |||
metrics of different tasks. Inherits from BaseMetric. Default to None. | |||
""" | |||
def __init__(self, prefix: Optional[str] = None) -> None: | |||
super().__init__(prefix) | |||
def process(self, data_examples: ListData) -> None: | |||
""" | |||
Processes a batch of data examples. | |||
@@ -1,6 +1,8 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
# Modified from | |||
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py | |||
""" | |||
Copyright (c) OpenMMLab. All rights reserved. | |||
Modified from | |||
https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py # noqa: E501 pylint: disable=line-too-long | |||
""" | |||
import copy | |||
from typing import Any, Iterator, Optional, Tuple, Type, Union | |||
@@ -1,6 +1,8 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
# Modified from | |||
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |||
""" | |||
Copyright (c) OpenMMLab. All rights reserved. | |||
Modified from | |||
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa: E501 pylint: disable=line-too-long | |||
""" | |||
from typing import List, Union | |||
@@ -54,7 +56,7 @@ class ListData(BaseDataElement): | |||
``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``. | |||
This design is inspired by and extends the functionalities of the ``BaseDataElement`` | |||
class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501 | |||
class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501 pylint: disable=line-too-long | |||
Examples: | |||
>>> from ablkit.data.structures import ListData | |||
@@ -72,7 +74,7 @@ class ListData(BaseDataElement): | |||
DATA FIELDS | |||
Y: [1, 2, 3] | |||
gt_pseudo_label: [[1, 2], [3, 4], [5, 6]] | |||
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 | |||
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 pylint: disable=line-too-long | |||
) at 0x7f3bbf1991c0> | |||
>>> print(data_examples[:1]) | |||
<ListData( | |||
@@ -1,3 +1,10 @@ | |||
""" | |||
This module contains the class ABLModel, which provides a unified interface for different | |||
machine learning models. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
import pickle | |||
from typing import Any, Dict | |||
@@ -99,21 +106,20 @@ class ABLModel: | |||
method = getattr(model, operation) | |||
method(*args, **kwargs) | |||
else: | |||
if f"{operation}_path" not in kwargs.keys(): | |||
if f"{operation}_path" not in kwargs: | |||
raise ValueError(f"'{operation}_path' should not be None") | |||
else: | |||
try: | |||
if operation == "save": | |||
with open(kwargs["save_path"], "wb") as file: | |||
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
elif operation == "load": | |||
with open(kwargs["load_path"], "rb") as file: | |||
self.base_model = pickle.load(file) | |||
except (OSError, pickle.PickleError): | |||
raise NotImplementedError( | |||
f"{type(model).__name__} object doesn't have the {operation} method \ | |||
and the default pickle-based {operation} method failed." | |||
) | |||
try: | |||
if operation == "save": | |||
with open(kwargs["save_path"], "wb") as file: | |||
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
elif operation == "load": | |||
with open(kwargs["load_path"], "rb") as file: | |||
self.base_model = pickle.load(file) | |||
except (OSError, pickle.PickleError) as exc: | |||
raise NotImplementedError( | |||
f"{type(model).__name__} object doesn't have the {operation} method \ | |||
and the default pickle-based {operation} method failed." | |||
) from exc | |||
def save(self, *args, **kwargs) -> None: | |||
""" | |||
@@ -1,3 +1,9 @@ | |||
""" | |||
This module contains the class BasicNN, which servers as a wrapper for PyTorch NN models. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
from __future__ import annotations | |||
import logging | |||
@@ -474,7 +480,7 @@ class BasicNN: | |||
raise ValueError("X should not be None.") | |||
if y is None: | |||
y = [0] * len(X) | |||
if not (len(y) == len(X)): | |||
if not len(y) == len(X): | |||
raise ValueError("X and y should have equal length.") | |||
dataset = ClassificationDataset(X, y, transform=self.train_transform) | |||
@@ -1,3 +1,9 @@ | |||
""" | |||
Implementation of PyTorch dataset class used for classification. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
from typing import Any, Callable, List, Tuple, Optional | |||
import torch | |||
@@ -1,3 +1,9 @@ | |||
""" | |||
Implementation of PyTorch dataset class used for Prediction. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
from typing import Any, Callable, List, Tuple, Optional | |||
import torch | |||
@@ -1,3 +1,9 @@ | |||
""" | |||
Implementation of PyTorch dataset class used for regression. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
from typing import Any, List, Tuple | |||
from torch.utils.data import Dataset | |||
@@ -1,3 +1,10 @@ | |||
""" | |||
This module contains the classes KBBase, GroundKB, and PrologKB, which provide wrappers | |||
for different kinds of knowledge bases. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
import bisect | |||
import inspect | |||
import logging | |||
@@ -394,7 +401,7 @@ class GroundKB(KBBase): | |||
base. The second element is a list of reasoning results corresponding to each | |||
candidate, i.e., the outcome of the ``logic_forward`` function. | |||
""" | |||
if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list: | |||
if not self.GKB or len(pseudo_label) not in self.GKB_len_list: | |||
return [], [] | |||
all_candidates, all_reasoning_results = self._find_candidate_GKB(pseudo_label, y) | |||
@@ -478,7 +485,7 @@ class PrologKB(KBBase): | |||
super().__init__(pseudo_label_list) | |||
try: | |||
import pyswip | |||
import pyswip # pylint: disable=import-outside-toplevel | |||
except (IndexError, ImportError): | |||
print( | |||
"A Prolog-based knowledge base is in use. Please install SWI-Prolog using the" | |||
@@ -493,7 +500,7 @@ class PrologKB(KBBase): | |||
raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.") | |||
self.prolog.consult(self.pl_file) | |||
def logic_forward(self, pseudo_label: List[Any]) -> Any: | |||
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any: | |||
""" | |||
Consult prolog with the query ``logic_forward(pseudo_labels, Res).``, and set the | |||
returned ``Res`` as the reasoning results. To use this default function, there must be | |||
@@ -504,11 +511,15 @@ class PrologKB(KBBase): | |||
---------- | |||
pseudo_label : List[Any] | |||
Pseudo-labels of an example. | |||
x : List[Any] | |||
The corresponding input example. If the information from the input | |||
is not required in the reasoning process, then this parameter will not have | |||
any effect. | |||
""" | |||
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"] | |||
result = list(self.prolog.query(f"logic_forward({pseudo_label}, Res)."))[0]["Res"] | |||
if result == "true": | |||
return True | |||
elif result == "false": | |||
if result == "false": | |||
return False | |||
return result | |||
@@ -517,7 +528,7 @@ class PrologKB(KBBase): | |||
pseudo_label: List[Any], | |||
revision_idx: List[int], | |||
) -> List[Any]: | |||
import re | |||
import re # pylint: disable=import-outside-toplevel | |||
revision_pseudo_label = pseudo_label.copy() | |||
revision_pseudo_label = flatten(revision_pseudo_label) | |||
@@ -533,7 +544,7 @@ class PrologKB(KBBase): | |||
self, | |||
pseudo_label: List[Any], | |||
y: Any, | |||
x: List[Any], | |||
x: List[Any], # pylint: disable=unused-argument | |||
revision_idx: List[int], | |||
) -> str: | |||
""" | |||
@@ -563,7 +574,7 @@ class PrologKB(KBBase): | |||
query_string = "logic_forward(" | |||
query_string += self._revision_pseudo_label(pseudo_label, revision_idx) | |||
key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None) | |||
query_string += ",%s)." % y if not key_is_none_flag else ")." | |||
query_string += f",{y})." if not key_is_none_flag else ")." | |||
return query_string | |||
def revise_at_idx( | |||
@@ -1,3 +1,10 @@ | |||
""" | |||
This module contains the class Reasoner, which is used for minimizing the inconsistency | |||
between the knowledge base and learning models. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
import inspect | |||
from typing import Any, Callable, List, Optional, Union | |||
@@ -251,25 +258,21 @@ class Reasoner: | |||
def zoopt_budget(self, symbol_num: int) -> int: | |||
""" | |||
Set the budget for ZOOpt optimization. The function, in its default implementation, | |||
returns a fixed budget value of 100. However, it can be adjusted to return other fixed | |||
values, or a dynamic budget based on the number of symbols, if desired. For example, | |||
one might choose to set the budget as 100 times ``symbol_num``. | |||
Set the budget for ZOOpt optimization. The budget can be dynamic relying on | |||
the number of symbols considered, e.g., the default implementation shown below. | |||
Alternatively, it can be a fixed value, such as simply setting it to 100. | |||
Parameters | |||
---------- | |||
symbol_num : int | |||
The number of symbols to be considered in the ZOOpt optimization process. Although this | |||
parameter can be used to compute a dynamic optimization budget, by default it is not | |||
utilized in the calculation. | |||
The number of symbols to be considered in the ZOOpt optimization process. | |||
Returns | |||
------- | |||
int | |||
The budget for ZOOpt optimization. By default, this is a fixed value of 100, | |||
irrespective of the symbol_num value. | |||
The budget for ZOOpt optimization. | |||
""" | |||
return 100 | |||
return 10 * symbol_num | |||
def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int: | |||
""" | |||
@@ -288,19 +291,18 @@ class Reasoner: | |||
if max_revision == -1: | |||
return symbol_num | |||
elif isinstance(max_revision, float): | |||
if not (0 <= max_revision <= 1): | |||
if isinstance(max_revision, float): | |||
if not 0 <= max_revision <= 1: | |||
raise ValueError( | |||
"If max_revision is a float, it must be between 0 and 1, " | |||
+ f"but got {max_revision}" | |||
) | |||
return round(symbol_num * max_revision) | |||
else: | |||
if max_revision < 0: | |||
raise ValueError( | |||
f"If max_revision is an int, it must be non-negative, but got {max_revision}" | |||
) | |||
return max_revision | |||
if max_revision < 0: | |||
raise ValueError( | |||
f"If max_revision is an int, it must be non-negative, but got {max_revision}" | |||
) | |||
return max_revision | |||
def abduce(self, data_example: ListData) -> List[Any]: | |||
""" | |||
@@ -1,14 +1,15 @@ | |||
# Python module wrapper for _functools C module | |||
# to allow utilities written in Python to be added | |||
# to the functools module. | |||
# Written by Nick Coghlan <ncoghlan at gmail.com>, | |||
# Raymond Hettinger <python at rcn.com>, | |||
# and Łukasz Langa <lukasz at langa.pl>. | |||
# Copyright (C) 2006-2013 Python Software Foundation. | |||
# See C source code for _functools credits/copyright | |||
# Modified from | |||
# https://github.com/python/cpython/blob/3.12/Lib/functools.py | |||
""" | |||
Python module wrapper for _functools C module | |||
to allow utilities written in Python to be added | |||
to the functools module. | |||
Written by Nick Coghlan <ncoghlan at gmail.com>, | |||
Raymond Hettinger <python at rcn.com>, | |||
and Łukasz Langa <lukasz at langa.pl>. | |||
Copyright (C) 2006-2013 Python Software Foundation. | |||
See C source code for _functools credits/copyright | |||
Modified from | |||
https://github.com/python/cpython/blob/3.12/Lib/functools.py | |||
""" | |||
from typing import Callable, Generic, TypeVar | |||
@@ -18,30 +19,58 @@ PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields | |||
class Cache(Generic[K, T]): | |||
def __init__(self, func: Callable[[K], T]): | |||
"""Create cache | |||
""" | |||
A generic caching mechanism that stores the results of a function call and | |||
retrieves them to avoid repeated calculations. | |||
:param func: Function this cache evaluates | |||
:param cache: If true, do in memory caching. | |||
:param cache_root: If not None, cache to files at the provided path. | |||
:param key_func: Convert the key into a hashable object if needed | |||
""" | |||
This class implements a dictionary-based cache with a circular doubly linked | |||
list to manage the cache entries efficiently. It is designed to be generic, | |||
allowing for caching of any callable function. | |||
Parameters | |||
---------- | |||
func : Callable[[K], T] | |||
The function to be cached. This function takes an argument of type K and | |||
returns a value of type T. | |||
""" | |||
def __init__(self, func: Callable[[K], T]): | |||
self.func = func | |||
self.has_init = False | |||
self.cache = False | |||
self.cache_dict = {} | |||
self.key_func = None | |||
self.max_size = 0 | |||
self.hits, self.misses = 0, 0 | |||
self.full = False | |||
self.root = [] # root of the circular doubly linked list | |||
self.root[:] = [self.root, self.root, None, None] | |||
def __getitem__(self, obj, *args) -> T: | |||
return self.get_from_dict(obj, *args) | |||
def clear_cache(self): | |||
"""Invalidate entire cache.""" | |||
""" | |||
Invalidate the entire cache. | |||
""" | |||
self.cache_dict.clear() | |||
def _init_cache(self, obj): | |||
def init_cache(self, obj): | |||
""" | |||
Initialize the cache settings. | |||
Parameters | |||
---------- | |||
obj : Any | |||
The object containing settings for cache initialization. | |||
""" | |||
if self.has_init: | |||
return | |||
self.cache = True | |||
self.cache_dict = dict() | |||
self.cache_dict = {} | |||
self.key_func = obj.key_func | |||
self.max_size = obj.cache_size | |||
@@ -53,9 +82,23 @@ class Cache(Generic[K, T]): | |||
self.has_init = True | |||
def get_from_dict(self, obj, *args) -> T: | |||
"""Implements dict based cache.""" | |||
""" | |||
Retrieve a value from the cache or compute it using ``self.func``. | |||
Parameters | |||
---------- | |||
obj : Any | |||
The object to which the cached method/function belongs. | |||
*args : Any | |||
Arguments used in key generation for cache retrieval or function computation. | |||
Returns | |||
------- | |||
T | |||
The value from the cache or computed by the function. | |||
""" | |||
# x is not used in cache key | |||
pred_pseudo_label, y, x, *res_args = args | |||
pred_pseudo_label, y, _x, *res_args = args | |||
cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args) | |||
link = self.cache_dict.get(cache_key) | |||
if link is not None: | |||
@@ -96,15 +139,23 @@ class Cache(Generic[K, T]): | |||
def abl_cache(): | |||
""" | |||
Decorator to enable caching for a function. | |||
Returns | |||
------- | |||
Callable | |||
The wrapped function with caching capability. | |||
""" | |||
def decorator(func): | |||
cache_instance = Cache(func) | |||
def wrapper(obj, *args): | |||
if obj.use_cache: | |||
cache_instance._init_cache(obj) | |||
cache_instance.init_cache(obj) | |||
return cache_instance.get_from_dict(obj, *args) | |||
else: | |||
return func(obj, *args) | |||
return func(obj, *args) | |||
return wrapper | |||
@@ -1,6 +1,8 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
# Modified from | |||
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/logging/logger.py | |||
""" | |||
Copyright (c) OpenMMLab. All rights reserved. | |||
Modified from | |||
https://github.com/open-mmlab/mmengine/blob/main/mmengine/logging/logger.py | |||
""" | |||
import logging | |||
import os | |||
@@ -132,13 +134,13 @@ class ABLFormatter(logging.Formatter): | |||
Formatted result. | |||
""" | |||
if record.levelno == logging.ERROR: | |||
self._style._fmt = self.err_format | |||
self._style._fmt = self.err_format # pylint: disable=protected-access | |||
elif record.levelno == logging.WARNING: | |||
self._style._fmt = self.warn_format | |||
self._style._fmt = self.warn_format # pylint: disable=protected-access | |||
elif record.levelno == logging.INFO: | |||
self._style._fmt = self.info_format | |||
self._style._fmt = self.info_format # pylint: disable=protected-access | |||
elif record.levelno == logging.DEBUG: | |||
self._style._fmt = self.debug_format | |||
self._style._fmt = self.debug_format # pylint: disable=protected-access | |||
result = logging.Formatter.format(self, record) | |||
return result | |||
@@ -215,7 +217,7 @@ class ABLLogger(Logger, ManagerMixin): | |||
self.handlers.append(stream_handler) | |||
if log_file is None: | |||
import time | |||
import time # pylint: disable=import-outside-toplevel | |||
local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) | |||
@@ -234,10 +236,20 @@ class ABLLogger(Logger, ManagerMixin): | |||
@property | |||
def log_file(self): | |||
"""Get the file path of the log. | |||
Returns: | |||
str: Path of the log. | |||
""" | |||
return self._log_file | |||
@property | |||
def log_dir(self): | |||
"""Get the directory where the log is stored. | |||
Returns: | |||
str: Directory where the log is stored. | |||
""" | |||
return self._log_dir | |||
@classmethod | |||
@@ -284,11 +296,11 @@ class ABLLogger(Logger, ManagerMixin): | |||
level : Union[int, str] | |||
The logging level to set. | |||
""" | |||
self.level = logging._checkLevel(level) | |||
self.level = logging._checkLevel(level) # pylint: disable=protected-access | |||
_accquire_lock() | |||
# The same logic as ``logging.Manager._clear_cache``. | |||
for logger in ABLLogger._instance_dict.values(): | |||
logger._cache.clear() | |||
logger._cache.clear() # pylint: disable=protected-access | |||
_release_lock() | |||
@@ -1,4 +1,7 @@ | |||
# Copyright (c) OpenMMLab. All rights reserved. | |||
""" | |||
Copyright (c) OpenMMLab. All rights reserved. | |||
""" | |||
import inspect | |||
import threading | |||
import warnings | |||
@@ -72,7 +75,7 @@ class ManagerMixin(metaclass=ManagerMeta): | |||
name (str): Name of the instance. Defaults to ''. | |||
""" | |||
def __init__(self, name: str = "", **kwargs): | |||
def __init__(self, name: str = ""): | |||
assert isinstance(name, str) and name, "name argument must be an non-empty string." | |||
self._instance_name = name | |||
@@ -1,3 +1,9 @@ | |||
""" | |||
Implementation of utilities used in ablkit. | |||
Copyright (c) 2024 LAMDA. All rights reserved. | |||
""" | |||
from typing import List, Any, Union, Tuple, Optional | |||
import numpy as np | |||
@@ -198,6 +204,6 @@ def tab_data_to_tuple( | |||
return None | |||
if len(X) != len(y): | |||
raise ValueError( | |||
"The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)) | |||
f"The length of X and y should be the same, but got {len(X)} and {len(y)}." | |||
) | |||
return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) |