@@ -13,8 +13,9 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||||
""" | """ | ||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | from .batch import DataSetIter, BatchIter, TorchLoaderIter | ||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | |||||
from .callback import EvaluateCallback, FitlogCallback, SaveModelCallback | |||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | |||||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ | |||||
TesterCallback, CallbackException, EarlyStopError | |||||
from .const import Const | from .const import Const | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
@@ -51,13 +51,19 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"Callback", | "Callback", | ||||
"GradientClipCallback", | "GradientClipCallback", | ||||
"EarlyStopCallback", | "EarlyStopCallback", | ||||
"TensorboardCallback", | |||||
"FitlogCallback", | "FitlogCallback", | ||||
"EvaluateCallback", | |||||
"LRScheduler", | "LRScheduler", | ||||
"ControlC", | "ControlC", | ||||
"EvaluateCallback", | |||||
"LRFinder", | |||||
"TensorboardCallback", | |||||
"WarmupCallback", | |||||
"SaveModelCallback", | |||||
"EchoCallback", | |||||
"TesterCallback", | |||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError" | "EarlyStopError" | ||||
@@ -718,7 +724,7 @@ class SmoothValue(object): | |||||
self.smooth = None | self.smooth = None | ||||
def add_value(self, val: float) -> None: | def add_value(self, val: float) -> None: | ||||
"Add `val` to calculate updated smoothed value." | |||||
"""Add `val` to calculate updated smoothed value.""" | |||||
self.n += 1 | self.n += 1 | ||||
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | ||||
self.smooth = self.mov_avg / (1 - self.beta ** self.n) | self.smooth = self.mov_avg / (1 - self.beta ** self.n) | ||||
@@ -68,7 +68,8 @@ class Loader: | |||||
""" | """ | ||||
raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | ||||
def _get_dataset_path(self, dataset_name): | |||||
@staticmethod | |||||
def _get_dataset_path(dataset_name): | |||||
""" | """ | ||||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | ||||
@@ -239,6 +239,7 @@ class QuoraPipe(MatchingPipe): | |||||
data_bundle = QuoraLoader().load(paths) | data_bundle = QuoraLoader().load(paths) | ||||
return self.process(data_bundle) | return self.process(data_bundle) | ||||
class QNLIPipe(MatchingPipe): | class QNLIPipe(MatchingPipe): | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
data_bundle = QNLILoader().load(paths) | data_bundle = QNLILoader().load(paths) | ||||