diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index 43f8be62..14766fba 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -21,8 +21,8 @@ PRETRAINED_BERT_MODEL_DIR = { 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', - 'en-base-multi-cased': 'bert-base-multilingual-cased.zip', - 'en-base-multi-uncased': 'bert-base-multilingual-uncased.zip', + 'multi-base-cased': 'bert-base-multilingual-cased.zip', + 'multi-base-uncased': 'bert-base-multilingual-uncased.zip', 'cn': 'bert-chinese-wwm.zip', 'cn-base': 'bert-base-chinese.zip', @@ -38,7 +38,7 @@ PRETRAINED_ELMO_MODEL_DIR = { } PRETRAIN_STATIC_FILES = { - 'en': 'glove.840B.300d.tar.gz', + 'en': 'glove.840B.300d.zip', 'en-glove-6b-50d': 'glove.6B.50d.zip', 'en-glove-6b-100d': 'glove.6B.100d.zip', @@ -184,26 +184,26 @@ def _get_base_url(name): return URLS[name.lower()] -def _get_embedding_url(type, name): +def _get_embedding_url(embed_type, name): """ 给定embedding类似和名称,返回下载url - :param str type: 支持static, bert, elmo。即embedding的类型 + :param str embed_type: 支持static, bert, elmo。即embedding的类型 :param str name: embedding的名称, 例如en, cn, based等 :return: str, 下载的url地址 """ PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, "bert": PRETRAINED_BERT_MODEL_DIR, "static": PRETRAIN_STATIC_FILES} - map = PRETRAIN_MAP.get(type, None) - if map: - filename = map.get(name, None) + embed_map = PRETRAIN_MAP.get(embed_type, None) + if embed_map: + filename = embed_map.get(name, None) if filename: url = _get_base_url('embedding') + filename return url - raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys()))) + raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys()))) else: - raise KeyError(f"There is no {type}. Only supports bert, elmo, static") + raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") def _get_dataset_url(name):