@@ -45,7 +45,8 @@ __all__ = [ | |||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"ExtractiveQAMetric", | |||||
"CMRC2018Metric", | |||||
"ClassifyFPreRecMetric", | |||||
"Optimizer", | "Optimizer", | ||||
"SGD", | "SGD", | ||||
@@ -62,6 +62,7 @@ __all__ = [ | |||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"CMRC2018Metric", | "CMRC2018Metric", | ||||
"ClassifyFPreRecMetric", | |||||
"Optimizer", | "Optimizer", | ||||
"SGD", | "SGD", | ||||
@@ -84,7 +85,7 @@ from .dataset import DataSet | |||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss | ||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric | |||||
from .optimizer import Optimizer, SGD, Adam, AdamW | from .optimizer import Optimizer, SGD, Adam, AdamW | ||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | ||||
from .tester import Tester | from .tester import Tester | ||||
@@ -6,7 +6,8 @@ __all__ = [ | |||||
"MetricBase", | "MetricBase", | ||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"CMRC2018Metric" | |||||
"CMRC2018Metric", | |||||
"ClassifyFPreRecMetric" | |||||
] | ] | ||||
import inspect | import inspect | ||||
@@ -72,7 +72,7 @@ class Tester(object): | |||||
""" | """ | ||||
:param ~fastNLP.DataSet data: 需要测试的数据集 | :param ~fastNLP.DataSet data: 需要测试的数据集 | ||||
:param torch.nn.module model: 使用的模型 | |||||
:param torch.nn.Module model: 使用的模型 | |||||
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics | :param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics | ||||
:param int batch_size: evaluation时使用的batch_size有多大。 | :param int batch_size: evaluation时使用的batch_size有多大。 | ||||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | ||||