Browse Source

在 transformers 中恢复根据简写下载的功能

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
368211bfa7
2 changed files with 160 additions and 4 deletions
  1. +73
    -1
      fastNLP/transformers/torch/file_utils.py
  2. +87
    -3
      fastNLP/transformers/torch/tokenization_utils_base.py

+ 73
- 1
fastNLP/transformers/torch/file_utils.py View File

@@ -17,7 +17,7 @@ from enum import Enum
from functools import partial
from hashlib import sha256
from pathlib import Path
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union, List
from urllib.parse import urlparse
from uuid import uuid4
from zipfile import ZipFile, is_zipfile
@@ -750,6 +750,78 @@ def get_from_cache(

return cache_path

def get_list_of_files(
path_or_repo: Union[str, os.PathLike],
revision: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
) -> List[str]:
"""
Gets the list of files inside :obj:`path_or_repo`.

Args:
path_or_repo (:obj:`str` or :obj:`os.PathLike`):
Can be either the id of a repo on huggingface.co or a path to a `directory`.
revision (:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only rely on local files and not to attempt to download any files.

Returns:
:obj:`List[str]`: The list of files available in :obj:`path_or_repo`.
"""
path_or_repo = str(path_or_repo)
# If path_or_repo is a folder, we just return what is inside (subdirectories included).
if os.path.isdir(path_or_repo):
list_of_files = []
for path, dir_names, file_names in os.walk(path_or_repo):
list_of_files.extend([os.path.join(path, f) for f in file_names])
return list_of_files

# Can't grab the files if we are on offline mode.
if is_offline_mode() or local_files_only:
return []

# Otherwise we grab the token and use the model_info method.
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token is True:
# token = HfFolder.get_token()
path_token = os.path.expanduser("~/.huggingface/token")
try:
with open(path_token, "r") as f:
token = f.read()
except FileNotFoundError:
token = None
else:
token = None
# model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info(
# path_or_repo, revision=revision, token=token
# )
endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT
path = (
f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/api/models/{path_or_repo}"
if revision is None
else f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/api/models/{path_or_repo}/revision/{revision}"
)
headers = {"authorization": f"Bearer {token}"} if token is not None else None
status_query_param = None
r = requests.get(
path, headers=headers, timeout=None, params=status_query_param
)
r.raise_for_status()
d = r.json()
siblings = d.get("siblings", None)
rfilenames = (
[x["rfilename"] for x in siblings] if siblings is not None else None
)
return rfilenames

def is_torch_fx_available():
return _TORCH_GREATER_EQUAL_1_8 and _compare_version("torch", operator.lt, "1.9.0")



+ 87
- 3
fastNLP/transformers/torch/tokenization_utils_base.py View File

@@ -44,6 +44,8 @@ from .file_utils import (
cached_path,
is_offline_mode,
is_remote_url,
get_list_of_files,
hf_bucket_url,
is_tokenizers_available,
to_py_obj,
)
@@ -100,7 +102,7 @@ TOKENIZER_CONFIG_FILE = "tokenizer_config.json"

# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json"
_re_tokenizer_file = re.compile(r"tokenizer\.(.*)\.json")

class TruncationStrategy(ExplicitEnum):
"""
@@ -1607,8 +1609,41 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
file_id = list(cls.vocab_files_names.keys())[0]
vocab_files[file_id] = pretrained_model_name_or_path
else:
raise RuntimeError("At this point pretrained_model_name_or_path is either a directory or a model identifier name, ",
"which is not supported in fastNLP now.")
# raise RuntimeError("At this point pretrained_model_name_or_path is either a directory or a model identifier name, ",
# "which is not supported in fastNLP now.")
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
fast_tokenizer_file = get_fast_tokenizer_file(
pretrained_model_name_or_path,
revision=revision,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
"tokenizer_file": fast_tokenizer_file,
}
# Look for the tokenizer files
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
if os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None:
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
else:
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if not os.path.exists(full_file_name):
logger.info(f"Didn't find file {full_file_name}. We won't load it.")
full_file_name = None
else:
full_file_name = hf_bucket_url(
pretrained_model_name_or_path,
filename=file_name,
subfolder=subfolder,
revision=revision,
mirror=None,
)

vocab_files[file_id] = full_file_name

# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
@@ -3349,3 +3384,52 @@ For a more complete example, see the implementation of `prepare_seq2seq_batch`.
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs

def get_fast_tokenizer_file(
path_or_repo: Union[str, os.PathLike],
revision: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
) -> str:
"""
Get the tokenizer file to use for this version of transformers.

Args:
path_or_repo (:obj:`str` or :obj:`os.PathLike`):
Can be either the id of a repo on huggingface.co or a path to a `directory`.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only rely on local files and not to attempt to download any files.

Returns:
:obj:`str`: The tokenizer file to use.
"""
# Inspect all files from the repo/folder.
all_files = get_list_of_files(
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
)
tokenizer_files_map = {}
for file_name in all_files:
search = _re_tokenizer_file.search(file_name)
if search is not None:
v = search.groups()[0]
tokenizer_files_map[v] = file_name
available_versions = sorted(tokenizer_files_map.keys())

# Defaults to FULL_TOKENIZER_FILE and then try to look at some newer versions.
tokenizer_file = FULL_TOKENIZER_FILE
transformers_version = version.parse(__version__)
for v in available_versions:
if version.parse(v) <= transformers_version:
tokenizer_file = tokenizer_files_map[v]
else:
# No point going further since the versions are sorted.
break

return tokenizer_file

Loading…
Cancel
Save