From 1a1b1e8c9eeda95d4d88fd655fd064cc384da0ed Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 21 May 2022 07:50:53 +0000 Subject: [PATCH] =?UTF-8?q?=E8=BF=81=E7=A7=BBtransformers=E7=9A=84modeling?= =?UTF-8?q?=5Fauto?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/transformers/torch/file_utils.py | 17 +- .../torch/models/auto/__init__.py | 83 +++ .../torch/models/auto/auto_factory.py | 562 +++++++++++++++ .../transformers/torch/models/auto/dynamic.py | 208 ++++++ .../torch/models/auto/modeling_auto.py | 663 ++++++++++++++++++ 5 files changed, 1530 insertions(+), 3 deletions(-) create mode 100644 fastNLP/transformers/torch/models/auto/__init__.py create mode 100644 fastNLP/transformers/torch/models/auto/auto_factory.py create mode 100644 fastNLP/transformers/torch/models/auto/dynamic.py create mode 100644 fastNLP/transformers/torch/models/auto/modeling_auto.py diff --git a/fastNLP/transformers/torch/file_utils.py b/fastNLP/transformers/torch/file_utils.py index 4c7ee7a4..f0bfce26 100644 --- a/fastNLP/transformers/torch/file_utils.py +++ b/fastNLP/transformers/torch/file_utils.py @@ -10,6 +10,8 @@ import sys import tarfile import tempfile import operator +import types +import functools from collections import OrderedDict, UserDict from contextlib import contextmanager from dataclasses import fields @@ -37,6 +39,8 @@ if _NEED_IMPORT_TORCH: import torch _torch_version = importlib_metadata.version("torch") +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} + hf_cache_home = os.path.expanduser( os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) ) @@ -45,10 +49,9 @@ default_cache_path = os.path.join(hf_cache_home, "transformers") PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) +TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules" SESSION_ID = uuid4().hex - -ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} - DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES WEIGHTS_NAME = "pytorch_model.bin" @@ -1043,3 +1046,11 @@ class TensorType(ExplicitEnum): PYTORCH = "pt" NUMPY = "np" + +def copy_func(f): + """Returns a copy of a function f.""" + # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard) + g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = f.__kwdefaults__ + return g \ No newline at end of file diff --git a/fastNLP/transformers/torch/models/auto/__init__.py b/fastNLP/transformers/torch/models/auto/__init__.py new file mode 100644 index 00000000..ac2967d2 --- /dev/null +++ b/fastNLP/transformers/torch/models/auto/__init__.py @@ -0,0 +1,83 @@ +__all__ = [ + "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", + "CONFIG_MAPPING", + "MODEL_NAMES_MAPPING", + "AutoConfig", + "TOKENIZER_MAPPING", + "get_values", + "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "MODEL_FOR_CAUSAL_LM_MAPPING", + "MODEL_FOR_CTC_MAPPING", + "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", + "MODEL_FOR_MASKED_LM_MAPPING", + "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "MODEL_FOR_OBJECT_DETECTION_MAPPING", + "MODEL_FOR_PRETRAINING_MAPPING", + "MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", + "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", + "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "MODEL_MAPPING", + "MODEL_WITH_LM_HEAD_MAPPING", + "AutoModel", + "AutoModelForAudioClassification", + "AutoModelForCausalLM", + "AutoModelForCTC", + "AutoModelForImageClassification", + "AutoModelForMaskedLM", + "AutoModelForMultipleChoice", + "AutoModelForNextSentencePrediction", + "AutoModelForObjectDetection", + "AutoModelForPreTraining", + "AutoModelForQuestionAnswering", + "AutoModelForSeq2SeqLM", + "AutoModelForSequenceClassification", + "AutoModelForSpeechSeq2Seq", + "AutoModelForTableQuestionAnswering", + "AutoModelForTokenClassification", + "AutoModelWithLMHead", +] + +from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \ + AutoConfig +from .tokenization_auto import TOKENIZER_MAPPING +from .auto_factory import get_values +from .modeling_auto import ( + MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_CTC_MAPPING, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + MODEL_FOR_OBJECT_DETECTION_MAPPING, + MODEL_FOR_PRETRAINING_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + MODEL_MAPPING, + MODEL_WITH_LM_HEAD_MAPPING, + AutoModel, + AutoModelForAudioClassification, + AutoModelForCausalLM, + AutoModelForCTC, + AutoModelForImageClassification, + AutoModelForMaskedLM, + AutoModelForMultipleChoice, + AutoModelForNextSentencePrediction, + AutoModelForObjectDetection, + AutoModelForPreTraining, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForSpeechSeq2Seq, + AutoModelForTableQuestionAnswering, + AutoModelForTokenClassification, + AutoModelWithLMHead, +) \ No newline at end of file diff --git a/fastNLP/transformers/torch/models/auto/auto_factory.py b/fastNLP/transformers/torch/models/auto/auto_factory.py new file mode 100644 index 00000000..015f5642 --- /dev/null +++ b/fastNLP/transformers/torch/models/auto/auto_factory.py @@ -0,0 +1,562 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Factory function to build auto-model classes.""" +import importlib +from collections import OrderedDict + +from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings +from .dynamic import get_class_from_dynamic_module +from fastNLP.transformers.torch.configuration_utils import PretrainedConfig +from fastNLP.transformers.torch.file_utils import copy_func +from fastNLP.core.log import logger + + +CLASS_DOCSTRING = """ + This is a generic model class that will be instantiated as one of the model classes of the library when created + with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the + :meth:`~transformers.BaseAutoModelClass.from_config` class method. + + This class cannot be instantiated directly using ``__init__()`` (throws an error). +""" + +FROM_CONFIG_DOCSTRING = """ + Instantiates one of the model classes of the library from a configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. It only affects the + model's configuration. Use :meth:`~transformers.BaseAutoModelClass.from_pretrained` to load the model + weights. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The model class to instantiate is selected based on the configuration class: + + List options + + Examples:: + + >>> from transformers import AutoConfig, BaseAutoModelClass + >>> # Download configuration from huggingface.co and cache. + >>> config = AutoConfig.from_pretrained('checkpoint_placeholder') + >>> model = BaseAutoModelClass.from_config(config) +""" + +FROM_PRETRAINED_TORCH_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either + passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, + by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: + + List options + + The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are + deactivated). To train the model, you should first set it back in training mode with ``model.train()`` + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In + this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided + as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in + a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + model_args (additional positional arguments, `optional`): + Will be passed along to the underlying model ``__init__()`` method. + config (:class:`~transformers.PretrainedConfig`, `optional`): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the `model id` string of a pretrained + model). + - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a + configuration JSON file named `config.json` is found in the directory. + state_dict (`Dict[str, torch.Tensor]`, `optional`): + A state dictionary to use instead of a state dictionary loaded from saved weights file. + + This option can be used if you want to create a model from a pretrained configuration but load your own + weights. In this case though, you should check if using + :func:`~transformers.PreTrainedModel.save_pretrained` and + :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): + Load the model weights from a TensorFlow checkpoint save file (see docstring of + ``pretrained_model_name_or_path`` argument). + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (:obj:`Dict[str, str], `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it + will execute code present on the Hub on your local machine. + kwargs (additional keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or + automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the + underlying model's ``__init__`` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class + initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of + ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute + with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration + attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) + >>> config = AutoConfig.from_pretrained('./tf_model/shortcut_placeholder_tf_model_config.json') + >>> model = BaseAutoModelClass.from_pretrained('./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index', from_tf=True, config=config) +""" + +FROM_PRETRAINED_TF_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either + passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, + by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: + + List options + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In + this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided + as ``config`` argument. This loading path is slower than converting the PyTorch model in a + TensorFlow model using the provided conversion scripts and loading the TensorFlow model + afterwards. + model_args (additional positional arguments, `optional`): + Will be passed along to the underlying model ``__init__()`` method. + config (:class:`~transformers.PretrainedConfig`, `optional`): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the `model id` string of a pretrained + model). + - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a + configuration JSON file named `config.json` is found in the directory. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + ``pretrained_model_name_or_path`` argument). + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (:obj:`Dict[str, str], `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it + will execute code present on the Hub on your local machine. + kwargs (additional keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or + automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the + underlying model's ``__init__`` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class + initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of + ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute + with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration + attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) + >>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json') + >>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config) +""" + +FROM_PRETRAINED_FLAX_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either + passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, + by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: + + List options + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In + this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided + as ``config`` argument. This loading path is slower than converting the PyTorch model in a + TensorFlow model using the provided conversion scripts and loading the TensorFlow model + afterwards. + model_args (additional positional arguments, `optional`): + Will be passed along to the underlying model ``__init__()`` method. + config (:class:`~transformers.PretrainedConfig`, `optional`): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the `model id` string of a pretrained + model). + - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a + configuration JSON file named `config.json` is found in the directory. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + ``pretrained_model_name_or_path`` argument). + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (:obj:`Dict[str, str], `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it + will execute code present on the Hub on your local machine. + kwargs (additional keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or + automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the + underlying model's ``__init__`` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class + initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of + ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute + with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration + attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) + >>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json') + >>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config) +""" + + +def _get_model_class(config, model_mapping): + supported_models = model_mapping[type(config)] + if not isinstance(supported_models, (list, tuple)): + return supported_models + + name_to_model = {model.__name__: model for model in supported_models} + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in name_to_model: + return name_to_model[arch] + elif f"TF{arch}" in name_to_model: + return name_to_model[f"TF{arch}"] + elif f"Flax{arch}" in name_to_model: + return name_to_model[f"Flax{arch}"] + + # If not architecture is set in the config or match the supported models, the first element of the tuple is the + # defaults. + return supported_models[0] + + +class _BaseAutoModelClass: + # Base class for auto models. + _model_mapping = None + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_config(config)` methods." + ) + + @classmethod + def from_config(cls, config, **kwargs): + if type(config) in cls._model_mapping.keys(): + model_class = _get_model_class(config, cls._model_mapping) + return model_class._from_config(config, **kwargs) + + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + config = kwargs.pop("config", None) + trust_remote_code = kwargs.pop("trust_remote_code", False) + kwargs["_from_auto"] = True + if not isinstance(config, PretrainedConfig): + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) + if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: + if not trust_remote_code: + raise ValueError( + f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo " + "on your local machine. Make sure you have read the code there to avoid malicious use, then set " + "the option `trust_remote_code=True` to remove this error." + ) + if kwargs.get("revision", None) is None: + logger.warn( + "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " + "no malicious code has been contributed in a newer revision." + ) + class_ref = config.auto_map[cls.__name__] + module_file, class_name = class_ref.split(".") + model_class = get_class_from_dynamic_module( + pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs + ) + return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + elif type(config) in cls._model_mapping.keys(): + model_class = _get_model_class(config, cls._model_mapping) + return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." + ) + + +def insert_head_doc(docstring, head_doc=""): + if len(head_doc) > 0: + return docstring.replace( + "one of the model classes of the library ", + f"one of the model classes of the library (with a {head_doc} head) ", + ) + return docstring.replace( + "one of the model classes of the library ", "one of the base model classes of the library " + ) + + +def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""): + # Create a new class with the right name from the base class + model_mapping = cls._model_mapping + name = cls.__name__ + class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc) + cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name) + + # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't + # have a specific docstrings for them. + from_config = copy_func(_BaseAutoModelClass.from_config) + from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc) + from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) + from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) + from_config.__doc__ = from_config_docstring + from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config) + cls.from_config = classmethod(from_config) + + if name.startswith("TF"): + from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING + elif name.startswith("Flax"): + from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING + else: + from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING + from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained) + from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc) + from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name) + from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example) + shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] + from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) + from_pretrained.__doc__ = from_pretrained_docstring + from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained) + cls.from_pretrained = classmethod(from_pretrained) + return cls + + +def get_values(model_mapping): + result = [] + for model in model_mapping.values(): + if isinstance(model, (list, tuple)): + result += list(model) + else: + result.append(model) + + return result + + +def getattribute_from_module(module, attr): + if attr is None: + return None + if isinstance(attr, tuple): + return tuple(getattribute_from_module(module, a) for a in attr) + if hasattr(module, attr): + return getattr(module, attr) + # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the + # object at the top level. + transformers_module = importlib.import_module("transformers") + return getattribute_from_module(transformers_module, attr) + + +class _LazyAutoMapping(OrderedDict): + """ + " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. + + Args: + + - config_mapping: The map model type to config class + - model_mapping: The map model type to model (or tokenizer) class + """ + + def __init__(self, config_mapping, model_mapping): + self._config_mapping = config_mapping + self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} + self._model_mapping = model_mapping + self._modules = {} + + def __getitem__(self, key): + model_type = self._reverse_config_mapping[key.__name__] + if model_type not in self._model_mapping: + raise KeyError(key) + model_name = self._model_mapping[model_type] + return self._load_attr_from_module(model_type, model_name) + + def _load_attr_from_module(self, model_type, attr): + module_name = model_type_to_module_name(model_type) + if module_name not in self._modules: + self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") + return getattribute_from_module(self._modules[module_name], attr) + + def keys(self): + return [ + self._load_attr_from_module(key, name) + for key, name in self._config_mapping.items() + if key in self._model_mapping.keys() + ] + + def get(self, key, default): + try: + return self.__getitem__(key) + except KeyError: + return default + + def __bool__(self): + return bool(self.keys()) + + def values(self): + return [ + self._load_attr_from_module(key, name) + for key, name in self._model_mapping.items() + if key in self._config_mapping.keys() + ] + + def items(self): + return [ + ( + self._load_attr_from_module(key, self._config_mapping[key]), + self._load_attr_from_module(key, self._model_mapping[key]), + ) + for key in self._model_mapping.keys() + if key in self._config_mapping.keys() + ] + + def __iter__(self): + return iter(self._mapping.keys()) + + def __contains__(self, item): + if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: + return False + model_type = self._reverse_config_mapping[item.__name__] + return model_type in self._model_mapping diff --git a/fastNLP/transformers/torch/models/auto/dynamic.py b/fastNLP/transformers/torch/models/auto/dynamic.py new file mode 100644 index 00000000..33a0a793 --- /dev/null +++ b/fastNLP/transformers/torch/models/auto/dynamic.py @@ -0,0 +1,208 @@ +import importlib +import os +import re +import shutil +import sys +from pathlib import Path +from typing import Dict, Optional, Union + +from fastNLP.transformers.torch.file_utils import ( + HF_MODULES_CACHE, + TRANSFORMERS_DYNAMIC_MODULE_NAME, + cached_path, + hf_bucket_url, + is_offline_mode, +) +from fastNLP.core.log import logger + +def init_hf_modules(): + """ + Creates the cache directory for modules with an init, and adds it to the Python path. + """ + # This function has already been executed if HF_MODULES_CACHE already is in the Python path. + if HF_MODULES_CACHE in sys.path: + return + + sys.path.append(HF_MODULES_CACHE) + os.makedirs(HF_MODULES_CACHE, exist_ok=True) + init_path = Path(HF_MODULES_CACHE) / "__init__.py" + if not init_path.exists(): + init_path.touch() + +def create_dynamic_module(name: Union[str, os.PathLike]): + """ + Creates a dynamic module in the cache directory for modules. + """ + init_hf_modules() + dynamic_module_path = Path(HF_MODULES_CACHE) / name + # If the parent module does not exist yet, recursively create it. + if not dynamic_module_path.parent.exists(): + create_dynamic_module(dynamic_module_path.parent) + os.makedirs(dynamic_module_path, exist_ok=True) + init_path = dynamic_module_path / "__init__.py" + if not init_path.exists(): + init_path.touch() + +def check_imports(filename): + """ + Check if the current Python environment contains all the libraries that are imported in a file. + """ + with open(filename, "r", encoding="utf-8") as f: + content = f.read() + + # Imports of the form `import xxx` + imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from xxx import yyy` + imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + # Only keep the top-level module + imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + + # Unique-ify and test we got them all + imports = list(set(imports)) + missing_packages = [] + for imp in imports: + try: + importlib.import_module(imp) + except ImportError: + missing_packages.append(imp) + + if len(missing_packages) > 0: + raise ImportError( + "This modeling file requires the following packages that were not found in your environment: " + f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" + ) + + +def get_class_in_module(class_name, module_path): + """ + Import a module on the cache directory for modules and extract a class from it. + """ + module_path = module_path.replace(os.path.sep, ".") + module = importlib.import_module(module_path) + return getattr(module, class_name) + +def get_class_from_dynamic_module( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + class_name: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Extracts a class from a module file, present in the local folder or repository of a model. + + .. warning:: + + Calling this function will execute the code in the module file found locally or downloaded from the Hub. It + should therefore only be called on trusted repos. + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + This can be either: + + - a string, the `model id` of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or + namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a configuration file saved using the + :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``. + + module_file (:obj:`str`): + The name of the module file containing the class to look for. + class_name (:obj:`str`): + The name of the class to import in the module. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (:obj:`Dict[str, str]`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + If :obj:`True`, will only try to load the tokenizer configuration from local files. + + .. note:: + + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + + + Returns: + :obj:`type`: The class, dynamically imported from the module. + + Examples:: + + # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this + # module. + cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") + """ + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) + submodule = "local" + else: + module_file_or_url = hf_bucket_url( + pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None + ) + submodule = pretrained_model_name_or_path.replace("/", os.path.sep) + + try: + # Load from URL or cache if already cached + resolved_module_file = cached_path( + module_file_or_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + if submodule == "local": + # We always copy local files (we could hash the file to see if there was a change, and give them the name of + # that hash, to only copy when there is a modification but it seems overkill for now). + # The only reason we do the copy is to avoid putting too many folders in sys.path. + module_name = module_file + shutil.copy(resolved_module_file, submodule_path / module_file) + else: + # The module file will end up being named module_file + the etag. This way we get the benefit of versioning. + resolved_module_file_name = Path(resolved_module_file).name + module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".") + module_name = "_".join(module_name_parts) + ".py" + if not (submodule_path / module_name).exists(): + shutil.copy(resolved_module_file, submodule_path / module_name) + + # And lastly we get the class inside our newly created module + final_module = os.path.join(full_submodule, module_name.replace(".py", "")) + return get_class_in_module(class_name, final_module) \ No newline at end of file diff --git a/fastNLP/transformers/torch/models/auto/modeling_auto.py b/fastNLP/transformers/torch/models/auto/modeling_auto.py new file mode 100644 index 00000000..6406da14 --- /dev/null +++ b/fastNLP/transformers/torch/models/auto/modeling_auto.py @@ -0,0 +1,663 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Auto Model class. """ + +import warnings +from collections import OrderedDict + +from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update +from .configuration_auto import CONFIG_MAPPING_NAMES +from fastNLP.core.log import logger + + +MODEL_MAPPING_NAMES = OrderedDict( + [ + # Base model mapping + ("fnet", "FNetModel"), + ("gptj", "GPTJModel"), + ("layoutlmv2", "LayoutLMv2Model"), + ("beit", "BeitModel"), + ("rembert", "RemBertModel"), + ("visual_bert", "VisualBertModel"), + ("canine", "CanineModel"), + ("roformer", "RoFormerModel"), + ("clip", "CLIPModel"), + ("bigbird_pegasus", "BigBirdPegasusModel"), + ("deit", "DeiTModel"), + ("luke", "LukeModel"), + ("detr", "DetrModel"), + ("gpt_neo", "GPTNeoModel"), + ("big_bird", "BigBirdModel"), + ("speech_to_text", "Speech2TextModel"), + ("vit", "ViTModel"), + ("wav2vec2", "Wav2Vec2Model"), + ("hubert", "HubertModel"), + ("m2m_100", "M2M100Model"), + ("convbert", "ConvBertModel"), + ("led", "LEDModel"), + ("blenderbot-small", "BlenderbotSmallModel"), + ("retribert", "RetriBertModel"), + ("mt5", "MT5Model"), + ("t5", "T5Model"), + ("pegasus", "PegasusModel"), + ("marian", "MarianModel"), + ("mbart", "MBartModel"), + ("blenderbot", "BlenderbotModel"), + ("distilbert", "DistilBertModel"), + ("albert", "AlbertModel"), + ("camembert", "CamembertModel"), + ("xlm-roberta", "XLMRobertaModel"), + ("bart", "BartModel"), + ("longformer", "LongformerModel"), + ("roberta", "RobertaModel"), + ("layoutlm", "LayoutLMModel"), + ("squeezebert", "SqueezeBertModel"), + ("bert", "BertModel"), + ("openai-gpt", "OpenAIGPTModel"), + ("gpt2", "GPT2Model"), + ("megatron-bert", "MegatronBertModel"), + ("mobilebert", "MobileBertModel"), + ("transfo-xl", "TransfoXLModel"), + ("xlnet", "XLNetModel"), + ("flaubert", "FlaubertModel"), + ("fsmt", "FSMTModel"), + ("xlm", "XLMModel"), + ("ctrl", "CTRLModel"), + ("electra", "ElectraModel"), + ("reformer", "ReformerModel"), + ("funnel", ("FunnelModel", "FunnelBaseModel")), + ("lxmert", "LxmertModel"), + ("bert-generation", "BertGenerationEncoder"), + ("deberta", "DebertaModel"), + ("deberta-v2", "DebertaV2Model"), + ("dpr", "DPRQuestionEncoder"), + ("xlm-prophetnet", "XLMProphetNetModel"), + ("prophetnet", "ProphetNetModel"), + ("mpnet", "MPNetModel"), + ("tapas", "TapasModel"), + ("ibert", "IBertModel"), + ("splinter", "SplinterModel"), + ] +) + +MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( + [ + # Model for pre-training mapping + ("fnet", "FNetForPreTraining"), + ("visual_bert", "VisualBertForPreTraining"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("retribert", "RetriBertModel"), + ("t5", "T5ForConditionalGeneration"), + ("distilbert", "DistilBertForMaskedLM"), + ("albert", "AlbertForPreTraining"), + ("camembert", "CamembertForMaskedLM"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("bart", "BartForConditionalGeneration"), + ("fsmt", "FSMTForConditionalGeneration"), + ("longformer", "LongformerForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("bert", "BertForPreTraining"), + ("big_bird", "BigBirdForPreTraining"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("gpt2", "GPT2LMHeadModel"), + ("megatron-bert", "MegatronBertForPreTraining"), + ("mobilebert", "MobileBertForPreTraining"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("xlnet", "XLNetLMHeadModel"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("xlm", "XLMWithLMHeadModel"), + ("ctrl", "CTRLLMHeadModel"), + ("electra", "ElectraForPreTraining"), + ("lxmert", "LxmertForPreTraining"), + ("funnel", "FunnelForPreTraining"), + ("mpnet", "MPNetForMaskedLM"), + ("tapas", "TapasForMaskedLM"), + ("ibert", "IBertForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("wav2vec2", "Wav2Vec2ForPreTraining"), + ] +) + +MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( + [ + # Model with LM heads mapping + ("fnet", "FNetForMaskedLM"), + ("gptj", "GPTJForCausalLM"), + ("rembert", "RemBertForMaskedLM"), + ("roformer", "RoFormerForMaskedLM"), + ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), + ("gpt_neo", "GPTNeoForCausalLM"), + ("big_bird", "BigBirdForMaskedLM"), + ("speech_to_text", "Speech2TextForConditionalGeneration"), + ("wav2vec2", "Wav2Vec2ForMaskedLM"), + ("m2m_100", "M2M100ForConditionalGeneration"), + ("convbert", "ConvBertForMaskedLM"), + ("led", "LEDForConditionalGeneration"), + ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("t5", "T5ForConditionalGeneration"), + ("distilbert", "DistilBertForMaskedLM"), + ("albert", "AlbertForMaskedLM"), + ("camembert", "CamembertForMaskedLM"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("marian", "MarianMTModel"), + ("fsmt", "FSMTForConditionalGeneration"), + ("bart", "BartForConditionalGeneration"), + ("longformer", "LongformerForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("bert", "BertForMaskedLM"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("gpt2", "GPT2LMHeadModel"), + ("megatron-bert", "MegatronBertForCausalLM"), + ("mobilebert", "MobileBertForMaskedLM"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("xlnet", "XLNetLMHeadModel"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("xlm", "XLMWithLMHeadModel"), + ("ctrl", "CTRLLMHeadModel"), + ("electra", "ElectraForMaskedLM"), + ("encoder-decoder", "EncoderDecoderModel"), + ("reformer", "ReformerModelWithLMHead"), + ("funnel", "FunnelForMaskedLM"), + ("mpnet", "MPNetForMaskedLM"), + ("tapas", "TapasForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("ibert", "IBertForMaskedLM"), + ] +) + +MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Causal LM mapping + ("gptj", "GPTJForCausalLM"), + ("rembert", "RemBertForCausalLM"), + ("roformer", "RoFormerForCausalLM"), + ("bigbird_pegasus", "BigBirdPegasusForCausalLM"), + ("gpt_neo", "GPTNeoForCausalLM"), + ("big_bird", "BigBirdForCausalLM"), + ("camembert", "CamembertForCausalLM"), + ("xlm-roberta", "XLMRobertaForCausalLM"), + ("roberta", "RobertaForCausalLM"), + ("bert", "BertLMHeadModel"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("gpt2", "GPT2LMHeadModel"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("xlnet", "XLNetLMHeadModel"), + ("xlm", "XLMWithLMHeadModel"), + ("ctrl", "CTRLLMHeadModel"), + ("reformer", "ReformerModelWithLMHead"), + ("bert-generation", "BertGenerationDecoder"), + ("xlm-prophetnet", "XLMProphetNetForCausalLM"), + ("prophetnet", "ProphetNetForCausalLM"), + ("bart", "BartForCausalLM"), + ("mbart", "MBartForCausalLM"), + ("pegasus", "PegasusForCausalLM"), + ("marian", "MarianForCausalLM"), + ("blenderbot", "BlenderbotForCausalLM"), + ("blenderbot-small", "BlenderbotSmallForCausalLM"), + ("megatron-bert", "MegatronBertForCausalLM"), + ("speech_to_text_2", "Speech2Text2ForCausalLM"), + ] +) + +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Image Classification mapping + ("vit", "ViTForImageClassification"), + ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), + ("beit", "BeitForImageClassification"), + ] +) + +MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Masked LM mapping + ("fnet", "FNetForMaskedLM"), + ("rembert", "RemBertForMaskedLM"), + ("roformer", "RoFormerForMaskedLM"), + ("big_bird", "BigBirdForMaskedLM"), + ("wav2vec2", "Wav2Vec2ForMaskedLM"), + ("convbert", "ConvBertForMaskedLM"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("distilbert", "DistilBertForMaskedLM"), + ("albert", "AlbertForMaskedLM"), + ("bart", "BartForConditionalGeneration"), + ("mbart", "MBartForConditionalGeneration"), + ("camembert", "CamembertForMaskedLM"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("longformer", "LongformerForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("bert", "BertForMaskedLM"), + ("megatron-bert", "MegatronBertForMaskedLM"), + ("mobilebert", "MobileBertForMaskedLM"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("xlm", "XLMWithLMHeadModel"), + ("electra", "ElectraForMaskedLM"), + ("reformer", "ReformerForMaskedLM"), + ("funnel", "FunnelForMaskedLM"), + ("mpnet", "MPNetForMaskedLM"), + ("tapas", "TapasForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("ibert", "IBertForMaskedLM"), + ] +) + +MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( + [ + # Model for Object Detection mapping + ("detr", "DetrForObjectDetection"), + ] +) + +MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Seq2Seq Causal LM mapping + ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), + ("m2m_100", "M2M100ForConditionalGeneration"), + ("led", "LEDForConditionalGeneration"), + ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), + ("mt5", "MT5ForConditionalGeneration"), + ("t5", "T5ForConditionalGeneration"), + ("pegasus", "PegasusForConditionalGeneration"), + ("marian", "MarianMTModel"), + ("mbart", "MBartForConditionalGeneration"), + ("blenderbot", "BlenderbotForConditionalGeneration"), + ("bart", "BartForConditionalGeneration"), + ("fsmt", "FSMTForConditionalGeneration"), + ("encoder-decoder", "EncoderDecoderModel"), + ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), + ("prophetnet", "ProphetNetForConditionalGeneration"), + ] +) + +MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( + [ + ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), + ("speech_to_text", "Speech2TextForConditionalGeneration"), + ] +) + +MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Sequence Classification mapping + ("fnet", "FNetForSequenceClassification"), + ("gptj", "GPTJForSequenceClassification"), + ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), + ("rembert", "RemBertForSequenceClassification"), + ("canine", "CanineForSequenceClassification"), + ("roformer", "RoFormerForSequenceClassification"), + ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), + ("big_bird", "BigBirdForSequenceClassification"), + ("convbert", "ConvBertForSequenceClassification"), + ("led", "LEDForSequenceClassification"), + ("distilbert", "DistilBertForSequenceClassification"), + ("albert", "AlbertForSequenceClassification"), + ("camembert", "CamembertForSequenceClassification"), + ("xlm-roberta", "XLMRobertaForSequenceClassification"), + ("mbart", "MBartForSequenceClassification"), + ("bart", "BartForSequenceClassification"), + ("longformer", "LongformerForSequenceClassification"), + ("roberta", "RobertaForSequenceClassification"), + ("squeezebert", "SqueezeBertForSequenceClassification"), + ("layoutlm", "LayoutLMForSequenceClassification"), + ("bert", "BertForSequenceClassification"), + ("xlnet", "XLNetForSequenceClassification"), + ("megatron-bert", "MegatronBertForSequenceClassification"), + ("mobilebert", "MobileBertForSequenceClassification"), + ("flaubert", "FlaubertForSequenceClassification"), + ("xlm", "XLMForSequenceClassification"), + ("electra", "ElectraForSequenceClassification"), + ("funnel", "FunnelForSequenceClassification"), + ("deberta", "DebertaForSequenceClassification"), + ("deberta-v2", "DebertaV2ForSequenceClassification"), + ("gpt2", "GPT2ForSequenceClassification"), + ("gpt_neo", "GPTNeoForSequenceClassification"), + ("openai-gpt", "OpenAIGPTForSequenceClassification"), + ("reformer", "ReformerForSequenceClassification"), + ("ctrl", "CTRLForSequenceClassification"), + ("transfo-xl", "TransfoXLForSequenceClassification"), + ("mpnet", "MPNetForSequenceClassification"), + ("tapas", "TapasForSequenceClassification"), + ("ibert", "IBertForSequenceClassification"), + ] +) + +MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + # Model for Question Answering mapping + ("fnet", "FNetForQuestionAnswering"), + ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), + ("rembert", "RemBertForQuestionAnswering"), + ("canine", "CanineForQuestionAnswering"), + ("roformer", "RoFormerForQuestionAnswering"), + ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), + ("big_bird", "BigBirdForQuestionAnswering"), + ("convbert", "ConvBertForQuestionAnswering"), + ("led", "LEDForQuestionAnswering"), + ("distilbert", "DistilBertForQuestionAnswering"), + ("albert", "AlbertForQuestionAnswering"), + ("camembert", "CamembertForQuestionAnswering"), + ("bart", "BartForQuestionAnswering"), + ("mbart", "MBartForQuestionAnswering"), + ("longformer", "LongformerForQuestionAnswering"), + ("xlm-roberta", "XLMRobertaForQuestionAnswering"), + ("roberta", "RobertaForQuestionAnswering"), + ("squeezebert", "SqueezeBertForQuestionAnswering"), + ("bert", "BertForQuestionAnswering"), + ("xlnet", "XLNetForQuestionAnsweringSimple"), + ("flaubert", "FlaubertForQuestionAnsweringSimple"), + ("megatron-bert", "MegatronBertForQuestionAnswering"), + ("mobilebert", "MobileBertForQuestionAnswering"), + ("xlm", "XLMForQuestionAnsweringSimple"), + ("electra", "ElectraForQuestionAnswering"), + ("reformer", "ReformerForQuestionAnswering"), + ("funnel", "FunnelForQuestionAnswering"), + ("lxmert", "LxmertForQuestionAnswering"), + ("mpnet", "MPNetForQuestionAnswering"), + ("deberta", "DebertaForQuestionAnswering"), + ("deberta-v2", "DebertaV2ForQuestionAnswering"), + ("ibert", "IBertForQuestionAnswering"), + ("splinter", "SplinterForQuestionAnswering"), + ] +) + +MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + # Model for Table Question Answering mapping + ("tapas", "TapasForQuestionAnswering"), + ] +) + +MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Token Classification mapping + ("fnet", "FNetForTokenClassification"), + ("layoutlmv2", "LayoutLMv2ForTokenClassification"), + ("rembert", "RemBertForTokenClassification"), + ("canine", "CanineForTokenClassification"), + ("roformer", "RoFormerForTokenClassification"), + ("big_bird", "BigBirdForTokenClassification"), + ("convbert", "ConvBertForTokenClassification"), + ("layoutlm", "LayoutLMForTokenClassification"), + ("distilbert", "DistilBertForTokenClassification"), + ("camembert", "CamembertForTokenClassification"), + ("flaubert", "FlaubertForTokenClassification"), + ("xlm", "XLMForTokenClassification"), + ("xlm-roberta", "XLMRobertaForTokenClassification"), + ("longformer", "LongformerForTokenClassification"), + ("roberta", "RobertaForTokenClassification"), + ("squeezebert", "SqueezeBertForTokenClassification"), + ("bert", "BertForTokenClassification"), + ("megatron-bert", "MegatronBertForTokenClassification"), + ("mobilebert", "MobileBertForTokenClassification"), + ("xlnet", "XLNetForTokenClassification"), + ("albert", "AlbertForTokenClassification"), + ("electra", "ElectraForTokenClassification"), + ("funnel", "FunnelForTokenClassification"), + ("mpnet", "MPNetForTokenClassification"), + ("deberta", "DebertaForTokenClassification"), + ("deberta-v2", "DebertaV2ForTokenClassification"), + ("gpt2", "GPT2ForTokenClassification"), + ("ibert", "IBertForTokenClassification"), + ] +) + +MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( + [ + # Model for Multiple Choice mapping + ("fnet", "FNetForMultipleChoice"), + ("rembert", "RemBertForMultipleChoice"), + ("canine", "CanineForMultipleChoice"), + ("roformer", "RoFormerForMultipleChoice"), + ("big_bird", "BigBirdForMultipleChoice"), + ("convbert", "ConvBertForMultipleChoice"), + ("camembert", "CamembertForMultipleChoice"), + ("electra", "ElectraForMultipleChoice"), + ("xlm-roberta", "XLMRobertaForMultipleChoice"), + ("longformer", "LongformerForMultipleChoice"), + ("roberta", "RobertaForMultipleChoice"), + ("squeezebert", "SqueezeBertForMultipleChoice"), + ("bert", "BertForMultipleChoice"), + ("distilbert", "DistilBertForMultipleChoice"), + ("megatron-bert", "MegatronBertForMultipleChoice"), + ("mobilebert", "MobileBertForMultipleChoice"), + ("xlnet", "XLNetForMultipleChoice"), + ("albert", "AlbertForMultipleChoice"), + ("xlm", "XLMForMultipleChoice"), + ("flaubert", "FlaubertForMultipleChoice"), + ("funnel", "FunnelForMultipleChoice"), + ("mpnet", "MPNetForMultipleChoice"), + ("ibert", "IBertForMultipleChoice"), + ] +) + +MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( + [ + ("bert", "BertForNextSentencePrediction"), + ("fnet", "FNetForNextSentencePrediction"), + ("megatron-bert", "MegatronBertForNextSentencePrediction"), + ("mobilebert", "MobileBertForNextSentencePrediction"), + ] +) + +MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Audio Classification mapping + ("wav2vec2", "Wav2Vec2ForSequenceClassification"), + ("hubert", "HubertForSequenceClassification"), + ] +) + +MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( + [ + # Model for Connectionist temporal classification (CTC) mapping + ("wav2vec2", "Wav2Vec2ForCTC"), + ("hubert", "HubertForCTC"), + ] +) + +MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) +MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) +MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) +MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) +MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) +MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES +) +MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES +) +MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES +) +MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) +MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES +) +MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) +MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) + + +class AutoModel(_BaseAutoModelClass): + _model_mapping = MODEL_MAPPING + + +AutoModel = auto_class_update(AutoModel) + + +class AutoModelForPreTraining(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_PRETRAINING_MAPPING + + +AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") + + +# Private on purpose, the public class will add the deprecation warnings. +class _AutoModelWithLMHead(_BaseAutoModelClass): + _model_mapping = MODEL_WITH_LM_HEAD_MAPPING + + +_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") + + +class AutoModelForCausalLM(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING + + +AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") + + +class AutoModelForMaskedLM(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MASKED_LM_MAPPING + + +AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") + + +class AutoModelForSeq2SeqLM(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + + +AutoModelForSeq2SeqLM = auto_class_update( + AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" +) + + +class AutoModelForSequenceClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING + + +AutoModelForSequenceClassification = auto_class_update( + AutoModelForSequenceClassification, head_doc="sequence classification" +) + + +class AutoModelForQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING + + +AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") + + +class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING + + +AutoModelForTableQuestionAnswering = auto_class_update( + AutoModelForTableQuestionAnswering, + head_doc="table question answering", + checkpoint_for_example="google/tapas-base-finetuned-wtq", +) + + +class AutoModelForTokenClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING + + +AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") + + +class AutoModelForMultipleChoice(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING + + +AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") + + +class AutoModelForNextSentencePrediction(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING + + +AutoModelForNextSentencePrediction = auto_class_update( + AutoModelForNextSentencePrediction, head_doc="next sentence prediction" +) + + +class AutoModelForImageClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + + +AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") + + +class AutoModelForObjectDetection(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING + + +AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") + + +class AutoModelForAudioClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING + + +AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") + + +class AutoModelForCTC(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_CTC_MAPPING + + +AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") + + +class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING + + +AutoModelForSpeechSeq2Seq = auto_class_update( + AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing" +) + + +class AutoModelWithLMHead(_AutoModelWithLMHead): + @classmethod + def from_config(cls, config): + warnings.warn( + "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " + "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " + "`AutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) + return super().from_config(config) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + warnings.warn( + "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " + "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " + "`AutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)