Browse Source

update io/file_utils.py

tags/v0.4.10
xuyige 6 years ago
parent
commit
aaabcd6bab
1 changed files with 10 additions and 10 deletions
  1. +10
    -10
      fastNLP/io/file_utils.py

+ 10
- 10
fastNLP/io/file_utils.py View File

@@ -21,8 +21,8 @@ PRETRAINED_BERT_MODEL_DIR = {

'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip',

'en-base-multi-cased': 'bert-base-multilingual-cased.zip',
'en-base-multi-uncased': 'bert-base-multilingual-uncased.zip',
'multi-base-cased': 'bert-base-multilingual-cased.zip',
'multi-base-uncased': 'bert-base-multilingual-uncased.zip',

'cn': 'bert-chinese-wwm.zip',
'cn-base': 'bert-base-chinese.zip',
@@ -38,7 +38,7 @@ PRETRAINED_ELMO_MODEL_DIR = {
}

PRETRAIN_STATIC_FILES = {
'en': 'glove.840B.300d.tar.gz',
'en': 'glove.840B.300d.zip',

'en-glove-6b-50d': 'glove.6B.50d.zip',
'en-glove-6b-100d': 'glove.6B.100d.zip',
@@ -184,26 +184,26 @@ def _get_base_url(name):
return URLS[name.lower()]


def _get_embedding_url(type, name):
def _get_embedding_url(embed_type, name):
"""
给定embedding类似和名称,返回下载url

:param str type: 支持static, bert, elmo。即embedding的类型
:param str embed_type: 支持static, bert, elmo。即embedding的类型
:param str name: embedding的名称, 例如en, cn, based等
:return: str, 下载的url地址
"""
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,
"bert": PRETRAINED_BERT_MODEL_DIR,
"static": PRETRAIN_STATIC_FILES}
map = PRETRAIN_MAP.get(type, None)
if map:
filename = map.get(name, None)
embed_map = PRETRAIN_MAP.get(embed_type, None)
if embed_map:
filename = embed_map.get(name, None)
if filename:
url = _get_base_url('embedding') + filename
return url
raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys())))
raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys())))
else:
raise KeyError(f"There is no {type}. Only supports bert, elmo, static")
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static")


def _get_dataset_url(name):


Loading…
Cancel
Save