diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index a60a257..0632781 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -35,7 +35,7 @@ class KBBase(ABC): operations. Defaults to True. key_func : Callable, optional A function employed for hashing in abl_cache. This is only operational when use_cache - is set to True. Defaults to to_hashable. + is set to True. Defaults to ``to_hashable``. cache_size: int, optional The cache size in abl_cache. This is only operational when use_cache is set to True. Defaults to 4096. @@ -52,10 +52,10 @@ class KBBase(ABC): def __init__( self, pseudo_label_list: List[Any], - max_err: Optional[float] = 1e-10, - use_cache: Optional[bool] = True, - key_func: Optional[Callable] = to_hashable, - cache_size: Optional[int] = 4096, + max_err: float = 1e-10, + use_cache: bool = True, + key_func: Callable = to_hashable, + cache_size: int = 4096, ): if not isinstance(pseudo_label_list, list): raise TypeError(f"pseudo_label_list should be list, got {type(pseudo_label_list)}") @@ -308,7 +308,7 @@ class GroundKB(KBBase): self, pseudo_label_list: List[Any], GKB_len_list: List[int], - max_err: Optional[float] = 1e-10, + max_err: float = 1e-10, ): super().__init__(pseudo_label_list, max_err) if not isinstance(GKB_len_list, list): diff --git a/abl/utils/logger.py b/abl/utils/logger.py index 1876ebc..4e9d8b6 100644 --- a/abl/utils/logger.py +++ b/abl/utils/logger.py @@ -53,12 +53,13 @@ class ABLFormatter(logging.Formatter): Parameters ---------- - color : bool + color : bool, optional Whether to use colorful format. filehandler is not allowed to use color format, otherwise it will be garbled. - blink : bool + Defaults to True. + blink : bool, optional Whether to blink the ``INFO`` and ``DEBUG`` logging - level. + level. Defaults to False. kwargs : dict Keyword arguments passed to :meth:``logging.Formatter.__init__``. @@ -85,7 +86,7 @@ class ABLFormatter(logging.Formatter): self.info_format = f"%(asctime)s - %(name)s - {info_prefix} - %(" "message)s" self.debug_format = f"%(asctime)s - %(name)s - {debug_prefix} - %(" "message)s" - def _get_prefix(self, level: str, color: bool, blink: Optional[bool] = False) -> str: + def _get_prefix(self, level: str, color: bool, blink: bool = False) -> str: """ Get the prefix of the target log level. @@ -96,7 +97,7 @@ class ABLFormatter(logging.Formatter): color : bool Whether to get a colorful prefix. blink : bool, optional - Whether the prefix will blink. + Whether the prefix will blink. Defaults to False. Returns ------- @@ -192,8 +193,8 @@ class ABLLogger(Logger, ManagerMixin): name: str, logger_name="abl", log_file: Optional[str] = None, - log_level: Optional[Union[int, str]] = "INFO", - file_mode: Optional[str] = "w", + log_level: Union[int, str] = "INFO", + file_mode: str = "w", ): Logger.__init__(self, logger_name) ManagerMixin.__init__(self, name)