Browse Source

Merge branch 'Dev' into main

main
troyyyyy 1 year ago
parent
commit
12d87686ff
19 changed files with 251 additions and 103 deletions
  1. +1
    -1
      README.md
  2. +8
    -2
      ablkit/bridge/base_bridge.py
  3. +15
    -7
      ablkit/bridge/simple_bridge.py
  4. +6
    -0
      ablkit/data/evaluation/base_metric.py
  5. +9
    -1
      ablkit/data/evaluation/reasoning_metric.py
  6. +5
    -4
      ablkit/data/evaluation/symbol_accuracy.py
  7. +5
    -3
      ablkit/data/structures/base_data_element.py
  8. +7
    -5
      ablkit/data/structures/list_data.py
  9. +20
    -14
      ablkit/learning/abl_model.py
  10. +7
    -1
      ablkit/learning/basic_nn.py
  11. +6
    -0
      ablkit/learning/torch_dataset/classification_dataset.py
  12. +6
    -0
      ablkit/learning/torch_dataset/prediction_dataset.py
  13. +6
    -0
      ablkit/learning/torch_dataset/regression_dataset.py
  14. +19
    -8
      ablkit/reasoning/kb.py
  15. +20
    -18
      ablkit/reasoning/reasoner.py
  16. +77
    -26
      ablkit/utils/cache.py
  17. +22
    -10
      ablkit/utils/logger.py
  18. +5
    -2
      ablkit/utils/manager.py
  19. +7
    -1
      ablkit/utils/utils.py

+ 1
- 1
README.md View File

@@ -14,7 +14,7 @@

</div>

# Abductive Learning (ABL) Kit
# ABL Kit: A Python Toolkit for Abductive Learning

**ABL Kit** is an efficient Python toolkit for **Abductive Learning (ABL)**.
ABL is a novel paradigm that integrates machine learning and


+ 8
- 2
ablkit/bridge/base_bridge.py View File

@@ -1,3 +1,9 @@
"""
This module contains the base class for the Bridge part.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Tuple, Union

@@ -31,11 +37,11 @@ class BaseBridge(metaclass=ABCMeta):
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
if not isinstance(model, ABLModel):
raise TypeError(
"Expected an instance of ABLModel, but received type: {}".format(type(model))
f"Expected an instance of ABLModel, but received type: {type(model)}"
)
if not isinstance(reasoner, Reasoner):
raise TypeError(
"Expected an instance of Reasoner, but received type: {}".format(type(reasoner))
f"Expected an instance of Reasoner, but received type: {type(reasoner)}"
)

self.model = model


+ 15
- 7
ablkit/bridge/simple_bridge.py View File

@@ -1,3 +1,9 @@
"""
This module contains a simple implementation of the Bridge part.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

import os.path as osp
from typing import Any, List, Optional, Tuple, Union

@@ -221,7 +227,7 @@ class SimpleBridge(BaseBridge):
Labeled data should be in the same format as ``train_data``. The only difference is
that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be
utilized to train the model. Defaults to None.
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 pylint: disable=line-too-long
Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label``
and ``Y`` can be either None or not, which depends on the evaluation metircs in
``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate
@@ -327,10 +333,11 @@ class SimpleBridge(BaseBridge):

Parameters
----------
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501
Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object
with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be
either None or not, which depends on the evaluation metircs in ``self.metric_list``.
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long
Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label``
and ``Y`` can be either None or not, which depends on the evaluation metircs in
``self.metric_list``.
"""
val_data_examples = self.data_preprocess("val", val_data)
self._valid(val_data_examples)
@@ -346,10 +353,11 @@ class SimpleBridge(BaseBridge):

Parameters
----------
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long
Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object
with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y``
can be either None or not, which depends on the evaluation metircs in ``self.metric_list``.
can be either None or not, which depends on the evaluation metircs in
``self.metric_list``.
"""
print_log("Test start:", logger="current")
test_data_examples = self.data_preprocess("test", test_data)


+ 6
- 0
ablkit/data/evaluation/base_metric.py View File

@@ -1,3 +1,9 @@
"""
This module contains the base class used for evaluation.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

import logging
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional


+ 9
- 1
ablkit/data/evaluation/reasoning_metric.py View File

@@ -1,3 +1,10 @@
"""
This module contains the ReasoningMetric, which is used for evaluating the model performance
on tasks that need reasoning.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

from typing import Optional

from ...reasoning import KBBase
@@ -7,7 +14,7 @@ from .base_metric import BaseMetric

class ReasoningMetric(BaseMetric):
"""
A metrics class for evaluating the model performance on tasks need reasoning.
A metrics class for evaluating the model performance on tasks that need reasoning.

This class is designed to calculate the accuracy of the reasoing results. Reasoning
results are generated by first using the learning part to predict pseudo-labels
@@ -34,6 +41,7 @@ class ReasoningMetric(BaseMetric):
super().__init__(prefix)
self.kb = kb

# pylint: disable=protected-access
def process(self, data_examples: ListData) -> None:
"""
Process a batch of data examples.


+ 5
- 4
ablkit/data/evaluation/symbol_accuracy.py View File

@@ -1,4 +1,8 @@
from typing import Optional
"""
This module contains the class SymbolAccuracy, which is used for evaluating symbol-level accuracy.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

import numpy as np

@@ -20,9 +24,6 @@ class SymbolAccuracy(BaseMetric):
metrics of different tasks. Inherits from BaseMetric. Default to None.
"""

def __init__(self, prefix: Optional[str] = None) -> None:
super().__init__(prefix)

def process(self, data_examples: ListData) -> None:
"""
Processes a batch of data examples.


+ 5
- 3
ablkit/data/structures/base_data_element.py View File

@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py
"""
Copyright (c) OpenMMLab. All rights reserved.
Modified from
https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py # noqa: E501 pylint: disable=line-too-long
"""

import copy
from typing import Any, Iterator, Optional, Tuple, Type, Union


+ 7
- 5
ablkit/data/structures/list_data.py View File

@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
"""
Copyright (c) OpenMMLab. All rights reserved.
Modified from
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa: E501 pylint: disable=line-too-long
"""

from typing import List, Union

@@ -54,7 +56,7 @@ class ListData(BaseDataElement):
``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``.

This design is inspired by and extends the functionalities of the ``BaseDataElement``
class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501
class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501 pylint: disable=line-too-long

Examples:
>>> from ablkit.data.structures import ListData
@@ -72,7 +74,7 @@ class ListData(BaseDataElement):
DATA FIELDS
Y: [1, 2, 3]
gt_pseudo_label: [[1, 2], [3, 4], [5, 6]]
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 pylint: disable=line-too-long
) at 0x7f3bbf1991c0>
>>> print(data_examples[:1])
<ListData(


+ 20
- 14
ablkit/learning/abl_model.py View File

@@ -1,3 +1,10 @@
"""
This module contains the class ABLModel, which provides a unified interface for different
machine learning models.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

import pickle
from typing import Any, Dict

@@ -99,21 +106,20 @@ class ABLModel:
method = getattr(model, operation)
method(*args, **kwargs)
else:
if f"{operation}_path" not in kwargs.keys():
if f"{operation}_path" not in kwargs:
raise ValueError(f"'{operation}_path' should not be None")
else:
try:
if operation == "save":
with open(kwargs["save_path"], "wb") as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except (OSError, pickle.PickleError):
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method \
and the default pickle-based {operation} method failed."
)
try:
if operation == "save":
with open(kwargs["save_path"], "wb") as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except (OSError, pickle.PickleError) as exc:
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method \
and the default pickle-based {operation} method failed."
) from exc

def save(self, *args, **kwargs) -> None:
"""


+ 7
- 1
ablkit/learning/basic_nn.py View File

@@ -1,3 +1,9 @@
"""
This module contains the class BasicNN, which servers as a wrapper for PyTorch NN models.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

from __future__ import annotations

import logging
@@ -474,7 +480,7 @@ class BasicNN:
raise ValueError("X should not be None.")
if y is None:
y = [0] * len(X)
if not (len(y) == len(X)):
if not len(y) == len(X):
raise ValueError("X and y should have equal length.")

dataset = ClassificationDataset(X, y, transform=self.train_transform)


+ 6
- 0
ablkit/learning/torch_dataset/classification_dataset.py View File

@@ -1,3 +1,9 @@
"""
Implementation of PyTorch dataset class used for classification.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

from typing import Any, Callable, List, Tuple, Optional

import torch


+ 6
- 0
ablkit/learning/torch_dataset/prediction_dataset.py View File

@@ -1,3 +1,9 @@
"""
Implementation of PyTorch dataset class used for Prediction.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

from typing import Any, Callable, List, Tuple, Optional

import torch


+ 6
- 0
ablkit/learning/torch_dataset/regression_dataset.py View File

@@ -1,3 +1,9 @@
"""
Implementation of PyTorch dataset class used for regression.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

from typing import Any, List, Tuple

from torch.utils.data import Dataset


+ 19
- 8
ablkit/reasoning/kb.py View File

@@ -1,3 +1,10 @@
"""
This module contains the classes KBBase, GroundKB, and PrologKB, which provide wrappers
for different kinds of knowledge bases.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

import bisect
import inspect
import logging
@@ -394,7 +401,7 @@ class GroundKB(KBBase):
base. The second element is a list of reasoning results corresponding to each
candidate, i.e., the outcome of the ``logic_forward`` function.
"""
if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list:
if not self.GKB or len(pseudo_label) not in self.GKB_len_list:
return [], []

all_candidates, all_reasoning_results = self._find_candidate_GKB(pseudo_label, y)
@@ -478,7 +485,7 @@ class PrologKB(KBBase):
super().__init__(pseudo_label_list)

try:
import pyswip
import pyswip # pylint: disable=import-outside-toplevel
except (IndexError, ImportError):
print(
"A Prolog-based knowledge base is in use. Please install SWI-Prolog using the"
@@ -493,7 +500,7 @@ class PrologKB(KBBase):
raise FileNotFoundError(f"The Prolog file {self.pl_file} does not exist.")
self.prolog.consult(self.pl_file)

def logic_forward(self, pseudo_label: List[Any]) -> Any:
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any:
"""
Consult prolog with the query ``logic_forward(pseudo_labels, Res).``, and set the
returned ``Res`` as the reasoning results. To use this default function, there must be
@@ -504,11 +511,15 @@ class PrologKB(KBBase):
----------
pseudo_label : List[Any]
Pseudo-labels of an example.
x : List[Any]
The corresponding input example. If the information from the input
is not required in the reasoning process, then this parameter will not have
any effect.
"""
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"]
result = list(self.prolog.query(f"logic_forward({pseudo_label}, Res)."))[0]["Res"]
if result == "true":
return True
elif result == "false":
if result == "false":
return False
return result

@@ -517,7 +528,7 @@ class PrologKB(KBBase):
pseudo_label: List[Any],
revision_idx: List[int],
) -> List[Any]:
import re
import re # pylint: disable=import-outside-toplevel

revision_pseudo_label = pseudo_label.copy()
revision_pseudo_label = flatten(revision_pseudo_label)
@@ -533,7 +544,7 @@ class PrologKB(KBBase):
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
x: List[Any], # pylint: disable=unused-argument
revision_idx: List[int],
) -> str:
"""
@@ -563,7 +574,7 @@ class PrologKB(KBBase):
query_string = "logic_forward("
query_string += self._revision_pseudo_label(pseudo_label, revision_idx)
key_is_none_flag = y is None or (isinstance(y, list) and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
query_string += f",{y})." if not key_is_none_flag else ")."
return query_string

def revise_at_idx(


+ 20
- 18
ablkit/reasoning/reasoner.py View File

@@ -1,3 +1,10 @@
"""
This module contains the class Reasoner, which is used for minimizing the inconsistency
between the knowledge base and learning models.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

import inspect
from typing import Any, Callable, List, Optional, Union

@@ -251,25 +258,21 @@ class Reasoner:

def zoopt_budget(self, symbol_num: int) -> int:
"""
Set the budget for ZOOpt optimization. The function, in its default implementation,
returns a fixed budget value of 100. However, it can be adjusted to return other fixed
values, or a dynamic budget based on the number of symbols, if desired. For example,
one might choose to set the budget as 100 times ``symbol_num``.
Set the budget for ZOOpt optimization. The budget can be dynamic relying on
the number of symbols considered, e.g., the default implementation shown below.
Alternatively, it can be a fixed value, such as simply setting it to 100.

Parameters
----------
symbol_num : int
The number of symbols to be considered in the ZOOpt optimization process. Although this
parameter can be used to compute a dynamic optimization budget, by default it is not
utilized in the calculation.
The number of symbols to be considered in the ZOOpt optimization process.

Returns
-------
int
The budget for ZOOpt optimization. By default, this is a fixed value of 100,
irrespective of the symbol_num value.
The budget for ZOOpt optimization.
"""
return 100
return 10 * symbol_num

def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int:
"""
@@ -288,19 +291,18 @@ class Reasoner:

if max_revision == -1:
return symbol_num
elif isinstance(max_revision, float):
if not (0 <= max_revision <= 1):
if isinstance(max_revision, float):
if not 0 <= max_revision <= 1:
raise ValueError(
"If max_revision is a float, it must be between 0 and 1, "
+ f"but got {max_revision}"
)
return round(symbol_num * max_revision)
else:
if max_revision < 0:
raise ValueError(
f"If max_revision is an int, it must be non-negative, but got {max_revision}"
)
return max_revision
if max_revision < 0:
raise ValueError(
f"If max_revision is an int, it must be non-negative, but got {max_revision}"
)
return max_revision

def abduce(self, data_example: ListData) -> List[Any]:
"""


+ 77
- 26
ablkit/utils/cache.py View File

@@ -1,14 +1,15 @@
# Python module wrapper for _functools C module
# to allow utilities written in Python to be added
# to the functools module.
# Written by Nick Coghlan <ncoghlan at gmail.com>,
# Raymond Hettinger <python at rcn.com>,
# and Łukasz Langa <lukasz at langa.pl>.
# Copyright (C) 2006-2013 Python Software Foundation.
# See C source code for _functools credits/copyright
# Modified from
# https://github.com/python/cpython/blob/3.12/Lib/functools.py

"""
Python module wrapper for _functools C module
to allow utilities written in Python to be added
to the functools module.
Written by Nick Coghlan <ncoghlan at gmail.com>,
Raymond Hettinger <python at rcn.com>,
and Łukasz Langa <lukasz at langa.pl>.
Copyright (C) 2006-2013 Python Software Foundation.
See C source code for _functools credits/copyright
Modified from
https://github.com/python/cpython/blob/3.12/Lib/functools.py
"""

from typing import Callable, Generic, TypeVar

@@ -18,30 +19,58 @@ PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields


class Cache(Generic[K, T]):
def __init__(self, func: Callable[[K], T]):
"""Create cache
"""
A generic caching mechanism that stores the results of a function call and
retrieves them to avoid repeated calculations.

:param func: Function this cache evaluates
:param cache: If true, do in memory caching.
:param cache_root: If not None, cache to files at the provided path.
:param key_func: Convert the key into a hashable object if needed
"""
This class implements a dictionary-based cache with a circular doubly linked
list to manage the cache entries efficiently. It is designed to be generic,
allowing for caching of any callable function.

Parameters
----------
func : Callable[[K], T]
The function to be cached. This function takes an argument of type K and
returns a value of type T.
"""

def __init__(self, func: Callable[[K], T]):
self.func = func
self.has_init = False

self.cache = False
self.cache_dict = {}
self.key_func = None
self.max_size = 0

self.hits, self.misses = 0, 0
self.full = False
self.root = [] # root of the circular doubly linked list
self.root[:] = [self.root, self.root, None, None]

def __getitem__(self, obj, *args) -> T:
return self.get_from_dict(obj, *args)

def clear_cache(self):
"""Invalidate entire cache."""
"""
Invalidate the entire cache.
"""
self.cache_dict.clear()

def _init_cache(self, obj):
def init_cache(self, obj):
"""
Initialize the cache settings.

Parameters
----------
obj : Any
The object containing settings for cache initialization.
"""
if self.has_init:
return

self.cache = True
self.cache_dict = dict()
self.cache_dict = {}
self.key_func = obj.key_func
self.max_size = obj.cache_size

@@ -53,9 +82,23 @@ class Cache(Generic[K, T]):
self.has_init = True

def get_from_dict(self, obj, *args) -> T:
"""Implements dict based cache."""
"""
Retrieve a value from the cache or compute it using ``self.func``.

Parameters
----------
obj : Any
The object to which the cached method/function belongs.
*args : Any
Arguments used in key generation for cache retrieval or function computation.

Returns
-------
T
The value from the cache or computed by the function.
"""
# x is not used in cache key
pred_pseudo_label, y, x, *res_args = args
pred_pseudo_label, y, _x, *res_args = args
cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args)
link = self.cache_dict.get(cache_key)
if link is not None:
@@ -96,15 +139,23 @@ class Cache(Generic[K, T]):


def abl_cache():
"""
Decorator to enable caching for a function.

Returns
-------
Callable
The wrapped function with caching capability.
"""

def decorator(func):
cache_instance = Cache(func)

def wrapper(obj, *args):
if obj.use_cache:
cache_instance._init_cache(obj)
cache_instance.init_cache(obj)
return cache_instance.get_from_dict(obj, *args)
else:
return func(obj, *args)
return func(obj, *args)

return wrapper



+ 22
- 10
ablkit/utils/logger.py View File

@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/logging/logger.py
"""
Copyright (c) OpenMMLab. All rights reserved.
Modified from
https://github.com/open-mmlab/mmengine/blob/main/mmengine/logging/logger.py
"""

import logging
import os
@@ -132,13 +134,13 @@ class ABLFormatter(logging.Formatter):
Formatted result.
"""
if record.levelno == logging.ERROR:
self._style._fmt = self.err_format
self._style._fmt = self.err_format # pylint: disable=protected-access
elif record.levelno == logging.WARNING:
self._style._fmt = self.warn_format
self._style._fmt = self.warn_format # pylint: disable=protected-access
elif record.levelno == logging.INFO:
self._style._fmt = self.info_format
self._style._fmt = self.info_format # pylint: disable=protected-access
elif record.levelno == logging.DEBUG:
self._style._fmt = self.debug_format
self._style._fmt = self.debug_format # pylint: disable=protected-access

result = logging.Formatter.format(self, record)
return result
@@ -215,7 +217,7 @@ class ABLLogger(Logger, ManagerMixin):
self.handlers.append(stream_handler)

if log_file is None:
import time
import time # pylint: disable=import-outside-toplevel

local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime())

@@ -234,10 +236,20 @@ class ABLLogger(Logger, ManagerMixin):

@property
def log_file(self):
"""Get the file path of the log.

Returns:
str: Path of the log.
"""
return self._log_file

@property
def log_dir(self):
"""Get the directory where the log is stored.

Returns:
str: Directory where the log is stored.
"""
return self._log_dir

@classmethod
@@ -284,11 +296,11 @@ class ABLLogger(Logger, ManagerMixin):
level : Union[int, str]
The logging level to set.
"""
self.level = logging._checkLevel(level)
self.level = logging._checkLevel(level) # pylint: disable=protected-access
_accquire_lock()
# The same logic as ``logging.Manager._clear_cache``.
for logger in ABLLogger._instance_dict.values():
logger._cache.clear()
logger._cache.clear() # pylint: disable=protected-access
_release_lock()




+ 5
- 2
ablkit/utils/manager.py View File

@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""
Copyright (c) OpenMMLab. All rights reserved.
"""

import inspect
import threading
import warnings
@@ -72,7 +75,7 @@ class ManagerMixin(metaclass=ManagerMeta):
name (str): Name of the instance. Defaults to ''.
"""

def __init__(self, name: str = "", **kwargs):
def __init__(self, name: str = ""):
assert isinstance(name, str) and name, "name argument must be an non-empty string."
self._instance_name = name



+ 7
- 1
ablkit/utils/utils.py View File

@@ -1,3 +1,9 @@
"""
Implementation of utilities used in ablkit.

Copyright (c) 2024 LAMDA. All rights reserved.
"""

from typing import List, Any, Union, Tuple, Optional

import numpy as np
@@ -198,6 +204,6 @@ def tab_data_to_tuple(
return None
if len(X) != len(y):
raise ValueError(
"The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))
f"The length of X and y should be the same, but got {len(X)} and {len(y)}."
)
return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y))

Loading…
Cancel
Save