@@ -14,7 +14,6 @@ __all__ = [ | |||||
"Instance", | "Instance", | ||||
"FieldArray", | "FieldArray", | ||||
"DataSetIter", | "DataSetIter", | ||||
"BatchIter", | "BatchIter", | ||||
"TorchLoaderIter", | "TorchLoaderIter", | ||||
@@ -29,10 +28,18 @@ __all__ = [ | |||||
"Callback", | "Callback", | ||||
"GradientClipCallback", | "GradientClipCallback", | ||||
"EarlyStopCallback", | "EarlyStopCallback", | ||||
"TensorboardCallback", | |||||
"FitlogCallback", | |||||
"EvaluateCallback", | |||||
"LRScheduler", | "LRScheduler", | ||||
"ControlC", | "ControlC", | ||||
"LRFinder", | "LRFinder", | ||||
"TensorboardCallback", | |||||
"WarmupCallback", | |||||
'SaveModelCallback', | |||||
"EchoCallback", | |||||
"TesterCallback", | |||||
"CallbackException", | |||||
"EarlyStopError", | |||||
"Padder", | "Padder", | ||||
"AutoPadder", | "AutoPadder", | ||||
@@ -46,7 +53,7 @@ __all__ = [ | |||||
"SGD", | "SGD", | ||||
"Adam", | "Adam", | ||||
"AdamW", | "AdamW", | ||||
"Sampler", | "Sampler", | ||||
"SequentialSampler", | "SequentialSampler", | ||||
"BucketSampler", | "BucketSampler", | ||||
@@ -60,17 +67,18 @@ __all__ = [ | |||||
"LossInForward", | "LossInForward", | ||||
"cache_results", | "cache_results", | ||||
'logger' | 'logger' | ||||
] | ] | ||||
__version__ = '0.4.5' | __version__ = '0.4.5' | ||||
import sys | |||||
from . import embeddings | from . import embeddings | ||||
from . import models | from . import models | ||||
from . import modules | from . import modules | ||||
from .core import * | from .core import * | ||||
from .doc_utils import doc_process | |||||
from .io import loader, pipe | from .io import loader, pipe | ||||
import sys | |||||
from .doc_utils import doc_process | |||||
doc_process(sys.modules[__name__]) | |||||
doc_process(sys.modules[__name__]) |
@@ -25,3 +25,15 @@ def doc_process(m): | |||||
if module_name == m.__name__: | if module_name == m.__name__: | ||||
# print(name, ": not found defined doc.") | # print(name, ": not found defined doc.") | ||||
break | break | ||||
if inspect.isclass(obj): | |||||
for base in obj.__bases__: | |||||
if base.__module__.startswith("fastNLP"): | |||||
parts = base.__module__.split(".") + [] | |||||
module_name, i = "fastNLP", 1 | |||||
for i in range(len(parts) - 1): | |||||
defined_m = sys.modules[module_name] | |||||
if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: | |||||
obj.__doc__ = r"基类 :class:`" + defined_m.__name__ + "." + base.__name__ + "` \n\n" + obj.__doc__ | |||||
break | |||||
module_name += "." + parts[i + 1] |