Browse Source

[FIX] change optional in docstring

pull/1/head
troyyyyy 1 year ago
parent
commit
132437c7ae
2 changed files with 14 additions and 13 deletions
  1. +6
    -6
      abl/reasoning/kb.py
  2. +8
    -7
      abl/utils/logger.py

+ 6
- 6
abl/reasoning/kb.py View File

@@ -35,7 +35,7 @@ class KBBase(ABC):
operations. Defaults to True.
key_func : Callable, optional
A function employed for hashing in abl_cache. This is only operational when use_cache
is set to True. Defaults to to_hashable.
is set to True. Defaults to ``to_hashable``.
cache_size: int, optional
The cache size in abl_cache. This is only operational when use_cache is set to
True. Defaults to 4096.
@@ -52,10 +52,10 @@ class KBBase(ABC):
def __init__(
self,
pseudo_label_list: List[Any],
max_err: Optional[float] = 1e-10,
use_cache: Optional[bool] = True,
key_func: Optional[Callable] = to_hashable,
cache_size: Optional[int] = 4096,
max_err: float = 1e-10,
use_cache: bool = True,
key_func: Callable = to_hashable,
cache_size: int = 4096,
):
if not isinstance(pseudo_label_list, list):
raise TypeError(f"pseudo_label_list should be list, got {type(pseudo_label_list)}")
@@ -308,7 +308,7 @@ class GroundKB(KBBase):
self,
pseudo_label_list: List[Any],
GKB_len_list: List[int],
max_err: Optional[float] = 1e-10,
max_err: float = 1e-10,
):
super().__init__(pseudo_label_list, max_err)
if not isinstance(GKB_len_list, list):


+ 8
- 7
abl/utils/logger.py View File

@@ -53,12 +53,13 @@ class ABLFormatter(logging.Formatter):

Parameters
----------
color : bool
color : bool, optional
Whether to use colorful format. filehandler is not
allowed to use color format, otherwise it will be garbled.
blink : bool
Defaults to True.
blink : bool, optional
Whether to blink the ``INFO`` and ``DEBUG`` logging
level.
level. Defaults to False.
kwargs : dict
Keyword arguments passed to
:meth:``logging.Formatter.__init__``.
@@ -85,7 +86,7 @@ class ABLFormatter(logging.Formatter):
self.info_format = f"%(asctime)s - %(name)s - {info_prefix} - %(" "message)s"
self.debug_format = f"%(asctime)s - %(name)s - {debug_prefix} - %(" "message)s"

def _get_prefix(self, level: str, color: bool, blink: Optional[bool] = False) -> str:
def _get_prefix(self, level: str, color: bool, blink: bool = False) -> str:
"""
Get the prefix of the target log level.

@@ -96,7 +97,7 @@ class ABLFormatter(logging.Formatter):
color : bool
Whether to get a colorful prefix.
blink : bool, optional
Whether the prefix will blink.
Whether the prefix will blink. Defaults to False.

Returns
-------
@@ -192,8 +193,8 @@ class ABLLogger(Logger, ManagerMixin):
name: str,
logger_name="abl",
log_file: Optional[str] = None,
log_level: Optional[Union[int, str]] = "INFO",
file_mode: Optional[str] = "w",
log_level: Union[int, str] = "INFO",
file_mode: str = "w",
):
Logger.__init__(self, logger_name)
ManagerMixin.__init__(self, name)


Loading…
Cancel
Save