Features: 1. Refactor the directory structure of nlp models. All model files are placed into either the model folder or the task_model folder 2. Refactor all the comments to google style 3. Add detail comments to important tasks and nlp models, to list the description of the model, and its preprocessor&trainer 4. Model Exporting now supports a direct all to TorchModelExporter(no need to derive from it) 5. Refactor model save_pretrained method to support direct running(independent from trainer) 6. Remove the judgement of Model in the pipeline base class, to support outer register models running in our pipelines 7. Nlp trainer now has a NLPTrainingArguments class , user can pass arguments into the dataclass, and use it as a normal cfg_modify_fn, to simplify the operation of modify cfg. 8. Merge the BACKBONES and the MODELS, so user can get a backbone with the Model.from_pretrained call 9. Model.from_pretrained now support a task argument, so user can use a backbone and load it with a specific task class. 10. Support Preprocessor.from_pretrained method 11. Add standard return classes to important nlp tasks, so some of the pipelines and the models are independent now, the return values of the models will always be tensors, and the pipelines will take care of the conversion to numpy and the following stuffs. 12. Split the file of the nlp preprocessors, to make the dir structure more clear. Bugs Fixing: 1. Fix a bug that lr_scheduler can be called earlier than the optimizer's step 2. Fix a bug that the direct call of Pipelines (not from pipeline(xxx)) throws error 3. Fix a bug that the trainer will not call the correct TaskDataset class 4. Fix a bug that the internal loading of dataset will throws error in the trainer class Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10490585master
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:4fd6fa6b23c2fdaf876606a767d9b64b1924e1acddfc06ac42db73ba86083280 | |||
size 119940 | |||
oid sha256:4eae921001139d7e3c06331c9ef2213f8fc1c23512acd95751559866fb770e96 | |||
size 121855 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:4d37672a0e299a08d2daf5c7fc29bfce96bb15701fe5e5e68f068861ac2ee705 | |||
size 119619 | |||
oid sha256:f97d34d7450d17d0a93647129ab10d16b1f6e70c34a73b6f7687b79519ee4f71 | |||
size 121563 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:c692e0753cfe349e520511427727a8252f141fa10e85f9a61562845e8d731f9a | |||
size 119619 | |||
oid sha256:a8355f27a3235209f206b5e75f4400353e5989e94cf4d71270b42ded8821d536 | |||
size 121563 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:2bce1341f4b55d536771dad6e2b280458579f46c3216474ceb8a926022ab53d0 | |||
size 151572 | |||
oid sha256:344ef971bdf310b76c6571d1f4994ab6abc5edc659654d71a4f75b14a30960c2 | |||
size 152926 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:6af5024a26337a440c7ea2935fce84af558dd982ee97a2f027bb922cc874292b | |||
size 61741 | |||
oid sha256:f0aeb07b6c9b40a0cfa7492e839431764e9bece93c906833a07c05e83520a399 | |||
size 63161 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:bbce084781342ca7274c2e4d02ed5c5de43ba213a3b76328d5994404d6544c41 | |||
size 61745 | |||
oid sha256:7aa5c7a2565ccf0d2eea4baf8adbd0e020dbe36a7159b31156c53141cc9b2df2 | |||
size 63165 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:33ecc221513559a042ff975a38cc16aa47674545bc349362722c774c83f8d90c | |||
size 61239 | |||
oid sha256:cc6de82a8485fbfa008f6c2d5411cd07ba03e4a780bcb4e67efc6fba3c6ce92f | |||
size 63597 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:803c2e3ff7688abf0f83702b3904830a9f6f71e41e252de3c559354a9effefd1 | |||
size 61115 | |||
oid sha256:7d98ac11a4e9e2744a7402a5cc912da991a41938bbc5dd60f15ee5c6b3196030 | |||
size 63349 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:9e3ecc2c30d382641d561f84849b199c12bb1a9418e8099a191153f6f5275a85 | |||
size 61589 | |||
oid sha256:01f9b9bf6f8bbf9bb377d4cb6f399b2e5e065381f5b7332343e0db7b4fae72a5 | |||
size 62519 |
@@ -19,10 +19,13 @@ class Exporter(ABC): | |||
def from_model(cls, model: Model, **kwargs): | |||
"""Build the Exporter instance. | |||
@param model: A model instance. it will be used to output the generated file, | |||
Args: | |||
model: A Model instance. it will be used to generate the intermediate format file, | |||
and the configuration.json in its model_dir field will be used to create the exporter instance. | |||
@param kwargs: Extra kwargs used to create the Exporter instance. | |||
@return: The Exporter instance | |||
kwargs: Extra kwargs used to create the Exporter instance. | |||
Returns: | |||
The Exporter instance | |||
""" | |||
cfg = Config.from_file( | |||
os.path.join(model.model_dir, ModelFile.CONFIGURATION)) | |||
@@ -44,10 +47,13 @@ class Exporter(ABC): | |||
In some cases, several files may be generated, | |||
So please return a dict which contains the generated name with the file path. | |||
@param opset: The version of the ONNX operator set to use. | |||
@param outputs: The output dir. | |||
@param kwargs: In this default implementation, | |||
kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
@return: A dict contains the model name with the model file path. | |||
Args: | |||
opset: The version of the ONNX operator set to use. | |||
outputs: The output dir. | |||
kwargs: In this default implementation, | |||
kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
Returns: | |||
A dict contains the model name with the model file path. | |||
""" | |||
pass |
@@ -27,11 +27,14 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||
**kwargs) -> Dict[str, Any]: | |||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing. | |||
@param shape: A tuple of input shape which should have at most two dimensions. | |||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||
@param pair: Generate sentence pairs or single sentences for dummy inputs. | |||
@return: Dummy inputs. | |||
Args: | |||
shape: A tuple of input shape which should have at most two dimensions. | |||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||
pair(bool, `optional`): Whether to generate sentence pairs or single sentences. | |||
Returns: | |||
Dummy inputs. | |||
""" | |||
cfg = Config.from_file( | |||
@@ -13,8 +13,8 @@ from modelscope.models import TorchModel | |||
from modelscope.pipelines.base import collate_fn | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.regress_test_utils import compare_arguments_nested | |||
from modelscope.utils.tensor_utils import torch_nested_numpify | |||
from modelscope.utils.regress_test_utils import (compare_arguments_nested, | |||
numpify_tensor_nested) | |||
from .base import Exporter | |||
logger = get_logger(__name__) | |||
@@ -28,49 +28,61 @@ class TorchModelExporter(Exporter): | |||
and to provide implementations for generate_dummy_inputs/inputs/outputs methods. | |||
""" | |||
def export_onnx(self, outputs: str, opset=11, **kwargs): | |||
def export_onnx(self, output_dir: str, opset=13, **kwargs): | |||
"""Export the model as onnx format files. | |||
In some cases, several files may be generated, | |||
So please return a dict which contains the generated name with the file path. | |||
@param opset: The version of the ONNX operator set to use. | |||
@param outputs: The output dir. | |||
@param kwargs: In this default implementation, | |||
you can pass the arguments needed by _torch_export_onnx, other unrecognized args | |||
will be carried to generate_dummy_inputs as extra arguments (such as input shape). | |||
@return: A dict containing the model key - model file path pairs. | |||
Args: | |||
opset: The version of the ONNX operator set to use. | |||
output_dir: The output dir. | |||
kwargs: | |||
model: A model instance which will replace the exporting of self.model. | |||
In this default implementation, | |||
you can pass the arguments needed by _torch_export_onnx, other unrecognized args | |||
will be carried to generate_dummy_inputs as extra arguments (such as input shape). | |||
Returns: | |||
A dict containing the model key - model file path pairs. | |||
""" | |||
model = self.model | |||
model = self.model if 'model' not in kwargs else kwargs.pop('model') | |||
if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
model = model.model | |||
onnx_file = os.path.join(outputs, ModelFile.ONNX_MODEL_FILE) | |||
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE) | |||
self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs) | |||
return {'model': onnx_file} | |||
def export_torch_script(self, outputs: str, **kwargs): | |||
def export_torch_script(self, output_dir: str, **kwargs): | |||
"""Export the model as torch script files. | |||
In some cases, several files may be generated, | |||
So please return a dict which contains the generated name with the file path. | |||
@param outputs: The output dir. | |||
@param kwargs: In this default implementation, | |||
Args: | |||
output_dir: The output dir. | |||
kwargs: | |||
model: A model instance which will replace the exporting of self.model. | |||
In this default implementation, | |||
you can pass the arguments needed by _torch_export_torch_script, other unrecognized args | |||
will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
@return: A dict contains the model name with the model file path. | |||
Returns: | |||
A dict contains the model name with the model file path. | |||
""" | |||
model = self.model | |||
model = self.model if 'model' not in kwargs else kwargs.pop('model') | |||
if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
model = model.model | |||
ts_file = os.path.join(outputs, ModelFile.TS_MODEL_FILE) | |||
ts_file = os.path.join(output_dir, ModelFile.TS_MODEL_FILE) | |||
# generate ts by tracing | |||
self._torch_export_torch_script(model, ts_file, **kwargs) | |||
return {'model': ts_file} | |||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: | |||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing. | |||
@return: Dummy inputs. | |||
Returns: | |||
Dummy inputs. | |||
""" | |||
return None | |||
@@ -93,7 +105,7 @@ class TorchModelExporter(Exporter): | |||
def _torch_export_onnx(self, | |||
model: nn.Module, | |||
output: str, | |||
opset: int = 11, | |||
opset: int = 13, | |||
device: str = 'cpu', | |||
validation: bool = True, | |||
rtol: float = None, | |||
@@ -101,18 +113,27 @@ class TorchModelExporter(Exporter): | |||
**kwargs): | |||
"""Export the model to an onnx format file. | |||
@param model: A torch.nn.Module instance to export. | |||
@param output: The output file. | |||
@param opset: The version of the ONNX operator set to use. | |||
@param device: The device used to forward. | |||
@param validation: Whether validate the export file. | |||
@param rtol: The rtol used to regress the outputs. | |||
@param atol: The atol used to regress the outputs. | |||
Args: | |||
model: A torch.nn.Module instance to export. | |||
output: The output file. | |||
opset: The version of the ONNX operator set to use. | |||
device: The device used to forward. | |||
validation: Whether validate the export file. | |||
rtol: The rtol used to regress the outputs. | |||
atol: The atol used to regress the outputs. | |||
kwargs: | |||
dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). | |||
inputs: An inputs structure which will replace the calling of self.inputs. | |||
outputs: An outputs structure which will replace the calling of self.outputs. | |||
""" | |||
dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||
inputs = self.inputs | |||
outputs = self.outputs | |||
dummy_inputs = self.generate_dummy_inputs( | |||
**kwargs) if 'dummy_inputs' not in kwargs else kwargs.pop( | |||
'dummy_inputs') | |||
inputs = self.inputs if 'inputs' not in kwargs else kwargs.pop( | |||
'inputs') | |||
outputs = self.outputs if 'outputs' not in kwargs else kwargs.pop( | |||
'outputs') | |||
if dummy_inputs is None or inputs is None or outputs is None: | |||
raise NotImplementedError( | |||
'Model property dummy_inputs,inputs,outputs must be set.') | |||
@@ -125,7 +146,7 @@ class TorchModelExporter(Exporter): | |||
if isinstance(dummy_inputs, Mapping): | |||
dummy_inputs = dict(dummy_inputs) | |||
onnx_outputs = list(self.outputs.keys()) | |||
onnx_outputs = list(outputs.keys()) | |||
with replace_call(): | |||
onnx_export( | |||
@@ -160,11 +181,13 @@ class TorchModelExporter(Exporter): | |||
outputs_origin = model.forward( | |||
*_decide_input_format(model, dummy_inputs)) | |||
if isinstance(outputs_origin, Mapping): | |||
outputs_origin = torch_nested_numpify( | |||
outputs_origin = numpify_tensor_nested( | |||
list(outputs_origin.values())) | |||
elif isinstance(outputs_origin, (tuple, list)): | |||
outputs_origin = numpify_tensor_nested(outputs_origin) | |||
outputs = ort_session.run( | |||
onnx_outputs, | |||
torch_nested_numpify(dummy_inputs), | |||
numpify_tensor_nested(dummy_inputs), | |||
) | |||
tols = {} | |||
@@ -184,19 +207,26 @@ class TorchModelExporter(Exporter): | |||
validation: bool = True, | |||
rtol: float = None, | |||
atol: float = None, | |||
strict: bool = True, | |||
**kwargs): | |||
"""Export the model to a torch script file. | |||
@param model: A torch.nn.Module instance to export. | |||
@param output: The output file. | |||
@param device: The device used to forward. | |||
@param validation: Whether validate the export file. | |||
@param rtol: The rtol used to regress the outputs. | |||
@param atol: The atol used to regress the outputs. | |||
Args: | |||
model: A torch.nn.Module instance to export. | |||
output: The output file. | |||
device: The device used to forward. | |||
validation: Whether validate the export file. | |||
rtol: The rtol used to regress the outputs. | |||
atol: The atol used to regress the outputs. | |||
strict: strict mode in torch script tracing. | |||
kwargs: | |||
dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). | |||
""" | |||
model.eval() | |||
dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||
dummy_param = 'dummy_inputs' not in kwargs | |||
dummy_inputs = self.generate_dummy_inputs( | |||
**kwargs) if dummy_param else kwargs.pop('dummy_inputs') | |||
if dummy_inputs is None: | |||
raise NotImplementedError( | |||
'Model property dummy_inputs must be set.') | |||
@@ -207,7 +237,7 @@ class TorchModelExporter(Exporter): | |||
model.eval() | |||
with replace_call(): | |||
traced_model = torch.jit.trace( | |||
model, dummy_inputs, strict=False) | |||
model, dummy_inputs, strict=strict) | |||
torch.jit.save(traced_model, output) | |||
if validation: | |||
@@ -216,9 +246,9 @@ class TorchModelExporter(Exporter): | |||
model.eval() | |||
ts_model.eval() | |||
outputs = ts_model.forward(*dummy_inputs) | |||
outputs = torch_nested_numpify(outputs) | |||
outputs = numpify_tensor_nested(outputs) | |||
outputs_origin = model.forward(*dummy_inputs) | |||
outputs_origin = torch_nested_numpify(outputs_origin) | |||
outputs_origin = numpify_tensor_nested(outputs_origin) | |||
tols = {} | |||
if rtol is not None: | |||
tols['rtol'] = rtol | |||
@@ -240,7 +270,6 @@ def replace_call(): | |||
problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it | |||
back after the tracing was done. | |||
""" | |||
TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl | |||
yield | |||
TorchModel.__call__ = TorchModel.call_origin | |||
@@ -69,7 +69,6 @@ class Models(object): | |||
space_modeling = 'space-modeling' | |||
space_T_en = 'space-T-en' | |||
space_T_cn = 'space-T-cn' | |||
tcrf = 'transformer-crf' | |||
transformer_softmax = 'transformer-softmax' | |||
lcrf = 'lstm-crf' | |||
@@ -10,9 +10,6 @@ class Metric(ABC): | |||
complex metrics for a specific task with or without other Metric subclasses. | |||
""" | |||
def __init__(self, trainer=None, *args, **kwargs): | |||
self.trainer = trainer | |||
@abstractmethod | |||
def add(self, outputs: Dict, inputs: Dict): | |||
""" Append logits and labels within an eval loop. | |||
@@ -34,17 +34,24 @@ class TokenClassificationMetric(Metric): | |||
self.labels.append( | |||
torch_nested_numpify(torch_nested_detach(ground_truths))) | |||
def __init__(self, return_entity_level_metrics=False, *args, **kwargs): | |||
def __init__(self, | |||
return_entity_level_metrics=False, | |||
label2id=None, | |||
*args, | |||
**kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.return_entity_level_metrics = return_entity_level_metrics | |||
self.preds = [] | |||
self.labels = [] | |||
self.label2id = label2id | |||
def evaluate(self): | |||
self.id2label = { | |||
id: label | |||
for label, id in self.trainer.label2id.items() | |||
} | |||
label2id = self.label2id | |||
if label2id is None: | |||
assert hasattr(self, 'trainer') | |||
label2id = self.trainer.label2id | |||
self.id2label = {id: label for label, id in label2id.items()} | |||
self.preds = np.concatenate(self.preds, axis=0) | |||
self.labels = np.concatenate(self.labels, axis=0) | |||
predictions = np.argmax(self.preds, axis=-1) | |||
@@ -5,11 +5,11 @@ from abc import ABC, abstractmethod | |||
from typing import Any, Callable, Dict, List, Optional, Union | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models.builder import build_model | |||
from modelscope.utils.checkpoint import save_pretrained | |||
from modelscope.models.builder import MODELS, build_model | |||
from modelscope.utils.checkpoint import save_checkpoint, save_pretrained | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
from modelscope.utils.device import device_placement, verify_device | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile, Tasks | |||
from modelscope.utils.device import verify_device | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@@ -66,7 +66,6 @@ class Model(ABC): | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
cfg_dict: Config = None, | |||
device: str = None, | |||
*model_args, | |||
**kwargs): | |||
""" Instantiate a model from local directory or remote model repo. Note | |||
that when loading from remote, the model revision can be specified. | |||
@@ -90,11 +89,11 @@ class Model(ABC): | |||
cfg = Config.from_file( | |||
osp.join(local_model_dir, ModelFile.CONFIGURATION)) | |||
task_name = cfg.task | |||
if 'task' in kwargs: | |||
task_name = kwargs.pop('task') | |||
model_cfg = cfg.model | |||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | |||
model_cfg.type = model_cfg.model_type | |||
model_cfg.model_dir = local_model_dir | |||
for k, v in kwargs.items(): | |||
model_cfg[k] = v | |||
@@ -109,15 +108,19 @@ class Model(ABC): | |||
# dynamically add pipeline info to model for pipeline inference | |||
if hasattr(cfg, 'pipeline'): | |||
model.pipeline = cfg.pipeline | |||
if not hasattr(model, 'cfg'): | |||
model.cfg = cfg | |||
return model | |||
def save_pretrained(self, | |||
target_folder: Union[str, os.PathLike], | |||
save_checkpoint_names: Union[str, List[str]] = None, | |||
save_function: Callable = None, | |||
save_function: Callable = save_checkpoint, | |||
config: Optional[dict] = None, | |||
**kwargs): | |||
"""save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded | |||
"""save the pretrained model, its configuration and other related files to a directory, | |||
so that it can be re-loaded | |||
Args: | |||
target_folder (Union[str, os.PathLike]): | |||
@@ -133,5 +136,10 @@ class Model(ABC): | |||
The config for the configuration.json, might not be identical with model.config | |||
""" | |||
if config is None and hasattr(self, 'cfg'): | |||
config = self.cfg | |||
assert config is not None, 'Cannot save the model because the model config is empty.' | |||
if isinstance(config, Config): | |||
config = config.to_dict() | |||
save_pretrained(self, target_folder, save_checkpoint_names, | |||
save_function, config, **kwargs) |
@@ -1,10 +1,12 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.utils.config import ConfigDict | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | |||
MODELS = Registry('models') | |||
BACKBONES = Registry('backbones') | |||
BACKBONES._modules = MODELS._modules | |||
HEADS = Registry('heads') | |||
@@ -23,30 +25,27 @@ def build_model(cfg: ConfigDict, | |||
cfg, MODELS, group_key=task_name, default_args=default_args) | |||
def build_backbone(cfg: ConfigDict, | |||
field: str = None, | |||
default_args: dict = None): | |||
def build_backbone(cfg: ConfigDict, default_args: dict = None): | |||
""" build backbone given backbone config dict | |||
Args: | |||
cfg (:obj:`ConfigDict`): config dict for backbone object. | |||
field (str, optional): field, such as CV, NLP's backbone | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
return build_from_cfg( | |||
cfg, BACKBONES, group_key=field, default_args=default_args) | |||
cfg, BACKBONES, group_key=Tasks.backbone, default_args=default_args) | |||
def build_head(cfg: ConfigDict, | |||
group_key: str = None, | |||
task_name: str = None, | |||
default_args: dict = None): | |||
""" build head given config dict | |||
Args: | |||
cfg (:obj:`ConfigDict`): config dict for head object. | |||
task_name (str, optional): task name, refer to | |||
:obj:`Tasks` for more details | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
if group_key is None: | |||
group_key = cfg[TYPE_NAME] | |||
return build_from_cfg( | |||
cfg, HEADS, group_key=group_key, default_args=default_args) | |||
cfg, HEADS, group_key=task_name, default_args=default_args) |
@@ -1,13 +1,17 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .t5_for_text_generation import T5ForConditionalGeneration | |||
from .backbone import T5Model | |||
from .text2text_generation import T5ForConditionalGeneration | |||
else: | |||
_import_structure = { | |||
't5_for_text_generation': ['T5ForConditionalGeneration'], | |||
'backbone': ['T5Model'], | |||
'text2text_generation': ['T5ForConditionalGeneration'], | |||
} | |||
import sys | |||
@@ -1,3 +1,4 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Copyright 2020, The T5 Authors and HuggingFace Inc. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); |
@@ -1,56 +0,0 @@ | |||
from typing import Optional, Tuple | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Tensor, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
from .modeling_t5 import T5Config | |||
from .modeling_t5 import T5ForConditionalGeneration as T5ForGeneration | |||
@MODELS.register_module( | |||
group_key=Tasks.text2text_generation, | |||
module_name=Models.T5, | |||
) | |||
class T5ForConditionalGeneration(TorchModel): | |||
def __init__(self, model_dir=None, *args, **kwargs): | |||
"""initialize the text generation model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
model_cls (Optional[Any], optional): model loader, if None, use the | |||
default loader to load model weights, by default None. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.model = T5ForGeneration.from_pretrained(model_dir) | |||
self.generate = self.model.generate | |||
self.config = self.model.config | |||
def forward(self, | |||
input_ids: Optional[torch.LongTensor] = None, | |||
attention_mask: Optional[torch.FloatTensor] = None, | |||
decoder_input_ids: Optional[torch.LongTensor] = None, | |||
decoder_attention_mask: Optional[torch.BoolTensor] = None, | |||
head_mask: Optional[torch.FloatTensor] = None, | |||
decoder_head_mask: Optional[torch.FloatTensor] = None, | |||
cross_attn_head_mask: Optional[torch.Tensor] = None, | |||
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |||
inputs_embeds: Optional[torch.FloatTensor] = None, | |||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |||
labels: Optional[torch.LongTensor] = None, | |||
use_cache: Optional[bool] = None, | |||
output_attentions: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
**kwargs): | |||
return self.model.forward( | |||
self, input_ids, attention_mask, decoder_input_ids, | |||
decoder_attention_mask, head_mask, decoder_head_mask, | |||
cross_attn_head_mask, encoder_outputs, past_key_values, | |||
inputs_embeds, decoder_inputs_embeds, labels, use_cache, | |||
output_attentions, output_hidden_states, return_dict, **kwargs) |
@@ -0,0 +1,455 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import copy | |||
import warnings | |||
from typing import Optional, Tuple, Union | |||
import torch | |||
from torch import nn | |||
from torch.nn import CrossEntropyLoss | |||
from transformers.utils.model_parallel_utils import (assert_device_map, | |||
get_device_map) | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import BaseModelOutput, Seq2SeqLMOutput | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .backbone import T5PreTrainedModel, T5Stack | |||
from .configuration import T5Config | |||
logger = get_logger(__name__) | |||
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask | |||
__HEAD_MASK_WARNING_MSG = """ | |||
The input argument `head_mask` was split into two arguments `head_mask` and | |||
`decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`, | |||
but this feature is deprecated and will be removed in future versions. If you do | |||
not want to use any `decoder_head_mask` now, please set `decoder_head_mask = | |||
torch.ones(num_layers, num_heads)`. | |||
""" | |||
@MODELS.register_module( | |||
group_key=Tasks.text2text_generation, | |||
module_name=Models.T5, | |||
) | |||
class T5ForConditionalGeneration(T5PreTrainedModel): | |||
_keys_to_ignore_on_load_missing = [ | |||
r'encoder\.embed_tokens\.weight', | |||
r'decoder\.embed_tokens\.weight', | |||
r'lm_head\.weight', | |||
] | |||
_keys_to_ignore_on_load_unexpected = [ | |||
r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', | |||
] | |||
def __init__(self, config: T5Config): | |||
super().__init__(config) | |||
self.model_dim = config.d_model | |||
self.shared = nn.Embedding(config.vocab_size, config.d_model) | |||
encoder_config = copy.deepcopy(config) | |||
encoder_config.is_decoder = False | |||
encoder_config.use_cache = False | |||
encoder_config.is_encoder_decoder = False | |||
self.encoder = T5Stack(encoder_config, self.shared) | |||
decoder_config = copy.deepcopy(config) | |||
decoder_config.is_decoder = True | |||
decoder_config.is_encoder_decoder = False | |||
decoder_config.num_layers = config.num_decoder_layers | |||
self.decoder = T5Stack(decoder_config, self.shared) | |||
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
# Model parallel | |||
self.model_parallel = False | |||
self.device_map = None | |||
def parallelize(self, device_map=None): | |||
self.device_map = ( | |||
get_device_map( | |||
len(self.encoder.block), range(torch.cuda.device_count())) | |||
if device_map is None else device_map) | |||
assert_device_map(self.device_map, len(self.encoder.block)) | |||
self.encoder.parallelize(self.device_map) | |||
self.decoder.parallelize(self.device_map) | |||
self.lm_head = self.lm_head.to(self.decoder.first_device) | |||
self.model_parallel = True | |||
def deparallelize(self): | |||
self.encoder.deparallelize() | |||
self.decoder.deparallelize() | |||
self.encoder = self.encoder.to('cpu') | |||
self.decoder = self.decoder.to('cpu') | |||
self.lm_head = self.lm_head.to('cpu') | |||
self.model_parallel = False | |||
self.device_map = None | |||
torch.cuda.empty_cache() | |||
def get_input_embeddings(self): | |||
return self.shared | |||
def set_input_embeddings(self, new_embeddings): | |||
self.shared = new_embeddings | |||
self.encoder.set_input_embeddings(new_embeddings) | |||
self.decoder.set_input_embeddings(new_embeddings) | |||
def set_output_embeddings(self, new_embeddings): | |||
self.lm_head = new_embeddings | |||
def get_output_embeddings(self): | |||
return self.lm_head | |||
def get_encoder(self): | |||
return self.encoder | |||
def get_decoder(self): | |||
return self.decoder | |||
def forward(self, | |||
input_ids: Optional[torch.LongTensor] = None, | |||
attention_mask: Optional[torch.FloatTensor] = None, | |||
decoder_input_ids: Optional[torch.LongTensor] = None, | |||
decoder_attention_mask: Optional[torch.BoolTensor] = None, | |||
head_mask: Optional[torch.FloatTensor] = None, | |||
decoder_head_mask: Optional[torch.FloatTensor] = None, | |||
cross_attn_head_mask: Optional[torch.Tensor] = None, | |||
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |||
inputs_embeds: Optional[torch.FloatTensor] = None, | |||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |||
labels: Optional[torch.LongTensor] = None, | |||
use_cache: Optional[bool] = None, | |||
output_attentions: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
**kwargs) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: | |||
r""" | |||
Args: | |||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. T5 is a model | |||
with relative position embeddings so you should be able to pad the | |||
inputs on both the right and the left. | |||
Indices can be obtained using [`T5Tokenizer`]. See | |||
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] | |||
for detail. | |||
[What are input IDs?](../glossary#input-ids) | |||
To know more on how to prepare `input_ids` for pretraining take a | |||
look a [T5 Training](./t5#training). | |||
attention_mask (`torch.FloatTensor` of shape `(batch_size, | |||
sequence_length)`, *optional*): | |||
Mask to avoid performing attention on padding token indices. Mask | |||
values selected in `[0, 1]`: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
[What are attention masks?](../glossary#attention-mask) | |||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, | |||
target_sequence_length)`, *optional*): | |||
Indices of decoder input sequence tokens in the vocabulary. | |||
Indices can be obtained using [`T5Tokenizer`]. See | |||
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] | |||
for details. | |||
[What are decoder input IDs?](../glossary#decoder-input-ids) | |||
T5 uses the `pad_token_id` as the starting token for | |||
`decoder_input_ids` generation. If `past_key_values` is used, | |||
optionally only the last `decoder_input_ids` have to be input (see | |||
`past_key_values`). | |||
To know more on how to prepare `decoder_input_ids` for pretraining | |||
take a look at [T5 Training](./t5#training). | |||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, | |||
target_sequence_length)`, *optional*): | |||
Default behavior: generate a tensor that ignores pad tokens in | |||
`decoder_input_ids`. Causal mask will also be used by default. | |||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, | |||
num_heads)`, *optional*): | |||
Mask to nullify selected heads of the self-attention modules in the | |||
encoder. Mask values selected in `[0, 1]`: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or | |||
`(num_layers, num_heads)`, *optional*): | |||
Mask to nullify selected heads of the self-attention modules in the | |||
decoder. Mask values selected in `[0, 1]`: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or | |||
`(num_layers, num_heads)`, *optional*): | |||
Mask to nullify selected heads of the cross-attention modules in | |||
the decoder. Mask values selected in `[0, 1]`: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): | |||
Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, | |||
`optional`: *attentions*) `last_hidden_state` of shape `(batch_size, | |||
sequence_length, hidden_size)` is a sequence of hidden states at the | |||
output of the last layer of the encoder. Used in the cross-attention | |||
of the decoder. | |||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length | |||
`config.n_layers` with each tuple having 4 tensors of shape | |||
`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): | |||
Contains precomputed key and value hidden states of the attention | |||
blocks. Can be used to speed up decoding. | |||
If `past_key_values` are used, the user can optionally input only | |||
the last `decoder_input_ids` (those that don't have their past key | |||
value states given to this model) of shape `(batch_size, 1)` instead | |||
of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. | |||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, | |||
sequence_length, hidden_size)`, *optional*): | |||
Optionally, instead of passing `input_ids` you can choose to | |||
directly pass an embedded representation. This is useful if you want | |||
more control over how to convert `input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, | |||
target_sequence_length, hidden_size)`, *optional*): | |||
Optionally, instead of passing `decoder_input_ids` you can choose to | |||
directly pass an embedded representation. If `past_key_values` is | |||
used, optionally only the last `decoder_inputs_embeds` have to be | |||
input (see `past_key_values`). This is useful if you want more | |||
control over how to convert `decoder_input_ids` indices into | |||
associated vectors than the model's internal embedding lookup | |||
matrix. | |||
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, | |||
`decoder_inputs_embeds` takes the value of `inputs_embeds`. | |||
use_cache (`bool`, *optional*): | |||
If set to `True`, `past_key_values` key value states are returned | |||
and can be used to speed up decoding (see `past_key_values`). | |||
output_attentions (`bool`, *optional*): | |||
Whether or not to return the attentions tensors of all attention | |||
layers. See `attentions` under returned tensors for more detail. | |||
output_hidden_states (`bool`, *optional*): | |||
Whether or not to return the hidden states of all layers. See | |||
`hidden_states` under returned tensors for more detail. | |||
return_dict (`bool`, *optional*): | |||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain | |||
tuple. | |||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |||
Labels for computing the sequence classification/regression loss. | |||
Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All | |||
labels set to `-100` are ignored (masked), the loss is only computed | |||
for labels in `[0, ..., config.vocab_size]` | |||
Returns: | |||
Examples: | |||
```python >>> from transformers import T5Tokenizer, | |||
T5ForConditionalGeneration | |||
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small") | |||
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small") | |||
>>> # training | |||
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids | |||
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids | |||
>>> outputs = model(input_ids=input_ids, labels=labels) | |||
>>> loss = outputs.loss | |||
>>> logits = outputs.logits | |||
>>> # inference | |||
>>> input_ids = tokenizer( | |||
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" | |||
>>> ).input_ids # Batch size 1 | |||
>>> outputs = model.generate(input_ids) | |||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | |||
>>> # studies have shown that owning a dog is good for you. | |||
```""" | |||
use_cache = use_cache if use_cache is not None else self.config.use_cache | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask | |||
if head_mask is not None and decoder_head_mask is None: | |||
if self.config.num_layers == self.config.num_decoder_layers: | |||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) | |||
decoder_head_mask = head_mask | |||
# Encode if needed (training, first prediction pass) | |||
if encoder_outputs is None: | |||
# Convert encoder inputs in embeddings if needed | |||
encoder_outputs = self.encoder( | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
inputs_embeds=inputs_embeds, | |||
head_mask=head_mask, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | |||
encoder_outputs = BaseModelOutput( | |||
last_hidden_state=encoder_outputs[0], | |||
hidden_states=encoder_outputs[1] | |||
if len(encoder_outputs) > 1 else None, | |||
attentions=encoder_outputs[2] | |||
if len(encoder_outputs) > 2 else None, | |||
) | |||
hidden_states = encoder_outputs[0] | |||
if self.model_parallel: | |||
torch.cuda.set_device(self.decoder.first_device) | |||
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: | |||
# get decoder inputs from shifting lm labels to the right | |||
decoder_input_ids = self._shift_right(labels) | |||
# Set device for model parallelism | |||
if self.model_parallel: | |||
torch.cuda.set_device(self.decoder.first_device) | |||
hidden_states = hidden_states.to(self.decoder.first_device) | |||
if decoder_input_ids is not None: | |||
decoder_input_ids = decoder_input_ids.to( | |||
self.decoder.first_device) | |||
if attention_mask is not None: | |||
attention_mask = attention_mask.to(self.decoder.first_device) | |||
if decoder_attention_mask is not None: | |||
decoder_attention_mask = decoder_attention_mask.to( | |||
self.decoder.first_device) | |||
# Decode | |||
decoder_outputs = self.decoder( | |||
input_ids=decoder_input_ids, | |||
attention_mask=decoder_attention_mask, | |||
inputs_embeds=decoder_inputs_embeds, | |||
past_key_values=past_key_values, | |||
encoder_hidden_states=hidden_states, | |||
encoder_attention_mask=attention_mask, | |||
head_mask=decoder_head_mask, | |||
cross_attn_head_mask=cross_attn_head_mask, | |||
use_cache=use_cache, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = decoder_outputs[0] | |||
# Set device for model parallelism | |||
if self.model_parallel: | |||
torch.cuda.set_device(self.encoder.first_device) | |||
self.lm_head = self.lm_head.to(self.encoder.first_device) | |||
sequence_output = sequence_output.to(self.lm_head.weight.device) | |||
if self.config.tie_word_embeddings: | |||
# Rescale output before projecting on vocab See | |||
# https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 | |||
sequence_output = sequence_output * (self.model_dim**-0.5) | |||
lm_logits = self.lm_head(sequence_output) | |||
loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss(ignore_index=-100) | |||
loss = loss_fct( | |||
lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) | |||
# TODO(thom): Add z_loss | |||
# https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 | |||
if not return_dict: | |||
output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs | |||
return ((loss, ) + output) if loss is not None else output | |||
return Seq2SeqLMOutput( | |||
loss=loss, | |||
logits=lm_logits, | |||
past_key_values=decoder_outputs.past_key_values, | |||
decoder_hidden_states=decoder_outputs.hidden_states, | |||
decoder_attentions=decoder_outputs.attentions, | |||
cross_attentions=decoder_outputs.cross_attentions, | |||
encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |||
encoder_hidden_states=encoder_outputs.hidden_states, | |||
encoder_attentions=encoder_outputs.attentions, | |||
) | |||
def prepare_inputs_for_generation(self, | |||
input_ids, | |||
past=None, | |||
attention_mask=None, | |||
head_mask=None, | |||
decoder_head_mask=None, | |||
cross_attn_head_mask=None, | |||
use_cache=None, | |||
encoder_outputs=None, | |||
**kwargs): | |||
# cut decoder_input_ids if past is used | |||
if past is not None: | |||
input_ids = input_ids[:, -1:] | |||
return { | |||
'decoder_input_ids': input_ids, | |||
'past_key_values': past, | |||
'encoder_outputs': encoder_outputs, | |||
'attention_mask': attention_mask, | |||
'head_mask': head_mask, | |||
'decoder_head_mask': decoder_head_mask, | |||
'cross_attn_head_mask': cross_attn_head_mask, | |||
'use_cache': use_cache, | |||
} | |||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): | |||
return self._shift_right(labels) | |||
def _reorder_cache(self, past, beam_idx): | |||
# if decoder past is not included in output | |||
# speedy decoding is disabled and no need to reorder | |||
if past is None: | |||
logger.warning( | |||
'You might want to consider setting `use_cache=True` to speed up decoding' | |||
) | |||
return past | |||
reordered_decoder_past = () | |||
for layer_past_states in past: | |||
# get the correct batch idx from layer past batch dim | |||
# batch dim of `past` is at 2nd position | |||
reordered_layer_past_states = () | |||
for layer_past_state in layer_past_states: | |||
# need to set correct `past` for each of the four key / value states | |||
reordered_layer_past_states = reordered_layer_past_states + ( | |||
layer_past_state.index_select( | |||
0, beam_idx.to(layer_past_state.device)), ) | |||
assert reordered_layer_past_states[0].shape == layer_past_states[ | |||
0].shape | |||
assert len(reordered_layer_past_states) == len(layer_past_states) | |||
reordered_decoder_past = reordered_decoder_past + ( | |||
reordered_layer_past_states, ) | |||
return reordered_decoder_past |
@@ -4,80 +4,99 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .backbones import SbertModel | |||
from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
from .bert_for_document_segmentation import BertForDocumentSegmentation | |||
from .csanmt_for_translation import CsanmtForTranslation | |||
from .bart import BartForTextErrorCorrection | |||
from .csanmt import CsanmtForTranslation | |||
from .heads import SequenceClassificationHead | |||
from .gpt3 import GPT3ForTextGeneration | |||
from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, | |||
BertForMaskedLM, DebertaV2ForMaskedLM) | |||
from .ponet_for_masked_language import PoNetForMaskedLM | |||
from .nncrf_for_named_entity_recognition import ( | |||
TransformerCRFForNamedEntityRecognition, | |||
LSTMCRFForNamedEntityRecognition) | |||
from .palm_v2 import PalmForTextGeneration | |||
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||
from .star_text_to_sql import StarForTextToSql | |||
from .sequence_classification import (VecoForSequenceClassification, | |||
SbertForSequenceClassification, | |||
BertForSequenceClassification) | |||
from .space import SpaceForDialogIntent | |||
from .space import SpaceForDialogModeling | |||
from .space import SpaceForDialogStateTracking | |||
from .table_question_answering import TableQuestionAnswering | |||
from .task_models import (FeatureExtractionModel, | |||
InformationExtractionModel, | |||
SequenceClassificationModel, | |||
SingleBackboneTaskModelBase, | |||
TokenClassificationModel, | |||
TaskModelForTextGeneration) | |||
from .token_classification import SbertForTokenClassification | |||
from .sentence_embedding import SentenceEmbedding | |||
from .text_ranking import TextRanking | |||
from .T5 import T5ForConditionalGeneration | |||
from .space_T_en import StarForTextToSql | |||
from .space_T_cn import TableQuestionAnswering | |||
from .space import SpaceForDialogIntent, SpaceForDialogModeling, SpaceForDST | |||
from .ponet import PoNetForMaskedLM, PoNetModel, PoNetConfig | |||
from .structbert import ( | |||
SbertForFaqQuestionAnswering, | |||
SbertForMaskedLM, | |||
SbertForSequenceClassification, | |||
SbertForTokenClassification, | |||
SbertTokenizer, | |||
SbertTokenizerFast, | |||
) | |||
from .bert import ( | |||
BertForMaskedLM, | |||
BertForTextRanking, | |||
BertForSentenceEmbedding, | |||
BertForSequenceClassification, | |||
BertForTokenClassification, | |||
BertForDocumentSegmentation, | |||
BertModel, | |||
BertConfig, | |||
) | |||
from .veco import VecoModel, VecoConfig, VecoForTokenClassification, \ | |||
VecoForSequenceClassification, VecoForMaskedLM, VecoTokenizer, VecoTokenizerFast | |||
from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model | |||
from .task_models import ( | |||
FeatureExtractionModel, | |||
InformationExtractionModel, | |||
LSTMCRFForNamedEntityRecognition, | |||
SequenceClassificationModel, | |||
SingleBackboneTaskModelBase, | |||
TaskModelForTextGeneration, | |||
TokenClassificationModel, | |||
TransformerCRFForNamedEntityRecognition, | |||
) | |||
from .T5 import T5ForConditionalGeneration | |||
from .gpt_neo import GPTNeoModel | |||
else: | |||
_import_structure = { | |||
'backbones': ['SbertModel'], | |||
'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
'bert_for_document_segmentation': ['BertForDocumentSegmentation'], | |||
'csanmt_for_translation': ['CsanmtForTranslation'], | |||
'bart': ['BartForTextErrorCorrection'], | |||
'csanmt': ['CsanmtForTranslation'], | |||
'heads': ['SequenceClassificationHead'], | |||
'gpt3': ['GPT3ForTextGeneration'], | |||
'masked_language': [ | |||
'StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM', | |||
'DebertaV2ForMaskedLM' | |||
'structbert': [ | |||
'SbertForFaqQuestionAnswering', | |||
'SbertForMaskedLM', | |||
'SbertForSequenceClassification', | |||
'SbertForTokenClassification', | |||
'SbertTokenizer', | |||
'SbertTokenizerFast', | |||
], | |||
'nncrf_for_named_entity_recognition': [ | |||
'TransformerCRFForNamedEntityRecognition', | |||
'LSTMCRFForNamedEntityRecognition' | |||
], | |||
'ponet_for_masked_language': ['PoNetForMaskedLM'], | |||
'palm_v2': ['PalmForTextGeneration'], | |||
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||
'star_text_to_sql': ['StarForTextToSql'], | |||
'sequence_classification': [ | |||
'VecoForSequenceClassification', 'SbertForSequenceClassification', | |||
'BertForSequenceClassification' | |||
'veco': [ | |||
'VecoModel', 'VecoConfig', 'VecoForTokenClassification', | |||
'VecoForSequenceClassification', 'VecoForMaskedLM', | |||
'VecoTokenizer', 'VecoTokenizerFast' | |||
], | |||
'space': [ | |||
'SpaceForDialogIntent', 'SpaceForDialogModeling', | |||
'SpaceForDialogStateTracking' | |||
'bert': [ | |||
'BertForMaskedLM', | |||
'BertForTextRanking', | |||
'BertForSentenceEmbedding', | |||
'BertForSequenceClassification', | |||
'BertForTokenClassification', | |||
'BertForDocumentSegmentation', | |||
'BertModel', | |||
'BertConfig', | |||
], | |||
'ponet': ['PoNetForMaskedLM', 'PoNetModel', 'PoNetConfig'], | |||
'palm_v2': ['PalmForTextGeneration'], | |||
'deberta_v2': ['DebertaV2ForMaskedLM', 'DebertaV2Model'], | |||
'space_T_en': ['StarForTextToSql'], | |||
'space_T_cn': ['TableQuestionAnswering'], | |||
'space': | |||
['SpaceForDialogIntent', 'SpaceForDialogModeling', 'SpaceForDST'], | |||
'task_models': [ | |||
'FeatureExtractionModel', | |||
'InformationExtractionModel', | |||
'LSTMCRFForNamedEntityRecognition', | |||
'SequenceClassificationModel', | |||
'SingleBackboneTaskModelBase', | |||
'TokenClassificationModel', | |||
'TaskModelForTextGeneration', | |||
'TokenClassificationModel', | |||
'TransformerCRFForNamedEntityRecognition', | |||
], | |||
'token_classification': ['SbertForTokenClassification'], | |||
'table_question_answering': ['TableQuestionAnswering'], | |||
'sentence_embedding': ['SentenceEmbedding'], | |||
'text_ranking': ['TextRanking'], | |||
'T5': ['T5ForConditionalGeneration'], | |||
'gpt_neo': ['GPTNeoModel'], | |||
} | |||
import sys | |||
@@ -1,7 +0,0 @@ | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import BACKBONES | |||
from modelscope.models.nlp.bert import BertModel | |||
from modelscope.utils.constant import Fields | |||
BACKBONES.register_module( | |||
group_key=Fields.nlp, module_name=Models.bert, module_cls=BertModel) |
@@ -1,52 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import BACKBONES | |||
from modelscope.models.nlp.structbert import SbertConfig | |||
from modelscope.models.nlp.structbert import SbertModel as SbertModelTransform | |||
from modelscope.utils.constant import Fields | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger(__name__) | |||
@BACKBONES.register_module(Fields.nlp, module_name=Models.structbert) | |||
class SbertModel(TorchModel, SbertModelTransform): | |||
def __init__(self, model_dir=None, add_pooling_layer=True, **config): | |||
""" | |||
Args: | |||
model_dir (str, optional): The model checkpoint directory. Defaults to None. | |||
add_pooling_layer (bool, optional): to decide if pool the output from hidden layer. Defaults to True. | |||
""" | |||
config = SbertConfig(**config) | |||
super().__init__(model_dir) | |||
self.config = config | |||
SbertModelTransform.__init__(self, config, add_pooling_layer) | |||
def extract_sequence_outputs(self, outputs): | |||
return outputs['last_hidden_state'] | |||
def extract_pooled_outputs(self, outputs): | |||
return outputs['pooler_output'] | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_values=None, | |||
use_cache=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
**kwargs): | |||
return SbertModelTransform.forward( | |||
self, input_ids, attention_mask, token_type_ids, position_ids, | |||
head_mask, inputs_embeds, encoder_hidden_states, | |||
encoder_attention_mask, past_key_values, use_cache, | |||
output_attentions, output_hidden_states, return_dict, **kwargs) |
@@ -0,0 +1,2 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .text_error_correction import BartForTextErrorCorrection |
@@ -4,43 +4,33 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .modeling_bert import ( | |||
BertForMaskedLM, | |||
BertForMultipleChoice, | |||
BertForNextSentencePrediction, | |||
BertForPreTraining, | |||
BertForQuestionAnswering, | |||
BertForSequenceClassification, | |||
BertForTokenClassification, | |||
from .backbone import ( | |||
BertLayer, | |||
BertLMHeadModel, | |||
BertModel, | |||
BertPreTrainedModel, | |||
load_tf_weights_in_bert, | |||
) | |||
from .configuration_bert import BertConfig, BertOnnxConfig | |||
from .configuration import BertConfig | |||
from .fill_mask import BertForMaskedLM | |||
from .text_ranking import BertForTextRanking | |||
from .sentence_embedding import BertForSentenceEmbedding | |||
from .text_classification import BertForSequenceClassification | |||
from .token_classification import BertForTokenClassification | |||
from .document_segmentation import BertForDocumentSegmentation | |||
else: | |||
_import_structure = { | |||
'configuration_bert': ['BertConfig', 'BertOnnxConfig'], | |||
'backbone': [ | |||
'BertModel', | |||
'BertPreTrainedModel', | |||
], | |||
'configuration': ['BertConfig'], | |||
'fill_mask': ['BertForMaskedLM'], | |||
'text_ranking': ['BertForTextRanking'], | |||
'sentence_embedding': ['BertForSentenceEmbedding'], | |||
'text_classification': ['BertForSequenceClassification'], | |||
'token_classification': ['BertForTokenClassification'], | |||
'document_segmentation': ['BertForDocumentSegmentation'], | |||
} | |||
_import_structure['modeling_bert'] = [ | |||
'BertForMaskedLM', | |||
'BertForMultipleChoice', | |||
'BertForNextSentencePrediction', | |||
'BertForPreTraining', | |||
'BertForQuestionAnswering', | |||
'BertForSequenceClassification', | |||
'BertForTokenClassification', | |||
'BertLayer', | |||
'BertLMHeadModel', | |||
'BertModel', | |||
'BertPreTrainedModel', | |||
'load_tf_weights_in_bert', | |||
] | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
@@ -0,0 +1,952 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""PyTorch BERT model. """ | |||
import math | |||
import os | |||
from dataclasses import dataclass | |||
from typing import Optional, Tuple | |||
import torch | |||
import torch.utils.checkpoint | |||
from packaging import version | |||
from torch import nn | |||
from transformers.activations import ACT2FN | |||
from transformers.modeling_utils import (PreTrainedModel, | |||
apply_chunking_to_forward, | |||
find_pruneable_heads_and_indices, | |||
prune_linear_layer) | |||
from modelscope.metainfo import Models | |||
from modelscope.models import Model, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import (BaseModelOutputWithPastAndCrossAttentions, | |||
BaseModelOutputWithPoolingAndCrossAttentions) | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.hub import parse_label_mapping | |||
from modelscope.utils.logger import get_logger | |||
from .configuration import BertConfig | |||
logger = get_logger(__name__) | |||
_CONFIG_FOR_DOC = 'BertConfig' | |||
class BertEmbeddings(nn.Module): | |||
"""Construct the embeddings from word, position and token_type embeddings.""" | |||
def __init__(self, config): | |||
super().__init__() | |||
self.word_embeddings = nn.Embedding( | |||
config.vocab_size, | |||
config.hidden_size, | |||
padding_idx=config.pad_token_id) | |||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, | |||
config.hidden_size) | |||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, | |||
config.hidden_size) | |||
# self.LayerNorm is not snake-cased to stick with TensorFlow model | |||
# variable name and be able to load any TensorFlow checkpoint file | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
# position_ids (1, len position emb) is contiguous in memory and | |||
# exported when serialized | |||
self.position_embedding_type = getattr(config, | |||
'position_embedding_type', | |||
'absolute') | |||
self.register_buffer( | |||
'position_ids', | |||
torch.arange(config.max_position_embeddings).expand((1, -1))) | |||
if version.parse(torch.__version__) > version.parse('1.6.0'): | |||
self.register_buffer( | |||
'token_type_ids', | |||
torch.zeros(self.position_ids.size(), dtype=torch.long), | |||
persistent=False, | |||
) | |||
def forward(self, | |||
input_ids=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
inputs_embeds=None, | |||
past_key_values_length=0): | |||
if input_ids is not None: | |||
input_shape = input_ids.size() | |||
else: | |||
input_shape = inputs_embeds.size()[:-1] | |||
seq_length = input_shape[1] | |||
if position_ids is None: | |||
position_ids = self.position_ids[:, | |||
past_key_values_length:seq_length | |||
+ past_key_values_length] | |||
# Setting the token_type_ids to the registered buffer in constructor | |||
# where it is all zeros, which usually occurs when its auto-generated, | |||
# registered buffer helps users when tracing the model without passing | |||
# token_type_ids, solves issue #5664 | |||
if token_type_ids is None: | |||
if hasattr(self, 'token_type_ids'): | |||
buffered_token_type_ids = self.token_type_ids[:, :seq_length] | |||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( | |||
input_shape[0], seq_length) | |||
token_type_ids = buffered_token_type_ids_expanded | |||
else: | |||
token_type_ids = torch.zeros( | |||
input_shape, | |||
dtype=torch.long, | |||
device=self.position_ids.device) | |||
if inputs_embeds is None: | |||
inputs_embeds = self.word_embeddings(input_ids) | |||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||
embeddings = inputs_embeds + token_type_embeddings | |||
if self.position_embedding_type == 'absolute': | |||
position_embeddings = self.position_embeddings(position_ids) | |||
embeddings += position_embeddings | |||
embeddings = self.LayerNorm(embeddings) | |||
embeddings = self.dropout(embeddings) | |||
return embeddings | |||
class BertSelfAttention(nn.Module): | |||
def __init__(self, config, position_embedding_type=None): | |||
super().__init__() | |||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr( | |||
config, 'embedding_size'): | |||
raise ValueError( | |||
f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' | |||
f'heads ({config.num_attention_heads})') | |||
self.num_attention_heads = config.num_attention_heads | |||
self.attention_head_size = int(config.hidden_size | |||
/ config.num_attention_heads) | |||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
self.query = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.key = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.value = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | |||
self.position_embedding_type = position_embedding_type or getattr( | |||
config, 'position_embedding_type', 'absolute') | |||
if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': | |||
self.max_position_embeddings = config.max_position_embeddings | |||
self.distance_embedding = nn.Embedding( | |||
2 * config.max_position_embeddings - 1, | |||
self.attention_head_size) | |||
self.is_decoder = config.is_decoder | |||
def transpose_for_scores(self, x): | |||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, | |||
self.attention_head_size) | |||
x = x.view(*new_x_shape) | |||
return x.permute(0, 2, 1, 3) | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_value=None, | |||
output_attentions=False, | |||
): | |||
mixed_query_layer = self.query(hidden_states) | |||
# If this is instantiated as a cross-attention module, the keys | |||
# and values come from an encoder; the attention mask needs to be | |||
# such that the encoder's padding tokens are not attended to. | |||
is_cross_attention = encoder_hidden_states is not None | |||
if is_cross_attention and past_key_value is not None: | |||
# reuse k,v, cross_attentions | |||
key_layer = past_key_value[0] | |||
value_layer = past_key_value[1] | |||
attention_mask = encoder_attention_mask | |||
elif is_cross_attention: | |||
key_layer = self.transpose_for_scores( | |||
self.key(encoder_hidden_states)) | |||
value_layer = self.transpose_for_scores( | |||
self.value(encoder_hidden_states)) | |||
attention_mask = encoder_attention_mask | |||
elif past_key_value is not None: | |||
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |||
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) | |||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2) | |||
else: | |||
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |||
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||
if self.is_decoder: | |||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all | |||
# cross attention key/value_states. Further calls to cross_attention | |||
# layer can then reuse all cross-attention key/value_states (first | |||
# "if" case) if uni-directional self-attention (decoder) save | |||
# Tuple(torch.Tensor, torch.Tensor) of all previous decoder | |||
# key/value_states. Further calls to uni-directional self-attention | |||
# can concat previous decoder key/value_states to current projected | |||
# key/value_states (third "elif" case) if encoder bi-directional | |||
# self-attention `past_key_value` is always `None` | |||
past_key_value = (key_layer, value_layer) | |||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||
attention_scores = torch.matmul(query_layer, | |||
key_layer.transpose(-1, -2)) | |||
if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': | |||
seq_length = hidden_states.size()[1] | |||
position_ids_l = torch.arange( | |||
seq_length, dtype=torch.long, | |||
device=hidden_states.device).view(-1, 1) | |||
position_ids_r = torch.arange( | |||
seq_length, dtype=torch.long, | |||
device=hidden_states.device).view(1, -1) | |||
distance = position_ids_l - position_ids_r | |||
positional_embedding = self.distance_embedding( | |||
distance + self.max_position_embeddings - 1) | |||
positional_embedding = positional_embedding.to( | |||
dtype=query_layer.dtype) # fp16 compatibility | |||
if self.position_embedding_type == 'relative_key': | |||
relative_position_scores = torch.einsum( | |||
'bhld,lrd->bhlr', query_layer, positional_embedding) | |||
attention_scores = attention_scores + relative_position_scores | |||
elif self.position_embedding_type == 'relative_key_query': | |||
relative_position_scores_query = torch.einsum( | |||
'bhld,lrd->bhlr', query_layer, positional_embedding) | |||
relative_position_scores_key = torch.einsum( | |||
'bhrd,lrd->bhlr', key_layer, positional_embedding) | |||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key | |||
attention_scores = attention_scores / math.sqrt( | |||
self.attention_head_size) | |||
if attention_mask is not None: | |||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||
attention_scores = attention_scores + attention_mask | |||
# Normalize the attention scores to probabilities. | |||
attention_probs = nn.functional.softmax(attention_scores, dim=-1) | |||
# This is actually dropping out entire tokens to attend to, which might | |||
# seem a bit unusual, but is taken from the original Transformer paper. | |||
attention_probs = self.dropout(attention_probs) | |||
# Mask heads if we want to | |||
if head_mask is not None: | |||
attention_probs = attention_probs * head_mask | |||
context_layer = torch.matmul(attention_probs, value_layer) | |||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
new_context_layer_shape = context_layer.size()[:-2] + ( | |||
self.all_head_size, ) | |||
context_layer = context_layer.view(*new_context_layer_shape) | |||
outputs = (context_layer, | |||
attention_probs) if output_attentions else (context_layer, ) | |||
if self.is_decoder: | |||
outputs = outputs + (past_key_value, ) | |||
return outputs | |||
class BertSelfOutput(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class BertAttention(nn.Module): | |||
def __init__(self, config, position_embedding_type=None): | |||
super().__init__() | |||
self.self = BertSelfAttention( | |||
config, position_embedding_type=position_embedding_type) | |||
self.output = BertSelfOutput(config) | |||
self.pruned_heads = set() | |||
def prune_heads(self, heads): | |||
if len(heads) == 0: | |||
return | |||
heads, index = find_pruneable_heads_and_indices( | |||
heads, self.self.num_attention_heads, | |||
self.self.attention_head_size, self.pruned_heads) | |||
# Prune linear layers | |||
self.self.query = prune_linear_layer(self.self.query, index) | |||
self.self.key = prune_linear_layer(self.self.key, index) | |||
self.self.value = prune_linear_layer(self.self.value, index) | |||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) | |||
# Update hyper params and store pruned heads | |||
self.self.num_attention_heads = self.self.num_attention_heads - len( | |||
heads) | |||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads | |||
self.pruned_heads = self.pruned_heads.union(heads) | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_value=None, | |||
output_attentions=False, | |||
): | |||
self_outputs = self.self( | |||
hidden_states, | |||
attention_mask, | |||
head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
past_key_value, | |||
output_attentions, | |||
) | |||
attention_output = self.output(self_outputs[0], hidden_states) | |||
outputs = (attention_output, | |||
) + self_outputs[1:] # add attentions if we output them | |||
return outputs | |||
class BertIntermediate(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |||
if isinstance(config.hidden_act, str): | |||
self.intermediate_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.intermediate_act_fn = config.hidden_act | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.intermediate_act_fn(hidden_states) | |||
return hidden_states | |||
class BertOutput(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class BertLayer(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.chunk_size_feed_forward = config.chunk_size_feed_forward | |||
self.seq_len_dim = 1 | |||
self.attention = BertAttention(config) | |||
self.is_decoder = config.is_decoder | |||
self.add_cross_attention = config.add_cross_attention | |||
if self.add_cross_attention: | |||
if not self.is_decoder: | |||
raise ValueError( | |||
f'{self} should be used as a decoder model if cross attention is added' | |||
) | |||
self.crossattention = BertAttention( | |||
config, position_embedding_type='absolute') | |||
self.intermediate = BertIntermediate(config) | |||
self.output = BertOutput(config) | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_value=None, | |||
output_attentions=False, | |||
): | |||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 | |||
self_attn_past_key_value = past_key_value[: | |||
2] if past_key_value is not None else None | |||
self_attention_outputs = self.attention( | |||
hidden_states, | |||
attention_mask, | |||
head_mask, | |||
output_attentions=output_attentions, | |||
past_key_value=self_attn_past_key_value, | |||
) | |||
attention_output = self_attention_outputs[0] | |||
# if decoder, the last output is tuple of self-attn cache | |||
if self.is_decoder: | |||
outputs = self_attention_outputs[1:-1] | |||
present_key_value = self_attention_outputs[-1] | |||
else: | |||
outputs = self_attention_outputs[ | |||
1:] # add self attentions if we output attention weights | |||
cross_attn_present_key_value = None | |||
if self.is_decoder and encoder_hidden_states is not None: | |||
if not hasattr(self, 'crossattention'): | |||
raise ValueError( | |||
f'If `encoder_hidden_states` are passed, {self} has to be instantiated ' | |||
f'with cross-attention layers by setting `config.add_cross_attention=True`' | |||
) | |||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple | |||
cross_attn_past_key_value = past_key_value[ | |||
-2:] if past_key_value is not None else None | |||
cross_attention_outputs = self.crossattention( | |||
attention_output, | |||
attention_mask, | |||
head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
cross_attn_past_key_value, | |||
output_attentions, | |||
) | |||
attention_output = cross_attention_outputs[0] | |||
outputs = outputs + cross_attention_outputs[ | |||
1:-1] # add cross attentions if we output attention weights | |||
# add cross-attn cache to positions 3,4 of present_key_value tuple | |||
cross_attn_present_key_value = cross_attention_outputs[-1] | |||
present_key_value = present_key_value + cross_attn_present_key_value | |||
layer_output = apply_chunking_to_forward(self.feed_forward_chunk, | |||
self.chunk_size_feed_forward, | |||
self.seq_len_dim, | |||
attention_output) | |||
outputs = (layer_output, ) + outputs | |||
# if decoder, return the attn key/values as the last output | |||
if self.is_decoder: | |||
outputs = outputs + (present_key_value, ) | |||
return outputs | |||
def feed_forward_chunk(self, attention_output): | |||
intermediate_output = self.intermediate(attention_output) | |||
layer_output = self.output(intermediate_output, attention_output) | |||
return layer_output | |||
class BertEncoder(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.config = config | |||
self.layer = nn.ModuleList( | |||
[BertLayer(config) for _ in range(config.num_hidden_layers)]) | |||
self.gradient_checkpointing = False | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_values=None, | |||
use_cache=None, | |||
output_attentions=False, | |||
output_hidden_states=False, | |||
return_dict=True, | |||
): | |||
all_hidden_states = () if output_hidden_states else None | |||
all_self_attentions = () if output_attentions else None | |||
all_cross_attentions = ( | |||
) if output_attentions and self.config.add_cross_attention else None | |||
next_decoder_cache = () if use_cache else None | |||
for i, layer_module in enumerate(self.layer): | |||
if output_hidden_states: | |||
all_hidden_states = all_hidden_states + (hidden_states, ) | |||
layer_head_mask = head_mask[i] if head_mask is not None else None | |||
past_key_value = past_key_values[ | |||
i] if past_key_values is not None else None | |||
if self.gradient_checkpointing and self.training: | |||
if use_cache: | |||
logger.warning( | |||
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' | |||
) | |||
use_cache = False | |||
def create_custom_forward(module): | |||
def custom_forward(*inputs): | |||
return module(*inputs, past_key_value, | |||
output_attentions) | |||
return custom_forward | |||
layer_outputs = torch.utils.checkpoint.checkpoint( | |||
create_custom_forward(layer_module), | |||
hidden_states, | |||
attention_mask, | |||
layer_head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
) | |||
else: | |||
layer_outputs = layer_module( | |||
hidden_states, | |||
attention_mask, | |||
layer_head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
past_key_value, | |||
output_attentions, | |||
) | |||
hidden_states = layer_outputs[0] | |||
if use_cache: | |||
next_decoder_cache += (layer_outputs[-1], ) | |||
if output_attentions: | |||
all_self_attentions = all_self_attentions + ( | |||
layer_outputs[1], ) | |||
if self.config.add_cross_attention: | |||
all_cross_attentions = all_cross_attentions + ( | |||
layer_outputs[2], ) | |||
if output_hidden_states: | |||
all_hidden_states = all_hidden_states + (hidden_states, ) | |||
if not return_dict: | |||
return tuple(v for v in [ | |||
hidden_states, | |||
next_decoder_cache, | |||
all_hidden_states, | |||
all_self_attentions, | |||
all_cross_attentions, | |||
] if v is not None) | |||
return BaseModelOutputWithPastAndCrossAttentions( | |||
last_hidden_state=hidden_states, | |||
past_key_values=next_decoder_cache, | |||
hidden_states=all_hidden_states, | |||
attentions=all_self_attentions, | |||
cross_attentions=all_cross_attentions, | |||
) | |||
class BertPooler(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
self.activation = nn.Tanh() | |||
def forward(self, hidden_states): | |||
# We "pool" the model by simply taking the hidden state corresponding | |||
# to the first token. | |||
first_token_tensor = hidden_states[:, 0] | |||
pooled_output = self.dense(first_token_tensor) | |||
pooled_output = self.activation(pooled_output) | |||
return pooled_output | |||
class BertPreTrainedModel(TorchModel, PreTrainedModel): | |||
""" | |||
An abstract class to handle weights initialization and a simple interface | |||
for downloading and loading pretrained models. | |||
""" | |||
config_class = BertConfig | |||
base_model_prefix = 'bert' | |||
supports_gradient_checkpointing = True | |||
_keys_to_ignore_on_load_missing = [r'position_ids'] | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config.name_or_path, **kwargs) | |||
super(Model, self).__init__(config) | |||
def _init_weights(self, module): | |||
"""Initialize the weights""" | |||
if isinstance(module, nn.Linear): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.bias is not None: | |||
module.bias.data.zero_() | |||
elif isinstance(module, nn.Embedding): | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.padding_idx is not None: | |||
module.weight.data[module.padding_idx].zero_() | |||
elif isinstance(module, nn.LayerNorm): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
def _set_gradient_checkpointing(self, module, value=False): | |||
if isinstance(module, BertEncoder): | |||
module.gradient_checkpointing = value | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
Args: | |||
kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
num_labels: An optional arg to tell the model how many classes to initialize. | |||
Method will call utils.parse_label_mapping if num_labels not supplied. | |||
If num_labels is not found, the model will use the default setting (2 classes). | |||
Returns: | |||
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
model_dir = kwargs.get('model_dir', None) | |||
if model_dir is None: | |||
config = BertConfig(**kwargs) | |||
model = cls(config) | |||
else: | |||
model_kwargs = {} | |||
label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) | |||
id2label = kwargs.get( | |||
'id2label', None if label2id is None else | |||
{id: label | |||
for label, id in label2id.items()}) | |||
if id2label is not None and label2id is None: | |||
label2id = {label: id for id, label in id2label.items()} | |||
num_labels = kwargs.get( | |||
'num_labels', None if label2id is None else len(label2id)) | |||
if num_labels is not None: | |||
model_kwargs['num_labels'] = num_labels | |||
if label2id is not None: | |||
model_kwargs['label2id'] = label2id | |||
if id2label is not None: | |||
model_kwargs['id2label'] = id2label | |||
model = super(Model, cls).from_pretrained( | |||
pretrained_model_name_or_path=model_dir, **model_kwargs) | |||
model.model_dir = model_dir | |||
return model | |||
@MODELS.register_module(group_key=Tasks.backbone, module_name=Models.bert) | |||
class BertModel(BertPreTrainedModel): | |||
"""The Bert Model transformer outputting raw hidden-states without any | |||
specific head on top. | |||
This model inherits from [`PreTrainedModel`]. Check the superclass | |||
documentation for the generic methods the library implements for all its | |||
model (such as downloading or saving, resizing the input embeddings, pruning | |||
heads etc.) | |||
This model is also a PyTorch | |||
[torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) | |||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch | |||
documentation for all matter related to general usage and behavior. | |||
Parameters: | |||
config ([`BertConfig`]): Model configuration class with all the | |||
parameters of the model. | |||
Initializing with a config file does not load the weights associated | |||
with the model, only the configuration. Check out the | |||
[`~PreTrainedModel.from_pretrained`] method to load the model | |||
weights. | |||
The model can behave as an encoder (with only self-attention) as well as a | |||
decoder, in which case a layer of cross-attention is added between the | |||
self-attention layers, following the architecture described in [Attention is | |||
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam | |||
Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz | |||
Kaiser and Illia Polosukhin. | |||
To behave as an decoder the model needs to be initialized with the | |||
`is_decoder` argument of the configuration set to `True`. To be used in a | |||
Seq2Seq model, the model needs to initialized with both `is_decoder` | |||
argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` | |||
is then expected as an input to the forward pass. | |||
""" | |||
def __init__(self, config, add_pooling_layer=True): | |||
super().__init__(config) | |||
self.embeddings = BertEmbeddings(config) | |||
self.encoder = BertEncoder(config) | |||
self.pooler = BertPooler(config) if add_pooling_layer else None | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
@classmethod | |||
def _instantiate(cls, model_dir=None, add_pooling_layer=True, **config): | |||
config = BertConfig(**config) | |||
model = cls(config, add_pooling_layer) | |||
return model | |||
def get_input_embeddings(self): | |||
return self.embeddings.word_embeddings | |||
def set_input_embeddings(self, value): | |||
self.embeddings.word_embeddings = value | |||
def _prune_heads(self, heads_to_prune): | |||
""" | |||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base | |||
class PreTrainedModel | |||
""" | |||
for layer, heads in heads_to_prune.items(): | |||
self.encoder.layer[layer].attention.prune_heads(heads) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_values=None, | |||
use_cache=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
**kwargs): | |||
r""" | |||
Args: | |||
input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using [`BertTokenizer`]. See | |||
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] | |||
for details. | |||
[What are input IDs?](../glossary#input-ids) | |||
attention_mask (`torch.FloatTensor` of shape `((batch_size, sequence_length)`, *optional*): | |||
Mask to avoid performing attention on padding token indices. Mask | |||
values selected in `[0, 1]`: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
[What are attention masks?](../glossary#attention-mask) | |||
token_type_ids (`torch.LongTensor` of shape `((batch_size, sequence_length)`, *optional*): | |||
Segment token indices to indicate first and second portions of the | |||
inputs. Indices are selected in `[0, 1]`: | |||
- 0 corresponds to a *sentence A* token, | |||
- 1 corresponds to a *sentence B* token. | |||
[What are token type IDs?](../glossary#token-type-ids) | |||
position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length)`, *optional*): | |||
Indices of positions of each input sequence tokens in the position | |||
embeddings. Selected in the range `[0, | |||
config.max_position_embeddings - 1]`. | |||
[What are position IDs?](../glossary#position-ids) | |||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, | |||
num_heads)`, *optional*): | |||
Mask to nullify selected heads of the self-attention modules. Mask | |||
values selected in `[0, 1]`: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length, hidden_size)`, | |||
*optional*): | |||
Optionally, instead of passing `input_ids` you can choose to | |||
directly pass an embedded representation. This is useful if you want | |||
more control over how to convert `input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (`bool`, *optional*): | |||
Whether or not to return the attentions tensors of all attention | |||
layers. See `attentions` under returned tensors for more detail. | |||
output_hidden_states (`bool`, *optional*): | |||
Whether or not to return the hidden states of all layers. See | |||
`hidden_states` under returned tensors for more detail. | |||
return_dict (`bool`, *optional*): | |||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a | |||
plain tuple. | |||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, | |||
sequence_length, hidden_size)`, *optional*): | |||
Sequence of hidden-states at the output of the last layer of the | |||
encoder. Used in the cross-attention if the model is configured as a | |||
decoder. | |||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, | |||
sequence_length)`, *optional*): | |||
Mask to avoid performing attention on the padding token indices of | |||
the encoder input. This mask is used in the cross-attention if the | |||
model is configured as a decoder. Mask values selected in `[0, 1]`: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length | |||
`config.n_layers` with each tuple having 4 tensors of shape | |||
`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): | |||
Contains precomputed key and value hidden states of the attention | |||
blocks. Can be used to speed up decoding. | |||
If `past_key_values` are used, the user can optionally input only | |||
the last `decoder_input_ids` (those that don't have their past key | |||
value states given to this model) of shape `(batch_size, 1)` instead | |||
of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. | |||
use_cache (`bool`, *optional*): | |||
If set to `True`, `past_key_values` key value states are returned | |||
and can be used to speed up decoding (see `past_key_values`). | |||
Others (**kwargs) | |||
some additional parameters might passed in from upstream pipeline, | |||
which not influence the results. | |||
""" | |||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |||
output_hidden_states = ( | |||
output_hidden_states if output_hidden_states is not None else | |||
self.config.output_hidden_states) | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
if self.config.is_decoder: | |||
use_cache = use_cache if use_cache is not None else self.config.use_cache | |||
else: | |||
use_cache = False | |||
if input_ids is not None and inputs_embeds is not None: | |||
raise ValueError( | |||
'You cannot specify both input_ids and inputs_embeds at the same time' | |||
) | |||
elif input_ids is not None: | |||
input_shape = input_ids.size() | |||
elif inputs_embeds is not None: | |||
input_shape = inputs_embeds.size()[:-1] | |||
else: | |||
raise ValueError( | |||
'You have to specify either input_ids or inputs_embeds') | |||
batch_size, seq_length = input_shape | |||
device = input_ids.device if input_ids is not None else inputs_embeds.device | |||
# past_key_values_length | |||
past_key_values_length = past_key_values[0][0].shape[ | |||
2] if past_key_values is not None else 0 | |||
if attention_mask is None: | |||
attention_mask = torch.ones( | |||
((batch_size, seq_length + past_key_values_length)), | |||
device=device) | |||
if token_type_ids is None: | |||
if hasattr(self.embeddings, 'token_type_ids'): | |||
buffered_token_type_ids = self.embeddings.token_type_ids[:, : | |||
seq_length] | |||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( | |||
batch_size, seq_length) | |||
token_type_ids = buffered_token_type_ids_expanded | |||
else: | |||
token_type_ids = torch.zeros( | |||
input_shape, dtype=torch.long, device=device) | |||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |||
# ourselves in which case we just need to make it broadcastable to all heads. | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( | |||
attention_mask, input_shape, device) | |||
# If a 2D or 3D attention mask is provided for the cross-attention | |||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | |||
if self.config.is_decoder and encoder_hidden_states is not None: | |||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( | |||
) | |||
encoder_hidden_shape = (encoder_batch_size, | |||
encoder_sequence_length) | |||
if encoder_attention_mask is None: | |||
encoder_attention_mask = torch.ones( | |||
encoder_hidden_shape, device=device) | |||
encoder_extended_attention_mask = self.invert_attention_mask( | |||
encoder_attention_mask) | |||
else: | |||
encoder_extended_attention_mask = None | |||
# Prepare head mask if needed | |||
# 1.0 in head_mask indicate we keep the head | |||
# attention_probs has shape bsz x n_heads x N x N | |||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] | |||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] | |||
head_mask = self.get_head_mask(head_mask, | |||
self.config.num_hidden_layers) | |||
embedding_output = self.embeddings( | |||
input_ids=input_ids, | |||
position_ids=position_ids, | |||
token_type_ids=token_type_ids, | |||
inputs_embeds=inputs_embeds, | |||
past_key_values_length=past_key_values_length, | |||
) | |||
encoder_outputs = self.encoder( | |||
embedding_output, | |||
attention_mask=extended_attention_mask, | |||
head_mask=head_mask, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_extended_attention_mask, | |||
past_key_values=past_key_values, | |||
use_cache=use_cache, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = encoder_outputs[0] | |||
pooled_output = self.pooler( | |||
sequence_output) if self.pooler is not None else None | |||
if not return_dict: | |||
return (sequence_output, pooled_output) + encoder_outputs[1:] | |||
return BaseModelOutputWithPoolingAndCrossAttentions( | |||
last_hidden_state=sequence_output, | |||
pooler_output=pooled_output, | |||
past_key_values=encoder_outputs.past_key_values, | |||
hidden_states=encoder_outputs.hidden_states, | |||
attentions=encoder_outputs.attentions, | |||
cross_attentions=encoder_outputs.cross_attentions, | |||
) | |||
def extract_sequence_outputs(self, outputs): | |||
return outputs['last_hidden_state'] | |||
def extract_pooled_outputs(self, outputs): | |||
return outputs['pooler_output'] |
@@ -1,3 +1,4 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# |
@@ -2,6 +2,7 @@ | |||
from typing import Any, Dict | |||
import torch | |||
from torch import nn | |||
from torch.nn import CrossEntropyLoss | |||
from transformers.modeling_outputs import TokenClassifierOutput |
@@ -0,0 +1,299 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
import torch.nn as nn | |||
import torch.utils.checkpoint | |||
from torch.nn import CrossEntropyLoss | |||
from transformers.activations import ACT2FN | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionFillMaskModelOutput | |||
from modelscope.utils import logger as logging | |||
from modelscope.utils.constant import Tasks | |||
from .backbone import BertModel, BertPreTrainedModel | |||
from .configuration import BertConfig | |||
logger = logging.get_logger(__name__) | |||
class BertPredictionHeadTransform(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
if isinstance(config.hidden_act, str): | |||
self.transform_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.transform_act_fn = config.hidden_act | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.transform_act_fn(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states) | |||
return hidden_states | |||
class BertLMPredictionHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.transform = BertPredictionHeadTransform(config) | |||
# The output weights are the same as the input embeddings, but there is | |||
# an output-only bias for each token. | |||
self.decoder = nn.Linear( | |||
config.hidden_size, config.vocab_size, bias=False) | |||
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) | |||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` | |||
self.decoder.bias = self.bias | |||
def forward(self, hidden_states): | |||
hidden_states = self.transform(hidden_states) | |||
hidden_states = self.decoder(hidden_states) | |||
return hidden_states | |||
class BertOnlyMLMHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.predictions = BertLMPredictionHead(config) | |||
def forward(self, sequence_output): | |||
prediction_scores = self.predictions(sequence_output) | |||
return prediction_scores | |||
class BertOnlyNSPHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.seq_relationship = nn.Linear(config.hidden_size, 2) | |||
def forward(self, pooled_output): | |||
seq_relationship_score = self.seq_relationship(pooled_output) | |||
return seq_relationship_score | |||
class BertPreTrainingHeads(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.predictions = BertLMPredictionHead(config) | |||
self.seq_relationship = nn.Linear(config.hidden_size, 2) | |||
def forward(self, sequence_output, pooled_output): | |||
prediction_scores = self.predictions(sequence_output) | |||
seq_relationship_score = self.seq_relationship(pooled_output) | |||
return prediction_scores, seq_relationship_score | |||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.bert) | |||
class BertForMaskedLM(BertPreTrainedModel): | |||
r"""Bert Model with a `language modeling` head on top. | |||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |||
pruning heads etc.) | |||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | |||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | |||
general usage and behavior. | |||
Preprocessor: | |||
This is the fill_mask model of Structbert, the preprocessor of this model | |||
is `modelscope.preprocessors.NLPPreprocessor`. | |||
Parameters: | |||
config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with | |||
all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |||
weights. | |||
""" | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
_keys_to_ignore_on_load_missing = [ | |||
r'position_ids', r'predictions.decoder.bias' | |||
] | |||
def __init__(self, config: BertConfig, **kwargs): | |||
super().__init__(config) | |||
if config.is_decoder: | |||
logger.warning( | |||
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for ' | |||
'bi-directional self-attention.') | |||
self.bert = BertModel(config, add_pooling_layer=False) | |||
self.cls = BertOnlyMLMHead(config) | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
def get_output_embeddings(self): | |||
return self.cls.predictions.decoder | |||
def set_output_embeddings(self, new_embeddings): | |||
self.cls.predictions.decoder = new_embeddings | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
`What are input IDs? <../glossary.html#input-ids>`__ | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
`What are attention masks? <../glossary.html#attention-mask>`__ | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
`What are token type IDs? <../glossary.html#token-type-ids>`_ | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
`What are position IDs? <../glossary.html#position-ids>`_ | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. | |||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, | |||
*optional*): | |||
Labels for computing the masked language modeling loss. Indices | |||
should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` | |||
docstring) Tokens with indices set to `-100` are ignored (masked), | |||
the loss is only computed for the tokens with labels in `[0, ..., | |||
config.vocab_size]` | |||
Returns: | |||
Returns `modelscope.outputs.AttentionFillMaskModelOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_bert_backbone_base_std') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_bert_backbone_base_std') | |||
>>> print(model(**preprocessor(('This is a test', 'This is also a test')))) | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.bert( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_attention_mask, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
prediction_scores = self.cls(sequence_output) | |||
masked_lm_loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() # -100 index = padding token | |||
masked_lm_loss = loss_fct( | |||
prediction_scores.view(-1, self.config.vocab_size), | |||
labels.view(-1)) | |||
if not return_dict: | |||
output = (prediction_scores, ) + outputs[2:] | |||
return ((masked_lm_loss, ) | |||
+ output) if masked_lm_loss is not None else output | |||
return AttentionFillMaskModelOutput( | |||
loss=masked_lm_loss, | |||
logits=prediction_scores, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
input_ids=input_ids, | |||
) | |||
def prepare_inputs_for_generation(self, | |||
input_ids, | |||
attention_mask=None, | |||
**model_kwargs): | |||
input_shape = input_ids.shape | |||
effective_batch_size = input_shape[0] | |||
# add a dummy token | |||
if self.config.pad_token_id is None: | |||
raise ValueError('The PAD token should be defined for generation') | |||
padding_mask = attention_mask.new_zeros((attention_mask.shape[0], 1)) | |||
attention_mask = torch.cat([attention_mask, padding_mask], dim=-1) | |||
dummy_token = torch.full((effective_batch_size, 1), | |||
self.config.pad_token_id, | |||
dtype=torch.long, | |||
device=input_ids.device) | |||
input_ids = torch.cat([input_ids, dummy_token], dim=1) | |||
return {'input_ids': input_ids, 'attention_mask': attention_mask} |
@@ -0,0 +1,113 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.metainfo import Models | |||
from modelscope.models import Model | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import BackboneModelOutput | |||
from modelscope.utils.constant import Tasks | |||
from .backbone import BertModel, BertPreTrainedModel | |||
@MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) | |||
class BertForSentenceEmbedding(BertPreTrainedModel): | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.config = config | |||
setattr(self, self.base_model_prefix, | |||
BertModel(config, add_pooling_layer=False)) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
) -> BackboneModelOutput: | |||
r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. | |||
Returns: | |||
Returns `modelscope.outputs.AttentionTextClassificationModelOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base') | |||
>>> print(model(**preprocessor('This is a test'))) | |||
""" | |||
return self.base_model.forward( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict) | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
Args: | |||
kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
Returns: | |||
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
model_dir = kwargs.get('model_dir') | |||
model = super( | |||
Model, | |||
cls).from_pretrained(pretrained_model_name_or_path=model_dir) | |||
model.model_dir = model_dir | |||
return model |
@@ -0,0 +1,208 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
import torch.nn as nn | |||
import torch.utils.checkpoint | |||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionTextClassificationModelOutput | |||
from modelscope.utils import logger as logging | |||
from modelscope.utils.constant import Tasks | |||
from .backbone import BertModel, BertPreTrainedModel | |||
logger = logging.get_logger(__name__) | |||
@MODELS.register_module(Tasks.text_classification, module_name=Models.bert) | |||
@MODELS.register_module(Tasks.nli, module_name=Models.bert) | |||
@MODELS.register_module( | |||
Tasks.sentiment_classification, module_name=Models.bert) | |||
@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.bert) | |||
@MODELS.register_module( | |||
Tasks.zero_shot_classification, module_name=Models.bert) | |||
class BertForSequenceClassification(BertPreTrainedModel): | |||
r"""Bert Model transformer with a sequence classification/regression head on top | |||
(a linear layer on top of the pooled output) e.g. for GLUE tasks. | |||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |||
pruning heads etc.) | |||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | |||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | |||
general usage and behavior. | |||
Preprocessor: | |||
This is the fill_mask model of Bert, the preprocessor of this model | |||
is `modelscope.preprocessors.SequenceClassificationPreprocessor`. | |||
Trainer: | |||
This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, | |||
NlpEpochBasedTrainer, or trainers from other frameworks. | |||
The preferred trainer in ModelScope is NlpEpochBasedTrainer. | |||
Parameters: | |||
config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with | |||
all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |||
weights. | |||
""" | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.num_labels = config.num_labels | |||
self.config = config | |||
setattr(self, self.base_model_prefix, BertModel(config)) | |||
classifier_dropout = ( | |||
config.classifier_dropout if config.classifier_dropout is not None | |||
else config.hidden_dropout_prob) | |||
self.dropout = nn.Dropout(classifier_dropout) | |||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., | |||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |||
Returns: | |||
Returns `modelscope.outputs.AttentionTextClassificationModelOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') | |||
>>> print(model(**preprocessor(('This is a test', 'This is also a test')))) | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.base_model.forward( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
pooled_output = outputs[1] | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
loss = None | |||
if labels is not None: | |||
if self.config.problem_type is None: | |||
if self.num_labels == 1: | |||
self.config.problem_type = 'regression' | |||
elif self.num_labels > 1 and (labels.dtype == torch.long | |||
or labels.dtype == torch.int): | |||
self.config.problem_type = 'single_label_classification' | |||
else: | |||
self.config.problem_type = 'multi_label_classification' | |||
if self.config.problem_type == 'regression': | |||
loss_fct = MSELoss() | |||
if self.num_labels == 1: | |||
loss = loss_fct(logits.squeeze(), labels.squeeze()) | |||
else: | |||
loss = loss_fct(logits, labels) | |||
elif self.config.problem_type == 'single_label_classification': | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct( | |||
logits.view(-1, self.num_labels), labels.view(-1)) | |||
elif self.config.problem_type == 'multi_label_classification': | |||
loss_fct = BCEWithLogitsLoss() | |||
loss = loss_fct(logits, labels) | |||
if not return_dict: | |||
output = (logits, ) + outputs[2:] | |||
return ((loss, ) + output) if loss is not None else output | |||
return AttentionTextClassificationModelOutput( | |||
loss=loss, | |||
logits=logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) |
@@ -0,0 +1,89 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import torch | |||
import torch.utils.checkpoint | |||
from modelscope.metainfo import Models | |||
from modelscope.models import Model | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionTextClassificationModelOutput | |||
from modelscope.utils import logger as logging | |||
from modelscope.utils.constant import Tasks | |||
from .backbone import BertModel | |||
from .text_classification import BertForSequenceClassification | |||
logger = logging.get_logger(__name__) | |||
@MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) | |||
class BertForTextRanking(BertForSequenceClassification): | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config) | |||
self.train_batch_size = kwargs.get('train_batch_size', 4) | |||
setattr(self, self.base_model_prefix, | |||
BertModel(self.config, add_pooling_layer=True)) | |||
self.register_buffer( | |||
'target_label', | |||
torch.zeros(self.train_batch_size, dtype=torch.long)) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
**kwargs) -> AttentionTextClassificationModelOutput: | |||
outputs = self.base_model.forward( | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict) | |||
# backbone model should return pooled_output as its second output | |||
pooled_output = outputs[1] | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
if self.base_model.training: | |||
scores = logits.view(self.train_batch_size, -1) | |||
loss_fct = torch.nn.CrossEntropyLoss() | |||
loss = loss_fct(scores, self.target_label) | |||
return AttentionTextClassificationModelOutput( | |||
loss=loss, | |||
logits=logits, | |||
) | |||
return AttentionTextClassificationModelOutput(logits=logits, ) | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
Args: | |||
kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
num_labels: An optional arg to tell the model how many classes to initialize. | |||
Method will call utils.parse_label_mapping if num_labels not supplied. | |||
If num_labels is not found, the model will use the default setting (1 classes). | |||
Returns: | |||
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
num_labels = kwargs.get('num_labels', 1) | |||
model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
model_dir = kwargs.get('model_dir') | |||
model = super(Model, cls).from_pretrained( | |||
pretrained_model_name_or_path=model_dir, **model_args) | |||
model.model_dir = model_dir | |||
return model |
@@ -0,0 +1,225 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
import torch.nn as nn | |||
import torch.utils.checkpoint | |||
from torch.nn import CrossEntropyLoss | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import TokenClassifierOutput | |||
from modelscope.utils import logger as logging | |||
from modelscope.utils.constant import Tasks | |||
from .backbone import BertModel, BertPreTrainedModel | |||
logger = logging.get_logger(__name__) | |||
@MODELS.register_module(Tasks.token_classification, module_name=Models.bert) | |||
@MODELS.register_module(Tasks.part_of_speech, module_name=Models.bert) | |||
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert) | |||
class BertForTokenClassification(BertPreTrainedModel): | |||
r"""Bert Model with a token classification head on top (a linear layer on top of | |||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks, word-segmentation. | |||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |||
pruning heads etc.) | |||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | |||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | |||
general usage and behavior. | |||
Preprocessor: | |||
This is the fill_mask model of Bert, the preprocessor of this model | |||
is `modelscope.preprocessors.SequenceClassificationPreprocessor`. | |||
Trainer: | |||
This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, | |||
NlpEpochBasedTrainer, or trainers from other frameworks. | |||
The preferred trainer in ModelScope is NlpEpochBasedTrainer. | |||
Parameters: | |||
config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with | |||
all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |||
weights. | |||
""" | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config) | |||
self.num_labels = config.num_labels | |||
setattr(self, self.base_model_prefix, | |||
BertModel(config, add_pooling_layer=False)) | |||
classifier_dropout = ( | |||
config.classifier_dropout if config.classifier_dropout is not None | |||
else config.hidden_dropout_prob) | |||
self.dropout = nn.Dropout(classifier_dropout) | |||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
offset_mapping=None, | |||
label_mask=None, | |||
): | |||
r""" | |||
Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, | |||
sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using | |||
:class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and | |||
:meth:`transformers.PreTrainedTokenizer.__call__` for details. | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, | |||
sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask | |||
values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, | |||
sequence_length)`, `optional`): | |||
Segment token indices to indicate first and second portions of the | |||
inputs. Indices are selected in ``[0, 1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, | |||
sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the position | |||
embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or | |||
:obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask | |||
values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, | |||
sequence_length, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to | |||
directly pass an embedded representation. This is useful if you want | |||
more control over how to convert :obj:`input_ids` indices into | |||
associated vectors than the model's internal embedding lookup | |||
matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention | |||
layers. See ``attentions`` under returned tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See | |||
``hidden_states`` under returned tensors for more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.ModelOutput` | |||
instead of a plain tuple. | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, | |||
`optional`): | |||
Labels for computing the sequence classification/regression loss. | |||
Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If | |||
:obj:`config.num_labels == 1` a regression loss is computed | |||
(Mean-Square loss), If :obj:`config.num_labels > 1` a classification | |||
loss is computed (Cross-Entropy). | |||
offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, | |||
sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the sentence. | |||
Selected in the range ``[0, sequence_length - 1]``. | |||
label_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, | |||
sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask | |||
values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
Returns: | |||
Returns `modelscope.outputs.TokenClassifierOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_bert_word-segmentation_chinese-base') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_bert_word-segmentation_chinese-base') | |||
>>> print(model(**preprocessor(('This is a test', 'This is also a test')))) | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.bert( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
sequence_output = self.dropout(sequence_output) | |||
logits = self.classifier(sequence_output) | |||
loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() | |||
# Only keep active parts of the loss | |||
if attention_mask is not None: | |||
active_loss = attention_mask.view(-1) == 1 | |||
active_logits = logits.view(-1, self.num_labels) | |||
active_labels = torch.where( | |||
active_loss, labels.view(-1), | |||
torch.tensor(loss_fct.ignore_index).type_as(labels)) | |||
loss = loss_fct(active_logits, active_labels) | |||
else: | |||
loss = loss_fct( | |||
logits.view(-1, self.num_labels), labels.view(-1)) | |||
if not return_dict: | |||
output = (logits, ) + outputs[2:] | |||
return ((loss, ) + output) if loss is not None else output | |||
return TokenClassifierOutput( | |||
loss=loss, | |||
logits=logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
offset_mapping=offset_mapping, | |||
) |
@@ -0,0 +1,2 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .translation import CsanmtForTranslation |
@@ -22,38 +22,28 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .configuration_deberta_v2 import DebertaV2Config | |||
from .tokenization_deberta_v2 import DebertaV2Tokenizer | |||
from .tokenization_deberta_v2_fast import DebertaV2TokenizerFast | |||
from .modeling_deberta_v2 import ( | |||
DebertaV2ForMaskedLM, | |||
DebertaV2ForMultipleChoice, | |||
DebertaV2ForQuestionAnswering, | |||
DebertaV2ForSequenceClassification, | |||
DebertaV2ForTokenClassification, | |||
from .configuration import DebertaV2Config | |||
from .tokenization import DebertaV2Tokenizer | |||
from .tokenization_fast import DebertaV2TokenizerFast | |||
from .backbone import ( | |||
DebertaV2Model, | |||
DebertaV2PreTrainedModel, | |||
) | |||
from .fill_mask import DebertaV2ForMaskedLM | |||
else: | |||
_import_structure = { | |||
'configuration_deberta_v2': | |||
['DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP', 'DebertaV2Config'], | |||
'tokenization_deberta_v2': ['DebertaV2Tokenizer'] | |||
'configuration': ['DebertaV2Config'], | |||
'tokenization': ['DebertaV2Tokenizer'], | |||
'tokenization_fast': ['DebertaV2TokenizerFast'], | |||
'backbone': [ | |||
'DebertaV2Model', | |||
'DebertaV2PreTrainedModel', | |||
], | |||
'fill_mask': [ | |||
'DebertaV2ForMaskedLM', | |||
] | |||
} | |||
_import_structure['tokenization_deberta_v2_fast'] = [ | |||
'DebertaV2TokenizerFast' | |||
] | |||
_import_structure['modeling_deberta_v2'] = [ | |||
'DebertaV2ForMaskedLM', | |||
'DebertaV2ForMultipleChoice', | |||
'DebertaV2ForQuestionAnswering', | |||
'DebertaV2ForSequenceClassification', | |||
'DebertaV2ForTokenClassification', | |||
'DebertaV2Model', | |||
'DebertaV2PreTrainedModel', | |||
] | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
@@ -20,28 +20,22 @@ from typing import Optional, Tuple, Union | |||
import torch | |||
import torch.utils.checkpoint | |||
from torch import nn | |||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss | |||
from torch.nn import LayerNorm | |||
from transformers.activations import ACT2FN | |||
from transformers.file_utils import (add_code_sample_docstrings, | |||
add_start_docstrings, | |||
add_start_docstrings_to_model_forward) | |||
from transformers.modeling_outputs import (BaseModelOutput, MaskedLMOutput, | |||
MultipleChoiceModelOutput, | |||
QuestionAnsweringModelOutput, | |||
SequenceClassifierOutput, | |||
TokenClassifierOutput) | |||
from transformers.modeling_outputs import BaseModelOutput | |||
from transformers.modeling_utils import PreTrainedModel | |||
from transformers.pytorch_utils import softmax_backward_data | |||
from modelscope.metainfo import Models | |||
from modelscope.models import Model, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionBackboneModelOutput | |||
from modelscope.utils import logger as logging | |||
from .configuration_deberta_v2 import DebertaV2Config | |||
from modelscope.utils.constant import Tasks | |||
from .configuration import DebertaV2Config | |||
logger = logging.get_logger(__name__) | |||
_CONFIG_FOR_DOC = 'DebertaV2Config' | |||
_TOKENIZER_FOR_DOC = 'DebertaV2Tokenizer' | |||
_CHECKPOINT_FOR_DOC = 'nlp_debertav2_fill-mask_chinese-lite' | |||
# Copied from transformers.models.deberta.modeling_deberta.ContextPooler | |||
class ContextPooler(nn.Module): | |||
@@ -1006,7 +1000,7 @@ class DebertaV2Embeddings(nn.Module): | |||
# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 | |||
class DebertaV2PreTrainedModel(PreTrainedModel): | |||
class DebertaV2PreTrainedModel(TorchModel, PreTrainedModel): | |||
""" | |||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |||
models. | |||
@@ -1018,6 +1012,10 @@ class DebertaV2PreTrainedModel(PreTrainedModel): | |||
_keys_to_ignore_on_load_unexpected = ['position_embeddings'] | |||
supports_gradient_checkpointing = True | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config.name_or_path, **kwargs) | |||
super(Model, self).__init__(config) | |||
def _init_weights(self, module): | |||
"""Initialize the weights.""" | |||
if isinstance(module, nn.Linear): | |||
@@ -1037,8 +1035,24 @@ class DebertaV2PreTrainedModel(PreTrainedModel): | |||
if isinstance(module, DebertaV2Encoder): | |||
module.gradient_checkpointing = value | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
model_dir = kwargs.pop('model_dir', None) | |||
if model_dir is None: | |||
ponet_config = DebertaV2Config(**kwargs) | |||
model = cls(ponet_config) | |||
else: | |||
model = super( | |||
Model, | |||
cls).from_pretrained(pretrained_model_name_or_path=model_dir) | |||
return model | |||
@MODELS.register_module(Tasks.backbone, module_name=Models.deberta_v2) | |||
# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 | |||
class DebertaV2Model(DebertaV2PreTrainedModel): | |||
"""The bare DeBERTa_v2 Model transformer outputting raw hidden-states without any specific head on top. | |||
DEBERTA_START_DOCSTRING = r""" | |||
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled | |||
Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build | |||
on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two | |||
@@ -1048,65 +1062,13 @@ DEBERTA_START_DOCSTRING = r""" | |||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | |||
and behavior. | |||
Parameters: | |||
config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model. | |||
config (`DebertaV2Config`): Model configuration class with all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |||
""" | |||
DEBERTA_INPUTS_DOCSTRING = r""" | |||
Args: | |||
input_ids (`torch.LongTensor` of shape `({0})`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using [`DebertaV2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and | |||
[`PreTrainedTokenizer.__call__`] for details. | |||
[What are input IDs?](../glossary#input-ids) | |||
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
[What are attention masks?](../glossary#attention-mask) | |||
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, | |||
1]`: | |||
- 0 corresponds to a *sentence A* token, | |||
- 1 corresponds to a *sentence B* token. | |||
[What are token type IDs?](../glossary#token-type-ids) | |||
position_ids (`torch.LongTensor` of shape `({0})`, *optional*): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | |||
config.max_position_embeddings - 1]`. | |||
[What are position IDs?](../glossary#position-ids) | |||
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): | |||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This | |||
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the | |||
model's internal embedding lookup matrix. | |||
output_attentions (`bool`, *optional*): | |||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |||
tensors for more detail. | |||
output_hidden_states (`bool`, *optional*): | |||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |||
more detail. | |||
return_dict (`bool`, *optional*): | |||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |||
""" | |||
@add_start_docstrings( | |||
'The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.', | |||
DEBERTA_START_DOCSTRING, | |||
) | |||
# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 | |||
class DebertaV2Model(DebertaV2PreTrainedModel): | |||
configuration. | |||
""" | |||
def __init__(self, config): | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config) | |||
self.embeddings = DebertaV2Embeddings(config) | |||
@@ -1130,14 +1092,6 @@ class DebertaV2Model(DebertaV2PreTrainedModel): | |||
raise NotImplementedError( | |||
'The prune function is not implemented in DeBERTa model.') | |||
@add_start_docstrings_to_model_forward( | |||
DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=BaseModelOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids: Optional[torch.Tensor] = None, | |||
@@ -1148,7 +1102,53 @@ class DebertaV2Model(DebertaV2PreTrainedModel): | |||
output_attentions: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
) -> Union[Tuple, BaseModelOutput]: | |||
) -> Union[Tuple, AttentionBackboneModelOutput]: | |||
r""" | |||
Args: | |||
input_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`): | |||
Indices of input sequence tokens in the vocabulary. | |||
attention_mask (`torch.FloatTensor` of shape `('batch_size, sequence_length')`, *optional*): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, | |||
1]`: | |||
- 0 corresponds to a *sentence A* token, | |||
- 1 corresponds to a *sentence B* token. | |||
position_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | |||
config.max_position_embeddings - 1]`. | |||
inputs_embeds (`torch.FloatTensor` of shape `('batch_size, sequence_length', hidden_size)`, *optional*): | |||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This | |||
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the | |||
model's internal embedding lookup matrix. | |||
output_attentions (`bool`, *optional*): | |||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |||
tensors for more detail. | |||
output_hidden_states (`bool`, *optional*): | |||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |||
more detail. | |||
return_dict (`bool`, *optional*): | |||
Whether or not to return a dataclass instead of a plain tuple. | |||
Returns: | |||
Returns `modelscope.outputs.AttentionBackboneModelOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite', task='backbone') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite') | |||
>>> print(model(**preprocessor('这是个测试'))) | |||
""" | |||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |||
output_hidden_states = ( | |||
output_hidden_states if output_hidden_states is not None else | |||
@@ -1216,574 +1216,9 @@ class DebertaV2Model(DebertaV2PreTrainedModel): | |||
return (sequence_output, ) + encoder_outputs[ | |||
(1 if output_hidden_states else 2):] | |||
return BaseModelOutput( | |||
return AttentionBackboneModelOutput( | |||
last_hidden_state=sequence_output, | |||
hidden_states=encoder_outputs.hidden_states | |||
if output_hidden_states else None, | |||
attentions=encoder_outputs.attentions, | |||
) | |||
@add_start_docstrings( | |||
"""DeBERTa Model with a `language modeling` head on top.""", | |||
DEBERTA_START_DOCSTRING) | |||
# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 | |||
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
_keys_to_ignore_on_load_missing = [ | |||
r'position_ids', r'predictions.decoder.bias' | |||
] | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.deberta = DebertaV2Model(config) | |||
self.cls = DebertaV2OnlyMLMHead(config) | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
def get_output_embeddings(self): | |||
return self.cls.predictions.decoder | |||
def set_output_embeddings(self, new_embeddings): | |||
self.cls.predictions.decoder = new_embeddings | |||
@add_start_docstrings_to_model_forward( | |||
DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=MaskedLMOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.Tensor] = None, | |||
token_type_ids: Optional[torch.Tensor] = None, | |||
position_ids: Optional[torch.Tensor] = None, | |||
inputs_embeds: Optional[torch.Tensor] = None, | |||
labels: Optional[torch.Tensor] = None, | |||
output_attentions: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
) -> Union[Tuple, MaskedLMOutput]: | |||
r""" | |||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., | |||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the | |||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.deberta( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
prediction_scores = self.cls(sequence_output) | |||
masked_lm_loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() # -100 index = padding token | |||
masked_lm_loss = loss_fct( | |||
prediction_scores.view(-1, self.config.vocab_size), | |||
labels.view(-1)) | |||
if not return_dict: | |||
output = (prediction_scores, ) + outputs[1:] | |||
return ((masked_lm_loss, ) | |||
+ output) if masked_lm_loss is not None else output | |||
return MaskedLMOutput( | |||
loss=masked_lm_loss, | |||
logits=prediction_scores, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) | |||
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta | |||
class DebertaV2PredictionHeadTransform(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
if isinstance(config.hidden_act, str): | |||
self.transform_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.transform_act_fn = config.hidden_act | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.transform_act_fn(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states) | |||
return hidden_states | |||
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta | |||
class DebertaV2LMPredictionHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.transform = DebertaV2PredictionHeadTransform(config) | |||
# The output weights are the same as the input embeddings, but there is | |||
# an output-only bias for each token. | |||
self.decoder = nn.Linear( | |||
config.hidden_size, config.vocab_size, bias=False) | |||
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) | |||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` | |||
self.decoder.bias = self.bias | |||
def forward(self, hidden_states): | |||
hidden_states = self.transform(hidden_states) | |||
hidden_states = self.decoder(hidden_states) | |||
return hidden_states | |||
# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta | |||
class DebertaV2OnlyMLMHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.predictions = DebertaV2LMPredictionHead(config) | |||
def forward(self, sequence_output): | |||
prediction_scores = self.predictions(sequence_output) | |||
return prediction_scores | |||
@add_start_docstrings( | |||
""" | |||
DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the | |||
pooled output) e.g. for GLUE tasks. | |||
""", | |||
DEBERTA_START_DOCSTRING, | |||
) | |||
# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 | |||
class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): | |||
def __init__(self, config): | |||
super().__init__(config) | |||
num_labels = getattr(config, 'num_labels', 2) | |||
self.num_labels = num_labels | |||
self.deberta = DebertaV2Model(config) | |||
self.pooler = ContextPooler(config) | |||
output_dim = self.pooler.output_dim | |||
self.classifier = nn.Linear(output_dim, num_labels) | |||
drop_out = getattr(config, 'cls_dropout', None) | |||
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out | |||
self.dropout = StableDropout(drop_out) | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
def get_input_embeddings(self): | |||
return self.deberta.get_input_embeddings() | |||
def set_input_embeddings(self, new_embeddings): | |||
self.deberta.set_input_embeddings(new_embeddings) | |||
@add_start_docstrings_to_model_forward( | |||
DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=SequenceClassifierOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.Tensor] = None, | |||
token_type_ids: Optional[torch.Tensor] = None, | |||
position_ids: Optional[torch.Tensor] = None, | |||
inputs_embeds: Optional[torch.Tensor] = None, | |||
labels: Optional[torch.Tensor] = None, | |||
output_attentions: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
) -> Union[Tuple, SequenceClassifierOutput]: | |||
r""" | |||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.deberta( | |||
input_ids, | |||
token_type_ids=token_type_ids, | |||
attention_mask=attention_mask, | |||
position_ids=position_ids, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
encoder_layer = outputs[0] | |||
pooled_output = self.pooler(encoder_layer) | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
loss = None | |||
if labels is not None: | |||
if self.config.problem_type is None: | |||
if self.num_labels == 1: | |||
# regression task | |||
loss_fn = nn.MSELoss() | |||
logits = logits.view(-1).to(labels.dtype) | |||
loss = loss_fn(logits, labels.view(-1)) | |||
elif labels.dim() == 1 or labels.size(-1) == 1: | |||
label_index = (labels >= 0).nonzero() | |||
labels = labels.long() | |||
if label_index.size(0) > 0: | |||
labeled_logits = torch.gather( | |||
logits, 0, | |||
label_index.expand( | |||
label_index.size(0), logits.size(1))) | |||
labels = torch.gather(labels, 0, label_index.view(-1)) | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct( | |||
labeled_logits.view(-1, self.num_labels).float(), | |||
labels.view(-1)) | |||
else: | |||
loss = torch.tensor(0).to(logits) | |||
else: | |||
log_softmax = nn.LogSoftmax(-1) | |||
loss = -((log_softmax(logits) * labels).sum(-1)).mean() | |||
elif self.config.problem_type == 'regression': | |||
loss_fct = MSELoss() | |||
if self.num_labels == 1: | |||
loss = loss_fct(logits.squeeze(), labels.squeeze()) | |||
else: | |||
loss = loss_fct(logits, labels) | |||
elif self.config.problem_type == 'single_label_classification': | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct( | |||
logits.view(-1, self.num_labels), labels.view(-1)) | |||
elif self.config.problem_type == 'multi_label_classification': | |||
loss_fct = BCEWithLogitsLoss() | |||
loss = loss_fct(logits, labels) | |||
if not return_dict: | |||
output = (logits, ) + outputs[1:] | |||
return ((loss, ) + output) if loss is not None else output | |||
return SequenceClassifierOutput( | |||
loss=loss, | |||
logits=logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions) | |||
@add_start_docstrings( | |||
""" | |||
DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for | |||
Named-Entity-Recognition (NER) tasks. | |||
""", | |||
DEBERTA_START_DOCSTRING, | |||
) | |||
# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 | |||
class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.num_labels = config.num_labels | |||
self.deberta = DebertaV2Model(config) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
@add_start_docstrings_to_model_forward( | |||
DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=TokenClassifierOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.Tensor] = None, | |||
token_type_ids: Optional[torch.Tensor] = None, | |||
position_ids: Optional[torch.Tensor] = None, | |||
inputs_embeds: Optional[torch.Tensor] = None, | |||
labels: Optional[torch.Tensor] = None, | |||
output_attentions: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
) -> Union[Tuple, TokenClassifierOutput]: | |||
r""" | |||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.deberta( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
sequence_output = self.dropout(sequence_output) | |||
logits = self.classifier(sequence_output) | |||
loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |||
if not return_dict: | |||
output = (logits, ) + outputs[1:] | |||
return ((loss, ) + output) if loss is not None else output | |||
return TokenClassifierOutput( | |||
loss=loss, | |||
logits=logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions) | |||
@add_start_docstrings( | |||
""" | |||
DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear | |||
layers on top of the hidden-states output to compute `span start logits` and `span end logits`). | |||
""", | |||
DEBERTA_START_DOCSTRING, | |||
) | |||
# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 | |||
class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.num_labels = config.num_labels | |||
self.deberta = DebertaV2Model(config) | |||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
@add_start_docstrings_to_model_forward( | |||
DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=QuestionAnsweringModelOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.Tensor] = None, | |||
token_type_ids: Optional[torch.Tensor] = None, | |||
position_ids: Optional[torch.Tensor] = None, | |||
inputs_embeds: Optional[torch.Tensor] = None, | |||
start_positions: Optional[torch.Tensor] = None, | |||
end_positions: Optional[torch.Tensor] = None, | |||
output_attentions: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
) -> Union[Tuple, QuestionAnsweringModelOutput]: | |||
r""" | |||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |||
Labels for position (index) of the start of the labelled span for computing the token classification loss. | |||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence | |||
are not taken into account for computing the loss. | |||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |||
Labels for position (index) of the end of the labelled span for computing the token classification loss. | |||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence | |||
are not taken into account for computing the loss. | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.deberta( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
logits = self.qa_outputs(sequence_output) | |||
start_logits, end_logits = logits.split(1, dim=-1) | |||
start_logits = start_logits.squeeze(-1).contiguous() | |||
end_logits = end_logits.squeeze(-1).contiguous() | |||
total_loss = None | |||
if start_positions is not None and end_positions is not None: | |||
# If we are on multi-GPU, split add a dimension | |||
if len(start_positions.size()) > 1: | |||
start_positions = start_positions.squeeze(-1) | |||
if len(end_positions.size()) > 1: | |||
end_positions = end_positions.squeeze(-1) | |||
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |||
ignored_index = start_logits.size(1) | |||
start_positions = start_positions.clamp(0, ignored_index) | |||
end_positions = end_positions.clamp(0, ignored_index) | |||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | |||
start_loss = loss_fct(start_logits, start_positions) | |||
end_loss = loss_fct(end_logits, end_positions) | |||
total_loss = (start_loss + end_loss) / 2 | |||
if not return_dict: | |||
output = (start_logits, end_logits) + outputs[1:] | |||
return ((total_loss, ) | |||
+ output) if total_loss is not None else output | |||
return QuestionAnsweringModelOutput( | |||
loss=total_loss, | |||
start_logits=start_logits, | |||
end_logits=end_logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) | |||
@add_start_docstrings( | |||
""" | |||
DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a | |||
softmax) e.g. for RocStories/SWAG tasks. | |||
""", | |||
DEBERTA_START_DOCSTRING, | |||
) | |||
class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): | |||
def __init__(self, config): | |||
super().__init__(config) | |||
num_labels = getattr(config, 'num_labels', 2) | |||
self.num_labels = num_labels | |||
self.deberta = DebertaV2Model(config) | |||
self.pooler = ContextPooler(config) | |||
output_dim = self.pooler.output_dim | |||
self.classifier = nn.Linear(output_dim, 1) | |||
drop_out = getattr(config, 'cls_dropout', None) | |||
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out | |||
self.dropout = StableDropout(drop_out) | |||
self.init_weights() | |||
def get_input_embeddings(self): | |||
return self.deberta.get_input_embeddings() | |||
def set_input_embeddings(self, new_embeddings): | |||
self.deberta.set_input_embeddings(new_embeddings) | |||
@add_start_docstrings_to_model_forward( | |||
DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=MultipleChoiceModelOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., | |||
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See | |||
`input_ids` above) | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
num_choices = input_ids.shape[ | |||
1] if input_ids is not None else inputs_embeds.shape[1] | |||
flat_input_ids = input_ids.view( | |||
-1, input_ids.size(-1)) if input_ids is not None else None | |||
flat_position_ids = position_ids.view( | |||
-1, position_ids.size(-1)) if position_ids is not None else None | |||
flat_token_type_ids = token_type_ids.view( | |||
-1, | |||
token_type_ids.size(-1)) if token_type_ids is not None else None | |||
flat_attention_mask = attention_mask.view( | |||
-1, | |||
attention_mask.size(-1)) if attention_mask is not None else None | |||
flat_inputs_embeds = ( | |||
inputs_embeds.view(-1, inputs_embeds.size(-2), | |||
inputs_embeds.size(-1)) | |||
if inputs_embeds is not None else None) | |||
outputs = self.deberta( | |||
flat_input_ids, | |||
position_ids=flat_position_ids, | |||
token_type_ids=flat_token_type_ids, | |||
attention_mask=flat_attention_mask, | |||
inputs_embeds=flat_inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
encoder_layer = outputs[0] | |||
pooled_output = self.pooler(encoder_layer) | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
reshaped_logits = logits.view(-1, num_choices) | |||
loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct(reshaped_logits, labels) | |||
if not return_dict: | |||
output = (reshaped_logits, ) + outputs[1:] | |||
return ((loss, ) + output) if loss is not None else output | |||
return MultipleChoiceModelOutput( | |||
loss=loss, | |||
logits=reshaped_logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) |
@@ -13,8 +13,6 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" DeBERTa-v2 model configuration, mainly copied from :class:`~transformers.DeBERTaV2Config""" | |||
from collections import OrderedDict | |||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union | |||
from transformers import PretrainedConfig | |||
@@ -0,0 +1,230 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2020 Microsoft and the Hugging Face Inc. team. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Optional, Tuple, Union | |||
import torch | |||
import torch.utils.checkpoint | |||
from torch import nn | |||
from torch.nn import CrossEntropyLoss | |||
from transformers.activations import ACT2FN | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionFillMaskModelOutput | |||
from modelscope.utils.constant import Tasks | |||
from .backbone import DebertaV2Model, DebertaV2PreTrainedModel | |||
# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 | |||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.deberta_v2) | |||
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): | |||
r"""DeBERTa_v2 Model with a `language modeling` head on top. | |||
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled | |||
Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build | |||
on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two | |||
improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. | |||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. | |||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | |||
and behavior. | |||
Preprocessor: | |||
This is the fill_mask model of Deberta_v2, the preprocessor of this model | |||
is `modelscope.preprocessors.NLPPreprocessor`. | |||
Parameters: | |||
config (`DebertaV2Config`): Model configuration class with all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. | |||
""" | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
_keys_to_ignore_on_load_missing = [ | |||
r'position_ids', r'predictions.decoder.bias' | |||
] | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config) | |||
self.deberta = DebertaV2Model(config) | |||
self.cls = DebertaV2OnlyMLMHead(config) | |||
# Initialize weights and apply final processing | |||
self.post_init() | |||
def get_output_embeddings(self): | |||
return self.cls.predictions.decoder | |||
def set_output_embeddings(self, new_embeddings): | |||
self.cls.predictions.decoder = new_embeddings | |||
def forward( | |||
self, | |||
input_ids: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.Tensor] = None, | |||
token_type_ids: Optional[torch.Tensor] = None, | |||
position_ids: Optional[torch.Tensor] = None, | |||
inputs_embeds: Optional[torch.Tensor] = None, | |||
labels: Optional[torch.Tensor] = None, | |||
output_attentions: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
) -> Union[Tuple, AttentionFillMaskModelOutput]: | |||
r""" | |||
Args: | |||
input_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`): | |||
Indices of input sequence tokens in the vocabulary. | |||
attention_mask (`torch.FloatTensor` of shape `('batch_size, sequence_length')`, *optional*): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, | |||
1]`: | |||
- 0 corresponds to a *sentence A* token, | |||
- 1 corresponds to a *sentence B* token. | |||
position_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): | |||
Indices of positions of each input sequence tokens in the position embeddings. | |||
Selected in the range `[0, config.max_position_embeddings - 1]`. | |||
inputs_embeds (`torch.FloatTensor` of shape `('batch_size, sequence_length', hidden_size)`, *optional*): | |||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert *input_ids* indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (`bool`, *optional*): | |||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |||
tensors for more detail. | |||
output_hidden_states (`bool`, *optional*): | |||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |||
more detail. | |||
return_dict (`bool`, *optional*): | |||
Whether or not to return a dataclass instead of a plain tuple. | |||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., | |||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are | |||
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` | |||
Returns: | |||
Returns `modelscope.outputs.AttentionFillMaskModelOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite') | |||
>>> # Call the model, return some tensors | |||
>>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。'))) | |||
>>> # Call the pipeline | |||
>>> from modelscope.pipelines import pipeline | |||
>>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) | |||
>>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。')) | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.deberta( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
prediction_scores = self.cls(sequence_output) | |||
masked_lm_loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() # -100 index = padding token | |||
masked_lm_loss = loss_fct( | |||
prediction_scores.view(-1, self.config.vocab_size), | |||
labels.view(-1)) | |||
if not return_dict: | |||
output = (prediction_scores, ) + outputs[1:] | |||
return ((masked_lm_loss, ) | |||
+ output) if masked_lm_loss is not None else output | |||
return AttentionFillMaskModelOutput( | |||
loss=masked_lm_loss, | |||
logits=prediction_scores, | |||
input_ids=input_ids, | |||
attentions=outputs.attentions, | |||
hidden_states=outputs.hidden_states) | |||
# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta | |||
class DebertaV2PredictionHeadTransform(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
if isinstance(config.hidden_act, str): | |||
self.transform_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.transform_act_fn = config.hidden_act | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.transform_act_fn(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states) | |||
return hidden_states | |||
# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta | |||
class DebertaV2LMPredictionHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.transform = DebertaV2PredictionHeadTransform(config) | |||
# The output weights are the same as the input embeddings, but there is | |||
# an output-only bias for each token. | |||
self.decoder = nn.Linear( | |||
config.hidden_size, config.vocab_size, bias=False) | |||
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) | |||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` | |||
self.decoder.bias = self.bias | |||
def forward(self, hidden_states): | |||
hidden_states = self.transform(hidden_states) | |||
hidden_states = self.decoder(hidden_states) | |||
return hidden_states | |||
# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta | |||
class DebertaV2OnlyMLMHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.predictions = DebertaV2LMPredictionHead(config) | |||
def forward(self, sequence_output): | |||
prediction_scores = self.predictions(sequence_output) | |||
return prediction_scores |
@@ -24,7 +24,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast | |||
from modelscope.utils import logger as logging | |||
if is_sentencepiece_available(): | |||
from .tokenization_deberta_v2 import DebertaV2Tokenizer | |||
from .tokenization import DebertaV2Tokenizer | |||
else: | |||
DebertaV2Tokenizer = None | |||
@@ -4,16 +4,16 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .configuration_gpt3 import GPT3Config | |||
from .modeling_gpt3 import GPT3Model | |||
from .gpt3_for_text_generation import GPT3ForTextGeneration | |||
from .tokenizer_gpt3 import JiebaBPETokenizer | |||
from .configuration import GPT3Config | |||
from .backbone import GPT3Model | |||
from .text_generation import GPT3ForTextGeneration | |||
from .tokenizer import JiebaBPETokenizer | |||
else: | |||
_import_structure = { | |||
'configuration_gpt3': ['GPT3Config'], | |||
'modeling_gpt3': ['GPT3Model'], | |||
'gpt3_for_text_generation': ['GPT3ForTextGeneration'], | |||
'tokenizer_gpt3': ['JiebaBPETokenizer'], | |||
'configuration': ['GPT3Config'], | |||
'backbone': ['GPT3Model'], | |||
'text_generation': ['GPT3ForTextGeneration'], | |||
'tokenizer': ['JiebaBPETokenizer'], | |||
} | |||
import sys | |||
@@ -24,7 +24,7 @@ from torch.nn import functional as F | |||
from transformers.modeling_utils import PreTrainedModel | |||
from modelscope.utils.constant import ModelFile | |||
from .configuration_gpt3 import GPT3Config | |||
from .configuration import GPT3Config | |||
class GPT3SelfAttention(nn.Module): |
@@ -4,14 +4,12 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .structbert import SbertModel | |||
from .backbone import GPTNeoModel | |||
else: | |||
_import_structure = { | |||
'structbert': ['SbertModel'], | |||
'backbone': ['GPTNeoModel'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], |
@@ -4,10 +4,11 @@ from transformers import GPTNeoModel as GPTNeoModelTransform | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import BACKBONES | |||
from modelscope.utils.constant import Fields | |||
from modelscope.utils.constant import Tasks | |||
@BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.gpt_neo) | |||
@BACKBONES.register_module( | |||
group_key=Tasks.backbone, module_name=Models.gpt_neo) | |||
class GPTNeoModel(GPTNeoModelTransform): | |||
def __init__(self, **kwargs): |
@@ -37,9 +37,9 @@ class TokenClassificationHead(TorchHead): | |||
sequence_output = inputs | |||
sequence_output = self.dropout(sequence_output) | |||
logits = self.classifier(sequence_output) | |||
return {OutputKeys.LOGITS: logits} | |||
return logits | |||
def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
labels) -> Dict[str, torch.Tensor]: | |||
logits = outputs[OutputKeys.LOGITS] | |||
return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} | |||
return F.cross_entropy(logits, labels) |
@@ -1,164 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.bert import \ | |||
BertForMaskedLM as BertForMaskedLMTransformer | |||
from modelscope.models.nlp.deberta_v2 import \ | |||
DebertaV2ForMaskedLM as DebertaV2ForMaskedLMTransformer | |||
from modelscope.models.nlp.structbert import SbertForMaskedLM | |||
from modelscope.models.nlp.veco import \ | |||
VecoForMaskedLM as VecoForMaskedLMTransformer | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['BertForMaskedLM', 'StructBertForMaskedLM', 'VecoForMaskedLM'] | |||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) | |||
class StructBertForMaskedLM(TorchModel, SbertForMaskedLM): | |||
"""Structbert for MLM model. | |||
Inherited from structbert.SbertForMaskedLM and TorchModel, so this class can be registered into Model sets. | |||
""" | |||
def __init__(self, config, model_dir): | |||
super(TorchModel, self).__init__(model_dir) | |||
SbertForMaskedLM.__init__(self, config) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
labels=None): | |||
output = SbertForMaskedLM.forward( | |||
self, | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
labels=labels) | |||
output[OutputKeys.INPUT_IDS] = input_ids | |||
return output | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
model_dir = kwargs.get('model_dir') | |||
return super(SbertForMaskedLM, StructBertForMaskedLM).from_pretrained( | |||
pretrained_model_name_or_path=model_dir, model_dir=model_dir) | |||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.bert) | |||
class BertForMaskedLM(TorchModel, BertForMaskedLMTransformer): | |||
"""Bert for MLM model. | |||
Inherited from transformers.BertForMaskedLM and TorchModel, so this class can be registered into Model sets. | |||
""" | |||
def __init__(self, config, model_dir): | |||
super(TorchModel, self).__init__(model_dir) | |||
BertForMaskedLMTransformer.__init__(self, config) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
labels=None): | |||
output = BertForMaskedLMTransformer.forward( | |||
self, | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
labels=labels) | |||
output[OutputKeys.INPUT_IDS] = input_ids | |||
return output | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
model_dir = kwargs.get('model_dir') | |||
return super(BertForMaskedLMTransformer, | |||
BertForMaskedLM).from_pretrained( | |||
pretrained_model_name_or_path=model_dir, | |||
model_dir=model_dir) | |||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco) | |||
class VecoForMaskedLM(TorchModel, VecoForMaskedLMTransformer): | |||
"""Veco for MLM model. | |||
Inherited from veco.VecoForMaskedLM and TorchModel, so this class can be registered into Model sets. | |||
""" | |||
def __init__(self, config, model_dir): | |||
super(TorchModel, self).__init__(model_dir) | |||
VecoForMaskedLMTransformer.__init__(self, config) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
labels=None): | |||
output = VecoForMaskedLMTransformer.forward( | |||
self, | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
labels=labels) | |||
output[OutputKeys.INPUT_IDS] = input_ids | |||
return output | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
model_dir = kwargs.get('model_dir') | |||
return super(VecoForMaskedLMTransformer, | |||
VecoForMaskedLM).from_pretrained( | |||
pretrained_model_name_or_path=model_dir, | |||
model_dir=model_dir) | |||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.deberta_v2) | |||
class DebertaV2ForMaskedLM(TorchModel, DebertaV2ForMaskedLMTransformer): | |||
"""Deberta v2 for MLM model. | |||
Inherited from deberta_v2.DebertaV2ForMaskedLM and TorchModel, so this class can be registered into Model sets. | |||
""" | |||
def __init__(self, config, model_dir): | |||
super(TorchModel, self).__init__(model_dir) | |||
DebertaV2ForMaskedLMTransformer.__init__(self, config) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
labels=None): | |||
output = DebertaV2ForMaskedLMTransformer.forward( | |||
self, | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
labels=labels) | |||
output[OutputKeys.INPUT_IDS] = input_ids | |||
return output | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
model_dir = kwargs.get('model_dir') | |||
return super(DebertaV2ForMaskedLMTransformer, | |||
DebertaV2ForMaskedLM).from_pretrained( | |||
pretrained_model_name_or_path=model_dir, | |||
model_dir=model_dir) |
@@ -17,19 +17,19 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .configuration_palm import PalmConfig | |||
from .modeling_palm import ( | |||
from .configuration import PalmConfig | |||
from .backbone import ( | |||
AbsSummarizer, | |||
PalmForConditionalGeneration, | |||
Translator, | |||
) | |||
from .palm_for_text_generation import PalmForTextGeneration | |||
from .text_generation import PalmForTextGeneration | |||
else: | |||
_import_structure = { | |||
'configuration_palm': ['PalmConfig'], | |||
'modeling_palm': | |||
'configuration': ['PalmConfig'], | |||
'backbone': | |||
['AbsSummarizer', 'PalmForConditionalGeneration', 'Translator'], | |||
'palm_for_text_generation': ['PalmForTextGeneration'], | |||
'text_generation': ['PalmForTextGeneration'], | |||
} | |||
import sys | |||
@@ -35,7 +35,7 @@ from transformers.activations import ACT2FN | |||
from transformers.modeling_utils import PreTrainedModel | |||
from modelscope.utils import logger as logging | |||
from .configuration_palm import PalmConfig | |||
from .configuration import PalmConfig | |||
from .dureader_eval import compute_bleu_rouge, normalize | |||
CONFIG_NAME = 'config.json' |
@@ -4,13 +4,13 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .configuration_plug import PlugNLGConfig | |||
from .modeling_plug import PlugModel | |||
from .configuration import PlugNLGConfig | |||
from .backbone import PlugModel | |||
from .distributed_plug import DistributedPlug | |||
else: | |||
_import_structure = { | |||
'configuration_plug': ['PlugNLGConfig'], | |||
'modeling_plug': ['PlugModel'], | |||
'configuration': ['PlugNLGConfig'], | |||
'backbone': ['PlugModel'], | |||
'distributed_plug': ['DistributedPlug'], | |||
} | |||
@@ -28,7 +28,7 @@ from torch import nn | |||
from modelscope.utils.nlp.distributed import (normal_init_method, | |||
scaled_init_method) | |||
from .configuration_plug import PlugNLGConfig, PlugNLUConfig | |||
from .configuration import PlugNLGConfig, PlugNLUConfig | |||
logger = logging.getLogger(__name__) | |||
@@ -1,3 +1,4 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Dict | |||
@@ -14,7 +15,7 @@ from modelscope.utils.nlp.distributed import initialize_distributed | |||
from modelscope.utils.nlp.load_checkpoint import pre_load | |||
from modelscope.utils.torch_utils import set_random_seed_mpu | |||
from . import PlugModel | |||
from .configuration_plug import PlugNLGConfig | |||
from .configuration import PlugNLGConfig | |||
logger = get_logger(__name__) | |||
@@ -18,16 +18,16 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .configuration_ponet import PoNetConfig | |||
from .modeling_ponet import (PoNetForMaskedLM, PoNetModel, | |||
PoNetPreTrainedModel) | |||
from .tokenization_ponet import PoNetTokenizer | |||
from .configuration import PoNetConfig | |||
from .backbone import (PoNetModel, PoNetPreTrainedModel) | |||
from .tokenization import PoNetTokenizer | |||
from .fill_mask import PoNetForMaskedLM | |||
else: | |||
_import_structure = { | |||
'configuration_ponet': ['PoNetConfig'], | |||
'modeling_ponet': | |||
['PoNetForMaskedLM', 'PoNetModel', 'PoNetPreTrainedModel'], | |||
'tokenization_ponet': ['PoNetTokenizer'], | |||
'configuration': ['PoNetConfig'], | |||
'backbone': ['PoNetModel', 'PoNetPreTrainedModel'], | |||
'fill_mask': ['PoNetForMaskedLM'], | |||
'tokenization': ['PoNetTokenizer'], | |||
} | |||
import sys | |||
@@ -16,43 +16,32 @@ | |||
"""PyTorch PoNet model. """ | |||
import math | |||
from dataclasses import dataclass | |||
from distutils.version import LooseVersion | |||
from typing import Optional, Tuple | |||
import torch | |||
import torch.utils.checkpoint | |||
from packaging import version | |||
from torch import nn | |||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |||
from transformers.activations import ACT2FN | |||
from transformers.file_utils import (ModelOutput, add_code_sample_docstrings, | |||
add_start_docstrings, | |||
add_start_docstrings_to_model_forward, | |||
replace_return_docstrings) | |||
from transformers.modeling_outputs import ( | |||
BaseModelOutputWithPastAndCrossAttentions, | |||
BaseModelOutputWithPoolingAndCrossAttentions, | |||
CausalLMOutputWithCrossAttentions, MaskedLMOutput, | |||
SequenceClassifierOutput, TokenClassifierOutput) | |||
from transformers.modeling_outputs import \ | |||
BaseModelOutputWithPastAndCrossAttentions | |||
from transformers.modeling_utils import (PreTrainedModel, | |||
apply_chunking_to_forward, | |||
find_pruneable_heads_and_indices, | |||
prune_linear_layer) | |||
from transformers.models.bert.modeling_bert import \ | |||
load_tf_weights_in_bert as load_tf_weights_in_ponet | |||
from modelscope.metainfo import Models | |||
from modelscope.models import Model, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionBackboneModelOutput | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .configuration_ponet import PoNetConfig | |||
from .configuration import PoNetConfig | |||
logger = get_logger(__name__) | |||
is_pytorch_12plus = LooseVersion(torch.__version__) >= LooseVersion('1.12.0') | |||
_CHECKPOINT_FOR_DOC = 'ponet-base-uncased' | |||
_CONFIG_FOR_DOC = 'PoNetConfig' | |||
_TOKENIZER_FOR_DOC = 'PoNetTokenizer' | |||
CLS_ID = 101 | |||
EOS_ID = 102 | |||
@@ -609,82 +598,20 @@ class PoNetPooler(nn.Module): | |||
return pooled_output | |||
class PoNetPredictionHeadTransform(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
if isinstance(config.hidden_act, str): | |||
self.transform_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.transform_act_fn = config.hidden_act | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.transform_act_fn(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states) | |||
return hidden_states | |||
class PoNetLMPredictionHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.transform = PoNetPredictionHeadTransform(config) | |||
# The output weights are the same as the input embeddings, but there is | |||
# an output-only bias for each token. | |||
self.decoder = nn.Linear( | |||
config.hidden_size, config.vocab_size, bias=False) | |||
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) | |||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` | |||
self.decoder.bias = self.bias | |||
def forward(self, hidden_states): | |||
hidden_states = self.transform(hidden_states) | |||
hidden_states = self.decoder(hidden_states) | |||
return hidden_states | |||
class PoNetOnlyMLMHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.predictions = PoNetLMPredictionHead(config) | |||
def forward(self, sequence_output): | |||
prediction_scores = self.predictions(sequence_output) | |||
return prediction_scores | |||
class PoNetPreTrainingHeads(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.predictions = PoNetLMPredictionHead(config) | |||
self.seq_relationship = nn.Linear(config.hidden_size, 3) | |||
def forward(self, sequence_output, pooled_output): | |||
prediction_scores = self.predictions(sequence_output) | |||
seq_relationship_score = self.seq_relationship(pooled_output) | |||
return prediction_scores, seq_relationship_score | |||
class PoNetPreTrainedModel(PreTrainedModel): | |||
class PoNetPreTrainedModel(TorchModel, PreTrainedModel): | |||
""" | |||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |||
models. | |||
""" | |||
config_class = PoNetConfig | |||
load_tf_weights = load_tf_weights_in_ponet | |||
base_model_prefix = 'ponet' | |||
_keys_to_ignore_on_load_missing = [r'position_ids'] | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config.name_or_path, **kwargs) | |||
super(Model, self).__init__(config) | |||
def _init_weights(self, module): | |||
"""Initialize the weights""" | |||
if isinstance(module, nn.Linear): | |||
@@ -703,51 +630,22 @@ class PoNetPreTrainedModel(PreTrainedModel): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
@dataclass | |||
class PoNetForPreTrainingOutput(ModelOutput): | |||
""" | |||
Output type of :class:`~transformers.PoNetForPreTraining`. | |||
Args: | |||
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): | |||
Total loss as the sum of the masked language modeling loss and the next sequence prediction | |||
(classification) loss. | |||
mlm_loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): | |||
Masked language modeling loss. | |||
sop_loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): | |||
sop loss. | |||
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): | |||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |||
seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): | |||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation | |||
before SoftMax). | |||
hidden_states | |||
(:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed | |||
or when ``config.output_hidden_states=True``): | |||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |||
of shape :obj:`(batch_size, sequence_length, hidden_size)`. | |||
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed | |||
or when ``config.output_attentions=True``): | |||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, | |||
sequence_length, sequence_length)`. | |||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |||
heads. | |||
""" | |||
loss: Optional[torch.FloatTensor] = None | |||
mlm_loss: Optional[torch.FloatTensor] = None | |||
sop_loss: Optional[torch.FloatTensor] = None | |||
prediction_logits: torch.FloatTensor = None | |||
seq_relationship_logits: torch.FloatTensor = None | |||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |||
attentions: Optional[Tuple[torch.FloatTensor]] = None | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
model_dir = kwargs.pop('model_dir', None) | |||
if model_dir is None: | |||
ponet_config = PoNetConfig(**kwargs) | |||
model = cls(ponet_config) | |||
else: | |||
model = super( | |||
Model, | |||
cls).from_pretrained(pretrained_model_name_or_path=model_dir) | |||
return model | |||
PONET_START_DOCSTRING = r""" | |||
@MODELS.register_module(Tasks.backbone, module_name=Models.ponet) | |||
class PoNetModel(PoNetPreTrainedModel): | |||
"""The bare PoNet Model transformer outputting raw hidden-states without any specific head on top. | |||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |||
@@ -763,65 +661,6 @@ PONET_START_DOCSTRING = r""" | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |||
weights. | |||
""" | |||
PONET_INPUTS_DOCSTRING = r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
`What are input IDs? <../glossary.html#input-ids>`__ | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
`What are attention masks? <../glossary.html#attention-mask>`__ | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
`What are token type IDs? <../glossary.html#token-type-ids>`_ | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
`What are position IDs? <../glossary.html#position-ids>`_ | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |||
""" | |||
@add_start_docstrings( | |||
'The bare PoNet Model transformer outputting raw hidden-states without any specific head on top.', | |||
PONET_START_DOCSTRING, | |||
) | |||
class PoNetModel(PoNetPreTrainedModel): | |||
""" | |||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of | |||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is | |||
@@ -834,8 +673,8 @@ class PoNetModel(PoNetPreTrainedModel): | |||
input to the forward pass. | |||
""" | |||
def __init__(self, config, add_pooling_layer=True): | |||
super().__init__(config) | |||
def __init__(self, config, add_pooling_layer=True, **kwargs): | |||
super().__init__(config, **kwargs) | |||
self.config = config | |||
self.embeddings = PoNetEmbeddings(config) | |||
@@ -859,14 +698,6 @@ class PoNetModel(PoNetPreTrainedModel): | |||
for layer, heads in heads_to_prune.items(): | |||
self.encoder.layer[layer].attention.prune_heads(heads) | |||
@add_start_docstrings_to_model_forward( | |||
PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=BaseModelOutputWithPoolingAndCrossAttentions, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
@@ -885,6 +716,49 @@ class PoNetModel(PoNetPreTrainedModel): | |||
return_dict=None, | |||
): | |||
r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |||
encoder_hidden_states | |||
(:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if | |||
@@ -906,6 +780,16 @@ class PoNetModel(PoNetPreTrainedModel): | |||
use_cache (:obj:`bool`, `optional`): | |||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up | |||
decoding (see :obj:`past_key_values`). | |||
Returns: | |||
Returns `modelscope.outputs.AttentionBackboneModelOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base', task='backbone') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base') | |||
>>> print(model(**preprocessor('这是个测试'))) | |||
""" | |||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |||
output_hidden_states = ( | |||
@@ -1006,7 +890,7 @@ class PoNetModel(PoNetPreTrainedModel): | |||
if not return_dict: | |||
return (sequence_output, pooled_output) + encoder_outputs[1:] | |||
return BaseModelOutputWithPoolingAndCrossAttentions( | |||
return AttentionBackboneModelOutput( | |||
last_hidden_state=sequence_output, | |||
pooler_output=pooled_output, | |||
past_key_values=encoder_outputs.past_key_values, | |||
@@ -1014,578 +898,3 @@ class PoNetModel(PoNetPreTrainedModel): | |||
attentions=encoder_outputs.attentions, | |||
cross_attentions=encoder_outputs.cross_attentions, | |||
) | |||
@add_start_docstrings( | |||
""" | |||
PoNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next | |||
sentence prediction (classification)` head. | |||
""", | |||
PONET_START_DOCSTRING, | |||
) | |||
class PoNetForPreTraining(PoNetPreTrainedModel): | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.ponet = PoNetModel(config) | |||
self.cls = PoNetPreTrainingHeads(config) | |||
self.init_weights() | |||
def get_output_embeddings(self): | |||
return self.cls.predictions.decoder | |||
def set_output_embeddings(self, new_embeddings): | |||
self.cls.predictions.decoder = new_embeddings | |||
@add_start_docstrings_to_model_forward( | |||
PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@replace_return_docstrings( | |||
output_type=PoNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
segment_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
next_sentence_label=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): | |||
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., | |||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored | |||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` | |||
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): | |||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair | |||
(see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: | |||
- 0 indicates sequence B is a continuation of sequence A, | |||
- 1 indicates sequence B is a random sequence. | |||
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): | |||
Used to hide legacy arguments that have been deprecated. | |||
Returns: | |||
Example:: | |||
>>> from transformers import PoNetTokenizer, PoNetForPreTraining | |||
>>> import torch | |||
>>> tokenizer = PoNetTokenizer.from_pretrained('ponet-base-uncased') | |||
>>> model = PoNetForPreTraining.from_pretrained('ponet-base-uncased') | |||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") | |||
>>> outputs = model(**inputs) | |||
>>> prediction_logits = outputs.prediction_logits | |||
>>> seq_relationship_logits = outputs.seq_relationship_logits | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.ponet( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
segment_ids=segment_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output, pooled_output = outputs[:2] | |||
prediction_scores, seq_relationship_score = self.cls( | |||
sequence_output, pooled_output) | |||
total_loss = None | |||
masked_lm_loss = None | |||
next_sentence_loss = None | |||
if labels is not None and next_sentence_label is not None: | |||
loss_fct = CrossEntropyLoss() | |||
masked_lm_loss = loss_fct( | |||
prediction_scores.view(-1, self.config.vocab_size), | |||
labels.view(-1)) | |||
next_sentence_loss = loss_fct( | |||
seq_relationship_score.view(-1, 3), | |||
next_sentence_label.view(-1)) | |||
total_loss = masked_lm_loss + next_sentence_loss | |||
if not return_dict: | |||
output = (prediction_scores, seq_relationship_score) + outputs[2:] | |||
return ((total_loss, masked_lm_loss, next_sentence_loss) | |||
+ output) if total_loss is not None else output | |||
return PoNetForPreTrainingOutput( | |||
loss=total_loss, | |||
mlm_loss=masked_lm_loss, | |||
sop_loss=next_sentence_loss, | |||
prediction_logits=prediction_scores, | |||
seq_relationship_logits=seq_relationship_score, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) | |||
@add_start_docstrings( | |||
"""PoNet Model with a `language modeling` head on top for CLM fine-tuning. """, | |||
PONET_START_DOCSTRING) | |||
class PoNetLMHeadModel(PoNetPreTrainedModel): | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
_keys_to_ignore_on_load_missing = [ | |||
r'position_ids', r'predictions.decoder.bias' | |||
] | |||
def __init__(self, config): | |||
super().__init__(config) | |||
if not config.is_decoder: | |||
logger.warning( | |||
'If you want to use `PoNetLMHeadModel` as a standalone, add `is_decoder=True.`' | |||
) | |||
self.ponet = PoNetModel(config, add_pooling_layer=False) | |||
self.cls = PoNetOnlyMLMHead(config) | |||
self.init_weights() | |||
def get_output_embeddings(self): | |||
return self.cls.predictions.decoder | |||
def set_output_embeddings(self, new_embeddings): | |||
self.cls.predictions.decoder = new_embeddings | |||
@add_start_docstrings_to_model_forward( | |||
PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@replace_return_docstrings( | |||
output_type=CausalLMOutputWithCrossAttentions, | |||
config_class=_CONFIG_FOR_DOC) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
segment_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
labels=None, | |||
past_key_values=None, | |||
use_cache=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj: | |||
`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if | |||
the model is configured as a decoder. | |||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in | |||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in | |||
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are | |||
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` | |||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` | |||
with each tuple having 4 tensors of shape : | |||
obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): | |||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. | |||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` | |||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` | |||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. | |||
use_cache (:obj:`bool`, `optional`): | |||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up | |||
decoding (see :obj:`past_key_values`). | |||
Returns: | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
if labels is not None: | |||
use_cache = False | |||
outputs = self.ponet( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
segment_ids=segment_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_attention_mask, | |||
past_key_values=past_key_values, | |||
use_cache=use_cache, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
prediction_scores = self.cls(sequence_output) | |||
lm_loss = None | |||
if labels is not None: | |||
# we are doing next-token prediction; shift prediction scores and input ids by one | |||
shifted_prediction_scores = prediction_scores[:, : | |||
-1, :].contiguous() | |||
labels = labels[:, 1:].contiguous() | |||
loss_fct = CrossEntropyLoss() | |||
lm_loss = loss_fct( | |||
shifted_prediction_scores.view(-1, self.config.vocab_size), | |||
labels.view(-1)) | |||
if not return_dict: | |||
output = (prediction_scores, ) + outputs[2:] | |||
return ((lm_loss, ) + output) if lm_loss is not None else output | |||
return CausalLMOutputWithCrossAttentions( | |||
loss=lm_loss, | |||
logits=prediction_scores, | |||
past_key_values=outputs.past_key_values, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
cross_attentions=outputs.cross_attentions, | |||
) | |||
def prepare_inputs_for_generation(self, | |||
input_ids, | |||
past=None, | |||
attention_mask=None, | |||
**model_kwargs): | |||
input_shape = input_ids.shape | |||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly | |||
if attention_mask is None: | |||
attention_mask = input_ids.new_ones(input_shape) | |||
# cut decoder_input_ids if past is used | |||
if past is not None: | |||
input_ids = input_ids[:, -1:] | |||
return { | |||
'input_ids': input_ids, | |||
'attention_mask': attention_mask, | |||
'past_key_values': past | |||
} | |||
def _reorder_cache(self, past, beam_idx): | |||
reordered_past = () | |||
for layer_past in past: | |||
reordered_past += (tuple( | |||
past_state.index_select(0, beam_idx) | |||
for past_state in layer_past), ) | |||
return reordered_past | |||
@add_start_docstrings( | |||
"""PoNet Model with a `language modeling` head on top. """, | |||
PONET_START_DOCSTRING) | |||
class PoNetForMaskedLM(PoNetPreTrainedModel): | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
_keys_to_ignore_on_load_missing = [ | |||
r'position_ids', r'predictions.decoder.bias' | |||
] | |||
def __init__(self, config): | |||
super().__init__(config) | |||
if config.is_decoder: | |||
logger.warning( | |||
'If you want to use `PoNetForMaskedLM` make sure `config.is_decoder=False` for ' | |||
'bi-directional self-attention.') | |||
self.ponet = PoNetModel(config, add_pooling_layer=False) | |||
self.cls = PoNetOnlyMLMHead(config) | |||
self.init_weights() | |||
def get_output_embeddings(self): | |||
return self.cls.predictions.decoder | |||
def set_output_embeddings(self, new_embeddings): | |||
self.cls.predictions.decoder = new_embeddings | |||
@add_start_docstrings_to_model_forward( | |||
PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=MaskedLMOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
segment_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., | |||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored | |||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.ponet( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
segment_ids=segment_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_attention_mask, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
prediction_scores = self.cls(sequence_output) | |||
masked_lm_loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() # -100 index = padding token | |||
masked_lm_loss = loss_fct( | |||
prediction_scores.view(-1, self.config.vocab_size), | |||
labels.view(-1)) | |||
if not return_dict: | |||
output = (prediction_scores, ) + outputs[2:] | |||
return ((masked_lm_loss, ) | |||
+ output) if masked_lm_loss is not None else output | |||
return MaskedLMOutput( | |||
loss=masked_lm_loss, | |||
logits=prediction_scores, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) | |||
@add_start_docstrings( | |||
""" | |||
PoNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled | |||
output) e.g. for GLUE tasks. | |||
""", | |||
PONET_START_DOCSTRING, | |||
) | |||
class PoNetForSequenceClassification(PoNetPreTrainedModel): | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.num_labels = config.num_labels | |||
self.config = config | |||
self.ponet = PoNetModel(config) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
self.init_weights() | |||
@add_start_docstrings_to_model_forward( | |||
PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=SequenceClassifierOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
segment_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., | |||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.ponet( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
segment_ids=segment_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
pooled_output = outputs[1] | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
loss = None | |||
if labels is not None: | |||
if self.config.problem_type is None: | |||
if self.num_labels == 1: | |||
self.config.problem_type = 'regression' | |||
elif self.num_labels > 1 and (labels.dtype == torch.long | |||
or labels.dtype == torch.int): | |||
self.config.problem_type = 'single_label_classification' | |||
else: | |||
self.config.problem_type = 'multi_label_classification' | |||
if self.config.problem_type == 'regression': | |||
loss_fct = MSELoss() | |||
if self.num_labels == 1: | |||
loss = loss_fct(logits.squeeze(), labels.squeeze()) | |||
else: | |||
loss = loss_fct(logits, labels) | |||
elif self.config.problem_type == 'single_label_classification': | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct( | |||
logits.view(-1, self.num_labels), labels.view(-1)) | |||
elif self.config.problem_type == 'multi_label_classification': | |||
loss_fct = BCEWithLogitsLoss() | |||
loss = loss_fct(logits, labels) | |||
if not return_dict: | |||
output = (logits, ) + outputs[2:] | |||
return ((loss, ) + output) if loss is not None else output | |||
return SequenceClassifierOutput( | |||
loss=loss, | |||
logits=logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) | |||
@add_start_docstrings( | |||
""" | |||
PoNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for | |||
Named-Entity-Recognition (NER) tasks. | |||
""", | |||
PONET_START_DOCSTRING, | |||
) | |||
class PoNetForTokenClassification(PoNetPreTrainedModel): | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.num_labels = config.num_labels | |||
self.ponet = PoNetModel(config, add_pooling_layer=False) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
self.init_weights() | |||
@add_start_docstrings_to_model_forward( | |||
PONET_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint=_CHECKPOINT_FOR_DOC, | |||
output_type=TokenClassifierOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
segment_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - | |||
1]``. | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.ponet( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
segment_ids=segment_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
sequence_output = self.dropout(sequence_output) | |||
logits = self.classifier(sequence_output) | |||
loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() | |||
# Only keep active parts of the loss | |||
if attention_mask is not None: | |||
active_loss = attention_mask.view(-1) == 1 | |||
active_logits = logits.view(-1, self.num_labels) | |||
active_labels = torch.where( | |||
active_loss, labels.view(-1), | |||
torch.tensor(loss_fct.ignore_index).type_as(labels)) | |||
loss = loss_fct(active_logits, active_labels) | |||
else: | |||
loss = loss_fct( | |||
logits.view(-1, self.num_labels), labels.view(-1)) | |||
if not return_dict: | |||
output = (logits, ) + outputs[2:] | |||
return ((loss, ) + output) if loss is not None else output | |||
return TokenClassifierOutput( | |||
loss=loss, | |||
logits=logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) |
@@ -34,8 +34,7 @@ class PoNetConfig(PretrainedConfig): | |||
Args: | |||
vocab_size (:obj:`int`, `optional`, defaults to 30522): | |||
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the | |||
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or | |||
:class:`~transformers.TFBertModel`. | |||
:obj:`inputs_ids` passed. | |||
hidden_size (:obj:`int`, `optional`, defaults to 768): | |||
Dimensionality of the encoder layers and the pooler layer. | |||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12): | |||
@@ -55,8 +54,7 @@ class PoNetConfig(PretrainedConfig): | |||
The maximum sequence length that this model might ever be used with. Typically set this to something large | |||
just in case (e.g., 512 or 1024 or 2048). | |||
type_vocab_size (:obj:`int`, `optional`, defaults to 2): | |||
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or | |||
:class:`~transformers.TFBertModel`. | |||
The vocabulary size of the :obj:`token_type_ids` passed. | |||
initializer_range (:obj:`float`, `optional`, defaults to 0.02): | |||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): |
@@ -0,0 +1,252 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO Team Authors. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch.utils.checkpoint | |||
from torch import nn | |||
from torch.nn import CrossEntropyLoss | |||
from transformers.activations import ACT2FN | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionFillMaskModelOutput | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .backbone import PoNetModel, PoNetPreTrainedModel | |||
logger = get_logger(__name__) | |||
class PoNetPredictionHeadTransform(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
if isinstance(config.hidden_act, str): | |||
self.transform_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.transform_act_fn = config.hidden_act | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.transform_act_fn(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states) | |||
return hidden_states | |||
class PoNetLMPredictionHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.transform = PoNetPredictionHeadTransform(config) | |||
# The output weights are the same as the input embeddings, but there is | |||
# an output-only bias for each token. | |||
self.decoder = nn.Linear( | |||
config.hidden_size, config.vocab_size, bias=False) | |||
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) | |||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` | |||
self.decoder.bias = self.bias | |||
def forward(self, hidden_states): | |||
hidden_states = self.transform(hidden_states) | |||
hidden_states = self.decoder(hidden_states) | |||
return hidden_states | |||
class PoNetOnlyMLMHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.predictions = PoNetLMPredictionHead(config) | |||
def forward(self, sequence_output): | |||
prediction_scores = self.predictions(sequence_output) | |||
return prediction_scores | |||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.ponet) | |||
class PoNetForMaskedLM(PoNetPreTrainedModel): | |||
r"""PoNet Model with a `language modeling` head on top. | |||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |||
pruning heads etc.) | |||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | |||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | |||
general usage and behavior. | |||
Preprocessor: | |||
This is the fill_mask model of PoNet, the preprocessor of this model | |||
is `modelscope.preprocessors.FillMaskPoNetPreprocessor`. | |||
Parameters: | |||
config (:class:`~modelscope.models.nlp.ponet.PoNetConfig`): | |||
Model configuration class with all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |||
weights. | |||
""" | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
_keys_to_ignore_on_load_missing = [ | |||
r'position_ids', r'predictions.decoder.bias' | |||
] | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config) | |||
if config.is_decoder: | |||
logger.warning( | |||
'If you want to use `PoNetForMaskedLM` make sure `config.is_decoder=False` for ' | |||
'bi-directional self-attention.') | |||
self.ponet = PoNetModel(config, add_pooling_layer=False) | |||
self.cls = PoNetOnlyMLMHead(config) | |||
self.init_weights() | |||
def get_output_embeddings(self): | |||
return self.cls.predictions.decoder | |||
def set_output_embeddings(self, new_embeddings): | |||
self.cls.predictions.decoder = new_embeddings | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
segment_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`('batch_size, sequence_length', hidden_size)`, | |||
`optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., | |||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored | |||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` | |||
Returns: | |||
Returns `modelscope.outputs.AttentionFillMaskModelOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base') | |||
>>> # Call the model, return some tensors | |||
>>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。'))) | |||
>>> # Call the pipeline | |||
>>> from modelscope.pipelines import pipeline | |||
>>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) | |||
>>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。')) | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.ponet( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
segment_ids=segment_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_attention_mask, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
prediction_scores = self.cls(sequence_output) | |||
masked_lm_loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() # -100 index = padding token | |||
masked_lm_loss = loss_fct( | |||
prediction_scores.view(-1, self.config.vocab_size), | |||
labels.view(-1)) | |||
if not return_dict: | |||
output = (prediction_scores, ) + outputs[2:] | |||
return ((masked_lm_loss, ) | |||
+ output) if masked_lm_loss is not None else output | |||
return AttentionFillMaskModelOutput( | |||
loss=masked_lm_loss, | |||
logits=prediction_scores, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
input_ids=input_ids, | |||
) |
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | |||
from transformers.file_utils import PaddingStrategy | |||
from transformers.models.bert.tokenization_bert import BertTokenizer | |||
from transformers.tokenization_utils import BatchEncoding, EncodedInput | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger |
@@ -1,53 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.ponet import \ | |||
PoNetForMaskedLM as PoNetForMaskedLMTransformer | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['PoNetForMaskedLM'] | |||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.ponet) | |||
class PoNetForMaskedLM(TorchModel, PoNetForMaskedLMTransformer): | |||
"""PoNet for MLM model.'. | |||
Inherited from ponet.PoNetForMaskedLM and TorchModel, so this class can be registered into Model sets. | |||
""" | |||
def __init__(self, config, model_dir): | |||
super(TorchModel, self).__init__(model_dir) | |||
PoNetForMaskedLMTransformer.__init__(self, config) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
segment_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
labels=None): | |||
output = PoNetForMaskedLMTransformer.forward( | |||
self, | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
segment_ids=segment_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
labels=labels) | |||
output[OutputKeys.INPUT_IDS] = input_ids | |||
return output | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
model_dir = kwargs.get('model_dir') | |||
return super(PoNetForMaskedLMTransformer, | |||
PoNetForMaskedLM).from_pretrained( | |||
pretrained_model_name_or_path=model_dir, | |||
model_dir=model_dir) |
@@ -1,74 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
import numpy as np | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['SentenceEmbedding'] | |||
@MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) | |||
class SentenceEmbedding(TorchModel, SbertPreTrainedModel): | |||
base_model_prefix: str = 'bert' | |||
supports_gradient_checkpointing = True | |||
_keys_to_ignore_on_load_missing = [r'position_ids'] | |||
def __init__(self, config, model_dir): | |||
super().__init__(model_dir) | |||
self.config = config | |||
setattr(self, self.base_model_prefix, self.build_base_model()) | |||
def build_base_model(self): | |||
from .structbert import SbertModel | |||
return SbertModel(self.config, add_pooling_layer=False) | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Any]): the preprocessed data | |||
Returns: | |||
Dict[str, np.ndarray]: results | |||
Example: | |||
{ | |||
'predictions': array([1]), # lable 0-negative 1-positive | |||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
} | |||
""" | |||
return self.base_model(**input) | |||
def postprocess(self, inputs: Dict[str, np.ndarray], | |||
**kwargs) -> Dict[str, np.ndarray]: | |||
embs = inputs['last_hidden_state'][:, 0].cpu().numpy() | |||
num_sent = embs.shape[0] | |||
if num_sent >= 2: | |||
scores = np.dot(embs[0:1, ], np.transpose(embs[1:, ], | |||
(1, 0))).tolist()[0] | |||
else: | |||
scores = [] | |||
result = {'text_embedding': embs, 'scores': scores} | |||
return result | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
@param kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
model_args = {} | |||
return super(SbertPreTrainedModel, SentenceEmbedding).from_pretrained( | |||
pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
model_dir=kwargs.get('model_dir'), | |||
**model_args) |
@@ -1,287 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from abc import abstractmethod | |||
from torch import nn | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.bert import BertPreTrainedModel | |||
from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||
from modelscope.models.nlp.veco import \ | |||
VecoForSequenceClassification as VecoForSequenceClassificationTransform | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.hub import parse_label_mapping | |||
from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
torch_nested_numpify) | |||
__all__ = [ | |||
'SbertForSequenceClassification', 'VecoForSequenceClassification', | |||
'BertForSequenceClassification' | |||
] | |||
class SequenceClassificationBase(TorchModel): | |||
"""A sequence classification base class for all the fitted sequence classification models. | |||
""" | |||
base_model_prefix: str = 'bert' | |||
def __init__(self, config, model_dir): | |||
super().__init__(model_dir) | |||
self.num_labels = config.num_labels | |||
self.config = config | |||
setattr(self, self.base_model_prefix, self.build_base_model()) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
@abstractmethod | |||
def build_base_model(self): | |||
"""Build the backbone model. | |||
Returns: the backbone instance. | |||
""" | |||
pass | |||
@property | |||
def base_model(self): | |||
return getattr(self, self.base_model_prefix) | |||
def forward(self, **kwargs): | |||
labels = None | |||
if OutputKeys.LABEL in kwargs: | |||
labels = kwargs.pop(OutputKeys.LABEL) | |||
elif OutputKeys.LABELS in kwargs: | |||
labels = kwargs.pop(OutputKeys.LABELS) | |||
outputs = self.base_model.forward(**kwargs) | |||
# backbone model should return pooled_output as its second output | |||
pooled_output = outputs[1] | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
if labels is not None: | |||
loss_fct = nn.CrossEntropyLoss() | |||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |||
return {OutputKeys.LOGITS: logits, OutputKeys.LOSS: loss} | |||
return {OutputKeys.LOGITS: logits} | |||
def postprocess(self, input, **kwargs): | |||
logits = input[OutputKeys.LOGITS] | |||
probs = torch_nested_numpify(torch_nested_detach(logits.softmax(-1))) | |||
pred = torch_nested_numpify(torch_nested_detach(logits.argmax(-1))) | |||
logits = torch_nested_numpify(torch_nested_detach(logits)) | |||
res = { | |||
OutputKeys.PREDICTIONS: pred, | |||
OutputKeys.PROBABILITIES: probs, | |||
OutputKeys.LOGITS: logits | |||
} | |||
return res | |||
@MODELS.register_module( | |||
Tasks.sentence_similarity, module_name=Models.structbert) | |||
@MODELS.register_module( | |||
Tasks.sentiment_classification, module_name=Models.structbert) | |||
@MODELS.register_module(Tasks.nli, module_name=Models.structbert) | |||
@MODELS.register_module( | |||
Tasks.zero_shot_classification, module_name=Models.structbert) | |||
class SbertForSequenceClassification(SequenceClassificationBase, | |||
SbertPreTrainedModel): | |||
"""Sbert sequence classification model. | |||
Inherited from SequenceClassificationBase. | |||
""" | |||
base_model_prefix: str = 'bert' | |||
supports_gradient_checkpointing = True | |||
_keys_to_ignore_on_load_missing = [r'position_ids'] | |||
def __init__(self, config, model_dir): | |||
if hasattr(config, 'base_model_prefix'): | |||
SbertForSequenceClassification.base_model_prefix = config.base_model_prefix | |||
super().__init__(config, model_dir) | |||
def build_base_model(self): | |||
from .structbert import SbertModel | |||
return SbertModel(self.config, add_pooling_layer=True) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
labels=None, | |||
**kwargs): | |||
return super().forward( | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
labels=labels) | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
@param kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
num_labels: An optional arg to tell the model how many classes to initialize. | |||
Method will call utils.parse_label_mapping if num_labels not supplied. | |||
If num_labels is not found, the model will use the default setting (2 classes). | |||
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
model_dir = kwargs.get('model_dir') | |||
num_labels = kwargs.get('num_labels') | |||
if num_labels is None: | |||
label2id = parse_label_mapping(model_dir) | |||
if label2id is not None and len(label2id) > 0: | |||
num_labels = len(label2id) | |||
cls.id2label = {id: label for label, id in label2id.items()} | |||
model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
return super(SbertPreTrainedModel, | |||
SbertForSequenceClassification).from_pretrained( | |||
pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
model_dir=kwargs.get('model_dir'), | |||
**model_args) | |||
@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.veco) | |||
@MODELS.register_module( | |||
Tasks.sentiment_classification, module_name=Models.veco) | |||
@MODELS.register_module(Tasks.nli, module_name=Models.veco) | |||
class VecoForSequenceClassification(TorchModel, | |||
VecoForSequenceClassificationTransform): | |||
"""Veco sequence classification model. | |||
Inherited from VecoForSequenceClassification and TorchModel, so this class can be registered into the model set. | |||
This model cannot be inherited from SequenceClassificationBase, because Veco/XlmRoberta's classification structure | |||
is different. | |||
""" | |||
def __init__(self, config, model_dir): | |||
super().__init__(model_dir) | |||
VecoForSequenceClassificationTransform.__init__(self, config) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
**kwargs): | |||
return VecoForSequenceClassificationTransform.forward( | |||
self, | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
labels=labels) | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
@param kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
num_labels: An optional arg to tell the model how many classes to initialize. | |||
Method will call utils.parse_label_mapping if num_labels not supplied. | |||
If num_labels is not found, the model will use the default setting (2 classes). | |||
@return: The loaded model, which is initialized by veco.VecoForSequenceClassification.from_pretrained | |||
""" | |||
model_dir = kwargs.get('model_dir') | |||
num_labels = kwargs.get('num_labels') | |||
if num_labels is None: | |||
label2id = parse_label_mapping(model_dir) | |||
if label2id is not None and len(label2id) > 0: | |||
num_labels = len(label2id) | |||
model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
return super(VecoForSequenceClassificationTransform, | |||
VecoForSequenceClassification).from_pretrained( | |||
pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
model_dir=kwargs.get('model_dir'), | |||
**model_args) | |||
@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.bert) | |||
@MODELS.register_module( | |||
Tasks.sentiment_classification, module_name=Models.bert) | |||
@MODELS.register_module(Tasks.nli, module_name=Models.bert) | |||
@MODELS.register_module(Tasks.text_classification, module_name=Models.bert) | |||
class BertForSequenceClassification(SequenceClassificationBase, | |||
BertPreTrainedModel): | |||
"""Bert sequence classification model. | |||
Inherited from SequenceClassificationBase. | |||
""" | |||
base_model_prefix: str = 'bert' | |||
supports_gradient_checkpointing = True | |||
_keys_to_ignore_on_load_missing = [r'position_ids'] | |||
def __init__(self, config, model_dir): | |||
if hasattr(config, 'base_model_prefix'): | |||
BertForSequenceClassification.base_model_prefix = config.base_model_prefix | |||
super().__init__(config, model_dir) | |||
def build_base_model(self): | |||
from .bert import BertModel | |||
return BertModel(self.config, add_pooling_layer=True) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
**kwargs): | |||
return super().forward( | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
labels=labels, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict) | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
@param kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
num_labels: An optional arg to tell the model how many classes to initialize. | |||
Method will call utils.parse_label_mapping if num_labels not supplied. | |||
If num_labels is not found, the model will use the default setting (2 classes). | |||
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
model_dir = kwargs.get('model_dir') | |||
num_labels = kwargs.get('num_labels') | |||
if num_labels is None: | |||
label2id = parse_label_mapping(model_dir) | |||
if label2id is not None and len(label2id) > 0: | |||
num_labels = len(label2id) | |||
model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
return super(BertPreTrainedModel, | |||
BertForSequenceClassification).from_pretrained( | |||
pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
model_dir=kwargs.get('model_dir'), | |||
**model_args) |
@@ -1,20 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .model import SpaceGenerator | |||
from .model import SpaceModelBase, SpaceTokenizer, SpaceConfig | |||
from .space_for_dialog_intent_prediction import SpaceForDialogIntent | |||
from .space_for_dialog_modeling import SpaceForDialogModeling | |||
from .space_for_dialog_state_tracking import SpaceForDialogStateTracking | |||
from .model import SpaceModelBase, SpaceTokenizer | |||
from .dialog_intent_prediction import SpaceForDialogIntent | |||
from .dialog_modeling import SpaceForDialogModeling | |||
from .dialog_state_tracking import SpaceForDST | |||
from .configuration import SpaceConfig | |||
else: | |||
_import_structure = { | |||
'model': | |||
['SpaceGenerator', 'SpaceModelBase', 'SpaceTokenizer', 'SpaceConfig'], | |||
'space_for_dialog_intent_prediction': ['SpaceForDialogIntent'], | |||
'space_for_dialog_modeling': ['SpaceForDialogModeling'], | |||
'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'], | |||
'model': ['SpaceGenerator', 'SpaceModelBase', 'SpaceTokenizer'], | |||
'dialog_intent_prediction': ['SpaceForDialogIntent'], | |||
'dialog_modeling': ['SpaceForDialogModeling'], | |||
'dialog_state_tracking': ['SpaceForDST'], | |||
'configuration': ['SpaceConfig'] | |||
} | |||
import sys | |||
@@ -8,7 +8,7 @@ from modelscope.models import TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.space import SpaceGenerator, SpaceModelBase | |||
from modelscope.preprocessors.space import IntentBPETextField | |||
from modelscope.preprocessors.nlp import IntentBPETextField | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
@@ -24,6 +24,10 @@ class SpaceForDialogIntent(TorchModel): | |||
Args: | |||
model_dir (str): the model path. | |||
text_field (`BPETextField`, *optional*, defaults to `IntentBPETextField`): | |||
The text field. | |||
config (`Config`, *optional*, defaults to config in model hub): | |||
The config. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
@@ -72,10 +76,21 @@ class SpaceForDialogIntent(TorchModel): | |||
Example: | |||
{ | |||
'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05 | |||
1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04 | |||
6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01 | |||
2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32) | |||
1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04 | |||
6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01 | |||
2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32), | |||
} | |||
Example: | |||
>>> from modelscope.hub.snapshot_download import snapshot_download | |||
>>> from modelscope.models.nlp import SpaceForDialogIntent | |||
>>> from modelscope.preprocessors import DialogIntentPredictionPreprocessor | |||
>>> cache_path = snapshot_download('damo/nlp_space_dialog-intent-prediction') | |||
>>> preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | |||
>>> model = SpaceForDialogIntent( | |||
model_dir=cache_path, | |||
text_field=preprocessor.text_field, | |||
config=preprocessor.config) | |||
>>> print(model(preprocessor("What do I need to do for the card activation?"))) | |||
""" | |||
import numpy as np | |||
pred = self.trainer.forward(input) |
@@ -8,7 +8,7 @@ from modelscope.models import TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.space import SpaceGenerator, SpaceModelBase | |||
from modelscope.preprocessors.space import MultiWOZBPETextField | |||
from modelscope.preprocessors.nlp import MultiWOZBPETextField | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
@@ -23,7 +23,12 @@ class SpaceForDialogModeling(TorchModel): | |||
"""initialize the test generation model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
model_dir (`str`): | |||
The model path. | |||
text_field (`BPETextField`, *optional*, defaults to `MultiWOZBPETextField`): | |||
The text field. | |||
config (`Config`, *optional*, defaults to config in model hub): | |||
The config. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
@@ -82,6 +87,19 @@ class SpaceForDialogModeling(TorchModel): | |||
'aspn': array([47,8345,32,29,1983]), | |||
'db': array([19, 24, 20]), | |||
} | |||
Examples: | |||
>>> from modelscope.hub.snapshot_download import snapshot_download | |||
>>> from modelscope.models.nlp import SpaceForDialogModeling | |||
>>> from modelscope.preprocessors import DialogModelingPreprocessor | |||
>>> cache_path = snapshot_download('damo/nlp_space_dialog-modeling') | |||
>>> preprocessor = DialogModelingPreprocessor(model_dir=cache_path) | |||
>>> model = SpaceForDialogModeling(model_dir=cache_path, | |||
text_field=preprocessor.text_field, | |||
config=preprocessor.config) | |||
>>> print(model(preprocessor({ | |||
'user_input': 'i would like a taxi from saint john \'s college to pizza hut fen ditton .', | |||
'history': {} | |||
}))) | |||
""" | |||
first_turn = input['first_turn'] |
@@ -1,6 +1,6 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
@@ -16,14 +16,22 @@ | |||
# limitations under the License. | |||
"""PyTorch Space model. mainly copied from :module:`~transformers.modeling_xlm_roberta`""" | |||
from typing import Dict | |||
import torch | |||
from torch import nn | |||
from torch.nn import CrossEntropyLoss | |||
from transformers.file_utils import add_start_docstrings | |||
from transformers.modeling_utils import PreTrainedModel | |||
from modelscope.models.nlp.structbert.modeling_sbert import ( | |||
SbertForMaskedLM, SbertModel, SbertPreTrainedModel) | |||
from .configuration_space import SpaceConfig | |||
from modelscope.metainfo import Models | |||
from modelscope.models import Model, TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.structbert import (SbertForMaskedLM, SbertModel, | |||
SbertPreTrainedModel) | |||
from modelscope.utils.constant import Tasks | |||
from .configuration import SpaceConfig | |||
SPACE_START_DOCSTRING = r""" | |||
@@ -57,6 +65,63 @@ class SpaceModel(SbertModel): | |||
config_class = SpaceConfig | |||
class SpacePreTrainedModel(TorchModel, PreTrainedModel): | |||
""" | |||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |||
models. | |||
""" | |||
config_class = SpaceConfig | |||
base_model_prefix = 'bert' | |||
supports_gradient_checkpointing = True | |||
_keys_to_ignore_on_load_missing = [r'position_ids'] | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config.name_or_path, **kwargs) | |||
super(Model, self).__init__(config) | |||
def _init_weights(self, module): | |||
"""Initialize the weights""" | |||
if isinstance(module, nn.Linear): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.bias is not None: | |||
module.bias.data.zero_() | |||
elif isinstance(module, nn.Embedding): | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.padding_idx is not None: | |||
module.weight.data[module.padding_idx].zero_() | |||
elif isinstance(module, nn.LayerNorm): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
@param kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
num_labels: An optional arg to tell the model how many classes to initialize. | |||
Method will call utils.parse_label_mapping if num_labels is not input. | |||
label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). | |||
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
model_dir = kwargs.pop('model_dir', None) | |||
if model_dir is None: | |||
config = SpaceConfig(**kwargs) | |||
model = cls(config) | |||
else: | |||
model_kwargs = {} | |||
model = super(Model, cls).from_pretrained( | |||
pretrained_model_name_or_path=model_dir, **model_kwargs) | |||
return model | |||
@add_start_docstrings( | |||
""" | |||
Space Model transformer with Dialog state tracking heads on top (a inform projection | |||
@@ -65,7 +130,9 @@ class SpaceModel(SbertModel): | |||
""", | |||
SPACE_START_DOCSTRING, | |||
) | |||
class SpaceForDST(SbertPreTrainedModel): | |||
@MODELS.register_module( | |||
Tasks.task_oriented_conversation, module_name=Models.space_dst) | |||
class SpaceForDST(SpacePreTrainedModel): | |||
def __init__(self, config): | |||
super(SpaceForDST, self).__init__(config) | |||
@@ -113,18 +180,105 @@ class SpaceForDST(SbertPreTrainedModel): | |||
self.init_weights() | |||
def forward(self, | |||
input_ids, | |||
input_mask=None, | |||
segment_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
start_pos=None, | |||
end_pos=None, | |||
inform_slot_id=None, | |||
refer_id=None, | |||
class_label_id=None, | |||
diag_state=None): | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Tensor]): the preprocessed data | |||
Returns: | |||
Dict[str, Tensor]: results | |||
Example: | |||
{ | |||
'inputs': dict(input_ids, input_masks,start_pos), # tracking states | |||
'outputs': dict(slots_logits), | |||
'unique_ids': str(test-example.json-0), # default value | |||
'input_ids_unmasked': array([101, 7632, 1010,0,0,0]) | |||
'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||
'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||
'prefix': str('final'), #default value | |||
'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]) | |||
} | |||
Example: | |||
>>> from modelscope.hub.snapshot_download import snapshot_download | |||
>>> from modelscope.models.nlp import SpaceForDST | |||
>>> from modelscope.preprocessors import DialogStateTrackingPreprocessor | |||
>>> cache_path = snapshot_download('damo/nlp_space_dialog-state-tracking') | |||
>>> model = SpaceForDST.from_pretrained(cache_path) | |||
>>> preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | |||
>>> print(model(preprocessor({ | |||
'utter': { | |||
'User-1': "Hi, I'm looking for a train that is going" | |||
"to cambridge and arriving there by 20:45, is there anything like that?" | |||
}, | |||
'history_states': [{}] | |||
}))) | |||
""" | |||
import numpy as np | |||
import torch | |||
# self.model.eval() ???? | |||
batch = input['batch'] | |||
features = input['features'] | |||
diag_state = input['diag_state'] | |||
turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] | |||
reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] | |||
for slot in self.config.dst_slot_list: | |||
for i in reset_diag_state: | |||
diag_state[slot][i] = 0 | |||
with torch.no_grad(): | |||
inputs = { | |||
'input_ids': batch[0], | |||
'input_mask': batch[1], | |||
'segment_ids': batch[2], | |||
'start_pos': batch[3], | |||
'end_pos': batch[4], | |||
'inform_slot_id': batch[5], | |||
'refer_id': batch[6], | |||
'diag_state': diag_state, | |||
'class_label_id': batch[8] | |||
} | |||
unique_ids = [features[i.item()].guid for i in batch[9]] | |||
values = [features[i.item()].values for i in batch[9]] | |||
input_ids_unmasked = [ | |||
features[i.item()].input_ids_unmasked for i in batch[9] | |||
] | |||
inform = [features[i.item()].inform for i in batch[9]] | |||
outputs = self._forward(**inputs) | |||
# Update dialog state for next turn. | |||
for slot in self.config.dst_slot_list: | |||
updates = outputs[2][slot].max(1)[1] | |||
for i, u in enumerate(updates): | |||
if u != 0: | |||
diag_state[slot][i] = u | |||
return { | |||
'inputs': inputs, | |||
'outputs': outputs, | |||
'unique_ids': unique_ids, | |||
'input_ids_unmasked': input_ids_unmasked, | |||
'values': values, | |||
'inform': inform, | |||
'prefix': 'final', | |||
'ds': input['ds'] | |||
} | |||
def _forward(self, | |||
input_ids, | |||
input_mask=None, | |||
segment_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
start_pos=None, | |||
end_pos=None, | |||
inform_slot_id=None, | |||
refer_id=None, | |||
class_label_id=None, | |||
diag_state=None): | |||
outputs = self.bert( | |||
input_ids, | |||
attention_mask=input_mask, | |||
@@ -132,8 +286,8 @@ class SpaceForDST(SbertPreTrainedModel): | |||
position_ids=position_ids, | |||
head_mask=head_mask) | |||
sequence_output = outputs[0] | |||
pooled_output = outputs[1] | |||
sequence_output = outputs.last_hidden_state | |||
pooled_output = outputs.pooler_output | |||
sequence_output = self.dropout(sequence_output) | |||
pooled_output = self.dropout(pooled_output) | |||
@@ -233,36 +387,6 @@ class SpaceForDST(SbertPreTrainedModel): | |||
per_slot_start_logits, | |||
per_slot_end_logits, | |||
per_slot_refer_logits, | |||
) + outputs[2:] | |||
) + (outputs.embedding_output, ) | |||
return outputs | |||
@add_start_docstrings( | |||
'The Space Model Model with a `language modeling` head on tops', | |||
SPACE_START_DOCSTRING, | |||
) | |||
class SpaceForMaskedLM(SbertForMaskedLM): | |||
""" | |||
This class overrides [`SbertForMaskedLM`]. Please check the superclass for the | |||
appropriate documentation alongside usage examples. | |||
""" | |||
config_class = SpaceConfig | |||
@add_start_docstrings( | |||
""" | |||
Space Model with only one head on top as done during the pretraining: a `masked language modeling` head. | |||
""", | |||
SPACE_START_DOCSTRING, | |||
) | |||
class SpaceForPreTraining(SbertPreTrainedModel): | |||
def __init__(self, model_name_or_path: str): | |||
super(SpaceForPreTraining, self).__init__() | |||
self.bert_model = SpaceForMaskedLM.from_pretrained(model_name_or_path) | |||
def forward(self, input_ids: torch.tensor, mlm_labels: torch.tensor): | |||
outputs = self.bert_model(input_ids, masked_lm_labels=mlm_labels) | |||
return outputs[0] |
@@ -1,10 +1,8 @@ | |||
from .configuration_space import SpaceConfig | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .gen_unified_transformer import GenUnifiedTransformer | |||
from .generator import SpaceGenerator | |||
from .intent_unified_transformer import IntentUnifiedTransformer | |||
from .model_base import SpaceModelBase | |||
from .modeling_space import (SpaceForDST, SpaceForMaskedLM, | |||
SpaceForPreTraining, SpaceModel) | |||
from .tokenization_space import (BasicTokenizer, SpaceTokenizer, | |||
WordpieceTokenizer) | |||
from .unified_transformer import UnifiedTransformer |
@@ -71,14 +71,11 @@ class SpaceGenerator(object): | |||
return | |||
def __call__(self, step_fn, state): | |||
""" | |||
Running generation. | |||
@param : step_fn : decoding one step | |||
@type : function | |||
"""Running generation. | |||
@param : state : initial state | |||
@type : dict | |||
Args: | |||
step_fn (`function`) : decoding one step | |||
state(`dict`) : initial state | |||
""" | |||
raise NotImplementedError | |||
@@ -104,11 +101,9 @@ class BeamSearch(SpaceGenerator): | |||
""" | |||
Running beam search. | |||
@param : step_fn : decoding one step | |||
@type : function | |||
@param : state : initial state | |||
@type : dict | |||
Args: | |||
step_fn(`function`) : decoding one step | |||
state(`dict`) : initial state | |||
""" | |||
if prev_input is not None: | |||
@@ -64,8 +64,8 @@ class SpaceModelBase(nn.Module): | |||
""" | |||
Forward process, include real forward, collect metrices and optimize(optional) | |||
@params : inputs : input data | |||
@type : dict of numpy.ndarray/int/float/... | |||
Args: | |||
inputs(`dict` of numpy.ndarray/int/float/...) : input data | |||
""" | |||
if is_training: | |||
self.train() | |||
@@ -85,11 +85,10 @@ class SpaceModelBase(nn.Module): | |||
eos_id=None, | |||
max_gen_len=None, | |||
prev_input=None): | |||
""" | |||
Inference process. | |||
"""Inference process. | |||
@params : inputs : input data | |||
@type : dict of numpy.ndarray/int/float/... | |||
Args: | |||
inputs(`dict` of numpy.ndarray/int/float/...) : input data | |||
""" | |||
self.eval() | |||
results = self._infer( | |||
@@ -1,5 +1,5 @@ | |||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
@@ -119,15 +119,12 @@ class UnifiedTransformer(SpaceModelBase): | |||
input_mask, | |||
append_head=False, | |||
auto_regressive=False): | |||
""" | |||
Create attention mask. | |||
"""Create attention mask. | |||
from sequence to matrix:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] | |||
@param : input_mask | |||
@type : Variable(shape: [batch_size, max_seq_len]) | |||
@param : auto_regressive | |||
@type : bool | |||
Args: | |||
input_mask (Variable(shape: [batch_size, max_seq_len])) | |||
auto_regressive(bool) | |||
""" | |||
seq_len = input_mask.shape[1] | |||
@@ -150,15 +147,12 @@ class UnifiedTransformer(SpaceModelBase): | |||
return mask | |||
def _join_mask(self, mask1, mask2): | |||
""" | |||
Merge source attention mask and target attention mask. | |||
"""Merge source attention mask and target attention mask. | |||
There are four parts:left upper (lu) / right upper (ru) / left below (lb) / right below (rb) | |||
@param : mask1 : source attention mask | |||
@type : Variable(shape: [batch_size, max_src_len, max_src_len]) | |||
@param : mask1 : target attention mask | |||
@type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len]) | |||
Args: | |||
mask1(Variable(shape: [batch_size, max_src_len, max_src_len])) : source attention mask | |||
mask2(Variable(shape: [batch_size, max_tgt_len, max_tgt_len])) : target attention mask | |||
""" | |||
batch_size = mask1.shape[0] | |||
seq_len1 = mask1.shape[1] | |||
@@ -30,18 +30,13 @@ class TransformerBlock(nn.Module): | |||
return | |||
def forward(self, inp, mask=None, cache=None): | |||
""" | |||
Forward process on one transformer layer. | |||
@param : x | |||
@type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||
@param : memory | |||
@type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||
@param : mask | |||
"""Forward process on one transformer layer. | |||
@param : cache | |||
Args: | |||
x(Variable(shape: [batch_size, seq_len, hidden_size])) | |||
memory(Variable(shape: [batch_size, seq_len, hidden_size])) | |||
mask | |||
cache | |||
""" | |||
attn_out = self.attn(inp, mask, cache) | |||
attn_out = self.dropout_layer(attn_out) | |||
@@ -1,101 +0,0 @@ | |||
from typing import Dict | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['SpaceForDialogStateTracking'] | |||
@MODELS.register_module( | |||
Tasks.task_oriented_conversation, module_name=Models.space_dst) | |||
class SpaceForDialogStateTracking(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the test generation model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
from modelscope.models.nlp.space.model import SpaceForDST, SpaceConfig | |||
self.model_dir = model_dir | |||
self.config = SpaceConfig.from_pretrained(self.model_dir) | |||
self.model = SpaceForDST.from_pretrained(self.model_dir) | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Tensor]): the preprocessed data | |||
Returns: | |||
Dict[str, Tensor]: results | |||
Example: | |||
{ | |||
'inputs': dict(input_ids, input_masks,start_pos), # tracking states | |||
'outputs': dict(slots_logits), | |||
'unique_ids': str(test-example.json-0), # default value | |||
'input_ids_unmasked': array([101, 7632, 1010,0,0,0]) | |||
'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||
'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||
'prefix': str('final'), #default value | |||
'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]) | |||
} | |||
""" | |||
import numpy as np | |||
import torch | |||
self.model.eval() | |||
batch = input['batch'] | |||
features = input['features'] | |||
diag_state = input['diag_state'] | |||
turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] | |||
reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] | |||
for slot in self.config.dst_slot_list: | |||
for i in reset_diag_state: | |||
diag_state[slot][i] = 0 | |||
with torch.no_grad(): | |||
inputs = { | |||
'input_ids': batch[0], | |||
'input_mask': batch[1], | |||
'segment_ids': batch[2], | |||
'start_pos': batch[3], | |||
'end_pos': batch[4], | |||
'inform_slot_id': batch[5], | |||
'refer_id': batch[6], | |||
'diag_state': diag_state, | |||
'class_label_id': batch[8] | |||
} | |||
unique_ids = [features[i.item()].guid for i in batch[9]] | |||
values = [features[i.item()].values for i in batch[9]] | |||
input_ids_unmasked = [ | |||
features[i.item()].input_ids_unmasked for i in batch[9] | |||
] | |||
inform = [features[i.item()].inform for i in batch[9]] | |||
outputs = self.model(**inputs) | |||
# Update dialog state for next turn. | |||
for slot in self.config.dst_slot_list: | |||
updates = outputs[2][slot].max(1)[1] | |||
for i, u in enumerate(updates): | |||
if u != 0: | |||
diag_state[slot][i] = u | |||
return { | |||
'inputs': inputs, | |||
'outputs': outputs, | |||
'unique_ids': unique_ids, | |||
'input_ids_unmasked': input_ids_unmasked, | |||
'values': values, | |||
'inform': inform, | |||
'prefix': 'final', | |||
'ds': input['ds'] | |||
} |
@@ -0,0 +1,21 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .table_question_answering import TableQuestionAnswering | |||
else: | |||
_import_structure = { | |||
'table_question_answering': ['TableQuestionAnswering'] | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -1,6 +1,6 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. | |||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -27,8 +27,7 @@ import numpy as np | |||
import torch | |||
from torch import nn | |||
from modelscope.models.nlp.space_T_cn.configuration_space_T_cn import \ | |||
SpaceTCnConfig | |||
from modelscope.models.nlp.space_T_cn.configuration import SpaceTCnConfig | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
@@ -11,11 +11,11 @@ from transformers import BertTokenizer | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Model, Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.preprocessors.space_T_cn.fields.struct import Constant | |||
from modelscope.preprocessors.nlp.space_T_cn.fields.struct import Constant | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.device import verify_device | |||
from .space_T_cn.configuration_space_T_cn import SpaceTCnConfig | |||
from .space_T_cn.modeling_space_T_cn import Seq2SQL, SpaceTCnModel | |||
from .backbone import Seq2SQL, SpaceTCnModel | |||
from .configuration import SpaceTCnConfig | |||
__all__ = ['TableQuestionAnswering'] | |||
@@ -732,9 +732,41 @@ class TableQuestionAnswering(Model): | |||
Args: | |||
input (Dict[str, Tensor]): the preprocessed data | |||
Returns: | |||
Dict[str, Tensor]: results | |||
Example: | |||
{ | |||
'result': | |||
{ | |||
'question_tok': ['有', '哪', '些', '风', '险', '类', '型', '?'], | |||
'question': '有哪些风险类型?', | |||
'table_id': 'fund', | |||
'sql': { | |||
'cond_conn_op': 0, | |||
'sel': [5], | |||
'agg': [0], | |||
'conds': [[10, 2, 'Nulll']] | |||
}, | |||
'action': 10, | |||
'model_out': [ | |||
[6, 0, 0, 0], | |||
[0, 0, 0, 0], | |||
[0, 0, 0, 0, 0, 0], | |||
[2, 0, 0, 0, 0, 0], | |||
[0, 0, 0, 0, 0, 0], | |||
[0, 0, 0, 0, 0, 0] | |||
] | |||
}, | |||
'history_sql': None | |||
} | |||
Example: | |||
>>> from modelscope.models.nlp import TableQuestionAnswering | |||
>>> from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||
>>> model = TableQuestionAnswering.from_pretrained('damo/nlp_convai_text2sql_pretrain_cn') | |||
>>> preprocessor = TableQuestionAnsweringPreprocessor(model_dir=model.model_dir) | |||
>>> print(model(preprocessor({'question': '有哪些风险类型?'}))) | |||
""" | |||
result = self.predict(input['datas'])[0] | |||
@@ -0,0 +1,21 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .text_to_sql import StarForTextToSql | |||
else: | |||
_import_structure = { | |||
'text_to_sql': ['StarForTextToSql'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -4,14 +4,13 @@ import os | |||
from typing import Dict, Optional | |||
import torch | |||
import torch.nn as nn | |||
from text2sql_lgesql.asdl.asdl import ASDLGrammar | |||
from text2sql_lgesql.asdl.transition_system import TransitionSystem | |||
from text2sql_lgesql.model.model_constructor import Text2SQL | |||
from text2sql_lgesql.utils.constants import GRAMMAR_FILEPATH | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Model, Tensor | |||
from modelscope.models import TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
@@ -21,7 +20,7 @@ __all__ = ['StarForTextToSql'] | |||
@MODELS.register_module( | |||
Tasks.table_question_answering, module_name=Models.space_T_en) | |||
class StarForTextToSql(Model): | |||
class StarForTextToSql(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the star model from the `model_dir` path. | |||
@@ -59,6 +58,33 @@ class StarForTextToSql(Model): | |||
Returns: | |||
Dict[str, Tensor]: results | |||
Example: | |||
Example: | |||
>>> from modelscope.hub.snapshot_download import snapshot_download | |||
>>> from modelscope.models.nlp import StarForTextToSql | |||
>>> from modelscope.preprocessors import ConversationalTextToSqlPreprocessor | |||
>>> test_case = { | |||
'database_id': 'employee_hire_evaluation', | |||
'local_db_path': None, | |||
'utterance': [ | |||
"I'd like to see Shop names.", 'Which of these are hiring?', | |||
'Which shop is hiring the highest number of employees?' | |||
' | do you want the name of the shop ? | Yes' | |||
] | |||
} | |||
>>> cache_path = snapshot_download('damo/nlp_star_conversational-text-to-sql') | |||
>>> preprocessor = ConversationalTextToSqlPreprocessor( | |||
model_dir=cache_path, | |||
database_id=test_case['database_id'], | |||
db_content=True) | |||
>>> model = StarForTextToSql(cache_path, config=preprocessor.config) | |||
>>> print(model(preprocessor({ | |||
'utterance': "I'd like to see Shop names.", | |||
'history': [], | |||
'last_sql': '', | |||
'database_id': 'employee_hire_evaluation', | |||
'local_db_path': None | |||
}))) | |||
""" | |||
self.model.eval() | |||
hyps = self.model.parse(input['batch'], self.beam_size) # |
@@ -18,20 +18,26 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .configuration_sbert import SbertConfig | |||
from .modeling_sbert import (SbertForMaskedLM, SbertModel, | |||
SbertPreTrainedModel) | |||
from .tokenization_sbert import (BasicTokenizer, SbertTokenizer, | |||
WordpieceTokenizer) | |||
from .tokenization_sbert_fast import SbertTokenizerFast | |||
from .backbone import (SbertModel, SbertPreTrainedModel) | |||
from .configuration import SbertConfig | |||
from .faq_question_answering import SbertForFaqQuestionAnswering | |||
from .fill_mask import SbertForMaskedLM | |||
from .text_classification import SbertForSequenceClassification | |||
from .token_classification import SbertForTokenClassification | |||
from .tokenization import (BasicTokenizer, SbertTokenizer, | |||
WordpieceTokenizer) | |||
from .tokenization_fast import SbertTokenizerFast | |||
else: | |||
_import_structure = { | |||
'configuration_sbert': ['SbertConfig'], | |||
'modeling_sbert': | |||
['SbertForMaskedLM', 'SbertModel', 'SbertPreTrainedModel'], | |||
'tokenization_sbert': | |||
'backbone': ['SbertModel', 'SbertPreTrainedModel'], | |||
'configuration': ['SbertConfig'], | |||
'fill_mask': ['SbertForMaskedLM'], | |||
'faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||
'text_classification': ['SbertForSequenceClassification'], | |||
'token_classification': ['SbertForTokenClassification'], | |||
'tokenization': | |||
['BasicTokenizer', 'SbertTokenizer', 'WordpieceTokenizer'], | |||
'tokenization_sbert_fast': ['SbertTokenizerFast'], | |||
'tokenization_fast': ['SbertTokenizerFast'], | |||
} | |||
import sys | |||
@@ -0,0 +1,932 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""PyTorch StructBERT model. mainly copied from :module:`~transformers.modeling_bert`""" | |||
import math | |||
from dataclasses import dataclass | |||
from typing import Optional, Tuple, Union | |||
import torch | |||
import torch.nn as nn | |||
import torch.utils.checkpoint | |||
from packaging import version | |||
from transformers.activations import ACT2FN | |||
from transformers.modeling_outputs import \ | |||
BaseModelOutputWithPastAndCrossAttentions | |||
from transformers.modeling_utils import (PreTrainedModel, | |||
apply_chunking_to_forward, | |||
find_pruneable_heads_and_indices, | |||
prune_linear_layer) | |||
from modelscope.metainfo import Models | |||
from modelscope.models import Model, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionBackboneModelOutput | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.hub import parse_label_mapping | |||
from modelscope.utils.logger import get_logger | |||
from .configuration import SbertConfig | |||
logger = get_logger(__name__) | |||
class SbertEmbeddings(nn.Module): | |||
"""Construct the embeddings from word, position and token_type embeddings.""" | |||
def __init__(self, config): | |||
super().__init__() | |||
self.word_embeddings = nn.Embedding( | |||
config.vocab_size, | |||
config.hidden_size, | |||
padding_idx=config.pad_token_id) | |||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, | |||
config.hidden_size) | |||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, | |||
config.hidden_size) | |||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||
# any TensorFlow checkpoint file | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized | |||
self.position_embedding_type = getattr(config, | |||
'position_embedding_type', | |||
'absolute') | |||
self.register_buffer( | |||
'position_ids', | |||
torch.arange(config.max_position_embeddings).expand((1, -1))) | |||
if version.parse(torch.__version__) > version.parse('1.6.0'): | |||
self.register_buffer( | |||
'token_type_ids', | |||
torch.zeros( | |||
self.position_ids.size(), | |||
dtype=torch.long, | |||
device=self.position_ids.device), | |||
persistent=False, | |||
) | |||
def forward(self, | |||
input_ids=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
inputs_embeds=None, | |||
past_key_values_length=0, | |||
return_inputs_embeds=False): | |||
if input_ids is not None: | |||
input_shape = input_ids.size() | |||
else: | |||
input_shape = inputs_embeds.size()[:-1] | |||
seq_length = input_shape[1] | |||
if position_ids is None: | |||
position_ids = self.position_ids[:, | |||
past_key_values_length:seq_length | |||
+ past_key_values_length] | |||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs | |||
# when its auto-generated, registered buffer helps users | |||
# when tracing the model without passing token_type_ids, solves | |||
# issue #5664 | |||
if token_type_ids is None: | |||
if hasattr(self, 'token_type_ids'): | |||
buffered_token_type_ids = self.token_type_ids[:, :seq_length] | |||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( | |||
input_shape[0], seq_length) | |||
token_type_ids = buffered_token_type_ids_expanded | |||
else: | |||
token_type_ids = torch.zeros( | |||
input_shape, | |||
dtype=torch.long, | |||
device=self.position_ids.device) | |||
if inputs_embeds is None: | |||
inputs_embeds = self.word_embeddings(input_ids) | |||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||
embeddings = inputs_embeds + token_type_embeddings | |||
if self.position_embedding_type == 'absolute': | |||
position_embeddings = self.position_embeddings(position_ids) | |||
embeddings += position_embeddings | |||
embeddings = self.LayerNorm(embeddings) | |||
embeddings = self.dropout(embeddings) | |||
if not return_inputs_embeds: | |||
return embeddings | |||
else: | |||
return embeddings, inputs_embeds | |||
class SbertSelfAttention(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr( | |||
config, 'embedding_size'): | |||
raise ValueError( | |||
f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' | |||
f'heads ({config.num_attention_heads})') | |||
self.num_attention_heads = config.num_attention_heads | |||
self.attention_head_size = int(config.hidden_size | |||
/ config.num_attention_heads) | |||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
self.query = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.key = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.value = nn.Linear(config.hidden_size, self.all_head_size) | |||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | |||
self.position_embedding_type = getattr(config, | |||
'position_embedding_type', | |||
'absolute') | |||
if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': | |||
self.max_position_embeddings = config.max_position_embeddings | |||
self.distance_embedding = nn.Embedding( | |||
2 * config.max_position_embeddings - 1, | |||
self.attention_head_size) | |||
self.is_decoder = config.is_decoder | |||
def transpose_for_scores(self, x): | |||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, | |||
self.attention_head_size) | |||
x = x.view(*new_x_shape) | |||
return x.permute(0, 2, 1, 3) | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_value=None, | |||
output_attentions=False, | |||
): | |||
mixed_query_layer = self.query(hidden_states) | |||
# If this is instantiated as a cross-attention module, the keys | |||
# and values come from an encoder; the attention mask needs to be | |||
# such that the encoder's padding tokens are not attended to. | |||
is_cross_attention = encoder_hidden_states is not None | |||
if is_cross_attention and past_key_value is not None: | |||
# reuse k,v, cross_attentions | |||
key_layer = past_key_value[0] | |||
value_layer = past_key_value[1] | |||
attention_mask = encoder_attention_mask | |||
elif is_cross_attention: | |||
key_layer = self.transpose_for_scores( | |||
self.key(encoder_hidden_states)) | |||
value_layer = self.transpose_for_scores( | |||
self.value(encoder_hidden_states)) | |||
attention_mask = encoder_attention_mask | |||
elif past_key_value is not None: | |||
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |||
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) | |||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2) | |||
else: | |||
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |||
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||
if self.is_decoder: | |||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | |||
# Further calls to cross_attention layer can then reuse all cross-attention | |||
# key/value_states (first "if" case) | |||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of | |||
# all previous decoder key/value_states. Further calls to uni-directional self-attention | |||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) | |||
# if encoder bi-directional self-attention `past_key_value` is always `None` | |||
past_key_value = (key_layer, value_layer) | |||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||
attention_scores = torch.matmul(query_layer, | |||
key_layer.transpose(-1, -2)) | |||
if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': | |||
seq_length = hidden_states.size()[1] | |||
position_ids_l = torch.arange( | |||
seq_length, dtype=torch.long, | |||
device=hidden_states.device).view(-1, 1) | |||
position_ids_r = torch.arange( | |||
seq_length, dtype=torch.long, | |||
device=hidden_states.device).view(1, -1) | |||
distance = position_ids_l - position_ids_r | |||
positional_embedding = self.distance_embedding( | |||
distance + self.max_position_embeddings - 1) | |||
positional_embedding = positional_embedding.to( | |||
dtype=query_layer.dtype) # fp16 compatibility | |||
if self.position_embedding_type == 'relative_key': | |||
relative_position_scores = torch.einsum( | |||
'bhld,lrd->bhlr', query_layer, positional_embedding) | |||
attention_scores = attention_scores + relative_position_scores | |||
elif self.position_embedding_type == 'relative_key_query': | |||
relative_position_scores_query = torch.einsum( | |||
'bhld,lrd->bhlr', query_layer, positional_embedding) | |||
relative_position_scores_key = torch.einsum( | |||
'bhrd,lrd->bhlr', key_layer, positional_embedding) | |||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key | |||
attention_scores = attention_scores / math.sqrt( | |||
self.attention_head_size) | |||
if attention_mask is not None: | |||
# Apply the attention mask is (precomputed for all layers in SbertModel forward() function) | |||
attention_scores = attention_scores + attention_mask | |||
# Normalize the attention scores to probabilities. | |||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||
# This is actually dropping out entire tokens to attend to, which might | |||
# seem a bit unusual, but is taken from the original Transformer paper. | |||
attention_probs = self.dropout(attention_probs) | |||
# Mask heads if we want to | |||
if head_mask is not None: | |||
attention_probs = attention_probs * head_mask | |||
context_layer = torch.matmul(attention_probs, value_layer) | |||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
new_context_layer_shape = context_layer.size()[:-2] + ( | |||
self.all_head_size, ) | |||
context_layer = context_layer.view(*new_context_layer_shape) | |||
outputs = (context_layer, | |||
attention_probs) if output_attentions else (context_layer, ) | |||
if self.is_decoder: | |||
outputs = outputs + (past_key_value, ) | |||
return outputs | |||
class SbertSelfOutput(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class SbertAttention(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.self = SbertSelfAttention(config) | |||
self.output = SbertSelfOutput(config) | |||
self.pruned_heads = set() | |||
def prune_heads(self, heads): | |||
if len(heads) == 0: | |||
return | |||
heads, index = find_pruneable_heads_and_indices( | |||
heads, self.self.num_attention_heads, | |||
self.self.attention_head_size, self.pruned_heads) | |||
# Prune linear layers | |||
self.self.query = prune_linear_layer(self.self.query, index) | |||
self.self.key = prune_linear_layer(self.self.key, index) | |||
self.self.value = prune_linear_layer(self.self.value, index) | |||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) | |||
# Update hyper params and store pruned heads | |||
self.self.num_attention_heads = self.self.num_attention_heads - len( | |||
heads) | |||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads | |||
self.pruned_heads = self.pruned_heads.union(heads) | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_value=None, | |||
output_attentions=False, | |||
): | |||
self_outputs = self.self( | |||
hidden_states, | |||
attention_mask, | |||
head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
past_key_value, | |||
output_attentions, | |||
) | |||
attention_output = self.output(self_outputs[0], hidden_states) | |||
outputs = (attention_output, | |||
) + self_outputs[1:] # add attentions if we output them | |||
return outputs | |||
class SbertIntermediate(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |||
if isinstance(config.hidden_act, str): | |||
self.intermediate_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.intermediate_act_fn = config.hidden_act | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.intermediate_act_fn(hidden_states) | |||
return hidden_states | |||
class SbertOutput(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class SbertLayer(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.chunk_size_feed_forward = config.chunk_size_feed_forward | |||
self.seq_len_dim = 1 | |||
self.attention = SbertAttention(config) | |||
self.is_decoder = config.is_decoder | |||
self.add_cross_attention = config.add_cross_attention | |||
if self.add_cross_attention: | |||
if not self.is_decoder: | |||
raise ValueError( | |||
f'{self} should be used as a decoder model if cross attention is added' | |||
) | |||
self.crossattention = SbertAttention(config) | |||
self.intermediate = SbertIntermediate(config) | |||
self.output = SbertOutput(config) | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_value=None, | |||
output_attentions=False, | |||
): | |||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 | |||
self_attn_past_key_value = past_key_value[: | |||
2] if past_key_value is not None else None | |||
self_attention_outputs = self.attention( | |||
hidden_states, | |||
attention_mask, | |||
head_mask, | |||
output_attentions=output_attentions, | |||
past_key_value=self_attn_past_key_value, | |||
) | |||
attention_output = self_attention_outputs[0] | |||
# if decoder, the last output is tuple of self-attn cache | |||
if self.is_decoder: | |||
outputs = self_attention_outputs[1:-1] | |||
present_key_value = self_attention_outputs[-1] | |||
else: | |||
outputs = self_attention_outputs[ | |||
1:] # add self attentions if we output attention weights | |||
cross_attn_present_key_value = None | |||
if self.is_decoder and encoder_hidden_states is not None: | |||
if not hasattr(self, 'crossattention'): | |||
raise ValueError( | |||
f'If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention ' | |||
f'layers by setting `config.add_cross_attention=True`') | |||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple | |||
cross_attn_past_key_value = past_key_value[ | |||
-2:] if past_key_value is not None else None | |||
cross_attention_outputs = self.crossattention( | |||
attention_output, | |||
attention_mask, | |||
head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
cross_attn_past_key_value, | |||
output_attentions, | |||
) | |||
attention_output = cross_attention_outputs[0] | |||
outputs = outputs + cross_attention_outputs[ | |||
1:-1] # add cross attentions if we output attention weights | |||
# add cross-attn cache to positions 3,4 of present_key_value tuple | |||
cross_attn_present_key_value = cross_attention_outputs[-1] | |||
present_key_value = present_key_value + cross_attn_present_key_value | |||
layer_output = apply_chunking_to_forward(self.feed_forward_chunk, | |||
self.chunk_size_feed_forward, | |||
self.seq_len_dim, | |||
attention_output) | |||
outputs = (layer_output, ) + outputs | |||
# if decoder, return the attn key/values as the last output | |||
if self.is_decoder: | |||
outputs = outputs + (present_key_value, ) | |||
return outputs | |||
def feed_forward_chunk(self, attention_output): | |||
intermediate_output = self.intermediate(attention_output) | |||
layer_output = self.output(intermediate_output, attention_output) | |||
return layer_output | |||
class SbertEncoder(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.config = config | |||
self.layer = nn.ModuleList( | |||
[SbertLayer(config) for _ in range(config.num_hidden_layers)]) | |||
self.gradient_checkpointing = False | |||
def forward( | |||
self, | |||
hidden_states, | |||
attention_mask=None, | |||
head_mask=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_values=None, | |||
use_cache=None, | |||
output_attentions=False, | |||
output_hidden_states=False, | |||
return_dict=True, | |||
): | |||
all_hidden_states = () if output_hidden_states else None | |||
all_self_attentions = () if output_attentions else None | |||
all_cross_attentions = ( | |||
) if output_attentions and self.config.add_cross_attention else None | |||
next_decoder_cache = () if use_cache else None | |||
for i, layer_module in enumerate(self.layer): | |||
if output_hidden_states: | |||
all_hidden_states = all_hidden_states + (hidden_states, ) | |||
layer_head_mask = head_mask[i] if head_mask is not None else None | |||
past_key_value = past_key_values[ | |||
i] if past_key_values is not None else None | |||
if self.gradient_checkpointing and self.training: | |||
if use_cache: | |||
logger.warning( | |||
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' | |||
) | |||
use_cache = False | |||
def create_custom_forward(module): | |||
def custom_forward(*inputs): | |||
return module(*inputs, past_key_value, | |||
output_attentions) | |||
return custom_forward | |||
layer_outputs = torch.utils.checkpoint.checkpoint( | |||
create_custom_forward(layer_module), | |||
hidden_states, | |||
attention_mask, | |||
layer_head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
) | |||
else: | |||
layer_outputs = layer_module( | |||
hidden_states, | |||
attention_mask, | |||
layer_head_mask, | |||
encoder_hidden_states, | |||
encoder_attention_mask, | |||
past_key_value, | |||
output_attentions, | |||
) | |||
hidden_states = layer_outputs[0] | |||
if use_cache: | |||
next_decoder_cache += (layer_outputs[-1], ) | |||
if output_attentions: | |||
all_self_attentions = all_self_attentions + ( | |||
layer_outputs[1], ) | |||
if self.config.add_cross_attention: | |||
all_cross_attentions = all_cross_attentions + ( | |||
layer_outputs[2], ) | |||
if output_hidden_states: | |||
all_hidden_states = all_hidden_states + (hidden_states, ) | |||
if not return_dict: | |||
return tuple(v for v in [ | |||
hidden_states, | |||
next_decoder_cache, | |||
all_hidden_states, | |||
all_self_attentions, | |||
all_cross_attentions, | |||
] if v is not None) | |||
return BaseModelOutputWithPastAndCrossAttentions( | |||
last_hidden_state=hidden_states, | |||
past_key_values=next_decoder_cache, | |||
hidden_states=all_hidden_states, | |||
attentions=all_self_attentions, | |||
cross_attentions=all_cross_attentions, | |||
) | |||
class SbertPooler(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
self.activation = nn.Tanh() | |||
def forward(self, hidden_states): | |||
# We "pool" the model by simply taking the hidden state corresponding | |||
# to the first token. | |||
first_token_tensor = hidden_states[:, 0] | |||
pooled_output = self.dense(first_token_tensor) | |||
pooled_output = self.activation(pooled_output) | |||
return pooled_output | |||
class SbertPreTrainedModel(TorchModel, PreTrainedModel): | |||
""" | |||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |||
models. | |||
""" | |||
config_class = SbertConfig | |||
base_model_prefix = 'bert' | |||
supports_gradient_checkpointing = True | |||
_keys_to_ignore_on_load_missing = [r'position_ids'] | |||
def __init__(self, config, **kwargs): | |||
super().__init__(config.name_or_path, **kwargs) | |||
super(Model, self).__init__(config) | |||
def _init_weights(self, module): | |||
"""Initialize the weights""" | |||
if isinstance(module, nn.Linear): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.bias is not None: | |||
module.bias.data.zero_() | |||
elif isinstance(module, nn.Embedding): | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.padding_idx is not None: | |||
module.weight.data[module.padding_idx].zero_() | |||
elif isinstance(module, nn.LayerNorm): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
def _set_gradient_checkpointing(self, module, value=False): | |||
if isinstance(module, SbertEncoder): | |||
module.gradient_checkpointing = value | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
"""Instantiate the model. | |||
Args: | |||
kwargs: Input args. | |||
model_dir: The model dir used to load the checkpoint and the label information. | |||
num_labels: An optional arg to tell the model how many classes to initialize. | |||
Method will call utils.parse_label_mapping if num_labels is not input. | |||
label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). | |||
Returns: | |||
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
model_dir = kwargs.pop('model_dir', None) | |||
if model_dir is None: | |||
config = SbertConfig(**kwargs) | |||
model = cls(config) | |||
else: | |||
model_kwargs = {} | |||
label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) | |||
id2label = kwargs.get( | |||
'id2label', None if label2id is None else | |||
{id: label | |||
for label, id in label2id.items()}) | |||
if id2label is not None and label2id is None: | |||
label2id = {label: id for id, label in id2label.items()} | |||
num_labels = kwargs.get( | |||
'num_labels', None if label2id is None else len(label2id)) | |||
if num_labels is not None: | |||
model_kwargs['num_labels'] = num_labels | |||
if label2id is not None: | |||
model_kwargs['label2id'] = label2id | |||
if id2label is not None: | |||
model_kwargs['id2label'] = id2label | |||
model = super(Model, cls).from_pretrained( | |||
pretrained_model_name_or_path=model_dir, **model_kwargs) | |||
return model | |||
@dataclass | |||
class AttentionBackboneModelOutputWithEmbedding(AttentionBackboneModelOutput): | |||
embedding_output: torch.FloatTensor = None | |||
logits: Optional[Union[tuple, torch.FloatTensor]] = None | |||
kwargs: dict = None | |||
@MODELS.register_module(Tasks.backbone, module_name=Models.structbert) | |||
class SbertModel(SbertPreTrainedModel): | |||
"""The StructBERT Model transformer outputting raw hidden-states without any specific head on top. | |||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |||
pruning heads etc.) | |||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | |||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | |||
general usage and behavior. | |||
Parameters: | |||
config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with | |||
all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |||
weights. | |||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of | |||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is | |||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, | |||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. | |||
To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration | |||
set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` | |||
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an | |||
input to the forward pass. | |||
""" | |||
def __init__(self, config: SbertConfig, add_pooling_layer=True, **kwargs): | |||
super().__init__(config) | |||
self.config = config | |||
self.embeddings = SbertEmbeddings(config) | |||
self.encoder = SbertEncoder(config) | |||
self.pooler = SbertPooler(config) if add_pooling_layer else None | |||
self.init_weights() | |||
def get_input_embeddings(self): | |||
return self.embeddings.word_embeddings | |||
def set_input_embeddings(self, value): | |||
self.embeddings.word_embeddings = value | |||
def _prune_heads(self, heads_to_prune): | |||
""" | |||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base | |||
class PreTrainedModel | |||
""" | |||
for layer, heads in heads_to_prune.items(): | |||
self.encoder.layer[layer].attention.prune_heads(heads) | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
past_key_values=None, | |||
use_cache=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
**kwargs): | |||
r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. | |||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, | |||
`optional`): | |||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if | |||
the model is configured as a decoder. | |||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in | |||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple | |||
having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): | |||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. | |||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` | |||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` | |||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. | |||
use_cache (:obj:`bool`, `optional`): | |||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up | |||
decoding (see :obj:`past_key_values`). | |||
Returns: | |||
Returns `modelscope.outputs.AttentionBackboneModelOutputWithEmbedding` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='backbone') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_backbone_base_std') | |||
>>> print(model(**preprocessor('这是个测试'))) | |||
""" | |||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |||
output_hidden_states = ( | |||
output_hidden_states if output_hidden_states is not None else | |||
self.config.output_hidden_states) | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
if self.config.is_decoder: | |||
use_cache = use_cache if use_cache is not None else self.config.use_cache | |||
else: | |||
use_cache = False | |||
if input_ids is not None and inputs_embeds is not None: | |||
raise ValueError( | |||
'You cannot specify both input_ids and inputs_embeds at the same time' | |||
) | |||
elif input_ids is not None: | |||
input_shape = input_ids.size() | |||
elif inputs_embeds is not None: | |||
input_shape = inputs_embeds.size()[:-1] | |||
else: | |||
raise ValueError( | |||
'You have to specify either input_ids or inputs_embeds') | |||
batch_size, seq_length = input_shape | |||
device = input_ids.device if input_ids is not None else inputs_embeds.device | |||
# past_key_values_length | |||
past_key_values_length = past_key_values[0][0].shape[ | |||
2] if past_key_values is not None else 0 | |||
if attention_mask is None: | |||
attention_mask = torch.ones( | |||
((batch_size, seq_length + past_key_values_length)), | |||
device=device) | |||
if token_type_ids is None: | |||
if hasattr(self.embeddings, 'token_type_ids'): | |||
buffered_token_type_ids = self.embeddings.token_type_ids[:, : | |||
seq_length] | |||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( | |||
batch_size, seq_length) | |||
token_type_ids = buffered_token_type_ids_expanded | |||
else: | |||
token_type_ids = torch.zeros( | |||
input_shape, dtype=torch.long, device=device) | |||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |||
# ourselves in which case we just need to make it broadcastable to all heads. | |||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( | |||
attention_mask, input_shape, device) | |||
# If a 2D or 3D attention mask is provided for the cross-attention | |||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | |||
if self.config.is_decoder and encoder_hidden_states is not None: | |||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( | |||
) | |||
encoder_hidden_shape = (encoder_batch_size, | |||
encoder_sequence_length) | |||
if encoder_attention_mask is None: | |||
encoder_attention_mask = torch.ones( | |||
encoder_hidden_shape, device=device) | |||
encoder_extended_attention_mask = self.invert_attention_mask( | |||
encoder_attention_mask) | |||
else: | |||
encoder_extended_attention_mask = None | |||
# Prepare head mask if needed | |||
# 1.0 in head_mask indicate we keep the head | |||
# attention_probs has shape bsz x n_heads x N x N | |||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] | |||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] | |||
head_mask = self.get_head_mask(head_mask, | |||
self.config.num_hidden_layers) | |||
embedding_output, orignal_embeds = self.embeddings( | |||
input_ids=input_ids, | |||
position_ids=position_ids, | |||
token_type_ids=token_type_ids, | |||
inputs_embeds=inputs_embeds, | |||
past_key_values_length=past_key_values_length, | |||
return_inputs_embeds=True, | |||
) | |||
encoder_outputs = self.encoder( | |||
embedding_output, | |||
attention_mask=extended_attention_mask, | |||
head_mask=head_mask, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_extended_attention_mask, | |||
past_key_values=past_key_values, | |||
use_cache=use_cache, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = encoder_outputs[0] | |||
pooled_output = self.pooler( | |||
sequence_output) if self.pooler is not None else None | |||
if not return_dict: | |||
return (sequence_output, | |||
pooled_output) + encoder_outputs[1:] + (orignal_embeds, ) | |||
return AttentionBackboneModelOutputWithEmbedding( | |||
last_hidden_state=sequence_output, | |||
pooler_output=pooled_output, | |||
past_key_values=encoder_outputs.past_key_values, | |||
hidden_states=encoder_outputs.hidden_states, | |||
attentions=encoder_outputs.attentions, | |||
cross_attentions=encoder_outputs.cross_attentions, | |||
embedding_output=orignal_embeds) |
@@ -14,7 +14,7 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" SBERT model configuration, mainly copied from :class:`~transformers.BertConfig` """ | |||
""" StructBERT model configuration, mainly copied from :class:`~transformers.BertConfig` """ | |||
from transformers import PretrainedConfig | |||
from modelscope.utils import logger as logging | |||
@@ -26,7 +26,7 @@ class SbertConfig(PretrainedConfig): | |||
r""" | |||
This is the configuration class to store the configuration | |||
of a :class:`~modelscope.models.nlp.structbert.SbertModel`. | |||
It is used to instantiate a SBERT model according to the specified arguments. | |||
It is used to instantiate a StructBERT model according to the specified arguments. | |||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model | |||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. | |||
@@ -74,15 +74,15 @@ class SbertConfig(PretrainedConfig): | |||
relevant if ``config.is_decoder=True``. | |||
classifier_dropout (:obj:`float`, `optional`): | |||
The dropout ratio for the classification head. | |||
adv_grad_factor (:obj:`float`, `optional`): This factor will be multipled by the KL loss grad and then | |||
adv_grad_factor (:obj:`float`, `optional`): This factor will be multiplied by the KL loss grad and then | |||
the result will be added to the original embedding. | |||
More details please check:https://arxiv.org/abs/1908.04577 | |||
The range of this value always be 1e-3~1e-7 | |||
The range of this value should between 1e-3~1e-7 | |||
adv_bound (:obj:`float`, `optional`): adv_bound is used to cut the top and the bottom bound of | |||
the produced embedding. | |||
If not proveded, 2 * sigma will be used as the adv_bound factor | |||
If not provided, 2 * sigma will be used as the adv_bound factor | |||
sigma (:obj:`float`, `optional`): The std factor used to produce a 0 mean normal distribution. | |||
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | |||
If adv_bound not provided, 2 * sigma will be used as the adv_bound factor | |||
""" | |||
model_type = 'structbert' |
@@ -1,3 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import math | |||
import os | |||
from collections import namedtuple | |||
@@ -15,103 +17,6 @@ from modelscope.models.nlp.task_models.task_model import BaseTaskModel | |||
from modelscope.utils.config import Config, ConfigFields | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
__all__ = ['SbertForFaqQuestionAnswering'] | |||
class SbertForFaqQuestionAnsweringBase(BaseTaskModel): | |||
"""base class for faq models | |||
""" | |||
def __init__(self, model_dir, *args, **kwargs): | |||
super(SbertForFaqQuestionAnsweringBase, | |||
self).__init__(model_dir, *args, **kwargs) | |||
backbone_cfg = SbertConfig.from_pretrained(model_dir) | |||
self.bert = SbertModel(backbone_cfg) | |||
model_config = Config.from_file( | |||
os.path.join(model_dir, | |||
ModelFile.CONFIGURATION)).get(ConfigFields.model, {}) | |||
metric = model_config.get('metric', 'cosine') | |||
pooling_method = model_config.get('pooling', 'avg') | |||
Arg = namedtuple('args', [ | |||
'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling' | |||
]) | |||
args = Arg( | |||
metrics=metric, | |||
proj_hidden_size=self.bert.config.hidden_size, | |||
hidden_size=self.bert.config.hidden_size, | |||
dropout=0.0, | |||
pooling=pooling_method) | |||
self.metrics_layer = MetricsLayer(args) | |||
self.pooling = PoolingLayer(args) | |||
def _get_onehot_labels(self, labels, support_size, num_cls): | |||
labels_ = labels.view(support_size, 1) | |||
target_oh = torch.zeros(support_size, num_cls).to(labels) | |||
target_oh.scatter_(dim=1, index=labels_, value=1) | |||
return target_oh.view(support_size, num_cls).float() | |||
def forward_sentence_embedding(self, inputs: Dict[str, Tensor]): | |||
input_ids = inputs['input_ids'] | |||
input_mask = inputs['attention_mask'] | |||
if not isinstance(input_ids, Tensor): | |||
input_ids = torch.IntTensor(input_ids) | |||
if not isinstance(input_mask, Tensor): | |||
input_mask = torch.IntTensor(input_mask) | |||
rst = self.bert(input_ids, input_mask) | |||
last_hidden_states = rst.last_hidden_state | |||
if len(input_mask.shape) == 2: | |||
input_mask = input_mask.unsqueeze(-1) | |||
pooled_representation = self.pooling(last_hidden_states, input_mask) | |||
return pooled_representation | |||
@MODELS.register_module( | |||
Tasks.faq_question_answering, module_name=Models.structbert) | |||
class SbertForFaqQuestionAnswering(SbertForFaqQuestionAnsweringBase): | |||
_backbone_prefix = '' | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
assert not self.training | |||
query = input['query'] | |||
support = input['support'] | |||
if isinstance(query, list): | |||
query = torch.stack(query) | |||
if isinstance(support, list): | |||
support = torch.stack(support) | |||
n_query = query.shape[0] | |||
n_support = support.shape[0] | |||
query_mask = torch.ne(query, 0).view([n_query, -1]) | |||
support_mask = torch.ne(support, 0).view([n_support, -1]) | |||
support_labels = input['support_labels'] | |||
num_cls = torch.max(support_labels) + 1 | |||
onehot_labels = self._get_onehot_labels(support_labels, n_support, | |||
num_cls) | |||
input_ids = torch.cat([query, support]) | |||
input_mask = torch.cat([query_mask, support_mask], dim=0) | |||
pooled_representation = self.forward_sentence_embedding({ | |||
'input_ids': | |||
input_ids, | |||
'attention_mask': | |||
input_mask | |||
}) | |||
z_query = pooled_representation[:n_query] | |||
z_support = pooled_representation[n_query:] | |||
cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5 | |||
protos = torch.matmul(onehot_labels.transpose(0, 1), | |||
z_support) / cls_n_support.unsqueeze(-1) | |||
scores = self.metrics_layer(z_query, protos).view([n_query, num_cls]) | |||
if self.metrics_layer.name == 'relation': | |||
scores = torch.sigmoid(scores) | |||
return {'scores': scores} | |||
activations = { | |||
'relu': F.relu, | |||
'tanh': torch.tanh, | |||
@@ -247,3 +152,142 @@ class PoolingLayer(nn.Module): | |||
def forward(self, x, mask): | |||
return self.pooling(x, mask) | |||
@MODELS.register_module( | |||
Tasks.faq_question_answering, module_name=Models.structbert) | |||
class SbertForFaqQuestionAnswering(BaseTaskModel): | |||
_backbone_prefix = '' | |||
@classmethod | |||
def _instantiate(cls, **kwargs): | |||
model = cls(kwargs.get('model_dir')) | |||
model.load_checkpoint(kwargs.get('model_dir')) | |||
return model | |||
def __init__(self, model_dir, *args, **kwargs): | |||
super().__init__(model_dir, *args, **kwargs) | |||
backbone_cfg = SbertConfig.from_pretrained(model_dir) | |||
self.bert = SbertModel(backbone_cfg) | |||
model_config = Config.from_file( | |||
os.path.join(model_dir, | |||
ModelFile.CONFIGURATION)).get(ConfigFields.model, {}) | |||
metric = model_config.get('metric', 'cosine') | |||
pooling_method = model_config.get('pooling', 'avg') | |||
Arg = namedtuple('args', [ | |||
'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling' | |||
]) | |||
args = Arg( | |||
metrics=metric, | |||
proj_hidden_size=self.bert.config.hidden_size, | |||
hidden_size=self.bert.config.hidden_size, | |||
dropout=0.0, | |||
pooling=pooling_method) | |||
self.metrics_layer = MetricsLayer(args) | |||
self.pooling = PoolingLayer(args) | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
""" | |||
Args: | |||
input (Dict[str, Tensor]): the preprocessed data, it contains the following keys: | |||
query(:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |||
The query to be predicted. | |||
support(:obj:`torch.LongTensor` of shape :obj:`(support_size, sequence_length)`): | |||
The support set. | |||
support_label(:obj:`torch.LongTensor` of shape :obj:`(support_size, )`): | |||
The labels of support set. | |||
Returns: | |||
Dict[str, Tensor]: result, it contains the following key: | |||
scores(:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_cls)`): | |||
Predicted scores of all classes for each query. | |||
Examples: | |||
>>> from modelscope.hub.snapshot_download import snapshot_download | |||
>>> from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor | |||
>>> from modelscope.models.nlp import SbertForFaqQuestionAnswering | |||
>>> cache_path = snapshot_download('damo/nlp_structbert_faq-question-answering_chinese-base') | |||
>>> preprocessor = FaqQuestionAnsweringPreprocessor.from_pretrained(cache_path) | |||
>>> model = SbertForFaqQuestionAnswering.from_pretrained(cache_path) | |||
>>> param = { | |||
>>> 'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'], | |||
>>> 'support_set': [{ | |||
>>> 'text': '卖品代金券怎么用', | |||
>>> 'label': '6527856' | |||
>>> }, { | |||
>>> 'text': '怎么使用优惠券', | |||
>>> 'label': '6527856' | |||
>>> }, { | |||
>>> 'text': '这个可以一起领吗', | |||
>>> 'label': '1000012000' | |||
>>> }, { | |||
>>> 'text': '付款时送的优惠券哪里领', | |||
>>> 'label': '1000012000' | |||
>>> }, { | |||
>>> 'text': '购物等级怎么长', | |||
>>> 'label': '13421097' | |||
>>> }, { | |||
>>> 'text': '购物等级二心', | |||
>>> 'label': '13421097' | |||
>>> }] | |||
>>> } | |||
>>> result = model(preprocessor(param)) | |||
""" | |||
assert not self.training | |||
query = input['query'] | |||
support = input['support'] | |||
if isinstance(query, list): | |||
query = torch.stack(query) | |||
if isinstance(support, list): | |||
support = torch.stack(support) | |||
n_query = query.shape[0] | |||
n_support = support.shape[0] | |||
query_mask = torch.ne(query, 0).view([n_query, -1]) | |||
support_mask = torch.ne(support, 0).view([n_support, -1]) | |||
support_labels = input['support_labels'] | |||
num_cls = torch.max(support_labels) + 1 | |||
onehot_labels = self._get_onehot_labels(support_labels, n_support, | |||
num_cls) | |||
input_ids = torch.cat([query, support]) | |||
input_mask = torch.cat([query_mask, support_mask], dim=0) | |||
pooled_representation = self.forward_sentence_embedding({ | |||
'input_ids': | |||
input_ids, | |||
'attention_mask': | |||
input_mask | |||
}) | |||
z_query = pooled_representation[:n_query] | |||
z_support = pooled_representation[n_query:] | |||
cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5 | |||
protos = torch.matmul(onehot_labels.transpose(0, 1), | |||
z_support) / cls_n_support.unsqueeze(-1) | |||
scores = self.metrics_layer(z_query, protos).view([n_query, num_cls]) | |||
if self.metrics_layer.name == 'relation': | |||
scores = torch.sigmoid(scores) | |||
return {'scores': scores} | |||
def _get_onehot_labels(self, labels, support_size, num_cls): | |||
labels_ = labels.view(support_size, 1) | |||
target_oh = torch.zeros(support_size, num_cls).to(labels) | |||
target_oh.scatter_(dim=1, index=labels_, value=1) | |||
return target_oh.view(support_size, num_cls).float() | |||
def forward_sentence_embedding(self, inputs: Dict[str, Tensor]): | |||
input_ids = inputs['input_ids'] | |||
input_mask = inputs['attention_mask'] | |||
if not isinstance(input_ids, Tensor): | |||
input_ids = torch.IntTensor(input_ids) | |||
if not isinstance(input_mask, Tensor): | |||
input_mask = torch.IntTensor(input_mask) | |||
rst = self.bert(input_ids, input_mask) | |||
last_hidden_states = rst.last_hidden_state | |||
if len(input_mask.shape) == 2: | |||
input_mask = input_mask.unsqueeze(-1) | |||
pooled_representation = self.pooling(last_hidden_states, input_mask) | |||
return pooled_representation |
@@ -0,0 +1,284 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
import torch.nn as nn | |||
import torch.utils.checkpoint | |||
from torch.nn import CrossEntropyLoss | |||
from transformers.activations import ACT2FN | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionFillMaskModelOutput | |||
from modelscope.utils import logger as logging | |||
from modelscope.utils.constant import Tasks | |||
from .backbone import SbertModel, SbertPreTrainedModel | |||
from .configuration import SbertConfig | |||
logger = logging.get_logger(__name__) | |||
class SbertPredictionHeadTransform(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
if isinstance(config.hidden_act, str): | |||
self.transform_act_fn = ACT2FN[config.hidden_act] | |||
else: | |||
self.transform_act_fn = config.hidden_act | |||
self.LayerNorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layer_norm_eps) | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.transform_act_fn(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states) | |||
return hidden_states | |||
class SbertLMPredictionHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.transform = SbertPredictionHeadTransform(config) | |||
# The output weights are the same as the input embeddings, but there is | |||
# an output-only bias for each token. | |||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size) | |||
def forward(self, hidden_states): | |||
hidden_states = self.transform(hidden_states) | |||
hidden_states = self.decoder(hidden_states) | |||
return hidden_states | |||
class SbertOnlyMLMHead(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.predictions = SbertLMPredictionHead(config) | |||
def forward(self, sequence_output): | |||
prediction_scores = self.predictions(sequence_output) | |||
return prediction_scores | |||
class SbertPreTrainingHeads(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.predictions = SbertLMPredictionHead(config) | |||
self.seq_relationship = nn.Linear(config.hidden_size, 2) | |||
def forward(self, sequence_output, pooled_output): | |||
prediction_scores = self.predictions(sequence_output) | |||
seq_relationship_score = self.seq_relationship(pooled_output) | |||
return prediction_scores, seq_relationship_score | |||
@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) | |||
class SbertForMaskedLM(SbertPreTrainedModel): | |||
r"""StructBERT Model with a `language modeling` head on top. | |||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |||
pruning heads etc.) | |||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | |||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | |||
general usage and behavior. | |||
Preprocessor: | |||
This is the fill_mask model of StructBERT, the preprocessor of this model | |||
is `modelscope.preprocessors.NLPPreprocessor`. | |||
Parameters: | |||
config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with | |||
all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |||
weights. | |||
""" | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
_keys_to_ignore_on_load_missing = [ | |||
r'position_ids', r'predictions.decoder.bias' | |||
] | |||
def __init__(self, config: SbertConfig, **kwargs): | |||
super().__init__(config) | |||
if config.is_decoder: | |||
logger.warning( | |||
'If you want to use `SbertForMaskedLM` make sure `config.is_decoder=False` for ' | |||
'bi-directional self-attention.') | |||
self.bert = SbertModel(config) | |||
self.cls = SbertOnlyMLMHead(config) | |||
self.init_weights() | |||
def get_output_embeddings(self): | |||
return self.cls.predictions.decoder | |||
def set_output_embeddings(self, new_embeddings): | |||
self.cls.predictions.decoder = new_embeddings | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
): | |||
r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
`What are input IDs? <../glossary.html#input-ids>`__ | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
`What are attention masks? <../glossary.html#attention-mask>`__ | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
`What are token type IDs? <../glossary.html#token-type-ids>`_ | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., | |||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored | |||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` | |||
Returns: | |||
Returns `modelscope.outputs.AttentionFillMaskModelOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor, NLPPreprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_structbert_fill-mask_chinese-large') | |||
>>> preprocessor = NLPPreprocessor('damo/nlp_structbert_fill-mask_chinese-large') | |||
>>> # Call the model, return some tensors | |||
>>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。'))) | |||
>>> # Call the pipeline | |||
>>> from modelscope.pipelines import pipeline | |||
>>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) | |||
>>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。')) | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
outputs = self.bert( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_attention_mask, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
) | |||
sequence_output = outputs[0] | |||
prediction_scores = self.cls(sequence_output) | |||
masked_lm_loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() # -100 index = padding token | |||
masked_lm_loss = loss_fct( | |||
prediction_scores.view(-1, self.config.vocab_size), | |||
labels.view(-1)) | |||
if not return_dict: | |||
output = (prediction_scores, ) + outputs[2:-1] | |||
return ((masked_lm_loss, ) | |||
+ output) if masked_lm_loss is not None else output | |||
return AttentionFillMaskModelOutput( | |||
loss=masked_lm_loss, | |||
logits=prediction_scores, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
input_ids=input_ids, | |||
) | |||
def prepare_inputs_for_generation(self, | |||
input_ids, | |||
attention_mask=None, | |||
**model_kwargs): | |||
input_shape = input_ids.shape | |||
effective_batch_size = input_shape[0] | |||
# add a dummy token | |||
assert self.config.pad_token_id is not None, 'The PAD token should be defined for generation' | |||
attention_mask_zero = attention_mask.new_zeros( | |||
(attention_mask.shape[0], 1)) | |||
attention_mask = torch.cat([attention_mask, attention_mask_zero], | |||
dim=-1) | |||
dummy_token = torch.full((effective_batch_size, 1), | |||
self.config.pad_token_id, | |||
dtype=torch.long, | |||
device=input_ids.device) | |||
input_ids = torch.cat([input_ids, dummy_token], dim=1) | |||
return {'input_ids': input_ids, 'attention_mask': attention_mask} |
@@ -0,0 +1,235 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
import torch.nn as nn | |||
import torch.utils.checkpoint | |||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import AttentionTextClassificationModelOutput | |||
from modelscope.utils import logger as logging | |||
from modelscope.utils.constant import Tasks | |||
from .adv_utils import compute_adv_loss | |||
from .backbone import SbertModel, SbertPreTrainedModel | |||
from .configuration import SbertConfig | |||
logger = logging.get_logger(__name__) | |||
@MODELS.register_module( | |||
Tasks.text_classification, module_name=Models.structbert) | |||
@MODELS.register_module(Tasks.nli, module_name=Models.structbert) | |||
@MODELS.register_module( | |||
Tasks.sentiment_classification, module_name=Models.structbert) | |||
@MODELS.register_module( | |||
Tasks.sentence_similarity, module_name=Models.structbert) | |||
@MODELS.register_module( | |||
Tasks.zero_shot_classification, module_name=Models.structbert) | |||
class SbertForSequenceClassification(SbertPreTrainedModel): | |||
r"""StructBERT Model transformer with a sequence classification/regression head on top | |||
(a linear layer on top of the pooled output) e.g. for GLUE tasks. | |||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |||
pruning heads etc.) | |||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | |||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | |||
general usage and behavior. | |||
Preprocessor: | |||
This is the text classification model of StructBERT, the preprocessor of this model | |||
is `modelscope.preprocessors.SequenceClassificationPreprocessor`. | |||
Trainer: | |||
This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, | |||
NlpEpochBasedTrainer, or trainers from other frameworks. | |||
The preferred trainer in ModelScope is NlpEpochBasedTrainer. | |||
Parameters: | |||
config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with | |||
all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |||
weights. | |||
""" | |||
def __init__(self, config: SbertConfig, **kwargs): | |||
super().__init__(config) | |||
self.num_labels = config.num_labels | |||
self.config = config | |||
if self.config.adv_grad_factor is None: | |||
logger.warning( | |||
'Adv parameters not set, skipping compute_adv_loss.') | |||
SbertForSequenceClassification.base_model_prefix = getattr( | |||
config, 'base_model_prefix', | |||
SbertForSequenceClassification.base_model_prefix) | |||
setattr(self, self.base_model_prefix, SbertModel(config)) | |||
classifier_dropout = ( | |||
config.classifier_dropout if config.classifier_dropout is not None | |||
else config.hidden_dropout_prob) | |||
self.dropout = nn.Dropout(classifier_dropout) | |||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
self.init_weights() | |||
def _forward_call(self, **kwargs): | |||
outputs = self.base_model(**kwargs) | |||
pooled_output = outputs[1] | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
outputs['logits'] = logits | |||
outputs.kwargs = kwargs | |||
return outputs | |||
def forward(self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
**kwargs): | |||
r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., | |||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |||
Returns: | |||
Returns `modelscope.outputs.AttentionTextClassificationModelOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') | |||
>>> # Call the model, return some tensors | |||
>>> print(model(**preprocessor(('这是个测试', '这也是个测试')))) | |||
>>> # Call the pipeline | |||
>>> from modelscope.pipelines import pipeline | |||
>>> pipeline_ins = pipeline('text-classification', model=model, preprocessor=preprocessor) | |||
>>> print(pipeline_ins(('这是个测试', '这也是个测试'))) | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
if not return_dict: | |||
logger.error('Return tuple in sbert is not supported now.') | |||
outputs = self._forward_call( | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict) | |||
return self.compute_loss(outputs, labels, **outputs.kwargs) | |||
def compute_loss(self, outputs, labels, **kwargs): | |||
logits = outputs.logits | |||
embedding_output = outputs.embedding_output | |||
loss = None | |||
if labels is not None: | |||
if self.config.problem_type is None: | |||
if self.num_labels == 1: | |||
self.config.problem_type = 'regression' | |||
elif self.num_labels > 1 and (labels.dtype == torch.long | |||
or labels.dtype == torch.int): | |||
self.config.problem_type = 'single_label_classification' | |||
else: | |||
self.config.problem_type = 'multi_label_classification' | |||
if self.config.problem_type == 'regression': | |||
loss_fct = MSELoss() | |||
if self.num_labels == 1: | |||
loss = loss_fct(logits.squeeze(), labels.squeeze()) | |||
else: | |||
loss = loss_fct(logits, labels) | |||
elif self.config.problem_type == 'single_label_classification': | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct( | |||
logits.view(-1, self.num_labels), labels.view(-1)) | |||
if self.config.adv_grad_factor is not None and self.training: | |||
loss = compute_adv_loss( | |||
embedding=embedding_output, | |||
model=self._forward_call, | |||
ori_logits=logits, | |||
ori_loss=loss, | |||
adv_bound=self.config.adv_bound, | |||
adv_grad_factor=self.config.adv_grad_factor, | |||
sigma=self.config.sigma, | |||
**kwargs) | |||
elif self.config.problem_type == 'multi_label_classification': | |||
loss_fct = BCEWithLogitsLoss() | |||
loss = loss_fct(logits, labels) | |||
return AttentionTextClassificationModelOutput( | |||
loss=loss, | |||
logits=logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) |
@@ -0,0 +1,229 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
import torch.nn as nn | |||
import torch.utils.checkpoint | |||
from torch.nn import CrossEntropyLoss | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import TokenClassifierOutput | |||
from modelscope.utils import logger as logging | |||
from modelscope.utils.constant import Tasks | |||
from .adv_utils import compute_adv_loss | |||
from .backbone import SbertModel, SbertPreTrainedModel | |||
from .configuration import SbertConfig | |||
logger = logging.get_logger(__name__) | |||
@MODELS.register_module( | |||
Tasks.token_classification, module_name=Models.structbert) | |||
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) | |||
@MODELS.register_module(Tasks.part_of_speech, module_name=Models.structbert) | |||
class SbertForTokenClassification(SbertPreTrainedModel): | |||
r"""StructBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) | |||
e.g. for Named-Entity-Recognition (NER) tasks. | |||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |||
pruning heads etc.) | |||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | |||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | |||
general usage and behavior. | |||
Preprocessor: | |||
This is the token-classification model of StructBERT, the preprocessor of this model | |||
is `modelscope.preprocessors.TokenClassificationPreprocessor`. | |||
Trainer: | |||
This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, | |||
NlpEpochBasedTrainer, or trainers from other frameworks. | |||
The preferred trainer in modelscope is NlpEpochBasedTrainer. | |||
Parameters: | |||
config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with | |||
all the parameters of the model. | |||
Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |||
weights. | |||
""" | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
def __init__(self, config: SbertConfig, **kwargs): | |||
super().__init__(config) | |||
self.num_labels = config.num_labels | |||
self.config = config | |||
if self.config.adv_grad_factor is None: | |||
logger.warning( | |||
'Adv parameters not set, skipping compute_adv_loss.') | |||
setattr(self, self.base_model_prefix, | |||
SbertModel(config, add_pooling_layer=False)) | |||
classifier_dropout = ( | |||
config.classifier_dropout if config.classifier_dropout is not None | |||
else config.hidden_dropout_prob) | |||
self.dropout = nn.Dropout(classifier_dropout) | |||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
self.init_weights() | |||
def _forward_call(self, **kwargs): | |||
outputs = self.bert(**kwargs) | |||
sequence_output = outputs[0] | |||
sequence_output = self.dropout(sequence_output) | |||
logits = self.classifier(sequence_output) | |||
outputs['logits'] = logits | |||
outputs.kwargs = kwargs | |||
return outputs | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
labels=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
offset_mapping=None, | |||
label_mask=None, | |||
): | |||
r""" | |||
Args: | |||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |||
Indices of input sequence tokens in the vocabulary. | |||
Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See | |||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
details. | |||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |||
1]``: | |||
- 0 corresponds to a `sentence A` token, | |||
- 1 corresponds to a `sentence B` token. | |||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |||
config.max_position_embeddings - 1]``. | |||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |||
- 1 indicates the head is **not masked**, | |||
- 0 indicates the head is **masked**. | |||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |||
vectors than the model's internal embedding lookup matrix. | |||
output_attentions (:obj:`bool`, `optional`): | |||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |||
tensors for more detail. | |||
output_hidden_states (:obj:`bool`, `optional`): | |||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |||
more detail. | |||
return_dict (:obj:`bool`, `optional`): | |||
Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. | |||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - | |||
1]``. | |||
offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, | |||
sequence_length)`, `optional`): | |||
Indices of positions of each input sequence tokens in the sentence. | |||
Selected in the range ``[0, sequence_length - 1]``. | |||
label_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, | |||
sequence_length)`, `optional`): | |||
Mask to avoid performing attention on padding token indices. Mask | |||
values selected in ``[0, 1]``: | |||
- 1 for tokens that are **not masked**, | |||
- 0 for tokens that are **masked**. | |||
Returns: | |||
Returns `modelscope.outputs.TokenClassifierOutput` | |||
Examples: | |||
>>> from modelscope.models import Model | |||
>>> from modelscope.preprocessors import Preprocessor | |||
>>> model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') | |||
>>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') | |||
>>> print(model(**preprocessor(('This is a test', 'This is also a test')))) | |||
""" | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
if not return_dict: | |||
logger.error('Return tuple in sbert is not supported now.') | |||
outputs = self._forward_call( | |||
input_ids=input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict) | |||
logits = outputs.logits | |||
embedding_output = outputs.embedding_output | |||
loss = None | |||
if labels is not None: | |||
loss_fct = CrossEntropyLoss() | |||
# Only keep active parts of the loss | |||
if attention_mask is not None: | |||
active_loss = attention_mask.view(-1) == 1 | |||
active_logits = logits.view(-1, self.num_labels) | |||
active_labels = torch.where( | |||
active_loss, labels.view(-1), | |||
torch.tensor(loss_fct.ignore_index).type_as(labels)) | |||
loss = loss_fct(active_logits, active_labels) | |||
else: | |||
loss = loss_fct( | |||
logits.view(-1, self.num_labels), labels.view(-1)) | |||
if self.config.adv_grad_factor is not None and self.training: | |||
loss = compute_adv_loss( | |||
embedding=embedding_output, | |||
model=self._forward_call, | |||
ori_logits=logits, | |||
ori_loss=loss, | |||
adv_bound=self.config.adv_bound, | |||
adv_grad_factor=self.config.adv_grad_factor, | |||
sigma=self.config.sigma, | |||
with_attention_mask=attention_mask is not None, | |||
**outputs.kwargs) | |||
return TokenClassifierOutput( | |||
loss=loss, | |||
logits=logits, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
offset_mapping=offset_mapping, | |||
) |
@@ -24,7 +24,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from .tokenization_sbert import SbertTokenizer | |||
from .tokenization import SbertTokenizer | |||
logger = get_logger(__name__) | |||
@@ -7,6 +7,9 @@ if TYPE_CHECKING: | |||
from .information_extraction import InformationExtractionModel | |||
from .feature_extraction import FeatureExtractionModel | |||
from .fill_mask import FillMaskModel | |||
from .nncrf_for_named_entity_recognition import ( | |||
TransformerCRFForNamedEntityRecognition, | |||
LSTMCRFForNamedEntityRecognition) | |||
from .sequence_classification import SequenceClassificationModel | |||
from .task_model import SingleBackboneTaskModelBase | |||
from .token_classification import TokenClassificationModel | |||
@@ -17,6 +20,10 @@ else: | |||
'information_extraction': ['InformationExtractionModel'], | |||
'feature_extraction': ['FeatureExtractionModel'], | |||
'fill_mask': ['FillMaskModel'], | |||
'nncrf_for_named_entity_recognition': [ | |||
'TransformerCRFForNamedEntityRecognition', | |||
'LSTMCRFForNamedEntityRecognition' | |||
], | |||
'sequence_classification': ['SequenceClassificationModel'], | |||
'task_model': ['SingleBackboneTaskModelBase'], | |||
'token_classification': ['TokenClassificationModel'], | |||
@@ -1,3 +1,4 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
import numpy as np | |||
@@ -31,13 +32,8 @@ class FeatureExtractionModel(SingleBackboneTaskModelBase): | |||
self.build_backbone(self.backbone_cfg) | |||
def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
# backbone do not need labels, only head need for loss compute | |||
labels = input.pop(OutputKeys.LABELS, None) | |||
input.pop(OutputKeys.LABELS, None) | |||
outputs = super().forward(input) | |||
sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||
if labels is not None: | |||
input[OutputKeys.LABELS] = labels | |||
sequence_output = outputs.last_hidden_state | |||
return {OutputKeys.TEXT_EMBEDDING: sequence_output} |