Browse Source

dataset增加set_pad, set_ignore

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
a9316b0209
10 changed files with 1213 additions and 366 deletions
  1. +1
    -1
      docs/source/_templates/versions.html
  2. +7
    -4
      fastNLP/core/dataset/dataset.py
  3. +14
    -3
      fastNLP/transformers/torch/file_utils.py
  4. +83
    -0
      fastNLP/transformers/torch/models/auto/__init__.py
  5. +565
    -0
      fastNLP/transformers/torch/models/auto/auto_factory.py
  6. +3
    -191
      fastNLP/transformers/torch/models/auto/configuration_auto.py
  7. +208
    -0
      fastNLP/transformers/torch/models/auto/dynamic.py
  8. +326
    -0
      fastNLP/transformers/torch/models/auto/modeling_auto.py
  9. +4
    -166
      fastNLP/transformers/torch/models/auto/tokenization_auto.py
  10. +2
    -1
      fastNLP/transformers/torch/models/cpt/__init__.py

+ 1
- 1
docs/source/_templates/versions.html View File

@@ -2,7 +2,7 @@
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: {{ current_version.name }}
{{ current_version.name }}
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">


+ 7
- 4
fastNLP/core/dataset/dataset.py View File

@@ -974,9 +974,10 @@ class DataSet:
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> Collator:
"""
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class: `~fastNLP.core.collators.Collator`
时该函数才有效。调用该函数可以对 field 内容的 pad_val, dtype, backend 等进行调整。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
:param field_name: 需要调整的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
@@ -999,12 +1000,14 @@ class DataSet:

def set_ignore(self, *field_names) -> Collator:
"""
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class: `~fastNLP.core.collators.Collator`
时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略。

Example::

collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
:param field_names: 需要忽略的 field 的名称。如果 DataSet 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身


+ 14
- 3
fastNLP/transformers/torch/file_utils.py View File

@@ -10,6 +10,8 @@ import sys
import tarfile
import tempfile
import operator
import types
import functools
from collections import OrderedDict, UserDict
from contextlib import contextmanager
from dataclasses import fields
@@ -37,6 +39,8 @@ if _NEED_IMPORT_TORCH:
import torch
_torch_version = importlib_metadata.version("torch")

ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}

hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
@@ -45,10 +49,9 @@ default_cache_path = os.path.join(hf_cache_home, "transformers")
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
SESSION_ID = uuid4().hex

ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}

DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES

WEIGHTS_NAME = "pytorch_model.bin"
@@ -1043,3 +1046,11 @@ class TensorType(ExplicitEnum):

PYTORCH = "pt"
NUMPY = "np"

def copy_func(f):
"""Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g

+ 83
- 0
fastNLP/transformers/torch/models/auto/__init__.py View File

@@ -0,0 +1,83 @@
__all__ = [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CONFIG_MAPPING",
"MODEL_NAMES_MAPPING",
"AutoConfig",
"TOKENIZER_MAPPING_NAMES",
"get_values",
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
"MODEL_FOR_PRETRAINING_MAPPING",
"MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel",
"AutoModelForAudioClassification",
"AutoModelForCausalLM",
"AutoModelForCTC",
"AutoModelForImageClassification",
"AutoModelForMaskedLM",
"AutoModelForMultipleChoice",
"AutoModelForNextSentencePrediction",
"AutoModelForObjectDetection",
"AutoModelForPreTraining",
"AutoModelForQuestionAnswering",
"AutoModelForSeq2SeqLM",
"AutoModelForSequenceClassification",
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelWithLMHead",
]

from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, \
AutoConfig
from .tokenization_auto import TOKENIZER_MAPPING_NAMES
from .auto_factory import get_values
from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction,
AutoModelForObjectDetection,
AutoModelForPreTraining,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelWithLMHead,
)

+ 565
- 0
fastNLP/transformers/torch/models/auto/auto_factory.py View File

@@ -0,0 +1,565 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory function to build auto-model classes."""
import importlib
from collections import OrderedDict

from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
from .dynamic import get_class_from_dynamic_module
from fastNLP.transformers.torch.configuration_utils import PretrainedConfig
from fastNLP.transformers.torch.file_utils import copy_func
from fastNLP.core.log import logger


CLASS_DOCSTRING = """
This is a generic model class that will be instantiated as one of the model classes of the library when created
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the
:meth:`~transformers.BaseAutoModelClass.from_config` class method.

This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""

FROM_CONFIG_DOCSTRING = """
Instantiates one of the model classes of the library from a configuration.

Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use :meth:`~transformers.BaseAutoModelClass.from_pretrained` to load the model
weights.

Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:

List options

Examples::

>>> from transformers import AutoConfig, BaseAutoModelClass
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('checkpoint_placeholder')
>>> model = BaseAutoModelClass.from_config(config)
"""

FROM_PRETRAINED_TORCH_DOCSTRING = """
Instantiate one of the model classes of the library from a pretrained model.

The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:

List options

The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with ``model.train()``

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:

- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
model_args (additional positional arguments, `optional`):
Will be passed along to the underlying model ``__init__()`` method.
config (:class:`~transformers.PretrainedConfig`, `optional`):
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:

- The model is a model provided by the library (loaded with the `model id` string of a pretrained
model).
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
state_dict (`Dict[str, torch.Tensor]`, `optional`):
A state dictionary to use instead of a state dictionary loaded from saved weights file.

This option can be used if you want to create a model from a pretrained configuration but load your own
weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a TensorFlow checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:

- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.

Examples::

>>> from transformers import AutoConfig, BaseAutoModelClass

>>> # Download model and configuration from huggingface.co and cache.
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')

>>> # Update configuration during loading
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
>>> model.config.output_attentions
True

>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained('./tf_model/shortcut_placeholder_tf_model_config.json')
>>> model = BaseAutoModelClass.from_pretrained('./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""

FROM_PRETRAINED_TF_DOCSTRING = """
Instantiate one of the model classes of the library from a pretrained model.

The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:

List options

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:

- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
afterwards.
model_args (additional positional arguments, `optional`):
Will be passed along to the underlying model ``__init__()`` method.
config (:class:`~transformers.PretrainedConfig`, `optional`):
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:

- The model is a model provided by the library (loaded with the `model id` string of a pretrained
model).
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:

- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.

Examples::

>>> from transformers import AutoConfig, BaseAutoModelClass

>>> # Download model and configuration from huggingface.co and cache.
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')

>>> # Update configuration during loading
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
>>> model.config.output_attentions
True

>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json')
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config)
"""

FROM_PRETRAINED_FLAX_DOCSTRING = """
Instantiate one of the model classes of the library from a pretrained model.

The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing,
by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:

List options

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:

- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
afterwards.
model_args (additional positional arguments, `optional`):
Will be passed along to the underlying model ``__init__()`` method.
config (:class:`~transformers.PretrainedConfig`, `optional`):
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:

- The model is a model provided by the library (loaded with the `model id` string of a pretrained
model).
- The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:

- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.

Examples::

>>> from transformers import AutoConfig, BaseAutoModelClass

>>> # Download model and configuration from huggingface.co and cache.
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder')

>>> # Update configuration during loading
>>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True)
>>> model.config.output_attentions
True

>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json')
>>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config)
"""


def _get_model_class(config, model_mapping):
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models

name_to_model = {model.__name__: model for model in supported_models}
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]

# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return supported_models[0]


class _BaseAutoModelClass:
# Base class for auto models.
_model_mapping = None

def __init__(self, *args, **kwargs):
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__}.from_config(config)` methods."
)

@classmethod
def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class._from_config(config, **kwargs)

raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
kwargs["_from_auto"] = True
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warn(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision."
)
class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)


def insert_head_doc(docstring, head_doc=""):
if len(head_doc) > 0:
return docstring.replace(
"one of the model classes of the library ",
f"one of the model classes of the library (with a {head_doc} head) ",
)
return docstring.replace(
"one of the model classes of the library ", "one of the base model classes of the library "
)


def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""):
# Create a new class with the right name from the base class
model_mapping = cls._model_mapping
name = cls.__name__
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)

# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
# have a specific docstrings for them.
from_config = copy_func(_BaseAutoModelClass.from_config)
from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
cls.from_config = classmethod(from_config)

if name.startswith("TF"):
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
elif name.startswith("Flax"):
from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
else:
from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
cls.from_pretrained = classmethod(from_pretrained)
return cls


def get_values(model_mapping):
result = []
for model in model_mapping.values():
if isinstance(model, (list, tuple)):
result += list(model)
else:
result.append(model)

return result


def getattribute_from_module(module, attr):
if attr is None:
return None
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr):
return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)


class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.

Args:

- config_mapping: The map model type to config class
- model_mapping: The map model type to model (or tokenizer) class
"""

def __init__(self, config_mapping, model_mapping):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._modules = {}

def __getitem__(self, key):
model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)

def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
if module_name not in self._modules:
try:
self._modules[module_name] = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models")
except ImportError:
raise ImportError(f"fastNLP transformers does not support {module_name} now, please install and import `transformers` to use it.")
return getattribute_from_module(self._modules[module_name], attr)

def keys(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]

def get(self, key, default):
try:
return self.__getitem__(key)
except KeyError:
return default

def __bool__(self):
return bool(self.keys())

def values(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]

def items(self):
return [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]

def __iter__(self):
return iter(self._mapping.keys())

def __contains__(self, item):
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping

+ 3
- 191
fastNLP/transformers/torch/models/auto/configuration_auto.py View File

@@ -26,221 +26,33 @@ from fastNLP.core.log import logger
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("fnet", "FNetConfig"),
("gptj", "GPTJConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("beit", "BeitConfig"),
("rembert", "RemBertConfig"),
("visual_bert", "VisualBertConfig"),
("canine", "CanineConfig"),
("roformer", "RoFormerConfig"),
("clip", "CLIPConfig"),
("bigbird_pegasus", "BigBirdPegasusConfig"),
("deit", "DeiTConfig"),
("luke", "LukeConfig"),
("detr", "DetrConfig"),
("gpt_neo", "GPTNeoConfig"),
("big_bird", "BigBirdConfig"),
("speech_to_text_2", "Speech2Text2Config"),
("speech_to_text", "Speech2TextConfig"),
("vit", "ViTConfig"),
("wav2vec2", "Wav2Vec2Config"),
("m2m_100", "M2M100Config"),
("convbert", "ConvBertConfig"),
("led", "LEDConfig"),
("blenderbot-small", "BlenderbotSmallConfig"),
("retribert", "RetriBertConfig"),
("ibert", "IBertConfig"),
("mt5", "MT5Config"),
("t5", "T5Config"),
("mobilebert", "MobileBertConfig"),
("distilbert", "DistilBertConfig"),
("albert", "AlbertConfig"),
("bert-generation", "BertGenerationConfig"),
("camembert", "CamembertConfig"),
("xlm-roberta", "XLMRobertaConfig"),
("pegasus", "PegasusConfig"),
("marian", "MarianConfig"),
("mbart", "MBartConfig"),
("megatron-bert", "MegatronBertConfig"),
("mpnet", "MPNetConfig"),
("bart", "BartConfig"),
("blenderbot", "BlenderbotConfig"),
("reformer", "ReformerConfig"),
("longformer", "LongformerConfig"),
("roberta", "RobertaConfig"),
("deberta-v2", "DebertaV2Config"),
("deberta", "DebertaConfig"),
("flaubert", "FlaubertConfig"),
("fsmt", "FSMTConfig"),
("squeezebert", "SqueezeBertConfig"),
("hubert", "HubertConfig"),
("bert", "BertConfig"),
("openai-gpt", "OpenAIGPTConfig"),
("gpt2", "GPT2Config"),
("transfo-xl", "TransfoXLConfig"),
("xlnet", "XLNetConfig"),
("xlm-prophetnet", "XLMProphetNetConfig"),
("prophetnet", "ProphetNetConfig"),
("xlm", "XLMConfig"),
("ctrl", "CTRLConfig"),
("electra", "ElectraConfig"),
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
("encoder-decoder", "EncoderDecoderConfig"),
("funnel", "FunnelConfig"),
("lxmert", "LxmertConfig"),
("dpr", "DPRConfig"),
("layoutlm", "LayoutLMConfig"),
("rag", "RagConfig"),
("tapas", "TapasConfig"),
("splinter", "SplinterConfig"),
("cpt", "CPTConfig"),
]
)

CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("cpt", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
]
)

MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("fnet", "FNet"),
("gptj", "GPT-J"),
("beit", "BeiT"),
("rembert", "RemBERT"),
("layoutlmv2", "LayoutLMv2"),
("visual_bert", "VisualBert"),
("canine", "Canine"),
("roformer", "RoFormer"),
("clip", "CLIP"),
("bigbird_pegasus", "BigBirdPegasus"),
("deit", "DeiT"),
("luke", "LUKE"),
("detr", "DETR"),
("gpt_neo", "GPT Neo"),
("big_bird", "BigBird"),
("speech_to_text_2", "Speech2Text2"),
("speech_to_text", "Speech2Text"),
("vit", "ViT"),
("wav2vec2", "Wav2Vec2"),
("m2m_100", "M2M100"),
("convbert", "ConvBERT"),
("led", "LED"),
("blenderbot-small", "BlenderbotSmall"),
("retribert", "RetriBERT"),
("ibert", "I-BERT"),
("t5", "T5"),
("mobilebert", "MobileBERT"),
("distilbert", "DistilBERT"),
("albert", "ALBERT"),
("bert-generation", "Bert Generation"),
("camembert", "CamemBERT"),
("xlm-roberta", "XLM-RoBERTa"),
("pegasus", "Pegasus"),
("blenderbot", "Blenderbot"),
("marian", "Marian"),
("mbart", "mBART"),
("megatron-bert", "MegatronBert"),
("bart", "BART"),
("reformer", "Reformer"),
("longformer", "Longformer"),
("roberta", "RoBERTa"),
("flaubert", "FlauBERT"),
("fsmt", "FairSeq Machine-Translation"),
("squeezebert", "SqueezeBERT"),
("bert", "BERT"),
("openai-gpt", "OpenAI GPT"),
("gpt2", "OpenAI GPT-2"),
("transfo-xl", "Transformer-XL"),
("xlnet", "XLNet"),
("xlm", "XLM"),
("ctrl", "CTRL"),
("electra", "ELECTRA"),
("encoder-decoder", "Encoder decoder"),
("speech-encoder-decoder", "Speech Encoder decoder"),
("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"),
("deberta-v2", "DeBERTa-v2"),
("deberta", "DeBERTa"),
("layoutlm", "LayoutLM"),
("dpr", "DPR"),
("rag", "RAG"),
("xlm-prophetnet", "XLMProphetNet"),
("prophetnet", "ProphetNet"),
("mt5", "mT5"),
("mpnet", "MPNet"),
("tapas", "TAPAS"),
("hubert", "Hubert"),
("barthez", "BARThez"),
("phobert", "PhoBERT"),
("cpm", "CPM"),
("bertweet", "Bertweet"),
("bert-japanese", "BertJapanese"),
("byt5", "ByT5"),
("mbart50", "mBART-50"),
("splinter", "Splinter"),
("cpt", "CPT")
]
)



+ 208
- 0
fastNLP/transformers/torch/models/auto/dynamic.py View File

@@ -0,0 +1,208 @@
import importlib
import os
import re
import shutil
import sys
from pathlib import Path
from typing import Dict, Optional, Union

from fastNLP.transformers.torch.file_utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
)
from fastNLP.core.log import logger

def init_hf_modules():
"""
Creates the cache directory for modules with an init, and adds it to the Python path.
"""
# This function has already been executed if HF_MODULES_CACHE already is in the Python path.
if HF_MODULES_CACHE in sys.path:
return

sys.path.append(HF_MODULES_CACHE)
os.makedirs(HF_MODULES_CACHE, exist_ok=True)
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
if not init_path.exists():
init_path.touch()

def create_dynamic_module(name: Union[str, os.PathLike]):
"""
Creates a dynamic module in the cache directory for modules.
"""
init_hf_modules()
dynamic_module_path = Path(HF_MODULES_CACHE) / name
# If the parent module does not exist yet, recursively create it.
if not dynamic_module_path.parent.exists():
create_dynamic_module(dynamic_module_path.parent)
os.makedirs(dynamic_module_path, exist_ok=True)
init_path = dynamic_module_path / "__init__.py"
if not init_path.exists():
init_path.touch()

def check_imports(filename):
"""
Check if the current Python environment contains all the libraries that are imported in a file.
"""
with open(filename, "r", encoding="utf-8") as f:
content = f.read()

# Imports of the form `import xxx`
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy`
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]

# Unique-ify and test we got them all
imports = list(set(imports))
missing_packages = []
for imp in imports:
try:
importlib.import_module(imp)
except ImportError:
missing_packages.append(imp)

if len(missing_packages) > 0:
raise ImportError(
"This modeling file requires the following packages that were not found in your environment: "
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
)


def get_class_in_module(class_name, module_path):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
return getattr(module, class_name)

def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
class_name: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
Extracts a class from a module file, present in the local folder or repository of a model.

.. warning::

Calling this function will execute the code in the module file found locally or downloaded from the Hub. It
should therefore only be called on trusted repos.

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:

- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a configuration file saved using the
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``.

module_file (:obj:`str`):
The name of the module file containing the class to look for.
class_name (:obj:`str`):
The name of the class to import in the module.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, will only try to load the tokenizer configuration from local files.

.. note::

Passing :obj:`use_auth_token=True` is required when you want to use a private model.


Returns:
:obj:`type`: The class, dynamically imported from the module.

Examples::

# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
# module.
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True

# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
submodule = "local"
else:
module_file_or_url = hf_bucket_url(
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
)
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)

try:
# Load from URL or cache if already cached
resolved_module_file = cached_path(
module_file_or_url,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)

except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise

# Check we have all the requirements in our environment
check_imports(resolved_module_file)

# Now we move the module inside our cached dynamic modules.
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == "local":
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
# that hash, to only copy when there is a modification but it seems overkill for now).
# The only reason we do the copy is to avoid putting too many folders in sys.path.
module_name = module_file
shutil.copy(resolved_module_file, submodule_path / module_file)
else:
# The module file will end up being named module_file + the etag. This way we get the benefit of versioning.
resolved_module_file_name = Path(resolved_module_file).name
module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".")
module_name = "_".join(module_name_parts) + ".py"
if not (submodule_path / module_name).exists():
shutil.copy(resolved_module_file, submodule_path / module_name)

# And lastly we get the class inside our newly created module
final_module = os.path.join(full_submodule, module_name.replace(".py", ""))
return get_class_in_module(class_name, final_module)

+ 326
- 0
fastNLP/transformers/torch/models/auto/modeling_auto.py View File

@@ -0,0 +1,326 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """

import warnings
from collections import OrderedDict

from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
from fastNLP.core.log import logger


MODEL_MAPPING_NAMES = OrderedDict(
[
("bart", "BartModel"),
("roberta", "RobertaModel"),
("bert", "BertModel"),
("gpt2", "GPT2Model"),
("cpt", "CPTModel"),
]
)

MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
("bart", "BartForConditionalGeneration"),
("roberta", "RobertaForMaskedLM"),
("bert", "BertForPreTraining"),
("gpt2", "GPT2LMHeadModel"),
("cpt", "CPTForConditionalGeneration"),
]
)

MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("bart", "BartForConditionalGeneration"),
("roberta", "RobertaForMaskedLM"),
("bert", "BertForMaskedLM"),
("gpt2", "GPT2LMHeadModel"),
("cpt", "CPTForConditionalGeneration"),
]
)

MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("roberta", "RobertaForCausalLM"),
("bert", "BertLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("bart", "BartForCausalLM"),
]
)

MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("bart", "BartForConditionalGeneration"),
("roberta", "RobertaForMaskedLM"),
("bert", "BertForMaskedLM"),
("cpt", "CPTForConditionalGeneration"),
]
)

MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("bart", "BartForConditionalGeneration"),
("cpt", "CPTForConditionalGeneration"),
]
)

MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("bart", "BartForSequenceClassification"),
("roberta", "RobertaForSequenceClassification"),
("bert", "BertForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("cpt", "CPTForSequenceClassification"),
]
)

MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("bart", "BartForQuestionAnswering"),
("roberta", "RobertaForQuestionAnswering"),
("bert", "BertForQuestionAnswering"),
("cpt", "CPTForQuestionAnswering"),
]
)

MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("roberta", "RobertaForTokenClassification"),
("bert", "BertForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
]
)

MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("roberta", "RobertaForMultipleChoice"),
("bert", "BertForMultipleChoice"),
]
)

MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("bert", "BertForNextSentencePrediction"),
]
)

MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([])

MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict([])

MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)


class AutoModel(_BaseAutoModelClass):
_model_mapping = MODEL_MAPPING


AutoModel = auto_class_update(AutoModel)


class AutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING


AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")


# Private on purpose, the public class will add the deprecation warnings.
class _AutoModelWithLMHead(_BaseAutoModelClass):
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING


_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")


class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING


AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")


class AutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING


AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")


class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING


AutoModelForSeq2SeqLM = auto_class_update(
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
)


class AutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING


AutoModelForSequenceClassification = auto_class_update(
AutoModelForSequenceClassification, head_doc="sequence classification"
)


class AutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING


AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")


class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING


AutoModelForTableQuestionAnswering = auto_class_update(
AutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)


class AutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING


AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")


class AutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING


AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")


class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING


AutoModelForNextSentencePrediction = auto_class_update(
AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)


class AutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING


AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")


class AutoModelForObjectDetection(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING


AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")


class AutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")


class AutoModelForCTC(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CTC_MAPPING


AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")


class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


AutoModelForSpeechSeq2Seq = auto_class_update(
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing"
)


class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod
def from_config(cls, config):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_config(config)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

+ 4
- 166
fastNLP/transformers/torch/models/auto/tokenization_auto.py View File

@@ -29,171 +29,9 @@ if TYPE_CHECKING:
else:
TOKENIZER_MAPPING_NAMES = OrderedDict(
[
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
(
"t5",
(
"T5Tokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mt5",
(
"MT5Tokenizer" if is_sentencepiece_available() else None,
"MT5TokenizerFast" if is_tokenizers_available() else None,
),
),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
(
"albert",
(
"AlbertTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"camembert",
(
"CamembertTokenizer" if is_sentencepiece_available() else None,
"CamembertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"pegasus",
(
"PegasusTokenizer" if is_sentencepiece_available() else None,
"PegasusTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mbart",
(
"MBartTokenizer" if is_sentencepiece_available() else None,
"MBartTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"xlm-roberta",
(
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
),
),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("blenderbot", ("BlenderbotTokenizer", None)),
("bart", ("BartTokenizer", "BartTokenizerFast")),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
(
"reformer",
(
"ReformerTokenizer" if is_sentencepiece_available() else None,
"ReformerTokenizerFast" if is_tokenizers_available() else None,
),
),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
(
"dpr",
(
"DPRQuestionEncoderTokenizer",
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"squeezebert",
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
),
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
(
"xlnet",
(
"XLNetTokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None,
),
),
("flaubert", ("FlaubertTokenizer", None)),
("xlm", ("XLMTokenizer", None)),
("ctrl", ("CTRLTokenizer", None)),
("fsmt", ("FSMTTokenizer", None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)),
("rag", ("RagTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("tapas", ("TapasTokenizer", None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"big_bird",
(
"BigBirdTokenizer" if is_sentencepiece_available() else None,
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
),
),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("canine", ("CanineTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
("byt5", ("ByT5Tokenizer", None)),
(
"cpm",
(
"CpmTokenizer" if is_sentencepiece_available() else None,
"CpmTokenizerFast" if is_tokenizers_available() else None,
),
),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
(
"barthez",
(
"BarthezTokenizer" if is_sentencepiece_available() else None,
"BarthezTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mbart50",
(
"MBart50Tokenizer" if is_sentencepiece_available() else None,
"MBart50TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"rembert",
(
"RemBertTokenizer" if is_sentencepiece_available() else None,
"RemBertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"clip",
(
"CLIPTokenizer",
"CLIPTokenizerFast" if is_tokenizers_available() else None,
),
),
("bart", ("BartTokenizer", None)),
("roberta", ("RobertaTokenizer", None)),
("bert", ("BertTokenizer", None)),
("gpt2", ("GPT2Tokenizer", None)),
]
)

+ 2
- 1
fastNLP/transformers/torch/models/cpt/__init__.py View File

@@ -1,5 +1,6 @@
__all__ = [
"CPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"CPTConfig",
"CPTForConditionalGeneration",
"CPTForSequenceClassification",
"CPTForMaskedLM",
@@ -9,4 +10,4 @@ __all__ = [
]

from .modeling_cpt import CPT_PRETRAINED_MODEL_ARCHIVE_LIST, CPTForConditionalGeneration, CPTForSequenceClassification, \
CPTForMaskedLM, CPTForQuestionAnswering, CPTModel, CPTPretrainedModel
CPTForMaskedLM, CPTForQuestionAnswering, CPTModel, CPTPretrainedModel, CPTConfig

Loading…
Cancel
Save