diff --git a/fastNLP/transformers/torch/file_utils.py b/fastNLP/transformers/torch/file_utils.py index 2b606b33..60f95fdd 100644 --- a/fastNLP/transformers/torch/file_utils.py +++ b/fastNLP/transformers/torch/file_utils.py @@ -17,7 +17,7 @@ from enum import Enum from functools import partial from hashlib import sha256 from pathlib import Path -from typing import Any, BinaryIO, Dict, Optional, Tuple, Union +from typing import Any, BinaryIO, Dict, Optional, Tuple, Union, List from urllib.parse import urlparse from uuid import uuid4 from zipfile import ZipFile, is_zipfile @@ -750,6 +750,78 @@ def get_from_cache( return cache_path +def get_list_of_files( + path_or_repo: Union[str, os.PathLike], + revision: Optional[str] = None, + use_auth_token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, +) -> List[str]: + """ + Gets the list of files inside :obj:`path_or_repo`. + + Args: + path_or_repo (:obj:`str` or :obj:`os.PathLike`): + Can be either the id of a repo on huggingface.co or a path to a `directory`. + revision (:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only rely on local files and not to attempt to download any files. + + Returns: + :obj:`List[str]`: The list of files available in :obj:`path_or_repo`. + """ + path_or_repo = str(path_or_repo) + # If path_or_repo is a folder, we just return what is inside (subdirectories included). + if os.path.isdir(path_or_repo): + list_of_files = [] + for path, dir_names, file_names in os.walk(path_or_repo): + list_of_files.extend([os.path.join(path, f) for f in file_names]) + return list_of_files + + # Can't grab the files if we are on offline mode. + if is_offline_mode() or local_files_only: + return [] + + # Otherwise we grab the token and use the model_info method. + if isinstance(use_auth_token, str): + token = use_auth_token + elif use_auth_token is True: + # token = HfFolder.get_token() + path_token = os.path.expanduser("~/.huggingface/token") + try: + with open(path_token, "r") as f: + token = f.read() + except FileNotFoundError: + token = None + else: + token = None + # model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( + # path_or_repo, revision=revision, token=token + # ) + endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT + path = ( + f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/api/models/{path_or_repo}" + if revision is None + else f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/api/models/{path_or_repo}/revision/{revision}" + ) + headers = {"authorization": f"Bearer {token}"} if token is not None else None + status_query_param = None + r = requests.get( + path, headers=headers, timeout=None, params=status_query_param + ) + r.raise_for_status() + d = r.json() + siblings = d.get("siblings", None) + rfilenames = ( + [x["rfilename"] for x in siblings] if siblings is not None else None + ) + return rfilenames + def is_torch_fx_available(): return _TORCH_GREATER_EQUAL_1_8 and _compare_version("torch", operator.lt, "1.9.0") diff --git a/fastNLP/transformers/torch/tokenization_utils_base.py b/fastNLP/transformers/torch/tokenization_utils_base.py index aebf4bb6..ad62cd6e 100644 --- a/fastNLP/transformers/torch/tokenization_utils_base.py +++ b/fastNLP/transformers/torch/tokenization_utils_base.py @@ -44,6 +44,8 @@ from .file_utils import ( cached_path, is_offline_mode, is_remote_url, + get_list_of_files, + hf_bucket_url, is_tokenizers_available, to_py_obj, ) @@ -100,7 +102,7 @@ TOKENIZER_CONFIG_FILE = "tokenizer_config.json" # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file FULL_TOKENIZER_FILE = "tokenizer.json" - +_re_tokenizer_file = re.compile(r"tokenizer\.(.*)\.json") class TruncationStrategy(ExplicitEnum): """ @@ -1607,8 +1609,41 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): file_id = list(cls.vocab_files_names.keys())[0] vocab_files[file_id] = pretrained_model_name_or_path else: - raise RuntimeError("At this point pretrained_model_name_or_path is either a directory or a model identifier name, ", - "which is not supported in fastNLP now.") + # raise RuntimeError("At this point pretrained_model_name_or_path is either a directory or a model identifier name, ", + # "which is not supported in fastNLP now.") + # At this point pretrained_model_name_or_path is either a directory or a model identifier name + fast_tokenizer_file = get_fast_tokenizer_file( + pretrained_model_name_or_path, + revision=revision, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + ) + additional_files_names = { + "added_tokens_file": ADDED_TOKENS_FILE, + "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, + "tokenizer_config_file": TOKENIZER_CONFIG_FILE, + "tokenizer_file": fast_tokenizer_file, + } + # Look for the tokenizer files + for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): + if os.path.isdir(pretrained_model_name_or_path): + if subfolder is not None: + full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name) + else: + full_file_name = os.path.join(pretrained_model_name_or_path, file_name) + if not os.path.exists(full_file_name): + logger.info(f"Didn't find file {full_file_name}. We won't load it.") + full_file_name = None + else: + full_file_name = hf_bucket_url( + pretrained_model_name_or_path, + filename=file_name, + subfolder=subfolder, + revision=revision, + mirror=None, + ) + + vocab_files[file_id] = full_file_name # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} @@ -3349,3 +3384,52 @@ For a more complete example, see the implementation of `prepare_seq2seq_batch`. ) model_inputs["labels"] = labels["input_ids"] return model_inputs + +def get_fast_tokenizer_file( + path_or_repo: Union[str, os.PathLike], + revision: Optional[str] = None, + use_auth_token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, +) -> str: + """ + Get the tokenizer file to use for this version of transformers. + + Args: + path_or_repo (:obj:`str` or :obj:`os.PathLike`): + Can be either the id of a repo on huggingface.co or a path to a `directory`. + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only rely on local files and not to attempt to download any files. + + Returns: + :obj:`str`: The tokenizer file to use. + """ + # Inspect all files from the repo/folder. + all_files = get_list_of_files( + path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only + ) + tokenizer_files_map = {} + for file_name in all_files: + search = _re_tokenizer_file.search(file_name) + if search is not None: + v = search.groups()[0] + tokenizer_files_map[v] = file_name + available_versions = sorted(tokenizer_files_map.keys()) + + # Defaults to FULL_TOKENIZER_FILE and then try to look at some newer versions. + tokenizer_file = FULL_TOKENIZER_FILE + transformers_version = version.parse(__version__) + for v in available_versions: + if version.parse(v) <= transformers_version: + tokenizer_file = tokenizer_files_map[v] + else: + # No point going further since the versions are sorted. + break + + return tokenizer_file \ No newline at end of file