Browse Source

从远程下载权重

tags/v0.4.10
yh 6 years ago
parent
commit
71c9e0c30e
3 changed files with 38 additions and 66 deletions
  1. +2
    -2
      fastNLP/embeddings/elmo_embedding.py
  2. +35
    -63
      fastNLP/io/file_utils.py
  3. +1
    -1
      fastNLP/modules/encoder/bert.py

+ 2
- 2
fastNLP/embeddings/elmo_embedding.py View File

@@ -182,8 +182,8 @@ class _ElmoModel(nn.Module):
raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.")
elif config_count == 0 or weight_count == 0:
raise Exception(f"No config file or weight file found in {model_dir}")
config = json.load(open(os.path.join(model_dir, config_file), 'r'))
with open(os.path.join(model_dir, config_file), 'r') as config_f:
config = json.load(config_f)
self.weight_file = os.path.join(model_dir, weight_file)
self.config = config



+ 35
- 63
fastNLP/io/file_utils.py View File

@@ -11,7 +11,7 @@ import hashlib


PRETRAINED_BERT_MODEL_DIR = {
'en': 'bert-base-cased-f89bfe08.zip',
'en': 'bert-large-cased-wwm.zip',
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
'en-large-uncased': 'bert-large-uncased-20939f45.zip',
@@ -24,14 +24,14 @@ PRETRAINED_BERT_MODEL_DIR = {
'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip',
'multilingual': 'bert-base-multilingual-cased.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased.zip',
}

PRETRAINED_ELMO_MODEL_DIR = {
'en': 'elmo_en-d39843fe.tar.gz',
'cn': 'elmo_cn-5e9b34e2.tar.gz'
'en-small': "elmo_en_Small.zip"
}

PRETRAIN_STATIC_FILES = {
@@ -39,7 +39,7 @@ PRETRAIN_STATIC_FILES = {
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz',
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz",
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz",
'en-fasttext': "cc.en.300.vec-d53187b2.gz",
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip",
'cn': "tencent_cn-dab24577.tar.gz",
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz",
}
@@ -47,11 +47,15 @@ PRETRAIN_STATIC_FILES = {

def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path:
"""
给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并
将文件放入到cache_dir中
给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并
将文件放入到cache_dir中.

:param url_or_filename: 文件的下载url或者文件路径
:param cache_dir: 文件的缓存文件夹
:return:
"""
if cache_dir is None:
dataset_cache = Path(get_defalt_path())
dataset_cache = Path(get_default_cache_path())
else:
dataset_cache = cache_dir

@@ -75,7 +79,7 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path:

def get_filepath(filepath):
"""
如果filepath中只有一个文件,则直接返回对应的全路径
如果filepath中只有一个文件,则直接返回对应的全路径.
:param filepath:
:return:
"""
@@ -88,7 +92,7 @@ def get_filepath(filepath):
return filepath


def get_defalt_path():
def get_default_cache_path():
"""
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。

@@ -96,11 +100,10 @@ def get_defalt_path():
"""
if 'FASTNLP_CACHE_DIR' in os.environ:
fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR')
if os.path.exists(fastnlp_cache_dir):
if os.path.isdir(fastnlp_cache_dir):
return fastnlp_cache_dir
raise RuntimeError("Some errors happens on cache directory.")
else:
raise RuntimeError("There function is not available right now.")
else:
raise NotADirectoryError(f"{os.environ['FASTNLP_CACHE_DIR']} is not a directory.")
fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP"))
return fastnlp_cache_dir

@@ -109,13 +112,19 @@ def _get_base_url(name):
# 返回的URL结尾必须是/
if 'FASTNLP_BASE_URL' in os.environ:
fastnlp_base_url = os.environ['FASTNLP_BASE_URL']
return fastnlp_base_url
raise RuntimeError("There function is not available right now.")
if fastnlp_base_url.endswith('/'):
return fastnlp_base_url
else:
return fastnlp_base_url + '/'
else:
# TODO 替换
dbbrain_url = "http://dbcloud.irocn.cn:8989/api/public/dl/"
return dbbrain_url


def split_filename_suffix(filepath):
"""
给定filepath返回对应的name和suffix
给定filepath返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型
:param filepath:
:return: filename, suffix
"""
@@ -135,13 +144,6 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:

filename = re.sub(r".+/", "", url)
dir_name, suffix = split_filename_suffix(filename)
sep_index = dir_name[::-1].index('-')
if sep_index<0:
check_sum = None
else:
check_sum = dir_name[-sep_index+1:]
sep_index = len(dir_name) if sep_index==-1 else -sep_index-1
dir_name = dir_name[:sep_index]

# 寻找与它名字匹配的内容, 而不关心后缀
match_dir_name = match_file(dir_name, cache_dir)
@@ -154,11 +156,11 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
return get_filepath(cache_path)

# make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上
response = requests.head(url, headers={"User-Agent": "fastNLP"})
if response.status_code != 200:
raise IOError(
f"HEAD request failed for url {url} with status code {response.status_code}."
)
# response = requests.head(url, headers={"User-Agent": "fastNLP"})
# if response.status_code != 200:
# raise IOError(
# f"HEAD request failed for url {url} with status code {response.status_code}."
# )

# add ETag to filename if it exists
# etag = response.headers.get("ETag")
@@ -174,17 +176,11 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
sha256 = hashlib.sha256()
with open(temp_filename, "wb") as temp_file:
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
sha256.update(chunk)
# check sum
digit = sha256.hexdigest()[:8]
if not check_sum:
assert digit == check_sum, "File corrupted when download."
progress.close()
print(f"Finish download from {url}.")

@@ -193,7 +189,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
if suffix in ('.zip', '.tar.gz'):
uncompress_temp_dir = tempfile.mkdtemp()
delete_temp_dir = uncompress_temp_dir
print(f"Start to uncompress file to {uncompress_temp_dir}.")
print(f"Start to uncompress file to {uncompress_temp_dir}")
if suffix == '.zip':
unzip_file(Path(temp_filename), Path(uncompress_temp_dir))
else:
@@ -211,7 +207,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
success = False
try:
# 复制到指定的位置
print(f"Copy file to {cache_path}.")
print(f"Copy file to {cache_path}")
if os.path.isdir(uncompress_temp_dir):
for filename in os.listdir(uncompress_temp_dir):
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename)
@@ -252,7 +248,7 @@ def untar_gz_file(file:Path, to:Path):
tar.extractall(to)


def match_file(dir_name: str, cache_dir: str) -> str:
def match_file(dir_name: str, cache_dir: Path) -> str:
"""
匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串
@@ -273,27 +269,3 @@ def match_file(dir_name: str, cache_dir: str) -> str:
else:
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.")


if __name__ == '__main__':
cache_dir = Path('caches')
cache_dir = None
# 需要对cache_dir进行测试
base_url = 'http://0.0.0.0:8888/file/download'
# if True:
# for filename in os.listdir(cache_dir):
# if os.path.isdir(os.path.join(cache_dir, filename)):
# shutil.rmtree(os.path.join(cache_dir, filename))
# else:
# os.remove(os.path.join(cache_dir, filename))
# 1. 测试.txt文件
print(cached_path(base_url + '/{}'.format('txt_test-bcb4fe65.txt'), cache_dir))
# 2. 测试.zip文件(只有一个文件)
print(cached_path(base_url + '/{}'.format('zip_test-40966d39.zip'), cache_dir))
# 3. 测试.zip文件(有多个文件)
print(cached_path(base_url + '/{}'.format('zip_pack_test-70c0b20d.zip'), cache_dir))
# 4. 测试.tar.gz文件
print(cached_path(base_url + '/{}'.format('tar_gz_test-3e2679cf.tar.gz'), cache_dir))
# 5. 测试.tar.gz多个文件
print(cached_path(base_url + '/{}'.format('tar_gz_pack_test-08dfdccd.tar.gz'), cache_dir))

# 6. 测试.pkl文件

+ 1
- 1
fastNLP/modules/encoder/bert.py View File

@@ -563,7 +563,7 @@ class WordpieceTokenizer(object):
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
if len(output_tokens)==0:
if len(output_tokens)==0: #防止里面全是空格或者回车符号
return [self.unk_token]
return output_tokens



Loading…
Cancel
Save