Browse Source

解决读取bert权重潜在的bug问题

tags/v0.4.10
yh 6 years ago
parent
commit
db8c6a0b8a
2 changed files with 4 additions and 4 deletions
  1. +1
    -1
      fastNLP/embeddings/bert_embedding.py
  2. +3
    -3
      fastNLP/modules/utils.py

+ 1
- 1
fastNLP/embeddings/bert_embedding.py View File

@@ -63,7 +63,7 @@ class BertEmbedding(ContextualEmbedding):
model_dir = cached_path(model_url)
# 检查是否存在
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))):
model_dir = model_dir_or_name
model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name))
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")



+ 3
- 3
fastNLP/modules/utils.py View File

@@ -128,9 +128,9 @@ def _get_file_name_base_on_postfix(dir_path, postfix):
:param postfix: 形如".bin", ".json"等
:return: str,文件的路径
"""
files = glob.glob(os.path.join(dir_path, '*' + postfix))
files = list(filter(lambda filename:filename.endswith(postfix), os.listdir(os.path.join(dir_path))))
if len(files) == 0:
raise FileNotFoundError(f"There is no file endswith *.{postfix} file in {dir_path}")
raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}")
elif len(files) > 1:
raise FileExistsError(f"There are multiple *.{postfix} files in {dir_path}")
raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}")
return os.path.join(dir_path, files[0])

Loading…
Cancel
Save