@@ -2,7 +2,7 @@ | |||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions"> | |||
<span class="rst-current-version" data-toggle="rst-current-version"> | |||
<span class="fa fa-book"> Other Versions</span> | |||
v: {{ current_version.name }} | |||
{{ current_version.name }} | |||
<span class="fa fa-caret-down"></span> | |||
</span> | |||
<div class="rst-other-versions"> | |||
@@ -974,9 +974,10 @@ class DataSet: | |||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||
pad_fn: Callable = None) -> Collator: | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class: `~fastNLP.core.collators.Collator` | |||
时该函数才有效。调用该函数可以对 field 内容的 pad_val, dtype, backend 等进行调整。 | |||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
:param field_name: 需要调整的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
@@ -999,12 +1000,14 @@ class DataSet: | |||
def set_ignore(self, *field_names) -> Collator: | |||
""" | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class: `~fastNLP.core.collators.Collator` | |||
时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略。 | |||
Example:: | |||
collator.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
:param field_names: 需要忽略的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
@@ -10,6 +10,8 @@ import sys | |||
import tarfile | |||
import tempfile | |||
import operator | |||
import types | |||
import functools | |||
from collections import OrderedDict, UserDict | |||
from contextlib import contextmanager | |||
from dataclasses import fields | |||
@@ -37,6 +39,8 @@ if _NEED_IMPORT_TORCH: | |||
import torch | |||
_torch_version = importlib_metadata.version("torch") | |||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | |||
hf_cache_home = os.path.expanduser( | |||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) | |||
) | |||
@@ -45,10 +49,9 @@ default_cache_path = os.path.join(hf_cache_home, "transformers") | |||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) | |||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) | |||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) | |||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) | |||
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules" | |||
SESSION_ID = uuid4().hex | |||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | |||
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES | |||
WEIGHTS_NAME = "pytorch_model.bin" | |||
@@ -1043,3 +1046,11 @@ class TensorType(ExplicitEnum): | |||
PYTORCH = "pt" | |||
NUMPY = "np" | |||
def copy_func(f): | |||
"""Returns a copy of a function f.""" | |||
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard) | |||
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__) | |||
g = functools.update_wrapper(g, f) | |||
g.__kwdefaults__ = f.__kwdefaults__ | |||
return g |
@@ -0,0 +1,83 @@ | |||
__all__ = [ | |||
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", | |||
"CONFIG_MAPPING", | |||
"MODEL_NAMES_MAPPING", | |||
"AutoConfig", | |||
"TOKENIZER_MAPPING_NAMES", | |||
"get_values", | |||
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", | |||
"MODEL_FOR_CAUSAL_LM_MAPPING", | |||
"MODEL_FOR_CTC_MAPPING", | |||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", | |||
"MODEL_FOR_MASKED_LM_MAPPING", | |||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING", | |||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", | |||
"MODEL_FOR_OBJECT_DETECTION_MAPPING", | |||
"MODEL_FOR_PRETRAINING_MAPPING", | |||
"MODEL_FOR_QUESTION_ANSWERING_MAPPING", | |||
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", | |||
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", | |||
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", | |||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", | |||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", | |||
"MODEL_MAPPING", | |||
"MODEL_WITH_LM_HEAD_MAPPING", | |||
"AutoModel", | |||
"AutoModelForAudioClassification", | |||
"AutoModelForCausalLM", | |||
"AutoModelForCTC", | |||
"AutoModelForImageClassification", | |||
"AutoModelForMaskedLM", | |||
"AutoModelForMultipleChoice", | |||
"AutoModelForNextSentencePrediction", | |||
"AutoModelForObjectDetection", | |||
"AutoModelForPreTraining", | |||
"AutoModelForQuestionAnswering", | |||
"AutoModelForSeq2SeqLM", | |||
"AutoModelForSequenceClassification", | |||
"AutoModelForSpeechSeq2Seq", | |||
"AutoModelForTableQuestionAnswering", | |||
"AutoModelForTokenClassification", | |||
"AutoModelWithLMHead", | |||
] | |||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \ | |||
AutoConfig | |||
from .tokenization_auto import TOKENIZER_MAPPING_NAMES | |||
from .auto_factory import get_values | |||
from .modeling_auto import ( | |||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, | |||
MODEL_FOR_CAUSAL_LM_MAPPING, | |||
MODEL_FOR_CTC_MAPPING, | |||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, | |||
MODEL_FOR_MASKED_LM_MAPPING, | |||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, | |||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, | |||
MODEL_FOR_OBJECT_DETECTION_MAPPING, | |||
MODEL_FOR_PRETRAINING_MAPPING, | |||
MODEL_FOR_QUESTION_ANSWERING_MAPPING, | |||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | |||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | |||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, | |||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, | |||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, | |||
MODEL_MAPPING, | |||
MODEL_WITH_LM_HEAD_MAPPING, | |||
AutoModel, | |||
AutoModelForAudioClassification, | |||
AutoModelForCausalLM, | |||
AutoModelForCTC, | |||
AutoModelForImageClassification, | |||
AutoModelForMaskedLM, | |||
AutoModelForMultipleChoice, | |||
AutoModelForNextSentencePrediction, | |||
AutoModelForObjectDetection, | |||
AutoModelForPreTraining, | |||
AutoModelForQuestionAnswering, | |||
AutoModelForSeq2SeqLM, | |||
AutoModelForSequenceClassification, | |||
AutoModelForSpeechSeq2Seq, | |||
AutoModelForTableQuestionAnswering, | |||
AutoModelForTokenClassification, | |||
AutoModelWithLMHead, | |||
) |
@@ -0,0 +1,565 @@ | |||
# coding=utf-8 | |||
# Copyright 2021 The HuggingFace Inc. team. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Factory function to build auto-model classes.""" | |||
import importlib | |||
from collections import OrderedDict | |||
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings | |||
from .dynamic import get_class_from_dynamic_module | |||
from fastNLP.transformers.torch.configuration_utils import PretrainedConfig | |||
from fastNLP.transformers.torch.file_utils import copy_func | |||
from fastNLP.core.log import logger | |||
CLASS_DOCSTRING = """ | |||
This is a generic model class that will be instantiated as one of the model classes of the library when created | |||
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the | |||
:meth:`~transformers.BaseAutoModelClass.from_config` class method. | |||
This class cannot be instantiated directly using ``__init__()`` (throws an error). | |||
""" | |||
FROM_CONFIG_DOCSTRING = """ | |||
Instantiates one of the model classes of the library from a configuration. | |||
Note: | |||
Loading a model from its configuration file does **not** load the model weights. It only affects the | |||
model's configuration. Use :meth:`~transformers.BaseAutoModelClass.from_pretrained` to load the model | |||
weights. | |||
Args: | |||
config (:class:`~transformers.PretrainedConfig`): | |||
The model class to instantiate is selected based on the configuration class: | |||
List options | |||
Examples:: | |||
>>> from transformers import AutoConfig, BaseAutoModelClass | |||
>>> # Download configuration from huggingface.co and cache. | |||
>>> config = AutoConfig.from_pretrained('checkpoint_placeholder') | |||
>>> model = BaseAutoModelClass.from_config(config) | |||
""" | |||
FROM_PRETRAINED_TORCH_DOCSTRING = """ | |||
Instantiate one of the model classes of the library from a pretrained model. | |||
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either | |||
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, | |||
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: | |||
List options | |||
The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are | |||
deactivated). To train the model, you should first set it back in training mode with ``model.train()`` | |||
Args: | |||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||
Can be either: | |||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. | |||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under | |||
a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||
- A path to a `directory` containing model weights saved using | |||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. | |||
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In | |||
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided | |||
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in | |||
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. | |||
model_args (additional positional arguments, `optional`): | |||
Will be passed along to the underlying model ``__init__()`` method. | |||
config (:class:`~transformers.PretrainedConfig`, `optional`): | |||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can | |||
be automatically loaded when: | |||
- The model is a model provided by the library (loaded with the `model id` string of a pretrained | |||
model). | |||
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded | |||
by supplying the save directory. | |||
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a | |||
configuration JSON file named `config.json` is found in the directory. | |||
state_dict (`Dict[str, torch.Tensor]`, `optional`): | |||
A state dictionary to use instead of a state dictionary loaded from saved weights file. | |||
This option can be used if you want to create a model from a pretrained configuration but load your own | |||
weights. In this case though, you should check if using | |||
:func:`~transformers.PreTrainedModel.save_pretrained` and | |||
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. | |||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||
Path to a directory in which a downloaded pretrained model configuration should be cached if the | |||
standard cache should not be used. | |||
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Load the model weights from a TensorFlow checkpoint save file (see docstring of | |||
``pretrained_model_name_or_path`` argument). | |||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |||
cached versions if they exist. | |||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |||
file exists. | |||
proxies (:obj:`Dict[str, str], `optional`): | |||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |||
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. | |||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to only look at local files (e.g., not try downloading the model). | |||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||
identifier allowed by git. | |||
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | |||
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it | |||
will execute code present on the Hub on your local machine. | |||
kwargs (additional keyword arguments, `optional`): | |||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | |||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or | |||
automatically loaded: | |||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the | |||
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have | |||
already been done) | |||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class | |||
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of | |||
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute | |||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration | |||
attribute will be passed to the underlying model's ``__init__`` function. | |||
Examples:: | |||
>>> from transformers import AutoConfig, BaseAutoModelClass | |||
>>> # Download model and configuration from huggingface.co and cache. | |||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') | |||
>>> # Update configuration during loading | |||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) | |||
>>> model.config.output_attentions | |||
True | |||
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) | |||
>>> config = AutoConfig.from_pretrained('./tf_model/shortcut_placeholder_tf_model_config.json') | |||
>>> model = BaseAutoModelClass.from_pretrained('./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index', from_tf=True, config=config) | |||
""" | |||
FROM_PRETRAINED_TF_DOCSTRING = """ | |||
Instantiate one of the model classes of the library from a pretrained model. | |||
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either | |||
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, | |||
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: | |||
List options | |||
Args: | |||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||
Can be either: | |||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. | |||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under | |||
a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||
- A path to a `directory` containing model weights saved using | |||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. | |||
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In | |||
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided | |||
as ``config`` argument. This loading path is slower than converting the PyTorch model in a | |||
TensorFlow model using the provided conversion scripts and loading the TensorFlow model | |||
afterwards. | |||
model_args (additional positional arguments, `optional`): | |||
Will be passed along to the underlying model ``__init__()`` method. | |||
config (:class:`~transformers.PretrainedConfig`, `optional`): | |||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can | |||
be automatically loaded when: | |||
- The model is a model provided by the library (loaded with the `model id` string of a pretrained | |||
model). | |||
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded | |||
by supplying the save directory. | |||
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a | |||
configuration JSON file named `config.json` is found in the directory. | |||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||
Path to a directory in which a downloaded pretrained model configuration should be cached if the | |||
standard cache should not be used. | |||
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Load the model weights from a PyTorch checkpoint save file (see docstring of | |||
``pretrained_model_name_or_path`` argument). | |||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |||
cached versions if they exist. | |||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |||
file exists. | |||
proxies (:obj:`Dict[str, str], `optional`): | |||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |||
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. | |||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to only look at local files (e.g., not try downloading the model). | |||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||
identifier allowed by git. | |||
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | |||
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it | |||
will execute code present on the Hub on your local machine. | |||
kwargs (additional keyword arguments, `optional`): | |||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | |||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or | |||
automatically loaded: | |||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the | |||
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have | |||
already been done) | |||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class | |||
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of | |||
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute | |||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration | |||
attribute will be passed to the underlying model's ``__init__`` function. | |||
Examples:: | |||
>>> from transformers import AutoConfig, BaseAutoModelClass | |||
>>> # Download model and configuration from huggingface.co and cache. | |||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') | |||
>>> # Update configuration during loading | |||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) | |||
>>> model.config.output_attentions | |||
True | |||
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) | |||
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json') | |||
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config) | |||
""" | |||
FROM_PRETRAINED_FLAX_DOCSTRING = """ | |||
Instantiate one of the model classes of the library from a pretrained model. | |||
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either | |||
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, | |||
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: | |||
List options | |||
Args: | |||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||
Can be either: | |||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. | |||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under | |||
a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||
- A path to a `directory` containing model weights saved using | |||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. | |||
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In | |||
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided | |||
as ``config`` argument. This loading path is slower than converting the PyTorch model in a | |||
TensorFlow model using the provided conversion scripts and loading the TensorFlow model | |||
afterwards. | |||
model_args (additional positional arguments, `optional`): | |||
Will be passed along to the underlying model ``__init__()`` method. | |||
config (:class:`~transformers.PretrainedConfig`, `optional`): | |||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can | |||
be automatically loaded when: | |||
- The model is a model provided by the library (loaded with the `model id` string of a pretrained | |||
model). | |||
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded | |||
by supplying the save directory. | |||
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a | |||
configuration JSON file named `config.json` is found in the directory. | |||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||
Path to a directory in which a downloaded pretrained model configuration should be cached if the | |||
standard cache should not be used. | |||
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Load the model weights from a PyTorch checkpoint save file (see docstring of | |||
``pretrained_model_name_or_path`` argument). | |||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |||
cached versions if they exist. | |||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |||
file exists. | |||
proxies (:obj:`Dict[str, str], `optional`): | |||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |||
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. | |||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to only look at local files (e.g., not try downloading the model). | |||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||
identifier allowed by git. | |||
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | |||
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it | |||
will execute code present on the Hub on your local machine. | |||
kwargs (additional keyword arguments, `optional`): | |||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | |||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or | |||
automatically loaded: | |||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the | |||
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have | |||
already been done) | |||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class | |||
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of | |||
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute | |||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration | |||
attribute will be passed to the underlying model's ``__init__`` function. | |||
Examples:: | |||
>>> from transformers import AutoConfig, BaseAutoModelClass | |||
>>> # Download model and configuration from huggingface.co and cache. | |||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') | |||
>>> # Update configuration during loading | |||
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) | |||
>>> model.config.output_attentions | |||
True | |||
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) | |||
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json') | |||
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config) | |||
""" | |||
def _get_model_class(config, model_mapping): | |||
supported_models = model_mapping[type(config)] | |||
if not isinstance(supported_models, (list, tuple)): | |||
return supported_models | |||
name_to_model = {model.__name__: model for model in supported_models} | |||
architectures = getattr(config, "architectures", []) | |||
for arch in architectures: | |||
if arch in name_to_model: | |||
return name_to_model[arch] | |||
elif f"TF{arch}" in name_to_model: | |||
return name_to_model[f"TF{arch}"] | |||
elif f"Flax{arch}" in name_to_model: | |||
return name_to_model[f"Flax{arch}"] | |||
# If not architecture is set in the config or match the supported models, the first element of the tuple is the | |||
# defaults. | |||
return supported_models[0] | |||
class _BaseAutoModelClass: | |||
# Base class for auto models. | |||
_model_mapping = None | |||
def __init__(self, *args, **kwargs): | |||
raise EnvironmentError( | |||
f"{self.__class__.__name__} is designed to be instantiated " | |||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " | |||
f"`{self.__class__.__name__}.from_config(config)` methods." | |||
) | |||
@classmethod | |||
def from_config(cls, config, **kwargs): | |||
if type(config) in cls._model_mapping.keys(): | |||
model_class = _get_model_class(config, cls._model_mapping) | |||
return model_class._from_config(config, **kwargs) | |||
raise ValueError( | |||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" | |||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." | |||
) | |||
@classmethod | |||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |||
config = kwargs.pop("config", None) | |||
trust_remote_code = kwargs.pop("trust_remote_code", False) | |||
kwargs["_from_auto"] = True | |||
if not isinstance(config, PretrainedConfig): | |||
config, kwargs = AutoConfig.from_pretrained( | |||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs | |||
) | |||
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: | |||
if not trust_remote_code: | |||
raise ValueError( | |||
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo " | |||
"on your local machine. Make sure you have read the code there to avoid malicious use, then set " | |||
"the option `trust_remote_code=True` to remove this error." | |||
) | |||
if kwargs.get("revision", None) is None: | |||
logger.warn( | |||
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " | |||
"no malicious code has been contributed in a newer revision." | |||
) | |||
class_ref = config.auto_map[cls.__name__] | |||
module_file, class_name = class_ref.split(".") | |||
model_class = get_class_from_dynamic_module( | |||
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs | |||
) | |||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) | |||
elif type(config) in cls._model_mapping.keys(): | |||
model_class = _get_model_class(config, cls._model_mapping) | |||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) | |||
raise ValueError( | |||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" | |||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." | |||
) | |||
def insert_head_doc(docstring, head_doc=""): | |||
if len(head_doc) > 0: | |||
return docstring.replace( | |||
"one of the model classes of the library ", | |||
f"one of the model classes of the library (with a {head_doc} head) ", | |||
) | |||
return docstring.replace( | |||
"one of the model classes of the library ", "one of the base model classes of the library " | |||
) | |||
def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""): | |||
# Create a new class with the right name from the base class | |||
model_mapping = cls._model_mapping | |||
name = cls.__name__ | |||
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc) | |||
cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name) | |||
# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't | |||
# have a specific docstrings for them. | |||
from_config = copy_func(_BaseAutoModelClass.from_config) | |||
from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc) | |||
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) | |||
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) | |||
from_config.__doc__ = from_config_docstring | |||
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config) | |||
cls.from_config = classmethod(from_config) | |||
if name.startswith("TF"): | |||
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING | |||
elif name.startswith("Flax"): | |||
from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING | |||
else: | |||
from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING | |||
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained) | |||
from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc) | |||
from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name) | |||
from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example) | |||
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] | |||
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) | |||
from_pretrained.__doc__ = from_pretrained_docstring | |||
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained) | |||
cls.from_pretrained = classmethod(from_pretrained) | |||
return cls | |||
def get_values(model_mapping): | |||
result = [] | |||
for model in model_mapping.values(): | |||
if isinstance(model, (list, tuple)): | |||
result += list(model) | |||
else: | |||
result.append(model) | |||
return result | |||
def getattribute_from_module(module, attr): | |||
if attr is None: | |||
return None | |||
if isinstance(attr, tuple): | |||
return tuple(getattribute_from_module(module, a) for a in attr) | |||
if hasattr(module, attr): | |||
return getattr(module, attr) | |||
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the | |||
# object at the top level. | |||
transformers_module = importlib.import_module("transformers") | |||
return getattribute_from_module(transformers_module, attr) | |||
class _LazyAutoMapping(OrderedDict): | |||
""" | |||
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. | |||
Args: | |||
- config_mapping: The map model type to config class | |||
- model_mapping: The map model type to model (or tokenizer) class | |||
""" | |||
def __init__(self, config_mapping, model_mapping): | |||
self._config_mapping = config_mapping | |||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} | |||
self._model_mapping = model_mapping | |||
self._modules = {} | |||
def __getitem__(self, key): | |||
model_type = self._reverse_config_mapping[key.__name__] | |||
if model_type not in self._model_mapping: | |||
raise KeyError(key) | |||
model_name = self._model_mapping[model_type] | |||
return self._load_attr_from_module(model_type, model_name) | |||
def _load_attr_from_module(self, model_type, attr): | |||
module_name = model_type_to_module_name(model_type) | |||
if module_name not in self._modules: | |||
try: | |||
self._modules[module_name] = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") | |||
except ImportError: | |||
raise ImportError(f"fastNLP transformers does not support {module_name} now, please install and import `transformers` to use it.") | |||
return getattribute_from_module(self._modules[module_name], attr) | |||
def keys(self): | |||
return [ | |||
self._load_attr_from_module(key, name) | |||
for key, name in self._config_mapping.items() | |||
if key in self._model_mapping.keys() | |||
] | |||
def get(self, key, default): | |||
try: | |||
return self.__getitem__(key) | |||
except KeyError: | |||
return default | |||
def __bool__(self): | |||
return bool(self.keys()) | |||
def values(self): | |||
return [ | |||
self._load_attr_from_module(key, name) | |||
for key, name in self._model_mapping.items() | |||
if key in self._config_mapping.keys() | |||
] | |||
def items(self): | |||
return [ | |||
( | |||
self._load_attr_from_module(key, self._config_mapping[key]), | |||
self._load_attr_from_module(key, self._model_mapping[key]), | |||
) | |||
for key in self._model_mapping.keys() | |||
if key in self._config_mapping.keys() | |||
] | |||
def __iter__(self): | |||
return iter(self._mapping.keys()) | |||
def __contains__(self, item): | |||
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: | |||
return False | |||
model_type = self._reverse_config_mapping[item.__name__] | |||
return model_type in self._model_mapping |
@@ -26,221 +26,33 @@ from fastNLP.core.log import logger | |||
CONFIG_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Add configs here | |||
("fnet", "FNetConfig"), | |||
("gptj", "GPTJConfig"), | |||
("layoutlmv2", "LayoutLMv2Config"), | |||
("beit", "BeitConfig"), | |||
("rembert", "RemBertConfig"), | |||
("visual_bert", "VisualBertConfig"), | |||
("canine", "CanineConfig"), | |||
("roformer", "RoFormerConfig"), | |||
("clip", "CLIPConfig"), | |||
("bigbird_pegasus", "BigBirdPegasusConfig"), | |||
("deit", "DeiTConfig"), | |||
("luke", "LukeConfig"), | |||
("detr", "DetrConfig"), | |||
("gpt_neo", "GPTNeoConfig"), | |||
("big_bird", "BigBirdConfig"), | |||
("speech_to_text_2", "Speech2Text2Config"), | |||
("speech_to_text", "Speech2TextConfig"), | |||
("vit", "ViTConfig"), | |||
("wav2vec2", "Wav2Vec2Config"), | |||
("m2m_100", "M2M100Config"), | |||
("convbert", "ConvBertConfig"), | |||
("led", "LEDConfig"), | |||
("blenderbot-small", "BlenderbotSmallConfig"), | |||
("retribert", "RetriBertConfig"), | |||
("ibert", "IBertConfig"), | |||
("mt5", "MT5Config"), | |||
("t5", "T5Config"), | |||
("mobilebert", "MobileBertConfig"), | |||
("distilbert", "DistilBertConfig"), | |||
("albert", "AlbertConfig"), | |||
("bert-generation", "BertGenerationConfig"), | |||
("camembert", "CamembertConfig"), | |||
("xlm-roberta", "XLMRobertaConfig"), | |||
("pegasus", "PegasusConfig"), | |||
("marian", "MarianConfig"), | |||
("mbart", "MBartConfig"), | |||
("megatron-bert", "MegatronBertConfig"), | |||
("mpnet", "MPNetConfig"), | |||
("bart", "BartConfig"), | |||
("blenderbot", "BlenderbotConfig"), | |||
("reformer", "ReformerConfig"), | |||
("longformer", "LongformerConfig"), | |||
("roberta", "RobertaConfig"), | |||
("deberta-v2", "DebertaV2Config"), | |||
("deberta", "DebertaConfig"), | |||
("flaubert", "FlaubertConfig"), | |||
("fsmt", "FSMTConfig"), | |||
("squeezebert", "SqueezeBertConfig"), | |||
("hubert", "HubertConfig"), | |||
("bert", "BertConfig"), | |||
("openai-gpt", "OpenAIGPTConfig"), | |||
("gpt2", "GPT2Config"), | |||
("transfo-xl", "TransfoXLConfig"), | |||
("xlnet", "XLNetConfig"), | |||
("xlm-prophetnet", "XLMProphetNetConfig"), | |||
("prophetnet", "ProphetNetConfig"), | |||
("xlm", "XLMConfig"), | |||
("ctrl", "CTRLConfig"), | |||
("electra", "ElectraConfig"), | |||
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"), | |||
("encoder-decoder", "EncoderDecoderConfig"), | |||
("funnel", "FunnelConfig"), | |||
("lxmert", "LxmertConfig"), | |||
("dpr", "DPRConfig"), | |||
("layoutlm", "LayoutLMConfig"), | |||
("rag", "RagConfig"), | |||
("tapas", "TapasConfig"), | |||
("splinter", "SplinterConfig"), | |||
("cpt", "CPTConfig"), | |||
] | |||
) | |||
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Add archive maps here | |||
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
("cpt", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |||
] | |||
) | |||
MODEL_NAMES_MAPPING = OrderedDict( | |||
[ | |||
# Add full (and cased) model names here | |||
("fnet", "FNet"), | |||
("gptj", "GPT-J"), | |||
("beit", "BeiT"), | |||
("rembert", "RemBERT"), | |||
("layoutlmv2", "LayoutLMv2"), | |||
("visual_bert", "VisualBert"), | |||
("canine", "Canine"), | |||
("roformer", "RoFormer"), | |||
("clip", "CLIP"), | |||
("bigbird_pegasus", "BigBirdPegasus"), | |||
("deit", "DeiT"), | |||
("luke", "LUKE"), | |||
("detr", "DETR"), | |||
("gpt_neo", "GPT Neo"), | |||
("big_bird", "BigBird"), | |||
("speech_to_text_2", "Speech2Text2"), | |||
("speech_to_text", "Speech2Text"), | |||
("vit", "ViT"), | |||
("wav2vec2", "Wav2Vec2"), | |||
("m2m_100", "M2M100"), | |||
("convbert", "ConvBERT"), | |||
("led", "LED"), | |||
("blenderbot-small", "BlenderbotSmall"), | |||
("retribert", "RetriBERT"), | |||
("ibert", "I-BERT"), | |||
("t5", "T5"), | |||
("mobilebert", "MobileBERT"), | |||
("distilbert", "DistilBERT"), | |||
("albert", "ALBERT"), | |||
("bert-generation", "Bert Generation"), | |||
("camembert", "CamemBERT"), | |||
("xlm-roberta", "XLM-RoBERTa"), | |||
("pegasus", "Pegasus"), | |||
("blenderbot", "Blenderbot"), | |||
("marian", "Marian"), | |||
("mbart", "mBART"), | |||
("megatron-bert", "MegatronBert"), | |||
("bart", "BART"), | |||
("reformer", "Reformer"), | |||
("longformer", "Longformer"), | |||
("roberta", "RoBERTa"), | |||
("flaubert", "FlauBERT"), | |||
("fsmt", "FairSeq Machine-Translation"), | |||
("squeezebert", "SqueezeBERT"), | |||
("bert", "BERT"), | |||
("openai-gpt", "OpenAI GPT"), | |||
("gpt2", "OpenAI GPT-2"), | |||
("transfo-xl", "Transformer-XL"), | |||
("xlnet", "XLNet"), | |||
("xlm", "XLM"), | |||
("ctrl", "CTRL"), | |||
("electra", "ELECTRA"), | |||
("encoder-decoder", "Encoder decoder"), | |||
("speech-encoder-decoder", "Speech Encoder decoder"), | |||
("funnel", "Funnel Transformer"), | |||
("lxmert", "LXMERT"), | |||
("deberta-v2", "DeBERTa-v2"), | |||
("deberta", "DeBERTa"), | |||
("layoutlm", "LayoutLM"), | |||
("dpr", "DPR"), | |||
("rag", "RAG"), | |||
("xlm-prophetnet", "XLMProphetNet"), | |||
("prophetnet", "ProphetNet"), | |||
("mt5", "mT5"), | |||
("mpnet", "MPNet"), | |||
("tapas", "TAPAS"), | |||
("hubert", "Hubert"), | |||
("barthez", "BARThez"), | |||
("phobert", "PhoBERT"), | |||
("cpm", "CPM"), | |||
("bertweet", "Bertweet"), | |||
("bert-japanese", "BertJapanese"), | |||
("byt5", "ByT5"), | |||
("mbart50", "mBART-50"), | |||
("splinter", "Splinter"), | |||
("cpt", "CPT") | |||
] | |||
) | |||
@@ -0,0 +1,208 @@ | |||
import importlib | |||
import os | |||
import re | |||
import shutil | |||
import sys | |||
from pathlib import Path | |||
from typing import Dict, Optional, Union | |||
from fastNLP.transformers.torch.file_utils import ( | |||
HF_MODULES_CACHE, | |||
TRANSFORMERS_DYNAMIC_MODULE_NAME, | |||
cached_path, | |||
hf_bucket_url, | |||
is_offline_mode, | |||
) | |||
from fastNLP.core.log import logger | |||
def init_hf_modules(): | |||
""" | |||
Creates the cache directory for modules with an init, and adds it to the Python path. | |||
""" | |||
# This function has already been executed if HF_MODULES_CACHE already is in the Python path. | |||
if HF_MODULES_CACHE in sys.path: | |||
return | |||
sys.path.append(HF_MODULES_CACHE) | |||
os.makedirs(HF_MODULES_CACHE, exist_ok=True) | |||
init_path = Path(HF_MODULES_CACHE) / "__init__.py" | |||
if not init_path.exists(): | |||
init_path.touch() | |||
def create_dynamic_module(name: Union[str, os.PathLike]): | |||
""" | |||
Creates a dynamic module in the cache directory for modules. | |||
""" | |||
init_hf_modules() | |||
dynamic_module_path = Path(HF_MODULES_CACHE) / name | |||
# If the parent module does not exist yet, recursively create it. | |||
if not dynamic_module_path.parent.exists(): | |||
create_dynamic_module(dynamic_module_path.parent) | |||
os.makedirs(dynamic_module_path, exist_ok=True) | |||
init_path = dynamic_module_path / "__init__.py" | |||
if not init_path.exists(): | |||
init_path.touch() | |||
def check_imports(filename): | |||
""" | |||
Check if the current Python environment contains all the libraries that are imported in a file. | |||
""" | |||
with open(filename, "r", encoding="utf-8") as f: | |||
content = f.read() | |||
# Imports of the form `import xxx` | |||
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) | |||
# Imports of the form `from xxx import yyy` | |||
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) | |||
# Only keep the top-level module | |||
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] | |||
# Unique-ify and test we got them all | |||
imports = list(set(imports)) | |||
missing_packages = [] | |||
for imp in imports: | |||
try: | |||
importlib.import_module(imp) | |||
except ImportError: | |||
missing_packages.append(imp) | |||
if len(missing_packages) > 0: | |||
raise ImportError( | |||
"This modeling file requires the following packages that were not found in your environment: " | |||
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" | |||
) | |||
def get_class_in_module(class_name, module_path): | |||
""" | |||
Import a module on the cache directory for modules and extract a class from it. | |||
""" | |||
module_path = module_path.replace(os.path.sep, ".") | |||
module = importlib.import_module(module_path) | |||
return getattr(module, class_name) | |||
def get_class_from_dynamic_module( | |||
pretrained_model_name_or_path: Union[str, os.PathLike], | |||
module_file: str, | |||
class_name: str, | |||
cache_dir: Optional[Union[str, os.PathLike]] = None, | |||
force_download: bool = False, | |||
resume_download: bool = False, | |||
proxies: Optional[Dict[str, str]] = None, | |||
use_auth_token: Optional[Union[bool, str]] = None, | |||
revision: Optional[str] = None, | |||
local_files_only: bool = False, | |||
**kwargs, | |||
): | |||
""" | |||
Extracts a class from a module file, present in the local folder or repository of a model. | |||
.. warning:: | |||
Calling this function will execute the code in the module file found locally or downloaded from the Hub. It | |||
should therefore only be called on trusted repos. | |||
Args: | |||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |||
This can be either: | |||
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on | |||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or | |||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. | |||
- a path to a `directory` containing a configuration file saved using the | |||
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``. | |||
module_file (:obj:`str`): | |||
The name of the module file containing the class to look for. | |||
class_name (:obj:`str`): | |||
The name of the class to import in the module. | |||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard | |||
cache should not be used. | |||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to force to (re-)download the configuration files and override the cached versions if they | |||
exist. | |||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. | |||
proxies (:obj:`Dict[str, str]`, `optional`): | |||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. | |||
use_auth_token (:obj:`str` or `bool`, `optional`): | |||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token | |||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). | |||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||
identifier allowed by git. | |||
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
If :obj:`True`, will only try to load the tokenizer configuration from local files. | |||
.. note:: | |||
Passing :obj:`use_auth_token=True` is required when you want to use a private model. | |||
Returns: | |||
:obj:`type`: The class, dynamically imported from the module. | |||
Examples:: | |||
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this | |||
# module. | |||
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") | |||
""" | |||
if is_offline_mode() and not local_files_only: | |||
logger.info("Offline mode: forcing local_files_only=True") | |||
local_files_only = True | |||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. | |||
pretrained_model_name_or_path = str(pretrained_model_name_or_path) | |||
if os.path.isdir(pretrained_model_name_or_path): | |||
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) | |||
submodule = "local" | |||
else: | |||
module_file_or_url = hf_bucket_url( | |||
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None | |||
) | |||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep) | |||
try: | |||
# Load from URL or cache if already cached | |||
resolved_module_file = cached_path( | |||
module_file_or_url, | |||
cache_dir=cache_dir, | |||
force_download=force_download, | |||
proxies=proxies, | |||
resume_download=resume_download, | |||
local_files_only=local_files_only, | |||
use_auth_token=use_auth_token, | |||
) | |||
except EnvironmentError: | |||
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") | |||
raise | |||
# Check we have all the requirements in our environment | |||
check_imports(resolved_module_file) | |||
# Now we move the module inside our cached dynamic modules. | |||
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule | |||
create_dynamic_module(full_submodule) | |||
submodule_path = Path(HF_MODULES_CACHE) / full_submodule | |||
if submodule == "local": | |||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of | |||
# that hash, to only copy when there is a modification but it seems overkill for now). | |||
# The only reason we do the copy is to avoid putting too many folders in sys.path. | |||
module_name = module_file | |||
shutil.copy(resolved_module_file, submodule_path / module_file) | |||
else: | |||
# The module file will end up being named module_file + the etag. This way we get the benefit of versioning. | |||
resolved_module_file_name = Path(resolved_module_file).name | |||
module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".") | |||
module_name = "_".join(module_name_parts) + ".py" | |||
if not (submodule_path / module_name).exists(): | |||
shutil.copy(resolved_module_file, submodule_path / module_name) | |||
# And lastly we get the class inside our newly created module | |||
final_module = os.path.join(full_submodule, module_name.replace(".py", "")) | |||
return get_class_in_module(class_name, final_module) |
@@ -0,0 +1,326 @@ | |||
# coding=utf-8 | |||
# Copyright 2018 The HuggingFace Inc. team. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" Auto Model class. """ | |||
import warnings | |||
from collections import OrderedDict | |||
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update | |||
from .configuration_auto import CONFIG_MAPPING_NAMES | |||
from fastNLP.core.log import logger | |||
MODEL_MAPPING_NAMES = OrderedDict( | |||
[ | |||
("bart", "BartModel"), | |||
("roberta", "RobertaModel"), | |||
("bert", "BertModel"), | |||
("gpt2", "GPT2Model"), | |||
("cpt", "CPTModel"), | |||
] | |||
) | |||
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( | |||
[ | |||
("bart", "BartForConditionalGeneration"), | |||
("roberta", "RobertaForMaskedLM"), | |||
("bert", "BertForPreTraining"), | |||
("gpt2", "GPT2LMHeadModel"), | |||
("cpt", "CPTForConditionalGeneration"), | |||
] | |||
) | |||
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Model with LM heads mapping | |||
("bart", "BartForConditionalGeneration"), | |||
("roberta", "RobertaForMaskedLM"), | |||
("bert", "BertForMaskedLM"), | |||
("gpt2", "GPT2LMHeadModel"), | |||
("cpt", "CPTForConditionalGeneration"), | |||
] | |||
) | |||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Model for Causal LM mapping | |||
("roberta", "RobertaForCausalLM"), | |||
("bert", "BertLMHeadModel"), | |||
("gpt2", "GPT2LMHeadModel"), | |||
("bart", "BartForCausalLM"), | |||
] | |||
) | |||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict([]) | |||
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Model for Masked LM mapping | |||
("bart", "BartForConditionalGeneration"), | |||
("roberta", "RobertaForMaskedLM"), | |||
("bert", "BertForMaskedLM"), | |||
("cpt", "CPTForConditionalGeneration"), | |||
] | |||
) | |||
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict([]) | |||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Model for Seq2Seq Causal LM mapping | |||
("bart", "BartForConditionalGeneration"), | |||
("cpt", "CPTForConditionalGeneration"), | |||
] | |||
) | |||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict([]) | |||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Model for Sequence Classification mapping | |||
("bart", "BartForSequenceClassification"), | |||
("roberta", "RobertaForSequenceClassification"), | |||
("bert", "BertForSequenceClassification"), | |||
("gpt2", "GPT2ForSequenceClassification"), | |||
("cpt", "CPTForSequenceClassification"), | |||
] | |||
) | |||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Model for Question Answering mapping | |||
("bart", "BartForQuestionAnswering"), | |||
("roberta", "RobertaForQuestionAnswering"), | |||
("bert", "BertForQuestionAnswering"), | |||
("cpt", "CPTForQuestionAnswering"), | |||
] | |||
) | |||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict([]) | |||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Model for Token Classification mapping | |||
("roberta", "RobertaForTokenClassification"), | |||
("bert", "BertForTokenClassification"), | |||
("gpt2", "GPT2ForTokenClassification"), | |||
] | |||
) | |||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( | |||
[ | |||
# Model for Multiple Choice mapping | |||
("roberta", "RobertaForMultipleChoice"), | |||
("bert", "BertForMultipleChoice"), | |||
] | |||
) | |||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( | |||
[ | |||
("bert", "BertForNextSentencePrediction"), | |||
] | |||
) | |||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([]) | |||
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict([]) | |||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) | |||
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) | |||
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) | |||
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) | |||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |||
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES | |||
) | |||
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) | |||
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) | |||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( | |||
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES | |||
) | |||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |||
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES | |||
) | |||
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |||
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES | |||
) | |||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |||
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES | |||
) | |||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |||
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES | |||
) | |||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) | |||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( | |||
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES | |||
) | |||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |||
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES | |||
) | |||
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) | |||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) | |||
class AutoModel(_BaseAutoModelClass): | |||
_model_mapping = MODEL_MAPPING | |||
AutoModel = auto_class_update(AutoModel) | |||
class AutoModelForPreTraining(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING | |||
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") | |||
# Private on purpose, the public class will add the deprecation warnings. | |||
class _AutoModelWithLMHead(_BaseAutoModelClass): | |||
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING | |||
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") | |||
class AutoModelForCausalLM(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING | |||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") | |||
class AutoModelForMaskedLM(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING | |||
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") | |||
class AutoModelForSeq2SeqLM(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |||
AutoModelForSeq2SeqLM = auto_class_update( | |||
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" | |||
) | |||
class AutoModelForSequenceClassification(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | |||
AutoModelForSequenceClassification = auto_class_update( | |||
AutoModelForSequenceClassification, head_doc="sequence classification" | |||
) | |||
class AutoModelForQuestionAnswering(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING | |||
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") | |||
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING | |||
AutoModelForTableQuestionAnswering = auto_class_update( | |||
AutoModelForTableQuestionAnswering, | |||
head_doc="table question answering", | |||
checkpoint_for_example="google/tapas-base-finetuned-wtq", | |||
) | |||
class AutoModelForTokenClassification(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |||
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") | |||
class AutoModelForMultipleChoice(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING | |||
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") | |||
class AutoModelForNextSentencePrediction(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING | |||
AutoModelForNextSentencePrediction = auto_class_update( | |||
AutoModelForNextSentencePrediction, head_doc="next sentence prediction" | |||
) | |||
class AutoModelForImageClassification(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING | |||
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") | |||
class AutoModelForObjectDetection(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING | |||
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") | |||
class AutoModelForAudioClassification(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING | |||
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") | |||
class AutoModelForCTC(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_CTC_MAPPING | |||
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") | |||
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): | |||
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING | |||
AutoModelForSpeechSeq2Seq = auto_class_update( | |||
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing" | |||
) | |||
class AutoModelWithLMHead(_AutoModelWithLMHead): | |||
@classmethod | |||
def from_config(cls, config): | |||
warnings.warn( | |||
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |||
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |||
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |||
FutureWarning, | |||
) | |||
return super().from_config(config) | |||
@classmethod | |||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |||
warnings.warn( | |||
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |||
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |||
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |||
FutureWarning, | |||
) | |||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
@@ -29,171 +29,9 @@ if TYPE_CHECKING: | |||
else: | |||
TOKENIZER_MAPPING_NAMES = OrderedDict( | |||
[ | |||
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), | |||
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), | |||
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), | |||
( | |||
"t5", | |||
( | |||
"T5Tokenizer" if is_sentencepiece_available() else None, | |||
"T5TokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
( | |||
"mt5", | |||
( | |||
"MT5Tokenizer" if is_sentencepiece_available() else None, | |||
"MT5TokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), | |||
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)), | |||
( | |||
"albert", | |||
( | |||
"AlbertTokenizer" if is_sentencepiece_available() else None, | |||
"AlbertTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
( | |||
"camembert", | |||
( | |||
"CamembertTokenizer" if is_sentencepiece_available() else None, | |||
"CamembertTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
( | |||
"pegasus", | |||
( | |||
"PegasusTokenizer" if is_sentencepiece_available() else None, | |||
"PegasusTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
( | |||
"mbart", | |||
( | |||
"MBartTokenizer" if is_sentencepiece_available() else None, | |||
"MBartTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
( | |||
"xlm-roberta", | |||
( | |||
"XLMRobertaTokenizer" if is_sentencepiece_available() else None, | |||
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), | |||
("blenderbot-small", ("BlenderbotSmallTokenizer", None)), | |||
("blenderbot", ("BlenderbotTokenizer", None)), | |||
("bart", ("BartTokenizer", "BartTokenizerFast")), | |||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), | |||
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), | |||
( | |||
"reformer", | |||
( | |||
"ReformerTokenizer" if is_sentencepiece_available() else None, | |||
"ReformerTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), | |||
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)), | |||
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), | |||
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), | |||
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)), | |||
( | |||
"dpr", | |||
( | |||
"DPRQuestionEncoderTokenizer", | |||
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
( | |||
"squeezebert", | |||
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None), | |||
), | |||
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), | |||
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)), | |||
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), | |||
("transfo-xl", ("TransfoXLTokenizer", None)), | |||
( | |||
"xlnet", | |||
( | |||
"XLNetTokenizer" if is_sentencepiece_available() else None, | |||
"XLNetTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
("flaubert", ("FlaubertTokenizer", None)), | |||
("xlm", ("XLMTokenizer", None)), | |||
("ctrl", ("CTRLTokenizer", None)), | |||
("fsmt", ("FSMTTokenizer", None)), | |||
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)), | |||
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)), | |||
("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)), | |||
("rag", ("RagTokenizer", None)), | |||
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)), | |||
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), | |||
("speech_to_text_2", ("Speech2Text2Tokenizer", None)), | |||
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), | |||
("prophetnet", ("ProphetNetTokenizer", None)), | |||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), | |||
("tapas", ("TapasTokenizer", None)), | |||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), | |||
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), | |||
( | |||
"big_bird", | |||
( | |||
"BigBirdTokenizer" if is_sentencepiece_available() else None, | |||
"BigBirdTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), | |||
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)), | |||
("hubert", ("Wav2Vec2CTCTokenizer", None)), | |||
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), | |||
("luke", ("LukeTokenizer", None)), | |||
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)), | |||
("canine", ("CanineTokenizer", None)), | |||
("bertweet", ("BertweetTokenizer", None)), | |||
("bert-japanese", ("BertJapaneseTokenizer", None)), | |||
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")), | |||
("byt5", ("ByT5Tokenizer", None)), | |||
( | |||
"cpm", | |||
( | |||
"CpmTokenizer" if is_sentencepiece_available() else None, | |||
"CpmTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)), | |||
("phobert", ("PhobertTokenizer", None)), | |||
( | |||
"barthez", | |||
( | |||
"BarthezTokenizer" if is_sentencepiece_available() else None, | |||
"BarthezTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
( | |||
"mbart50", | |||
( | |||
"MBart50Tokenizer" if is_sentencepiece_available() else None, | |||
"MBart50TokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
( | |||
"rembert", | |||
( | |||
"RemBertTokenizer" if is_sentencepiece_available() else None, | |||
"RemBertTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
( | |||
"clip", | |||
( | |||
"CLIPTokenizer", | |||
"CLIPTokenizerFast" if is_tokenizers_available() else None, | |||
), | |||
), | |||
("bart", ("BartTokenizer", None)), | |||
("roberta", ("RobertaTokenizer", None)), | |||
("bert", ("BertTokenizer", None)), | |||
("gpt2", ("GPT2Tokenizer", None)), | |||
] | |||
) |
@@ -1,5 +1,6 @@ | |||
__all__ = [ | |||
"CPT_PRETRAINED_MODEL_ARCHIVE_LIST", | |||
"CPTConfig", | |||
"CPTForConditionalGeneration", | |||
"CPTForSequenceClassification", | |||
"CPTForMaskedLM", | |||
@@ -9,4 +10,4 @@ __all__ = [ | |||
] | |||
from .modeling_cpt import CPT_PRETRAINED_MODEL_ARCHIVE_LIST, CPTForConditionalGeneration, CPTForSequenceClassification, \ | |||
CPTForMaskedLM, CPTForQuestionAnswering, CPTModel, CPTPretrainedModel | |||
CPTForMaskedLM, CPTForQuestionAnswering, CPTModel, CPTPretrainedModel, CPTConfig |