@@ -12,6 +12,7 @@ from http.cookiejar import CookieJar | |||
from os.path import expanduser | |||
from typing import List, Optional, Tuple, Union | |||
import attrs | |||
import requests | |||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
@@ -21,9 +22,14 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
API_RESPONSE_FIELD_USERNAME, | |||
DEFAULT_CREDENTIALS_PATH, Licenses, | |||
ModelVisibility) | |||
from modelscope.hub.deploy import (DeleteServiceParameters, | |||
DeployServiceParameters, | |||
GetServiceParameters, ListServiceParameters, | |||
ServiceParameters, ServiceResourceConfig, | |||
Vendor) | |||
from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
NotLoginException, RequestError, | |||
datahub_raise_on_error, | |||
NotLoginException, NotSupportError, | |||
RequestError, datahub_raise_on_error, | |||
handle_http_post_error, | |||
handle_http_response, is_ok, raise_on_error) | |||
from modelscope.hub.git import GitCommandWrapper | |||
@@ -306,6 +312,169 @@ class HubApi: | |||
r.raise_for_status() | |||
return None | |||
def deploy_model(self, model_id: str, revision: str, instance_name: str, | |||
resource: ServiceResourceConfig, | |||
provider: ServiceParameters): | |||
"""Deploy model to cloud, current we only support PAI EAS, this is asynchronous | |||
call , please check instance status through the console or query the instance status. | |||
At the same time, this call may take a long time. | |||
Args: | |||
model_id (str): The deployed model id | |||
revision (str): The model revision | |||
instance_name (str): The deployed model instance name. | |||
resource (DeployResource): The resource information. | |||
provider (CreateParameter): The cloud service provider parameter | |||
Raises: | |||
NotLoginException: To use this api, you need login first. | |||
NotSupportError: Not supported platform. | |||
RequestError: The server return error. | |||
Returns: | |||
InstanceInfo: The instance information. | |||
""" | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies is None: | |||
raise NotLoginException( | |||
'Token does not exist, please login first.') | |||
if provider.vendor != Vendor.EAS: | |||
raise NotSupportError( | |||
'Not support vendor: %s ,only support EAS current.' % | |||
(provider.vendor)) | |||
create_params = DeployServiceParameters( | |||
instance_name=instance_name, | |||
model_id=model_id, | |||
revision=revision, | |||
resource=resource, | |||
provider=provider) | |||
path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||
body = attrs.asdict(create_params) | |||
r = requests.post( | |||
path, | |||
json=body, | |||
cookies=cookies, | |||
) | |||
handle_http_response(r, logger, cookies, 'create_eas_instance') | |||
if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||
if is_ok(r.json()): | |||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||
return data | |||
else: | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
r.raise_for_status() | |||
return None | |||
def list_deployed_model_instances(self, | |||
provider: ServiceParameters, | |||
skip: int = 0, | |||
limit: int = 100): | |||
"""List deployed model instances. | |||
Args: | |||
provider (ListServiceParameter): The cloud service provider parameter, | |||
for eas, need access_key_id and access_key_secret. | |||
skip: start of the list, current not support. | |||
limit: maximum number of instances return, current not support | |||
Raises: | |||
NotLoginException: To use this api, you need login first. | |||
RequestError: The request is failed from server. | |||
Returns: | |||
List: List of instance information | |||
""" | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies is None: | |||
raise NotLoginException( | |||
'Token does not exist, please login first.') | |||
params = ListServiceParameters( | |||
provider=provider, skip=skip, limit=limit) | |||
path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||
params.to_query_str()) | |||
r = requests.get(path, cookies=cookies) | |||
handle_http_response(r, logger, cookies, 'list_deployed_model') | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||
return data | |||
else: | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
r.raise_for_status() | |||
return None | |||
def get_deployed_model_instance(self, instance_name: str, | |||
provider: ServiceParameters): | |||
"""Query the specified instance information. | |||
Args: | |||
instance_name (str): The deployed instance name. | |||
provider (GetParameter): The cloud provider information, for eas | |||
need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
Raises: | |||
NotLoginException: To use this api, you need login first. | |||
RequestError: The request is failed from server. | |||
Returns: | |||
Dict: The request instance information | |||
""" | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies is None: | |||
raise NotLoginException( | |||
'Token does not exist, please login first.') | |||
params = GetServiceParameters(provider=provider) | |||
path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
self.endpoint, instance_name, params.to_query_str()) | |||
r = requests.get(path, cookies=cookies) | |||
handle_http_response(r, logger, cookies, 'get_deployed_model') | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||
return data | |||
else: | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
r.raise_for_status() | |||
return None | |||
def delete_deployed_model_instance(self, instance_name: str, | |||
provider: ServiceParameters): | |||
"""Delete deployed model, this api send delete command and return, it will take | |||
some to delete, please check through the cloud console. | |||
Args: | |||
instance_name (str): The instance name you want to delete. | |||
provider (DeleteParameter): The cloud provider information, for eas | |||
need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
Raises: | |||
NotLoginException: To call this api, you need login first. | |||
RequestError: The request is failed. | |||
Returns: | |||
Dict: The deleted instance information. | |||
""" | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies is None: | |||
raise NotLoginException( | |||
'Token does not exist, please login first.') | |||
params = DeleteServiceParameters(provider=provider) | |||
path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
self.endpoint, instance_name, params.to_query_str()) | |||
r = requests.delete(path, cookies=cookies) | |||
handle_http_response(r, logger, cookies, 'delete_deployed_model') | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||
return data | |||
else: | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
r.raise_for_status() | |||
return None | |||
def _check_cookie(self, | |||
use_cookies: Union[bool, | |||
CookieJar] = False) -> CookieJar: | |||
@@ -0,0 +1,189 @@ | |||
import urllib | |||
from abc import ABC, abstractmethod | |||
from typing import Optional, Union | |||
import json | |||
from attr import fields | |||
from attrs import asdict, define, field, validators | |||
class Accelerator(object): | |||
CPU = 'cpu' | |||
GPU = 'gpu' | |||
class Vendor(object): | |||
EAS = 'eas' | |||
class EASRegion(object): | |||
beijing = 'cn-beijing' | |||
hangzhou = 'cn-hangzhou' | |||
class EASCpuInstanceType(object): | |||
"""EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||
""" | |||
tiny = 'ecs.c6.2xlarge' | |||
small = 'ecs.c6.4xlarge' | |||
medium = 'ecs.c6.6xlarge' | |||
large = 'ecs.c6.8xlarge' | |||
class EASGpuInstanceType(object): | |||
"""EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||
""" | |||
tiny = 'ecs.gn5-c28g1.7xlarge' | |||
small = 'ecs.gn5-c8g1.4xlarge' | |||
medium = 'ecs.gn6i-c24g1.12xlarge' | |||
large = 'ecs.gn6e-c12g1.3xlarge' | |||
def min_smaller_than_max(instance, attribute, value): | |||
if value > instance.max_replica: | |||
raise ValueError( | |||
"'min_replica' value: %s has to be smaller than 'max_replica' value: %s!" | |||
% (value, instance.max_replica)) | |||
@define | |||
class ServiceScalingConfig(object): | |||
"""Resource scaling config | |||
Currently we ignore max_replica | |||
Args: | |||
max_replica: maximum replica | |||
min_replica: minimum replica | |||
""" | |||
max_replica: int = field(default=1, validator=validators.ge(1)) | |||
min_replica: int = field( | |||
default=1, validator=[validators.ge(1), min_smaller_than_max]) | |||
@define | |||
class ServiceResourceConfig(object): | |||
"""Eas Resource request. | |||
Args: | |||
accelerator: the accelerator(cpu|gpu) | |||
instance_type: the instance type. | |||
scaling: The instance scaling config. | |||
""" | |||
instance_type: str | |||
scaling: ServiceScalingConfig | |||
accelerator: str = field( | |||
default=Accelerator.CPU, | |||
validator=validators.in_([Accelerator.CPU, Accelerator.GPU])) | |||
@define | |||
class ServiceParameters(ABC): | |||
pass | |||
@define | |||
class EASDeployParameters(ServiceParameters): | |||
"""Parameters for EAS Deployment. | |||
Args: | |||
resource_group: the resource group to deploy, current default. | |||
region: The eas instance region(eg: cn-hangzhou). | |||
access_key_id: The eas account access key id. | |||
access_key_secret: The eas account access key secret. | |||
vendor: must be 'eas' | |||
""" | |||
region: str | |||
access_key_id: str | |||
access_key_secret: str | |||
resource_group: Optional[str] = None | |||
vendor: str = field( | |||
default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||
""" | |||
def __init__(self, | |||
instance_name: str, | |||
access_key_id: str, | |||
access_key_secret: str, | |||
region = EASRegion.beijing, | |||
instance_type: str = EASCpuInstances.small, | |||
accelerator: str = Accelerator.CPU, | |||
resource_group: Optional[str] = None, | |||
scaling: Optional[str] = None): | |||
self.instance_name=instance_name | |||
self.access_key_id=self.access_key_id | |||
self.access_key_secret = access_key_secret | |||
self.region = region | |||
self.instance_type = instance_type | |||
self.accelerator = accelerator | |||
self.resource_group = resource_group | |||
self.scaling = scaling | |||
""" | |||
@define | |||
class EASListParameters(ServiceParameters): | |||
"""EAS instance list parameters. | |||
Args: | |||
resource_group: the resource group to deploy, current default. | |||
region: The eas instance region(eg: cn-hangzhou). | |||
access_key_id: The eas account access key id. | |||
access_key_secret: The eas account access key secret. | |||
vendor: must be 'eas' | |||
""" | |||
access_key_id: str | |||
access_key_secret: str | |||
region: str = None | |||
resource_group: str = None | |||
vendor: str = field( | |||
default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||
@define | |||
class DeployServiceParameters(object): | |||
"""Deploy service parameters | |||
Args: | |||
instance_name: the name of the service. | |||
model_id: the modelscope model_id | |||
revision: the modelscope model revision | |||
resource: the resource requirement. | |||
provider: the cloud service provider. | |||
""" | |||
instance_name: str | |||
model_id: str | |||
revision: str | |||
resource: ServiceResourceConfig | |||
provider: ServiceParameters | |||
class AttrsToQueryString(ABC): | |||
"""Convert the attrs class to json string. | |||
Args: | |||
""" | |||
def to_query_str(self): | |||
self_dict = asdict( | |||
self.provider, filter=lambda attr, value: value is not None) | |||
json_str = json.dumps(self_dict) | |||
print(json_str) | |||
safe_str = urllib.parse.quote_plus(json_str) | |||
print(safe_str) | |||
query_param = 'provider=%s' % safe_str | |||
return query_param | |||
@define | |||
class ListServiceParameters(AttrsToQueryString): | |||
provider: ServiceParameters | |||
skip: int = 0 | |||
limit: int = 100 | |||
@define | |||
class GetServiceParameters(AttrsToQueryString): | |||
provider: ServiceParameters | |||
@define | |||
class DeleteServiceParameters(AttrsToQueryString): | |||
provider: ServiceParameters |
@@ -9,6 +9,10 @@ from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
class NotSupportError(Exception): | |||
pass | |||
class NotExistError(Exception): | |||
pass | |||
@@ -66,6 +70,7 @@ def handle_http_response(response, logger, cookies, model_id): | |||
logger.error( | |||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | |||
private. Please login first.') | |||
logger.error('Response details: %s' % response.content) | |||
raise error | |||
@@ -67,8 +67,9 @@ class Models(object): | |||
space_dst = 'space-dst' | |||
space_intent = 'space-intent' | |||
space_modeling = 'space-modeling' | |||
star = 'star' | |||
star3 = 'star3' | |||
space_T_en = 'space-T-en' | |||
space_T_cn = 'space-T-cn' | |||
tcrf = 'transformer-crf' | |||
transformer_softmax = 'transformer-softmax' | |||
lcrf = 'lstm-crf' | |||
@@ -16,6 +16,7 @@ from modelscope.models.builder import MODELS | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from .utils import timestamp_format | |||
from .yolox.data.data_augment import ValTransform | |||
from .yolox.exp import get_exp_by_name | |||
from .yolox.utils import postprocess | |||
@@ -99,14 +100,17 @@ class RealtimeVideoDetector(TorchModel): | |||
def inference_video(self, v_path): | |||
outputs = [] | |||
desc = 'Detecting video: {}'.format(v_path) | |||
for frame, result in tqdm( | |||
self.inference_video_iter(v_path), desc=desc): | |||
for frame_idx, (frame, result) in enumerate( | |||
tqdm(self.inference_video_iter(v_path), desc=desc)): | |||
result = result + (timestamp_format(seconds=frame_idx | |||
/ self.fps), ) | |||
outputs.append(result) | |||
return outputs | |||
def inference_video_iter(self, v_path): | |||
capture = cv2.VideoCapture(v_path) | |||
self.fps = capture.get(cv2.CAP_PROP_FPS) | |||
while capture.isOpened(): | |||
ret, frame = capture.read() | |||
if not ret: | |||
@@ -0,0 +1,9 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import math | |||
def timestamp_format(seconds): | |||
m, s = divmod(seconds, 60) | |||
h, m = divmod(m, 60) | |||
time = '%02d:%02d:%06.3f' % (h, m, s) | |||
return time |
@@ -24,8 +24,8 @@ import json | |||
logger = logging.getLogger(__name__) | |||
class Star3Config(object): | |||
"""Configuration class to store the configuration of a `Star3Model`. | |||
class SpaceTCnConfig(object): | |||
"""Configuration class to store the configuration of a `SpaceTCnModel`. | |||
""" | |||
def __init__(self, | |||
@@ -40,10 +40,10 @@ class Star3Config(object): | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.02): | |||
"""Constructs Star3Config. | |||
"""Constructs SpaceTCnConfig. | |||
Args: | |||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`. | |||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `SpaceTCnConfig`. | |||
hidden_size: Size of the encoder layers and the pooler layer. | |||
num_hidden_layers: Number of hidden layers in the Transformer encoder. | |||
num_attention_heads: Number of attention heads for each attention layer in | |||
@@ -59,7 +59,7 @@ class Star3Config(object): | |||
max_position_embeddings: The maximum sequence length that this model might | |||
ever be used with. Typically set this to something large just in case | |||
(e.g., 512 or 1024 or 2048). | |||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`. | |||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into `SpaceTCnConfig`. | |||
initializer_range: The sttdev of the truncated_normal_initializer for | |||
initializing all weight matrices. | |||
""" | |||
@@ -89,15 +89,15 @@ class Star3Config(object): | |||
@classmethod | |||
def from_dict(cls, json_object): | |||
"""Constructs a `Star3Config` from a Python dictionary of parameters.""" | |||
config = Star3Config(vocab_size_or_config_json_file=-1) | |||
"""Constructs a `SpaceTCnConfig` from a Python dictionary of parameters.""" | |||
config = SpaceTCnConfig(vocab_size_or_config_json_file=-1) | |||
for key, value in json_object.items(): | |||
config.__dict__[key] = value | |||
return config | |||
@classmethod | |||
def from_json_file(cls, json_file): | |||
"""Constructs a `Star3Config` from a json file of parameters.""" | |||
"""Constructs a `SpaceTCnConfig` from a json file of parameters.""" | |||
with open(json_file, 'r', encoding='utf-8') as reader: | |||
text = reader.read() | |||
return cls.from_dict(json.loads(text)) |
@@ -27,7 +27,8 @@ import numpy as np | |||
import torch | |||
from torch import nn | |||
from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||
from modelscope.models.nlp.space_T_cn.configuration_space_T_cn import \ | |||
SpaceTCnConfig | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
@@ -609,9 +610,9 @@ class PreTrainedBertModel(nn.Module): | |||
def __init__(self, config, *inputs, **kwargs): | |||
super(PreTrainedBertModel, self).__init__() | |||
if not isinstance(config, Star3Config): | |||
if not isinstance(config, SpaceTCnConfig): | |||
raise ValueError( | |||
'Parameter config in `{}(config)` should be an instance of class `Star3Config`. ' | |||
'Parameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. ' | |||
'To create a model from a Google pretrained model use ' | |||
'`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( | |||
self.__class__.__name__, self.__class__.__name__)) | |||
@@ -676,7 +677,7 @@ class PreTrainedBertModel(nn.Module): | |||
serialization_dir = tempdir | |||
# Load config | |||
config_file = os.path.join(serialization_dir, CONFIG_NAME) | |||
config = Star3Config.from_json_file(config_file) | |||
config = SpaceTCnConfig.from_json_file(config_file) | |||
logger.info('Model config {}'.format(config)) | |||
# Instantiate model. | |||
model = cls(config, *inputs, **kwargs) | |||
@@ -742,11 +743,11 @@ class PreTrainedBertModel(nn.Module): | |||
return model | |||
class Star3Model(PreTrainedBertModel): | |||
"""Star3Model model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR3.0"). | |||
class SpaceTCnModel(PreTrainedBertModel): | |||
"""SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN"). | |||
Params: | |||
config: a Star3Config class instance with the configuration to build a new model | |||
config: a SpaceTCnConfig class instance with the configuration to build a new model | |||
Inputs: | |||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] | |||
@@ -780,16 +781,16 @@ class Star3Model(PreTrainedBertModel): | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||
config = modeling.Star3Config(vocab_size_or_config_json_file=32000, hidden_size=768, | |||
config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | |||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||
model = modeling.Star3Model(config=config) | |||
model = modeling.SpaceTCnModel(config=config) | |||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | |||
``` | |||
""" | |||
def __init__(self, config, schema_link_module='none'): | |||
super(Star3Model, self).__init__(config) | |||
super(SpaceTCnModel, self).__init__(config) | |||
self.embeddings = BertEmbeddings(config) | |||
self.encoder = BertEncoder( | |||
config, schema_link_module=schema_link_module) |
@@ -20,7 +20,7 @@ __all__ = ['StarForTextToSql'] | |||
@MODELS.register_module( | |||
Tasks.conversational_text_to_sql, module_name=Models.star) | |||
Tasks.table_question_answering, module_name=Models.space_T_en) | |||
class StarForTextToSql(Model): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
@@ -3,27 +3,25 @@ | |||
import os | |||
from typing import Dict | |||
import json | |||
import numpy | |||
import torch | |||
import torch.nn.functional as F | |||
import tqdm | |||
from transformers import BertTokenizer | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Model, Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||
from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model | |||
from modelscope.preprocessors.star3.fields.struct import Constant | |||
from modelscope.preprocessors.space_T_cn.fields.struct import Constant | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.device import verify_device | |||
from .space_T_cn.configuration_space_T_cn import SpaceTCnConfig | |||
from .space_T_cn.modeling_space_T_cn import Seq2SQL, SpaceTCnModel | |||
__all__ = ['TableQuestionAnswering'] | |||
@MODELS.register_module( | |||
Tasks.table_question_answering, module_name=Models.star3) | |||
Tasks.table_question_answering, module_name=Models.space_T_cn) | |||
class TableQuestionAnswering(Model): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
@@ -43,9 +41,9 @@ class TableQuestionAnswering(Model): | |||
os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||
map_location='cpu') | |||
self.backbone_config = Star3Config.from_json_file( | |||
self.backbone_config = SpaceTCnConfig.from_json_file( | |||
os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
self.backbone_model = Star3Model( | |||
self.backbone_model = SpaceTCnModel( | |||
config=self.backbone_config, schema_link_module='rat') | |||
self.backbone_model.load_state_dict(state_dict['backbone_model']) | |||
@@ -606,21 +606,12 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], | |||
# conversational text-to-sql result for single sample | |||
# { | |||
# "text": "SELECT shop.Name FROM shop." | |||
# } | |||
Tasks.conversational_text_to_sql: [OutputKeys.TEXT], | |||
# table-question-answering result for single sample | |||
# { | |||
# "sql": "SELECT shop.Name FROM shop." | |||
# "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | |||
# } | |||
Tasks.table_question_answering: [ | |||
OutputKeys.SQL_STRING, OutputKeys.SQL_QUERY, OutputKeys.HISTORY, | |||
OutputKeys.QUERT_RESULT | |||
], | |||
Tasks.table_question_answering: [OutputKeys.OUTPUT], | |||
# ============ audio tasks =================== | |||
# asr result for single sample | |||
@@ -69,9 +69,6 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
'damo/nlp_space_dialog-modeling'), | |||
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||
'damo/nlp_space_dialog-state-tracking'), | |||
Tasks.conversational_text_to_sql: | |||
(Pipelines.conversational_text_to_sql, | |||
'damo/nlp_star_conversational-text-to-sql'), | |||
Tasks.table_question_answering: | |||
(Pipelines.table_question_answering_pipeline, | |||
'damo/nlp-convai-text2sql-pretrain-cn'), | |||
@@ -113,9 +113,8 @@ class AnimalRecognitionPipeline(Pipeline): | |||
label_mapping = f.readlines() | |||
score = torch.max(inputs['outputs']) | |||
inputs = { | |||
OutputKeys.SCORES: | |||
score.item(), | |||
OutputKeys.SCORES: [score.item()], | |||
OutputKeys.LABELS: | |||
label_mapping[inputs['outputs'].argmax()].split('\t')[1] | |||
[label_mapping[inputs['outputs'].argmax()].split('\t')[1]] | |||
} | |||
return inputs |
@@ -114,9 +114,8 @@ class GeneralRecognitionPipeline(Pipeline): | |||
label_mapping = f.readlines() | |||
score = torch.max(inputs['outputs']) | |||
inputs = { | |||
OutputKeys.SCORES: | |||
score.item(), | |||
OutputKeys.SCORES: [score.item()], | |||
OutputKeys.LABELS: | |||
label_mapping[inputs['outputs'].argmax()].split('\t')[1] | |||
[label_mapping[inputs['outputs'].argmax()].split('\t')[1]] | |||
} | |||
return inputs |
@@ -45,15 +45,17 @@ class RealtimeVideoObjectDetectionPipeline(Pipeline): | |||
**kwargs) -> str: | |||
forward_output = input['forward_output'] | |||
scores, boxes, labels = [], [], [] | |||
scores, boxes, labels, timestamps = [], [], [], [] | |||
for result in forward_output: | |||
box, score, label = result | |||
box, score, label, timestamp = result | |||
scores.append(score) | |||
boxes.append(box) | |||
labels.append(label) | |||
timestamps.append(timestamp) | |||
return { | |||
OutputKeys.BOXES: boxes, | |||
OutputKeys.SCORES: scores, | |||
OutputKeys.LABELS: labels, | |||
OutputKeys.TIMESTAMPS: timestamps, | |||
} |
@@ -19,7 +19,7 @@ __all__ = ['ConversationalTextToSqlPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.conversational_text_to_sql, | |||
Tasks.table_question_answering, | |||
module_name=Pipelines.conversational_text_to_sql) | |||
class ConversationalTextToSqlPipeline(Pipeline): | |||
@@ -62,7 +62,7 @@ class ConversationalTextToSqlPipeline(Pipeline): | |||
Dict[str, str]: the prediction results | |||
""" | |||
sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) | |||
result = {OutputKeys.TEXT: sql} | |||
result = {OutputKeys.OUTPUT: {OutputKeys.TEXT: sql}} | |||
return result | |||
def _collate_fn(self, data): | |||
@@ -13,8 +13,9 @@ from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||
from modelscope.preprocessors.star3.fields.database import Database | |||
from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery | |||
from modelscope.preprocessors.space_T_cn.fields.database import Database | |||
from modelscope.preprocessors.space_T_cn.fields.struct import (Constant, | |||
SQLQuery) | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
__all__ = ['TableQuestionAnsweringPipeline'] | |||
@@ -320,7 +321,7 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
OutputKeys.QUERT_RESULT: tabledata, | |||
} | |||
return output | |||
return {OutputKeys.OUTPUT: output} | |||
def _collate_fn(self, data): | |||
return data |
@@ -40,7 +40,7 @@ if TYPE_CHECKING: | |||
DialogStateTrackingPreprocessor) | |||
from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | |||
from .star import ConversationalTextToSqlPreprocessor | |||
from .star3 import TableQuestionAnsweringPreprocessor | |||
from .space_T_cn import TableQuestionAnsweringPreprocessor | |||
else: | |||
_import_structure = { | |||
@@ -81,7 +81,7 @@ else: | |||
'DialogStateTrackingPreprocessor', 'InputFeatures' | |||
], | |||
'star': ['ConversationalTextToSqlPreprocessor'], | |||
'star3': ['TableQuestionAnsweringPreprocessor'], | |||
'space_T_cn': ['TableQuestionAnsweringPreprocessor'], | |||
} | |||
import sys | |||
@@ -4,7 +4,7 @@ import sqlite3 | |||
import json | |||
import tqdm | |||
from modelscope.preprocessors.star3.fields.struct import Trie | |||
from modelscope.preprocessors.space_T_cn.fields.struct import Trie | |||
class Database: |
@@ -1,7 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import re | |||
from modelscope.preprocessors.star3.fields.struct import TypeInfo | |||
from modelscope.preprocessors.space_T_cn.fields.struct import TypeInfo | |||
class SchemaLinker: |
@@ -8,8 +8,8 @@ from transformers import BertTokenizer | |||
from modelscope.metainfo import Preprocessors | |||
from modelscope.preprocessors.base import Preprocessor | |||
from modelscope.preprocessors.builder import PREPROCESSORS | |||
from modelscope.preprocessors.star3.fields.database import Database | |||
from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker | |||
from modelscope.preprocessors.space_T_cn.fields.database import Database | |||
from modelscope.preprocessors.space_T_cn.fields.schema_link import SchemaLinker | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import Fields, ModelFile | |||
from modelscope.utils.type_assert import type_assert |
@@ -123,7 +123,6 @@ class NLPTasks(object): | |||
backbone = 'backbone' | |||
text_error_correction = 'text-error-correction' | |||
faq_question_answering = 'faq-question-answering' | |||
conversational_text_to_sql = 'conversational-text-to-sql' | |||
information_extraction = 'information-extraction' | |||
document_segmentation = 'document-segmentation' | |||
feature_extraction = 'feature-extraction' | |||
@@ -20,7 +20,7 @@ def text2sql_tracking_and_print_results( | |||
results = p(case) | |||
print({'question': item}) | |||
print(results) | |||
last_sql = results['text'] | |||
last_sql = results[OutputKeys.OUTPUT][OutputKeys.TEXT] | |||
history.append(item) | |||
@@ -1,4 +1,5 @@ | |||
addict | |||
attrs | |||
datasets | |||
easydict | |||
einops | |||
@@ -16,7 +16,7 @@ from modelscope.utils.test_utils import test_level | |||
class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | |||
def setUp(self) -> None: | |||
self.task = Tasks.conversational_text_to_sql | |||
self.task = Tasks.table_question_answering | |||
self.model_id = 'damo/nlp_star_conversational-text-to-sql' | |||
model_id = 'damo/nlp_star_conversational-text-to-sql' | |||
@@ -66,11 +66,6 @@ class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | |||
pipelines = [pipeline(task=self.task, model=self.model_id)] | |||
text2sql_tracking_and_print_results(self.test_case, pipelines) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipelines = [pipeline(task=self.task)] | |||
text2sql_tracking_and_print_results(self.test_case, pipelines) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_demo_compatibility(self): | |||
self.compatibility_check() | |||
@@ -12,7 +12,7 @@ from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | |||
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||
from modelscope.preprocessors.star3.fields.database import Database | |||
from modelscope.preprocessors.space_T_cn.fields.database import Database | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.test_utils import test_level | |||
@@ -38,7 +38,7 @@ def tableqa_tracking_and_print_results_with_history( | |||
output_dict = p({ | |||
'question': question, | |||
'history_sql': historical_queries | |||
}) | |||
})[OutputKeys.OUTPUT] | |||
print('question', question) | |||
print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
@@ -61,7 +61,7 @@ def tableqa_tracking_and_print_results_without_history( | |||
} | |||
for p in pipelines: | |||
for question in test_case['utterance']: | |||
output_dict = p({'question': question}) | |||
output_dict = p({'question': question})[OutputKeys.OUTPUT] | |||
print('question', question) | |||
print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
@@ -92,7 +92,7 @@ def tableqa_tracking_and_print_results_with_tableid( | |||
'question': question, | |||
'table_id': table_id, | |||
'history_sql': historical_queries | |||
}) | |||
})[OutputKeys.OUTPUT] | |||
print('question', question) | |||
print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
@@ -147,11 +147,6 @@ class TableQuestionAnswering(unittest.TestCase): | |||
] | |||
tableqa_tracking_and_print_results_with_tableid(pipelines) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_from_task(self): | |||
pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] | |||
tableqa_tracking_and_print_results_with_history(pipelines) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub_with_other_classes(self): | |||
model = Model.from_pretrained(self.model_id) | |||