@@ -142,9 +142,10 @@ class MetricBase(object): | |||||
设置metric的名称,默认是Metric的class name. | 设置metric的名称,默认是Metric的class name. | ||||
:param str name: | :param str name: | ||||
:return: | |||||
:return: self | |||||
""" | """ | ||||
self._metric_name = name | self._metric_name = name | ||||
return self | |||||
def get_metric_name(self): | def get_metric_name(self): | ||||
""" | """ | ||||
@@ -50,6 +50,7 @@ __all__ = [ | |||||
"SSTPipe", | "SSTPipe", | ||||
"SST2Pipe", | "SST2Pipe", | ||||
"IMDBPipe", | "IMDBPipe", | ||||
"Conll2003Pipe", | |||||
"Conll2003NERPipe", | "Conll2003NERPipe", | ||||
"OntoNotesNERPipe", | "OntoNotesNERPipe", | ||||
@@ -158,6 +158,16 @@ class DataBundle: | |||||
""" | """ | ||||
return self.datasets[name] | return self.datasets[name] | ||||
def delete_dataset(self, name:str): | |||||
""" | |||||
删除名为name的DataSet | |||||
:param str name: | |||||
:return: self | |||||
""" | |||||
self.datasets.pop(name, None) | |||||
return self | |||||
def get_vocab(self, field_name:str)->Vocabulary: | def get_vocab(self, field_name:str)->Vocabulary: | ||||
""" | """ | ||||
获取field名为field_name对应的vocab | 获取field名为field_name对应的vocab | ||||
@@ -167,6 +177,15 @@ class DataBundle: | |||||
""" | """ | ||||
return self.vocabs[field_name] | return self.vocabs[field_name] | ||||
def delete_vocab(self, field_name:str): | |||||
""" | |||||
删除vocab | |||||
:param str field_name: | |||||
:return: self | |||||
""" | |||||
self.vocabs.pop(field_name, None) | |||||
return self | |||||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | ||||
""" | """ | ||||
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: | 将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: | ||||
@@ -75,7 +75,12 @@ DATASET_DIR = { | |||||
"rte": "RTE.zip", | "rte": "RTE.zip", | ||||
"msra-ner": "MSRA_NER.zip", | "msra-ner": "MSRA_NER.zip", | ||||
"peopledaily": "peopledaily.zip", | "peopledaily": "peopledaily.zip", | ||||
"weibo-ner": "weibo_NER.zip" | |||||
"weibo-ner": "weibo_NER.zip", | |||||
"cws-pku": 'cws_pku.zip', | |||||
"cws-cityu": "cws_cityu.zip", | |||||
"cws-as": 'cws_as.zip', | |||||
"cws-msra": 'cws_msra.zip' | |||||
} | } | ||||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | ||||
@@ -7,7 +7,7 @@ import torch | |||||
from torch import nn | from torch import nn | ||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
from ...core import Vocabulary | |||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | ||||
""" | """ | ||||
@@ -15,7 +15,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) | |||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | ||||
:param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
:param dict,Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | ||||
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | :param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | ||||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | ||||
@@ -23,6 +23,8 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) | |||||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | ||||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | ||||
""" | """ | ||||
if isinstance(id2target, Vocabulary): | |||||
id2target = id2target.idx2word | |||||
num_tags = len(id2target) | num_tags = len(id2target) | ||||
start_idx = num_tags | start_idx = num_tags | ||||
end_idx = num_tags + 1 | end_idx = num_tags + 1 | ||||