| @@ -13,8 +13,9 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||||
| """ | """ | ||||
| from .batch import DataSetIter, BatchIter, TorchLoaderIter | from .batch import DataSetIter, BatchIter, TorchLoaderIter | ||||
| from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | |||||
| from .callback import EvaluateCallback, FitlogCallback, SaveModelCallback | |||||
| from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | |||||
| LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ | |||||
| TesterCallback, CallbackException, EarlyStopError | |||||
| from .const import Const | from .const import Const | ||||
| from .dataset import DataSet | from .dataset import DataSet | ||||
| from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
| @@ -51,13 +51,19 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: | |||||
| """ | """ | ||||
| __all__ = [ | __all__ = [ | ||||
| "Callback", | "Callback", | ||||
| "GradientClipCallback", | "GradientClipCallback", | ||||
| "EarlyStopCallback", | "EarlyStopCallback", | ||||
| "TensorboardCallback", | |||||
| "FitlogCallback", | "FitlogCallback", | ||||
| "EvaluateCallback", | |||||
| "LRScheduler", | "LRScheduler", | ||||
| "ControlC", | "ControlC", | ||||
| "EvaluateCallback", | |||||
| "LRFinder", | |||||
| "TensorboardCallback", | |||||
| "WarmupCallback", | |||||
| "SaveModelCallback", | |||||
| "EchoCallback", | |||||
| "TesterCallback", | |||||
| "CallbackException", | "CallbackException", | ||||
| "EarlyStopError" | "EarlyStopError" | ||||
| @@ -718,7 +724,7 @@ class SmoothValue(object): | |||||
| self.smooth = None | self.smooth = None | ||||
| def add_value(self, val: float) -> None: | def add_value(self, val: float) -> None: | ||||
| "Add `val` to calculate updated smoothed value." | |||||
| """Add `val` to calculate updated smoothed value.""" | |||||
| self.n += 1 | self.n += 1 | ||||
| self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | ||||
| self.smooth = self.mov_avg / (1 - self.beta ** self.n) | self.smooth = self.mov_avg / (1 - self.beta ** self.n) | ||||
| @@ -68,7 +68,8 @@ class Loader: | |||||
| """ | """ | ||||
| raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | ||||
| def _get_dataset_path(self, dataset_name): | |||||
| @staticmethod | |||||
| def _get_dataset_path(dataset_name): | |||||
| """ | """ | ||||
| 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | ||||
| @@ -239,6 +239,7 @@ class QuoraPipe(MatchingPipe): | |||||
| data_bundle = QuoraLoader().load(paths) | data_bundle = QuoraLoader().load(paths) | ||||
| return self.process(data_bundle) | return self.process(data_bundle) | ||||
| class QNLIPipe(MatchingPipe): | class QNLIPipe(MatchingPipe): | ||||
| def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
| data_bundle = QNLILoader().load(paths) | data_bundle = QNLILoader().load(paths) | ||||