Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9627026master
@@ -1,2 +1,3 @@ | |||
from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel | |||
from .tokenization_ofa import OFATokenizer | |||
from .tokenization_ofa import OFATokenizer, OFATokenizerZH | |||
from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast |
@@ -134,6 +134,8 @@ class OFAConfig(PretrainedConfig): | |||
code_layernorm_embedding=True, | |||
code_image_size=128, | |||
entangle_position_embedding=False, | |||
interpolate_position=False, | |||
orig_patch_image_size=224, | |||
**kwargs): | |||
self.vocab_size = vocab_size | |||
self.max_position_embeddings = max_position_embeddings | |||
@@ -173,6 +175,8 @@ class OFAConfig(PretrainedConfig): | |||
self.code_layernorm_embedding = code_layernorm_embedding | |||
self.code_image_size = code_image_size | |||
self.entangle_position_embedding = entangle_position_embedding | |||
self.interpolate_position = interpolate_position | |||
self.orig_patch_image_size = orig_patch_image_size | |||
super().__init__( | |||
pad_token_id=pad_token_id, | |||
@@ -311,7 +311,6 @@ class OFAAttention(nn.Module): | |||
self.head_dim * num_heads == self.embed_dim | |||
), f'embed_dim must be divisible by num_heads ' \ | |||
f'(got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads}).' | |||
# self.scaling = self.head_dim ** -0.5 | |||
# 1. difference | |||
scale_factor = 2 | |||
self.scaling = float(self.head_dim * scale_factor)**-0.5 | |||
@@ -913,7 +912,6 @@ class OFAEncoder(OFAPreTrainedModel): | |||
else: | |||
raise NotImplementedError | |||
# self.image_proj = nn.Linear(1024, embed_dim) | |||
self.image_proj = Linear(1024, embed_dim) | |||
if config.resnet_model_path: | |||
@@ -1075,7 +1073,25 @@ class OFAEncoder(OFAPreTrainedModel): | |||
image_num_patches = sample_patch_num | |||
image_padding_mask = image_padding_mask.gather(1, patch_orders) | |||
image_position_ids = image_position_ids.gather(1, patch_orders) | |||
image_pos_embed = self.embed_image_positions(image_position_ids) | |||
orig_num_patches = (self.config.orig_patch_image_size // 16)**2 | |||
orig_hw = self.config.orig_patch_image_size // 16 | |||
if self.config.interpolate_position and image_num_patches > orig_num_patches: | |||
old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \ | |||
torch.arange(orig_hw).unsqueeze(1) * \ | |||
self.config.image_bucket_size + 1 # noqa | |||
old_image_position_ids = old_image_position_ids.to(device) | |||
old_image_pos_embed = self.embed_image_positions( | |||
old_image_position_ids) | |||
old_image_pos_embed = old_image_pos_embed.reshape( | |||
1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2) | |||
image_pos_embed = F.interpolate( | |||
old_image_pos_embed, size=(h, w), mode='bilinear') | |||
image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape( | |||
1, image_num_patches, -1) | |||
image_pos_embed = image_pos_embed.expand( | |||
patch_images.size(0), -1, -1) | |||
else: | |||
image_pos_embed = self.embed_image_positions(image_position_ids) | |||
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed | |||
@@ -1250,7 +1266,6 @@ class OFAEncoder(OFAPreTrainedModel): | |||
position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): | |||
positional embeddings of the input image and tokens. | |||
""" | |||
image_embed = None | |||
image_embed_2 = None | |||
image_pos_embed = None | |||
@@ -1258,14 +1273,7 @@ class OFAEncoder(OFAPreTrainedModel): | |||
if patch_images is not None: | |||
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ | |||
self.get_patch_images_info(patch_images, sample_patch_num, input_ids.device) | |||
# print("patch_masks.shape") | |||
# print(patch_masks.shape) | |||
# print(patch_masks) | |||
# print("image_padding_mask.shape") | |||
# print(image_padding_mask.shape) | |||
# print(image_padding_mask) | |||
image_padding_mask[~patch_masks] = True | |||
# print(image_padding_mask) | |||
if patch_images_2 is not None: | |||
image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \ | |||
self.get_patch_images_info(patch_images_2, sample_patch_num, input_ids.device) | |||
@@ -1313,10 +1321,6 @@ class OFAEncoder(OFAPreTrainedModel): | |||
encoder_states = () if output_hidden_states else None | |||
all_attentions = () if output_attentions else None | |||
# if output_hidden_states: | |||
# # encoder_states.append(x) | |||
# encoder_states += (x,) | |||
# encoder layers | |||
for idx, layer in enumerate(self.layers): | |||
if output_hidden_states: | |||
@@ -1645,7 +1649,6 @@ class OFADecoder(OFAPreTrainedModel): | |||
def reorder_incremental_state_scripting( | |||
self, | |||
# incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |||
past_key_values: Optional[torch.Tensor], | |||
new_order: Tensor, | |||
): | |||
@@ -1799,15 +1802,12 @@ class OFADecoder(OFAPreTrainedModel): | |||
self_attn_bias = self_abs_pos_bias.clone() | |||
if code_masks is None or not code_masks.any(): | |||
# print("code_masks is None or not code_masks.any()") | |||
self_attn_bias += self.get_rel_pos_bias( | |||
all_prev_output_tokens, idx).unsqueeze(0) | |||
elif code_masks is not None and code_masks.all(): | |||
# print("code_masks is not None and code_masks.all()") | |||
self_attn_bias += self.get_image_rel_pos_bias( | |||
all_prev_output_tokens, idx).unsqueeze(0) | |||
else: | |||
# print("else") | |||
self_attn_bias[~code_masks] += self.get_rel_pos_bias( | |||
all_prev_output_tokens, idx).unsqueeze(0) | |||
self_attn_bias[code_masks] += self.get_image_rel_pos_bias( | |||
@@ -1921,7 +1921,7 @@ class OFAModel(OFAPreTrainedModel): | |||
output_type=Seq2SeqModelOutput, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
# 新增函数以适配fairseq的generator | |||
# an adaptor for fairseq generator | |||
def max_decoder_positions(self): | |||
"""Maximum length supported by the decoder.""" | |||
return self.decoder.max_positions() | |||
@@ -2062,7 +2062,6 @@ class OFAModel(OFAPreTrainedModel): | |||
return Seq2SeqLMOutput( | |||
logits=decoder_outputs.last_hidden_state, | |||
# last_hidden_state=decoder_outputs.last_hidden_state, | |||
past_key_values=decoder_outputs.past_key_values, | |||
decoder_hidden_states=decoder_outputs.hidden_states, | |||
decoder_attentions=decoder_outputs.attentions, | |||
@@ -12,7 +12,14 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Tokenization classes for OFA.""" | |||
import collections | |||
import os | |||
from typing import List, Optional, Tuple | |||
from transformers import PreTrainedTokenizer | |||
from transformers.models.bart.tokenization_bart import BartTokenizer | |||
from transformers.models.bert.tokenization_bert import (BasicTokenizer, | |||
WordpieceTokenizer) | |||
from transformers.utils import logging | |||
logger = logging.get_logger(__name__) | |||
@@ -26,12 +33,37 @@ PRETRAINED_VOCAB_FILES_MAP = { | |||
'merges_file': { | |||
'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', | |||
}, | |||
# OFA models are implemented to be compatible with both huggingface | |||
# and modelscope frameworks. For all OFA models available on huggingface, | |||
# please refer to https://huggingface.co/models?filter=ofa | |||
} | |||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
'ofa-base': 1024, | |||
} | |||
VOCAB_FILES_NAMES_ZH = {'vocab_file': 'vocab.txt'} | |||
PRETRAINED_VOCAB_FILES_MAP_ZH = { | |||
'vocab_file': { | |||
'bert-base-chinese': | |||
'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt', | |||
} | |||
# OFA models are implemented to be compatible with both huggingface | |||
# and modelscope frameworks. For all OFA models available on huggingface, | |||
# please refer to https://huggingface.co/models?filter=ofa | |||
} | |||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = { | |||
'ofa-base': 1024, | |||
} | |||
PRETRAINED_INIT_CONFIGURATION_ZH = { | |||
'bert-base-chinese': { | |||
'do_lower_case': True | |||
}, | |||
} | |||
class OFATokenizer(BartTokenizer): | |||
""" | |||
@@ -46,3 +78,293 @@ class OFATokenizer(BartTokenizer): | |||
vocab_files_names = VOCAB_FILES_NAMES | |||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
def load_vocab(vocab_file): | |||
"""Loads a vocabulary file into a dictionary.""" | |||
vocab = collections.OrderedDict() | |||
with open(vocab_file, 'r', encoding='utf-8') as reader: | |||
tokens = reader.readlines() | |||
for index, token in enumerate(tokens): | |||
token = token.rstrip('\n') | |||
vocab[token] = index | |||
return vocab | |||
def whitespace_tokenize(text): | |||
"""Runs basic whitespace cleaning and splitting on a piece of text.""" | |||
text = text.strip() | |||
if not text: | |||
return [] | |||
tokens = text.split() | |||
return tokens | |||
class OFATokenizerZH(PreTrainedTokenizer): | |||
r""" | |||
Construct a OFA tokenizer. Based on WordPiece. | |||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to | |||
this superclass for more information regarding those methods. | |||
Args: | |||
vocab_file (`str`): | |||
File containing the vocabulary. | |||
do_lower_case (`bool`, *optional*, defaults to `True`): | |||
Whether or not to lowercase the input when tokenizing. | |||
do_basic_tokenize (`bool`, *optional*, defaults to `True`): | |||
Whether or not to do basic tokenization before WordPiece. | |||
never_split (`Iterable`, *optional*): | |||
Collection of tokens which will never be split during tokenization. Only has an effect when | |||
`do_basic_tokenize=True` | |||
bos_token (`str`, *optional*, defaults to `"<s>"`): | |||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. | |||
<Tip> | |||
When building a sequence using special tokens, this is not the token that is used for the beginning of | |||
sequence. The token used is the `cls_token`. | |||
</Tip> | |||
eos_token (`str`, *optional*, defaults to `"</s>"`): | |||
The end of sequence token. | |||
<Tip> | |||
When building a sequence using special tokens, this is not the token that is used for the end of sequence. | |||
The token used is the `sep_token`. | |||
</Tip> | |||
sep_token (`str`, *optional*, defaults to `"</s>"`): | |||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for | |||
sequence classification or for a text and a question for question answering. It is also used as the last | |||
token of a sequence built with special tokens. | |||
cls_token (`str`, *optional*, defaults to `"<s>"`): | |||
The classifier token which is used when doing sequence classification (classification of the whole sequence | |||
instead of per-token classification). It is the first token of the sequence when built with special tokens. | |||
unk_token (`str`, *optional*, defaults to `"<unk>"`): | |||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this | |||
token instead. | |||
pad_token (`str`, *optional*, defaults to `"<pad>"`): | |||
The token used for padding, for example when batching sequences of different lengths. | |||
mask_token (`str`, *optional*, defaults to `"<mask>"`): | |||
The token used for masking values. This is the token used when training this model with masked language | |||
modeling. This is the token which the model will try to predict. | |||
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): | |||
Whether or not to tokenize Chinese characters. | |||
This should likely be deactivated for Japanese (see this | |||
[issue](https://github.com/huggingface/transformers/issues/328)). | |||
strip_accents (`bool`, *optional*): | |||
Whether or not to strip all accents. If this option is not specified, then it will be determined by the | |||
value for `lowercase` (as in the original BERT). | |||
""" | |||
vocab_files_names = VOCAB_FILES_NAMES_ZH | |||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH | |||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH | |||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH | |||
def __init__(self, | |||
vocab_file, | |||
do_lower_case=True, | |||
do_basic_tokenize=True, | |||
never_split=None, | |||
bos_token='<s>', | |||
eos_token='</s>', | |||
sep_token='</s>', | |||
cls_token='<s>', | |||
unk_token='<unk>', | |||
pad_token='<pad>', | |||
mask_token='<mask>', | |||
tokenize_chinese_chars=True, | |||
strip_accents=None, | |||
**kwargs): | |||
super().__init__( | |||
do_lower_case=do_lower_case, | |||
do_basic_tokenize=do_basic_tokenize, | |||
never_split=never_split, | |||
bos_token=bos_token, | |||
eos_token=eos_token, | |||
unk_token=unk_token, | |||
sep_token=sep_token, | |||
cls_token=cls_token, | |||
pad_token=pad_token, | |||
mask_token=mask_token, | |||
tokenize_chinese_chars=tokenize_chinese_chars, | |||
strip_accents=strip_accents, | |||
**kwargs, | |||
) | |||
if not os.path.isfile(vocab_file): | |||
raise ValueError( | |||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " | |||
'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' | |||
) | |||
self.vocab = load_vocab(vocab_file) | |||
self.ids_to_tokens = collections.OrderedDict([ | |||
(ids, tok) for tok, ids in self.vocab.items() | |||
]) | |||
self.do_basic_tokenize = do_basic_tokenize | |||
if do_basic_tokenize: | |||
self.basic_tokenizer = BasicTokenizer( | |||
do_lower_case=do_lower_case, | |||
never_split=never_split, | |||
tokenize_chinese_chars=tokenize_chinese_chars, | |||
strip_accents=strip_accents, | |||
) | |||
self.wordpiece_tokenizer = WordpieceTokenizer( | |||
vocab=self.vocab, unk_token=self.unk_token) | |||
@property | |||
def do_lower_case(self): | |||
return self.basic_tokenizer.do_lower_case | |||
@property | |||
def vocab_size(self): | |||
return len(self.vocab) | |||
def get_vocab(self): | |||
return dict(self.vocab, **self.added_tokens_encoder) | |||
def _tokenize(self, text): | |||
split_tokens = [] | |||
if self.do_basic_tokenize: | |||
for token in self.basic_tokenizer.tokenize( | |||
text, never_split=self.all_special_tokens): | |||
# If the token is part of the never_split set | |||
if token in self.basic_tokenizer.never_split: | |||
split_tokens.append(token) | |||
else: | |||
split_tokens += self.wordpiece_tokenizer.tokenize(token) | |||
else: | |||
split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||
return split_tokens | |||
def _convert_token_to_id(self, token): | |||
"""Converts a token (str) in an id using the vocab.""" | |||
return self.vocab.get(token, self.vocab.get(self.unk_token)) | |||
def _convert_id_to_token(self, index): | |||
"""Converts an index (integer) in a token (str) using the vocab.""" | |||
return self.ids_to_tokens.get(index, self.unk_token) | |||
def convert_tokens_to_string(self, tokens): | |||
"""Converts a sequence of tokens (string) in a single string.""" | |||
out_string = ' '.join(tokens).replace(' ##', '').strip() | |||
return out_string | |||
def build_inputs_with_special_tokens( | |||
self, | |||
token_ids_0: List[int], | |||
token_ids_1: Optional[List[int]] = None) -> List[int]: | |||
""" | |||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and | |||
adding special tokens. A BERT sequence has the following format: | |||
- single sequence: `[CLS] X [SEP]` | |||
- pair of sequences: `[CLS] A [SEP] B [SEP]` | |||
Args: | |||
token_ids_0 (`List[int]`): | |||
List of IDs to which the special tokens will be added. | |||
token_ids_1 (`List[int]`, *optional*): | |||
Optional second list of IDs for sequence pairs. | |||
Returns: | |||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. | |||
""" | |||
if token_ids_1 is None: | |||
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] | |||
cls = [self.cls_token_id] | |||
sep = [self.sep_token_id] | |||
return cls + token_ids_0 + sep + token_ids_1 + sep | |||
def get_special_tokens_mask( | |||
self, | |||
token_ids_0: List[int], | |||
token_ids_1: Optional[List[int]] = None, | |||
already_has_special_tokens: bool = False) -> List[int]: | |||
""" | |||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding | |||
special tokens using the tokenizer `prepare_for_model` method. | |||
Args: | |||
token_ids_0 (`List[int]`): | |||
List of IDs. | |||
token_ids_1 (`List[int]`, *optional*): | |||
Optional second list of IDs for sequence pairs. | |||
already_has_special_tokens (`bool`, *optional*, defaults to `False`): | |||
Whether or not the token list is already formatted with special tokens for the model. | |||
Returns: | |||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. | |||
""" | |||
if already_has_special_tokens: | |||
return super().get_special_tokens_mask( | |||
token_ids_0=token_ids_0, | |||
token_ids_1=token_ids_1, | |||
already_has_special_tokens=True) | |||
if token_ids_1 is not None: | |||
return [1] + ([0] * len(token_ids_0)) + [1] + ( | |||
[0] * len(token_ids_1)) + [1] | |||
return [1] + ([0] * len(token_ids_0)) + [1] | |||
def create_token_type_ids_from_sequences( | |||
self, | |||
token_ids_0: List[int], | |||
token_ids_1: Optional[List[int]] = None) -> List[int]: | |||
""" | |||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence | |||
pair mask has the following format: | |||
``` | |||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | |||
| first sequence | second sequence | | |||
``` | |||
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). | |||
Args: | |||
token_ids_0 (`List[int]`): | |||
List of IDs. | |||
token_ids_1 (`List[int]`, *optional*): | |||
Optional second list of IDs for sequence pairs. | |||
Returns: | |||
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). | |||
""" | |||
sep = [self.sep_token_id] | |||
cls = [self.cls_token_id] | |||
if token_ids_1 is None: | |||
return len(cls + token_ids_0 + sep) * [0] | |||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 | |||
+ sep) * [1] | |||
def save_vocabulary(self, | |||
save_directory: str, | |||
filename_prefix: Optional[str] = None) -> Tuple[str]: | |||
index = 0 | |||
if os.path.isdir(save_directory): | |||
vocab_file = os.path.join( | |||
save_directory, | |||
(filename_prefix + '-' if filename_prefix else '') | |||
+ VOCAB_FILES_NAMES['vocab_file']) | |||
else: | |||
vocab_file = (filename_prefix | |||
+ '-' if filename_prefix else '') + save_directory | |||
with open(vocab_file, 'w', encoding='utf-8') as writer: | |||
for token, token_index in sorted( | |||
self.vocab.items(), key=lambda kv: kv[1]): | |||
if index != token_index: | |||
logger.warning( | |||
f'Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive.' | |||
' Please check that the vocabulary is not corrupted!') | |||
index = token_index | |||
writer.write(token + '\n') | |||
index += 1 | |||
return (vocab_file, ) |
@@ -12,10 +12,15 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Tokenization classes for OFA.""" | |||
from typing import List, Optional, Tuple | |||
import json | |||
from tokenizers import normalizers | |||
from transformers import PreTrainedTokenizerFast | |||
from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast | |||
from transformers.utils import logging | |||
from .tokenization_ofa import OFATokenizer | |||
from .tokenization_ofa import OFATokenizer, OFATokenizerZH | |||
logger = logging.get_logger(__name__) | |||
@@ -36,12 +41,37 @@ PRETRAINED_VOCAB_FILES_MAP = { | |||
'ofa-base': | |||
'https://huggingface.co/ofa-base/resolve/main/tokenizer.json', | |||
}, | |||
# OFA models are implemented to be compatible with both huggingface | |||
# and modelscope frameworks. For all OFA models available on huggingface, | |||
# please refer to https://huggingface.co/models?filter=ofa | |||
} | |||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
'ofa-base': 1024, | |||
} | |||
VOCAB_FILES_NAMES_ZH = {'vocab_file': 'vocab.txt'} | |||
PRETRAINED_VOCAB_FILES_MAP_ZH = { | |||
'vocab_file': { | |||
'bert-base-chinese': | |||
'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt', | |||
} | |||
# OFA models are implemeted to be compatible with both huggingface | |||
# and modelscope frameworks. For all OFA models available on huggingface, | |||
# please refer to https://huggingface.co/models?filter=ofa | |||
} | |||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = { | |||
'ofa-base': 1024, | |||
} | |||
PRETRAINED_INIT_CONFIGURATION_ZH = { | |||
'bert-base-chinese': { | |||
'do_lower_case': True | |||
}, | |||
} | |||
class OFATokenizerFast(BartTokenizerFast): | |||
r""" | |||
@@ -57,3 +87,128 @@ class OFATokenizerFast(BartTokenizerFast): | |||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
slow_tokenizer_class = OFATokenizer | |||
class OFATokenizerZHFast(PreTrainedTokenizerFast): | |||
r""" | |||
Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library). | |||
[`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting | |||
and wordpiece. | |||
Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters. | |||
""" | |||
vocab_files_names = VOCAB_FILES_NAMES_ZH | |||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH | |||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH | |||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH | |||
slow_tokenizer_class = OFATokenizerZH | |||
def __init__(self, | |||
vocab_file=None, | |||
tokenizer_file=None, | |||
do_lower_case=True, | |||
bos_token='<s>', | |||
eos_token='</s>', | |||
sep_token='</s>', | |||
cls_token='<s>', | |||
unk_token='<unk>', | |||
pad_token='<pad>', | |||
mask_token='<mask>', | |||
tokenize_chinese_chars=True, | |||
strip_accents=None, | |||
**kwargs): | |||
super().__init__( | |||
vocab_file, | |||
tokenizer_file=tokenizer_file, | |||
do_lower_case=do_lower_case, | |||
bos_token=bos_token, | |||
eos_token=eos_token, | |||
unk_token=unk_token, | |||
sep_token=sep_token, | |||
cls_token=cls_token, | |||
pad_token=pad_token, | |||
mask_token=mask_token, | |||
tokenize_chinese_chars=tokenize_chinese_chars, | |||
strip_accents=strip_accents, | |||
**kwargs, | |||
) | |||
normalizer_state = json.loads( | |||
self.backend_tokenizer.normalizer.__getstate__()) | |||
if (normalizer_state.get('lowercase', do_lower_case) != do_lower_case | |||
or normalizer_state.get('strip_accents', strip_accents) | |||
!= strip_accents or normalizer_state.get( | |||
'handle_chinese_chars', | |||
tokenize_chinese_chars) != tokenize_chinese_chars): | |||
normalizer_class = getattr(normalizers, | |||
normalizer_state.pop('type')) | |||
normalizer_state['lowercase'] = do_lower_case | |||
normalizer_state['strip_accents'] = strip_accents | |||
normalizer_state['handle_chinese_chars'] = tokenize_chinese_chars | |||
self.backend_tokenizer.normalizer = normalizer_class( | |||
**normalizer_state) | |||
self.do_lower_case = do_lower_case | |||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): | |||
""" | |||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and | |||
adding special tokens. A BERT sequence has the following format: | |||
- single sequence: `[CLS] X [SEP]` | |||
- pair of sequences: `[CLS] A [SEP] B [SEP]` | |||
Args: | |||
token_ids_0 (`List[int]`): | |||
List of IDs to which the special tokens will be added. | |||
token_ids_1 (`List[int]`, *optional*): | |||
Optional second list of IDs for sequence pairs. | |||
Returns: | |||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. | |||
""" | |||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] | |||
if token_ids_1: | |||
output += token_ids_1 + [self.sep_token_id] | |||
return output | |||
def create_token_type_ids_from_sequences( | |||
self, | |||
token_ids_0: List[int], | |||
token_ids_1: Optional[List[int]] = None) -> List[int]: | |||
""" | |||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence | |||
pair mask has the following format: | |||
``` | |||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | |||
| first sequence | second sequence | | |||
``` | |||
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). | |||
Args: | |||
token_ids_0 (`List[int]`): | |||
List of IDs. | |||
token_ids_1 (`List[int]`, *optional*): | |||
Optional second list of IDs for sequence pairs. | |||
Returns: | |||
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). | |||
""" | |||
sep = [self.sep_token_id] | |||
cls = [self.cls_token_id] | |||
if token_ids_1 is None: | |||
return len(cls + token_ids_0 + sep) * [0] | |||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 | |||
+ sep) * [1] | |||
def save_vocabulary(self, | |||
save_directory: str, | |||
filename_prefix: Optional[str] = None) -> Tuple[str]: | |||
files = self._tokenizer.model.save( | |||
save_directory, name=filename_prefix) | |||
return tuple(files) |
@@ -16,7 +16,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_tokens | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.trie import Trie | |||
from .ofa import OFAModel, OFATokenizer | |||
from .ofa import OFAModel, OFATokenizer, OFATokenizerZH | |||
from .ofa.generate import sequence_generator as sg | |||
from .ofa.generate.utils import move_to_device | |||
from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks | |||
@@ -41,11 +41,21 @@ class OfaForAllTasks(TorchModel): | |||
self.cfg = Config.from_file( | |||
osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
self.model = model.module if hasattr(model, 'module') else model | |||
self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
self.language = self.cfg.model.get('language', 'en') | |||
if self.language == 'en': | |||
self.tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
elif self.language in ['zh', 'cn']: | |||
self.tokenizer = OFATokenizerZH.from_pretrained(model_dir) | |||
else: | |||
raise NotImplementedError | |||
# there is some diff between here and our ofa code, | |||
# there will be no need to use param: use_bpe | |||
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) | |||
self.batch_size = self.cfg.model.get('batch_size', 1) | |||
self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | |||
self.max_image_size = self.cfg.model.get('max_image_size', 512) | |||
self.val_batch_size = self.cfg.model.get('valid_batch_size', | |||
self.batch_size) | |||
self.gen_type = self.cfg.model.get('gen_type', 'generation') | |||
@@ -129,8 +139,8 @@ class OfaForAllTasks(TorchModel): | |||
- len(self.tokenizer.get_vocab().items()) | |||
+ self.cfg.num_bins) | |||
region_tensor = torch.stack(region_coord_l, dim=0) | |||
region_tensor = region_tensor / ( | |||
self.cfg.num_bins - 1) * self.cfg.model.get('max_image_size', 512) | |||
region_tensor = region_tensor / (self.cfg.num_bins | |||
- 1) * self.max_image_size | |||
region_tensor[:, ::2] /= input['w_resize_ratios'] | |||
region_tensor[:, 1::2] /= input['h_resize_ratios'] | |||
return { | |||
@@ -6,7 +6,7 @@ import json | |||
import numpy as np | |||
import torch | |||
from modelscope.models.multi_modal.ofa import OFATokenizer | |||
from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | |||
from modelscope.utils.trie import Trie | |||
from .utils.random_help import set_torch_seed | |||
@@ -21,7 +21,15 @@ class OfaBasePreprocessor: | |||
model_dir (str): model path | |||
""" | |||
self.cfg = cfg | |||
tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
self.language = self.cfg.model.get('language', 'en') | |||
if self.language == 'en': | |||
tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
elif self.language in ['zh', 'cn']: | |||
tokenizer = OFATokenizerZH.from_pretrained(model_dir) | |||
else: | |||
raise NotImplementedError | |||
# there is some diff between here and our ofa code, | |||
# there will be no need to use param: use_bpe | |||
tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)]) | |||
tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)]) | |||
self.tokenizer = tokenizer | |||
@@ -1,17 +1,33 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import unittest | |||
from os import path as osp | |||
import cv2 | |||
import numpy as np | |||
from PIL import Image | |||
from modelscope.models import Model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.preprocessors.image import load_image | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class OfaTasksTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.output_dir = 'unittest_output' | |||
os.makedirs(self.output_dir, exist_ok=True) | |||
def save_img(self, image_in, box, image_out): | |||
image = load_image(image_in) | |||
img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) | |||
cv2.rectangle(img, (int(box[0]), int(box[1])), | |||
(int(box[2]), int(box[3])), (0, 255, 0), 3) | |||
cv2.imwrite(osp.join(self.output_dir, image_out), img) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_image_captioning_with_model(self): | |||
model = Model.from_pretrained('damo/ofa_image-caption_coco_large_en') | |||
@@ -132,6 +148,9 @@ class OfaTasksTest(unittest.TestCase): | |||
input = {'image': image, 'text': text} | |||
result = ofa_pipe(input) | |||
print(result) | |||
image_name = image.split('/')[-2] | |||
self.save_img(image, result[OutputKeys.BOXES], | |||
osp.join('large_en_model_' + image_name + '.png')) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_visual_grounding_with_name(self): | |||
@@ -143,6 +162,22 @@ class OfaTasksTest(unittest.TestCase): | |||
input = {'image': image, 'text': text} | |||
result = ofa_pipe(input) | |||
print(result) | |||
image_name = image.split('/')[-2] | |||
self.save_img(image, result[OutputKeys.BOXES], | |||
osp.join('large_en_name_' + image_name + '.png')) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_visual_grounding_zh_with_name(self): | |||
model = 'damo/ofa_visual-grounding_refcoco_large_zh' | |||
ofa_pipe = pipeline(Tasks.visual_grounding, model=model) | |||
image = 'data/test/images/visual_grounding.png' | |||
text = '一个圆头的蓝色宝可梦' | |||
input = {'image': image, 'text': text} | |||
result = ofa_pipe(input) | |||
print(result) | |||
image_name = image.split('/')[-1] | |||
self.save_img(image, result[OutputKeys.BOXES], | |||
osp.join('large_zh_name_' + image_name)) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_visual_question_answering_with_model(self): | |||