diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index f994bd31..77695852 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -22,7 +22,7 @@ from .utils import _check_arg_dict_list from .utils import _get_func_signature from .utils import seq_len_to_mask from .vocabulary import Vocabulary - +from abc import abstractmethod class MetricBase(object): """ @@ -117,10 +117,12 @@ class MetricBase(object): def __init__(self): self.param_map = {} # key is param in function, value is input param. self._checked = False - + + @abstractmethod def evaluate(self, *args, **kwargs): raise NotImplementedError - + + @abstractmethod def get_metric(self, reset=True): raise NotImplemented diff --git a/reproduction/seqence_labelling/cws/model/module.py b/reproduction/seqence_labelling/cws/model/module.py index 6cd8b5e3..86149f39 100644 --- a/reproduction/seqence_labelling/cws/model/module.py +++ b/reproduction/seqence_labelling/cws/model/module.py @@ -1,11 +1,10 @@ from torch import nn import torch -from fastNLP.modules import Embedding import numpy as np class SemiCRFShiftRelay(nn.Module): """ - 该模块是一个decoder,但 + 该模块是一个decoder,但当前不支持含有tag的decode。 """ def __init__(self, L): diff --git a/reproduction/seqence_labelling/cws/train_shift_relay.py b/reproduction/seqence_labelling/cws/train_shift_relay.py index ed512252..c5d436fe 100644 --- a/reproduction/seqence_labelling/cws/train_shift_relay.py +++ b/reproduction/seqence_labelling/cws/train_shift_relay.py @@ -32,7 +32,7 @@ lr = 0.02 #########hyper device = 0 -# !!!!这里前往不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到 +# !!!!这里千万不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到 # 你们的reproduction路径下,然后设置.gitignore file_dir = '/path/to/pku' char_embed_path = '/path/to/1grams_t3_m50_corpus.txt'