diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index eeabda35..acf0efc4 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -13,8 +13,9 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa """ from .batch import DataSetIter, BatchIter, TorchLoaderIter -from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC -from .callback import EvaluateCallback, FitlogCallback, SaveModelCallback +from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ + LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ + TesterCallback, CallbackException, EarlyStopError from .const import Const from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 447186ca..17ded171 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -51,13 +51,19 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: """ __all__ = [ "Callback", + "GradientClipCallback", "EarlyStopCallback", - "TensorboardCallback", "FitlogCallback", + "EvaluateCallback", "LRScheduler", "ControlC", - "EvaluateCallback", + "LRFinder", + "TensorboardCallback", + "WarmupCallback", + "SaveModelCallback", + "EchoCallback", + "TesterCallback", "CallbackException", "EarlyStopError" @@ -718,7 +724,7 @@ class SmoothValue(object): self.smooth = None def add_value(self, val: float) -> None: - "Add `val` to calculate updated smoothed value." + """Add `val` to calculate updated smoothed value.""" self.n += 1 self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val self.smooth = self.mov_avg / (1 - self.beta ** self.n) diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py index 02f24097..e7b419ac 100644 --- a/fastNLP/io/loader/loader.py +++ b/fastNLP/io/loader/loader.py @@ -68,7 +68,8 @@ class Loader: """ raise NotImplementedError(f"{self.__class__} cannot download data automatically.") - def _get_dataset_path(self, dataset_name): + @staticmethod + def _get_dataset_path(dataset_name): """ 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index 474865c6..2eaeef58 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -239,6 +239,7 @@ class QuoraPipe(MatchingPipe): data_bundle = QuoraLoader().load(paths) return self.process(data_bundle) + class QNLIPipe(MatchingPipe): def process_from_file(self, paths=None): data_bundle = QNLILoader().load(paths)