Browse Source

1.删除Trainer中的prefetch参数; 2.增加中文分词的下载; 3.增加DataBundle的delete_dataset, delete_vocab

tags/v0.4.10
yh 5 years ago
parent
commit
7a0903d9ba
5 changed files with 32 additions and 4 deletions
  1. +2
    -1
      fastNLP/core/metrics.py
  2. +1
    -0
      fastNLP/io/__init__.py
  3. +19
    -0
      fastNLP/io/data_bundle.py
  4. +6
    -1
      fastNLP/io/file_utils.py
  5. +4
    -2
      fastNLP/modules/decoder/crf.py

+ 2
- 1
fastNLP/core/metrics.py View File

@@ -142,9 +142,10 @@ class MetricBase(object):
设置metric的名称,默认是Metric的class name. 设置metric的名称,默认是Metric的class name.


:param str name: :param str name:
:return:
:return: self
""" """
self._metric_name = name self._metric_name = name
return self


def get_metric_name(self): def get_metric_name(self):
""" """


+ 1
- 0
fastNLP/io/__init__.py View File

@@ -50,6 +50,7 @@ __all__ = [
"SSTPipe", "SSTPipe",
"SST2Pipe", "SST2Pipe",
"IMDBPipe", "IMDBPipe",
"Conll2003Pipe",


"Conll2003NERPipe", "Conll2003NERPipe",
"OntoNotesNERPipe", "OntoNotesNERPipe",


+ 19
- 0
fastNLP/io/data_bundle.py View File

@@ -158,6 +158,16 @@ class DataBundle:
""" """
return self.datasets[name] return self.datasets[name]


def delete_dataset(self, name:str):
"""
删除名为name的DataSet

:param str name:
:return: self
"""
self.datasets.pop(name, None)
return self

def get_vocab(self, field_name:str)->Vocabulary: def get_vocab(self, field_name:str)->Vocabulary:
""" """
获取field名为field_name对应的vocab 获取field名为field_name对应的vocab
@@ -167,6 +177,15 @@ class DataBundle:
""" """
return self.vocabs[field_name] return self.vocabs[field_name]


def delete_vocab(self, field_name:str):
"""
删除vocab
:param str field_name:
:return: self
"""
self.vocabs.pop(field_name, None)
return self

def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True):
""" """
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: 将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作::


+ 6
- 1
fastNLP/io/file_utils.py View File

@@ -75,7 +75,12 @@ DATASET_DIR = {
"rte": "RTE.zip", "rte": "RTE.zip",
"msra-ner": "MSRA_NER.zip", "msra-ner": "MSRA_NER.zip",
"peopledaily": "peopledaily.zip", "peopledaily": "peopledaily.zip",
"weibo-ner": "weibo_NER.zip"
"weibo-ner": "weibo_NER.zip",

"cws-pku": 'cws_pku.zip',
"cws-cityu": "cws_cityu.zip",
"cws-as": 'cws_as.zip',
"cws-msra": 'cws_msra.zip'
} }


PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,


+ 4
- 2
fastNLP/modules/decoder/crf.py View File

@@ -7,7 +7,7 @@ import torch
from torch import nn from torch import nn


from ..utils import initial_parameter from ..utils import initial_parameter
from ...core import Vocabulary


def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): def allowed_transitions(id2target, encoding_type='bio', include_start_end=False):
""" """
@@ -15,7 +15,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False)


给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。


:param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
:param dict,Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 :param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头;
@@ -23,6 +23,8 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False)
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。
""" """
if isinstance(id2target, Vocabulary):
id2target = id2target.idx2word
num_tags = len(id2target) num_tags = len(id2target)
start_idx = num_tags start_idx = num_tags
end_idx = num_tags + 1 end_idx = num_tags + 1


Loading…
Cancel
Save