@@ -1,26 +1,32 @@ | |||
awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
pip install -r requirements/tests.txt | |||
echo "Testing envs" | |||
printenv | |||
echo "ENV END" | |||
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||
awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
pip install -r requirements/tests.txt | |||
git config --global --add safe.directory /Maas-lib | |||
git config --global user.email tmp | |||
git config --global user.name tmp.com | |||
git config --global --add safe.directory /Maas-lib | |||
git config --global user.email tmp | |||
git config --global user.name tmp.com | |||
# linter test | |||
# use internal project for pre-commit due to the network problem | |||
if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then | |||
pre-commit run -c .pre-commit-config_local.yaml --all-files | |||
fi | |||
if [ $? -ne 0 ]; then | |||
echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||
exit -1 | |||
# linter test | |||
# use internal project for pre-commit due to the network problem | |||
if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then | |||
pre-commit run -c .pre-commit-config_local.yaml --all-files | |||
if [ $? -ne 0 ]; then | |||
echo "linter test failed, please run 'pre-commit run --all-files' to check" | |||
exit -1 | |||
fi | |||
fi | |||
# test with install | |||
python setup.py install | |||
else | |||
echo "Running case in release image, run case directly!" | |||
fi | |||
# test with install | |||
python setup.py install | |||
if [ $# -eq 0 ]; then | |||
ci_command="python tests/run.py --subprocess" | |||
else | |||
@@ -20,28 +20,52 @@ do | |||
# pull image if there are update | |||
docker pull ${IMAGE_NAME}:${IMAGE_VERSION} | |||
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||
--cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||
--gpus="device=$gpu" \ | |||
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
-v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||
-v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
-e CI_TEST=True \ | |||
-e TEST_LEVEL=$TEST_LEVEL \ | |||
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
-e TEST_LEVEL=$TEST_LEVEL \ | |||
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
-e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
--workdir=$CODE_DIR_IN_CONTAINER \ | |||
--net host \ | |||
${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
$CI_COMMAND | |||
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then | |||
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||
--cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||
--gpus="device=$gpu" \ | |||
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
-v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||
-v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
-e CI_TEST=True \ | |||
-e TEST_LEVEL=$TEST_LEVEL \ | |||
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
-e MODELSCOPE_SDK_DEBUG=True \ | |||
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
-e TEST_LEVEL=$TEST_LEVEL \ | |||
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
-e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
--workdir=$CODE_DIR_IN_CONTAINER \ | |||
--net host \ | |||
${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
$CI_COMMAND | |||
else | |||
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ | |||
--cpuset-cpus=${cpu_sets_arr[$gpu]} \ | |||
--gpus="device=$gpu" \ | |||
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ | |||
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
-v $MODELSCOPE_HOME_CACHE/$gpu:/root \ | |||
-v /home/admin/pre-commit:/home/admin/pre-commit \ | |||
-e CI_TEST=True \ | |||
-e TEST_LEVEL=$TEST_LEVEL \ | |||
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ | |||
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ | |||
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ | |||
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ | |||
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ | |||
-e TEST_LEVEL=$TEST_LEVEL \ | |||
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ | |||
-e MODEL_TAG_URL=$MODEL_TAG_URL \ | |||
--workdir=$CODE_DIR_IN_CONTAINER \ | |||
--net host \ | |||
${IMAGE_NAME}:${IMAGE_VERSION} \ | |||
$CI_COMMAND | |||
fi | |||
if [ $? -ne 0 ]; then | |||
echo "Running test case failed, please check the log!" | |||
exit -1 | |||
@@ -1,6 +1,6 @@ | |||
repos: | |||
- repo: https://gitlab.com/pycqa/flake8.git | |||
rev: 3.8.3 | |||
rev: 4.0.0 | |||
hooks: | |||
- id: flake8 | |||
exclude: thirdparty/|examples/ | |||
@@ -1,6 +1,6 @@ | |||
repos: | |||
- repo: /home/admin/pre-commit/flake8 | |||
rev: 3.8.3 | |||
rev: 4.0.0 | |||
hooks: | |||
- id: flake8 | |||
exclude: thirdparty/|examples/ | |||
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f | |||
size 38954 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4 | |||
size 93676 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb | |||
size 472012 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b | |||
size 71244 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61 | |||
size 181964 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425 | |||
size 112078 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281 | |||
size 200556 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b | |||
size 162636 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61 | |||
size 160204 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d | |||
size 2327764 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150 | |||
size 68524 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:76bf84536edbaf192a8a699efc62ba2b06056bac12c426ecfcc2e003d91fbd32 | |||
size 53219 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:ecbc9d0827cfb92e93e7d75868b1724142685dc20d3b32023c3c657a7b688a9c | |||
size 254845 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:d510ab26ddc58ffea882c8ef850c1f9bd4444772f2bce7ebea3e76944536c3ae | |||
size 48909 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:b2c1119e3d521cf2e583b1e85fc9c9afd1d44954b433135039a98050a730932d | |||
size 1127557 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:46db348eae61448f1668ce282caec21375e96c3268d53da44aa67ec32cbf4fa5 | |||
size 2747938 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:709c1828ed2d56badf2f19a40194da9a5e5e6db2fb73ef55d047407f49bc7a15 | |||
size 27616 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:772b19f76c98044e39330853928624f10e085106a4292b4dd19f865531080747 | |||
size 959 |
@@ -1,3 +0,0 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d | |||
size 1702339 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:dec0fbb931cb609bf481e56b89cd2fbbab79839f22832c3bbe69a8fae2769cdd | |||
size 167407 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:4fd6fa6b23c2fdaf876606a767d9b64b1924e1acddfc06ac42db73ba86083280 | |||
size 119940 | |||
oid sha256:4eae921001139d7e3c06331c9ef2213f8fc1c23512acd95751559866fb770e96 | |||
size 121855 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:4d37672a0e299a08d2daf5c7fc29bfce96bb15701fe5e5e68f068861ac2ee705 | |||
size 119619 | |||
oid sha256:f97d34d7450d17d0a93647129ab10d16b1f6e70c34a73b6f7687b79519ee4f71 | |||
size 121563 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:c692e0753cfe349e520511427727a8252f141fa10e85f9a61562845e8d731f9a | |||
size 119619 | |||
oid sha256:a8355f27a3235209f206b5e75f4400353e5989e94cf4d71270b42ded8821d536 | |||
size 121563 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:344ef971bdf310b76c6571d1f4994ab6abc5edc659654d71a4f75b14a30960c2 | |||
size 152926 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 | |||
size 62231 | |||
oid sha256:f0aeb07b6c9b40a0cfa7492e839431764e9bece93c906833a07c05e83520a399 | |||
size 63161 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a | |||
size 62235 | |||
oid sha256:7aa5c7a2565ccf0d2eea4baf8adbd0e020dbe36a7159b31156c53141cc9b2df2 | |||
size 63165 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:9103ce2bc89212f67fb49ce70783b7667e376900d0f70fb8f5c4432eb74bc572 | |||
size 60801 | |||
oid sha256:cc6de82a8485fbfa008f6c2d5411cd07ba03e4a780bcb4e67efc6fba3c6ce92f | |||
size 63597 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:2d4dee34c7e83b77db04fb2f0d1200bfd37c7c24954c58e185da5cb96445975c | |||
size 60801 | |||
oid sha256:7d98ac11a4e9e2744a7402a5cc912da991a41938bbc5dd60f15ee5c6b3196030 | |||
size 63349 |
@@ -1,3 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:9e3ecc2c30d382641d561f84849b199c12bb1a9418e8099a191153f6f5275a85 | |||
size 61589 | |||
oid sha256:01f9b9bf6f8bbf9bb377d4cb6f399b2e5e065381f5b7332343e0db7b4fae72a5 | |||
size 62519 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:a49c9bc74a60860c360a4bf4509fe9db915279aaabd953f354f2c38e9be1e6cb | |||
size 2924691 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:f58df1d25590c158ae0a04b3999bd44b610cdaddb17d78afd84c34b3f00d4e87 | |||
size 4068783 |
@@ -76,7 +76,7 @@ RUN pip install --no-cache-dir --upgrade pip && \ | |||
ENV SHELL=/bin/bash | |||
# install special package | |||
RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq | |||
RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq fasttext https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/xtcocotools-1.12-cp37-cp37m-linux_x86_64.whl | |||
RUN if [ "$USE_GPU" = "True" ] ; then \ | |||
pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \ | |||
@@ -1,4 +1,4 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .version import __version__ | |||
from .version import __release_datetime__, __version__ | |||
__all__ = ['__version__'] | |||
__all__ = ['__version__', '__release_datetime__'] |
@@ -19,10 +19,13 @@ class Exporter(ABC): | |||
def from_model(cls, model: Model, **kwargs): | |||
"""Build the Exporter instance. | |||
@param model: A model instance. it will be used to output the generated file, | |||
Args: | |||
model: A Model instance. it will be used to generate the intermediate format file, | |||
and the configuration.json in its model_dir field will be used to create the exporter instance. | |||
@param kwargs: Extra kwargs used to create the Exporter instance. | |||
@return: The Exporter instance | |||
kwargs: Extra kwargs used to create the Exporter instance. | |||
Returns: | |||
The Exporter instance | |||
""" | |||
cfg = Config.from_file( | |||
os.path.join(model.model_dir, ModelFile.CONFIGURATION)) | |||
@@ -44,10 +47,13 @@ class Exporter(ABC): | |||
In some cases, several files may be generated, | |||
So please return a dict which contains the generated name with the file path. | |||
@param opset: The version of the ONNX operator set to use. | |||
@param outputs: The output dir. | |||
@param kwargs: In this default implementation, | |||
kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
@return: A dict contains the model name with the model file path. | |||
Args: | |||
opset: The version of the ONNX operator set to use. | |||
outputs: The output dir. | |||
kwargs: In this default implementation, | |||
kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
Returns: | |||
A dict contains the model name with the model file path. | |||
""" | |||
pass |
@@ -23,13 +23,18 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||
def generate_dummy_inputs(self, | |||
shape: Tuple = None, | |||
pair: bool = False, | |||
**kwargs) -> Dict[str, Any]: | |||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing. | |||
@param shape: A tuple of input shape which should have at most two dimensions. | |||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||
@return: Dummy inputs. | |||
Args: | |||
shape: A tuple of input shape which should have at most two dimensions. | |||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||
pair(bool, `optional`): Whether to generate sentence pairs or single sentences. | |||
Returns: | |||
Dummy inputs. | |||
""" | |||
cfg = Config.from_file( | |||
@@ -55,7 +60,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): | |||
**sequence_length | |||
}) | |||
preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | |||
if preprocessor.pair: | |||
if pair: | |||
first_sequence = preprocessor.tokenizer.unk_token | |||
second_sequence = preprocessor.tokenizer.unk_token | |||
else: | |||
@@ -13,8 +13,8 @@ from modelscope.models import TorchModel | |||
from modelscope.pipelines.base import collate_fn | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.regress_test_utils import compare_arguments_nested | |||
from modelscope.utils.tensor_utils import torch_nested_numpify | |||
from modelscope.utils.regress_test_utils import (compare_arguments_nested, | |||
numpify_tensor_nested) | |||
from .base import Exporter | |||
logger = get_logger(__name__) | |||
@@ -28,49 +28,61 @@ class TorchModelExporter(Exporter): | |||
and to provide implementations for generate_dummy_inputs/inputs/outputs methods. | |||
""" | |||
def export_onnx(self, outputs: str, opset=11, **kwargs): | |||
def export_onnx(self, output_dir: str, opset=13, **kwargs): | |||
"""Export the model as onnx format files. | |||
In some cases, several files may be generated, | |||
So please return a dict which contains the generated name with the file path. | |||
@param opset: The version of the ONNX operator set to use. | |||
@param outputs: The output dir. | |||
@param kwargs: In this default implementation, | |||
you can pass the arguments needed by _torch_export_onnx, other unrecognized args | |||
will be carried to generate_dummy_inputs as extra arguments (such as input shape). | |||
@return: A dict containing the model key - model file path pairs. | |||
Args: | |||
opset: The version of the ONNX operator set to use. | |||
output_dir: The output dir. | |||
kwargs: | |||
model: A model instance which will replace the exporting of self.model. | |||
In this default implementation, | |||
you can pass the arguments needed by _torch_export_onnx, other unrecognized args | |||
will be carried to generate_dummy_inputs as extra arguments (such as input shape). | |||
Returns: | |||
A dict containing the model key - model file path pairs. | |||
""" | |||
model = self.model | |||
model = self.model if 'model' not in kwargs else kwargs.pop('model') | |||
if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
model = model.model | |||
onnx_file = os.path.join(outputs, ModelFile.ONNX_MODEL_FILE) | |||
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE) | |||
self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs) | |||
return {'model': onnx_file} | |||
def export_torch_script(self, outputs: str, **kwargs): | |||
def export_torch_script(self, output_dir: str, **kwargs): | |||
"""Export the model as torch script files. | |||
In some cases, several files may be generated, | |||
So please return a dict which contains the generated name with the file path. | |||
@param outputs: The output dir. | |||
@param kwargs: In this default implementation, | |||
Args: | |||
output_dir: The output dir. | |||
kwargs: | |||
model: A model instance which will replace the exporting of self.model. | |||
In this default implementation, | |||
you can pass the arguments needed by _torch_export_torch_script, other unrecognized args | |||
will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
@return: A dict contains the model name with the model file path. | |||
Returns: | |||
A dict contains the model name with the model file path. | |||
""" | |||
model = self.model | |||
model = self.model if 'model' not in kwargs else kwargs.pop('model') | |||
if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
model = model.model | |||
ts_file = os.path.join(outputs, ModelFile.TS_MODEL_FILE) | |||
ts_file = os.path.join(output_dir, ModelFile.TS_MODEL_FILE) | |||
# generate ts by tracing | |||
self._torch_export_torch_script(model, ts_file, **kwargs) | |||
return {'model': ts_file} | |||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: | |||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing. | |||
@return: Dummy inputs. | |||
Returns: | |||
Dummy inputs. | |||
""" | |||
return None | |||
@@ -93,7 +105,7 @@ class TorchModelExporter(Exporter): | |||
def _torch_export_onnx(self, | |||
model: nn.Module, | |||
output: str, | |||
opset: int = 11, | |||
opset: int = 13, | |||
device: str = 'cpu', | |||
validation: bool = True, | |||
rtol: float = None, | |||
@@ -101,18 +113,27 @@ class TorchModelExporter(Exporter): | |||
**kwargs): | |||
"""Export the model to an onnx format file. | |||
@param model: A torch.nn.Module instance to export. | |||
@param output: The output file. | |||
@param opset: The version of the ONNX operator set to use. | |||
@param device: The device used to forward. | |||
@param validation: Whether validate the export file. | |||
@param rtol: The rtol used to regress the outputs. | |||
@param atol: The atol used to regress the outputs. | |||
Args: | |||
model: A torch.nn.Module instance to export. | |||
output: The output file. | |||
opset: The version of the ONNX operator set to use. | |||
device: The device used to forward. | |||
validation: Whether validate the export file. | |||
rtol: The rtol used to regress the outputs. | |||
atol: The atol used to regress the outputs. | |||
kwargs: | |||
dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). | |||
inputs: An inputs structure which will replace the calling of self.inputs. | |||
outputs: An outputs structure which will replace the calling of self.outputs. | |||
""" | |||
dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||
inputs = self.inputs | |||
outputs = self.outputs | |||
dummy_inputs = self.generate_dummy_inputs( | |||
**kwargs) if 'dummy_inputs' not in kwargs else kwargs.pop( | |||
'dummy_inputs') | |||
inputs = self.inputs if 'inputs' not in kwargs else kwargs.pop( | |||
'inputs') | |||
outputs = self.outputs if 'outputs' not in kwargs else kwargs.pop( | |||
'outputs') | |||
if dummy_inputs is None or inputs is None or outputs is None: | |||
raise NotImplementedError( | |||
'Model property dummy_inputs,inputs,outputs must be set.') | |||
@@ -125,7 +146,7 @@ class TorchModelExporter(Exporter): | |||
if isinstance(dummy_inputs, Mapping): | |||
dummy_inputs = dict(dummy_inputs) | |||
onnx_outputs = list(self.outputs.keys()) | |||
onnx_outputs = list(outputs.keys()) | |||
with replace_call(): | |||
onnx_export( | |||
@@ -160,11 +181,13 @@ class TorchModelExporter(Exporter): | |||
outputs_origin = model.forward( | |||
*_decide_input_format(model, dummy_inputs)) | |||
if isinstance(outputs_origin, Mapping): | |||
outputs_origin = torch_nested_numpify( | |||
outputs_origin = numpify_tensor_nested( | |||
list(outputs_origin.values())) | |||
elif isinstance(outputs_origin, (tuple, list)): | |||
outputs_origin = numpify_tensor_nested(outputs_origin) | |||
outputs = ort_session.run( | |||
onnx_outputs, | |||
torch_nested_numpify(dummy_inputs), | |||
numpify_tensor_nested(dummy_inputs), | |||
) | |||
tols = {} | |||
@@ -184,19 +207,26 @@ class TorchModelExporter(Exporter): | |||
validation: bool = True, | |||
rtol: float = None, | |||
atol: float = None, | |||
strict: bool = True, | |||
**kwargs): | |||
"""Export the model to a torch script file. | |||
@param model: A torch.nn.Module instance to export. | |||
@param output: The output file. | |||
@param device: The device used to forward. | |||
@param validation: Whether validate the export file. | |||
@param rtol: The rtol used to regress the outputs. | |||
@param atol: The atol used to regress the outputs. | |||
Args: | |||
model: A torch.nn.Module instance to export. | |||
output: The output file. | |||
device: The device used to forward. | |||
validation: Whether validate the export file. | |||
rtol: The rtol used to regress the outputs. | |||
atol: The atol used to regress the outputs. | |||
strict: strict mode in torch script tracing. | |||
kwargs: | |||
dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). | |||
""" | |||
model.eval() | |||
dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||
dummy_param = 'dummy_inputs' not in kwargs | |||
dummy_inputs = self.generate_dummy_inputs( | |||
**kwargs) if dummy_param else kwargs.pop('dummy_inputs') | |||
if dummy_inputs is None: | |||
raise NotImplementedError( | |||
'Model property dummy_inputs must be set.') | |||
@@ -207,7 +237,7 @@ class TorchModelExporter(Exporter): | |||
model.eval() | |||
with replace_call(): | |||
traced_model = torch.jit.trace( | |||
model, dummy_inputs, strict=False) | |||
model, dummy_inputs, strict=strict) | |||
torch.jit.save(traced_model, output) | |||
if validation: | |||
@@ -216,9 +246,9 @@ class TorchModelExporter(Exporter): | |||
model.eval() | |||
ts_model.eval() | |||
outputs = ts_model.forward(*dummy_inputs) | |||
outputs = torch_nested_numpify(outputs) | |||
outputs = numpify_tensor_nested(outputs) | |||
outputs_origin = model.forward(*dummy_inputs) | |||
outputs_origin = torch_nested_numpify(outputs_origin) | |||
outputs_origin = numpify_tensor_nested(outputs_origin) | |||
tols = {} | |||
if rtol is not None: | |||
tols['rtol'] = rtol | |||
@@ -240,7 +270,6 @@ def replace_call(): | |||
problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it | |||
back after the tracing was done. | |||
""" | |||
TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl | |||
yield | |||
TorchModel.__call__ = TorchModel.call_origin | |||
@@ -1,32 +1,47 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# yapf: disable | |||
import datetime | |||
import os | |||
import pickle | |||
import platform | |||
import shutil | |||
import tempfile | |||
import uuid | |||
from collections import defaultdict | |||
from http import HTTPStatus | |||
from http.cookiejar import CookieJar | |||
from os.path import expanduser | |||
from typing import List, Optional, Tuple, Union | |||
from typing import Dict, List, Optional, Tuple, Union | |||
import requests | |||
from modelscope import __version__ | |||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
API_RESPONSE_FIELD_EMAIL, | |||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, | |||
API_RESPONSE_FIELD_MESSAGE, | |||
API_RESPONSE_FIELD_USERNAME, | |||
DEFAULT_CREDENTIALS_PATH) | |||
DEFAULT_CREDENTIALS_PATH, | |||
MODELSCOPE_ENVIRONMENT, ONE_YEAR_SECONDS, | |||
Licenses, ModelVisibility) | |||
from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
NotLoginException, NoValidRevisionError, | |||
RequestError, datahub_raise_on_error, | |||
handle_http_post_error, | |||
handle_http_response, is_ok, | |||
raise_for_http_status, raise_on_error) | |||
from modelscope.hub.git import GitCommandWrapper | |||
from modelscope.hub.repository import Repository | |||
from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH | |||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
DEFAULT_MODEL_REVISION, | |||
DatasetFormations, DatasetMetaFormats, | |||
DownloadMode) | |||
DEFAULT_REPOSITORY_REVISION, | |||
MASTER_MODEL_BRANCH, DatasetFormations, | |||
DatasetMetaFormats, DownloadMode, | |||
ModelFile) | |||
from modelscope.utils.logger import get_logger | |||
from .errors import (InvalidParameter, NotExistError, RequestError, | |||
datahub_raise_on_error, handle_http_response, is_ok, | |||
raise_on_error) | |||
from .utils.utils import (get_dataset_hub_endpoint, get_endpoint, | |||
from .utils.utils import (get_endpoint, get_release_datetime, | |||
model_id_to_group_owner_name) | |||
logger = get_logger() | |||
@@ -34,10 +49,9 @@ logger = get_logger() | |||
class HubApi: | |||
def __init__(self, endpoint=None, dataset_endpoint=None): | |||
def __init__(self, endpoint=None): | |||
self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||
self.dataset_endpoint = dataset_endpoint if dataset_endpoint is not None else get_dataset_hub_endpoint( | |||
) | |||
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||
def login( | |||
self, | |||
@@ -57,8 +71,9 @@ class HubApi: | |||
</Tip> | |||
""" | |||
path = f'{self.endpoint}/api/v1/login' | |||
r = requests.post(path, json={'AccessToken': access_token}) | |||
r.raise_for_status() | |||
r = requests.post( | |||
path, json={'AccessToken': access_token}, headers=self.headers) | |||
raise_for_http_status(r) | |||
d = r.json() | |||
raise_on_error(d) | |||
@@ -105,17 +120,16 @@ class HubApi: | |||
path = f'{self.endpoint}/api/v1/models' | |||
owner_or_group, name = model_id_to_group_owner_name(model_id) | |||
body = { | |||
'Path': owner_or_group, | |||
'Name': name, | |||
'ChineseName': chinese_name, | |||
'Visibility': visibility, # server check | |||
'License': license | |||
} | |||
r = requests.post( | |||
path, | |||
json={ | |||
'Path': owner_or_group, | |||
'Name': name, | |||
'ChineseName': chinese_name, | |||
'Visibility': visibility, # server check | |||
'License': license | |||
}, | |||
cookies=cookies) | |||
r.raise_for_status() | |||
path, json=body, cookies=cookies, headers=self.headers) | |||
handle_http_post_error(r, path, body) | |||
raise_on_error(r.json()) | |||
model_repo_url = f'{get_endpoint()}/{model_id}' | |||
return model_repo_url | |||
@@ -134,8 +148,8 @@ class HubApi: | |||
raise ValueError('Token does not exist, please login first.') | |||
path = f'{self.endpoint}/api/v1/models/{model_id}' | |||
r = requests.delete(path, cookies=cookies) | |||
r.raise_for_status() | |||
r = requests.delete(path, cookies=cookies, headers=self.headers) | |||
raise_for_http_status(r) | |||
raise_on_error(r.json()) | |||
def get_model_url(self, model_id): | |||
@@ -164,7 +178,7 @@ class HubApi: | |||
owner_or_group, name = model_id_to_group_owner_name(model_id) | |||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' | |||
r = requests.get(path, cookies=cookies) | |||
r = requests.get(path, cookies=cookies, headers=self.headers) | |||
handle_http_response(r, logger, cookies, model_id) | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
@@ -172,13 +186,108 @@ class HubApi: | |||
else: | |||
raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
r.raise_for_status() | |||
raise_for_http_status(r) | |||
def push_model(self, | |||
model_id: str, | |||
model_dir: str, | |||
visibility: int = ModelVisibility.PUBLIC, | |||
license: str = Licenses.APACHE_V2, | |||
chinese_name: Optional[str] = None, | |||
commit_message: Optional[str] = 'upload model', | |||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION): | |||
""" | |||
Upload model from a given directory to given repository. A valid model directory | |||
must contain a configuration.json file. | |||
def list_model(self, | |||
owner_or_group: str, | |||
page_number=1, | |||
page_size=10) -> dict: | |||
"""List model in owner or group. | |||
This function upload the files in given directory to given repository. If the | |||
given repository is not exists in remote, it will automatically create it with | |||
given visibility, license and chinese_name parameters. If the revision is also | |||
not exists in remote repository, it will create a new branch for it. | |||
This function must be called before calling HubApi's login with a valid token | |||
which can be obtained from ModelScope's website. | |||
Args: | |||
model_id (`str`): | |||
The model id to be uploaded, caller must have write permission for it. | |||
model_dir(`str`): | |||
The Absolute Path of the finetune result. | |||
visibility(`int`, defaults to `0`): | |||
Visibility of the new created model(1-private, 5-public). If the model is | |||
not exists in ModelScope, this function will create a new model with this | |||
visibility and this parameter is required. You can ignore this parameter | |||
if you make sure the model's existence. | |||
license(`str`, defaults to `None`): | |||
License of the new created model(see License). If the model is not exists | |||
in ModelScope, this function will create a new model with this license | |||
and this parameter is required. You can ignore this parameter if you | |||
make sure the model's existence. | |||
chinese_name(`str`, *optional*, defaults to `None`): | |||
chinese name of the new created model. | |||
commit_message(`str`, *optional*, defaults to `None`): | |||
commit message of the push request. | |||
revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): | |||
which branch to push. If the branch is not exists, It will create a new | |||
branch and push to it. | |||
""" | |||
if model_id is None: | |||
raise InvalidParameter('model_id cannot be empty!') | |||
if model_dir is None: | |||
raise InvalidParameter('model_dir cannot be empty!') | |||
if not os.path.exists(model_dir) or os.path.isfile(model_dir): | |||
raise InvalidParameter('model_dir must be a valid directory.') | |||
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
if not os.path.exists(cfg_file): | |||
raise ValueError(f'{model_dir} must contain a configuration.json.') | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies is None: | |||
raise NotLoginException('Must login before upload!') | |||
files_to_save = os.listdir(model_dir) | |||
try: | |||
self.get_model(model_id=model_id) | |||
except Exception: | |||
if visibility is None or license is None: | |||
raise InvalidParameter( | |||
'visibility and license cannot be empty if want to create new repo' | |||
) | |||
logger.info('Create new model %s' % model_id) | |||
self.create_model( | |||
model_id=model_id, | |||
visibility=visibility, | |||
license=license, | |||
chinese_name=chinese_name) | |||
tmp_dir = tempfile.mkdtemp() | |||
git_wrapper = GitCommandWrapper() | |||
try: | |||
repo = Repository(model_dir=tmp_dir, clone_from=model_id) | |||
branches = git_wrapper.get_remote_branches(tmp_dir) | |||
if revision not in branches: | |||
logger.info('Create new branch %s' % revision) | |||
git_wrapper.new_branch(tmp_dir, revision) | |||
git_wrapper.checkout(tmp_dir, revision) | |||
for f in files_to_save: | |||
if f[0] != '.': | |||
src = os.path.join(model_dir, f) | |||
if os.path.isdir(src): | |||
shutil.copytree(src, os.path.join(tmp_dir, f)) | |||
else: | |||
shutil.copy(src, tmp_dir) | |||
if not commit_message: | |||
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |||
commit_message = '[automsg] push model %s to hub at %s' % ( | |||
model_id, date) | |||
repo.push(commit_message=commit_message, local_branch=revision, remote_branch=revision) | |||
except Exception: | |||
raise | |||
finally: | |||
shutil.rmtree(tmp_dir, ignore_errors=True) | |||
def list_models(self, | |||
owner_or_group: str, | |||
page_number=1, | |||
page_size=10) -> dict: | |||
"""List models in owner or group. | |||
Args: | |||
owner_or_group(`str`): owner or group. | |||
@@ -193,7 +302,8 @@ class HubApi: | |||
path, | |||
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % | |||
(owner_or_group, page_number, page_size), | |||
cookies=cookies) | |||
cookies=cookies, | |||
headers=self.headers) | |||
handle_http_response(r, logger, cookies, 'list_model') | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
@@ -202,7 +312,7 @@ class HubApi: | |||
else: | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
r.raise_for_status() | |||
raise_for_http_status(r) | |||
return None | |||
def _check_cookie(self, | |||
@@ -217,10 +327,70 @@ class HubApi: | |||
raise ValueError('Token does not exist, please login first.') | |||
return cookies | |||
def list_model_revisions( | |||
self, | |||
model_id: str, | |||
cutoff_timestamp: int = None, | |||
use_cookies: Union[bool, CookieJar] = False) -> List[str]: | |||
"""Get model branch and tags. | |||
Args: | |||
model_id (str): The model id | |||
cutoff_timestamp (int): Tags created before the cutoff will be included. | |||
The timestamp is represented by the seconds elasped from the epoch time. | |||
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will | |||
will load cookie from local. Defaults to False. | |||
Returns: | |||
Tuple[List[str], List[str]]: Return list of branch name and tags | |||
""" | |||
cookies = self._check_cookie(use_cookies) | |||
if cutoff_timestamp is None: | |||
cutoff_timestamp = get_release_datetime() | |||
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp | |||
r = requests.get(path, cookies=cookies, headers=self.headers) | |||
handle_http_response(r, logger, cookies, model_id) | |||
d = r.json() | |||
raise_on_error(d) | |||
info = d[API_RESPONSE_FIELD_DATA] | |||
# tags returned from backend are guaranteed to be ordered by create-time | |||
tags = [x['Revision'] for x in info['RevisionMap']['Tags'] | |||
] if info['RevisionMap']['Tags'] else [] | |||
return tags | |||
def get_valid_revision(self, model_id: str, revision=None, cookies: Optional[CookieJar] = None): | |||
release_timestamp = get_release_datetime() | |||
current_timestamp = int(round(datetime.datetime.now().timestamp())) | |||
# for active development in library codes (non-release-branches), release_timestamp | |||
# is set to be a far-away-time-in-the-future, to ensure that we shall | |||
# get the master-HEAD version from model repo by default (when no revision is provided) | |||
if release_timestamp > current_timestamp + ONE_YEAR_SECONDS: | |||
branches, tags = self.get_model_branches_and_tags( | |||
model_id, use_cookies=False if cookies is None else cookies) | |||
if revision is None: | |||
revision = MASTER_MODEL_BRANCH | |||
logger.info('Model revision not specified, use default: %s in development mode' % revision) | |||
if revision not in branches and revision not in tags: | |||
raise NotExistError('The model: %s has no branch or tag : %s .' % revision) | |||
else: | |||
revisions = self.list_model_revisions( | |||
model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) | |||
if revision is None: | |||
if len(revisions) == 0: | |||
raise NoValidRevisionError('The model: %s has no valid revision!' % model_id) | |||
# tags (revisions) returned from backend are guaranteed to be ordered by create-time | |||
# we shall obtain the latest revision created earlier than release version of this branch | |||
revision = revisions[0] | |||
logger.info('Model revision not specified, use the latest revision: %s' % revision) | |||
else: | |||
if revision not in revisions: | |||
raise NotExistError( | |||
'The model: %s has no revision: %s !' % (model_id, revision)) | |||
return revision | |||
def get_model_branches_and_tags( | |||
self, | |||
model_id: str, | |||
use_cookies: Union[bool, CookieJar] = False | |||
use_cookies: Union[bool, CookieJar] = False, | |||
) -> Tuple[List[str], List[str]]: | |||
"""Get model branch and tags. | |||
@@ -234,7 +404,7 @@ class HubApi: | |||
cookies = self._check_cookie(use_cookies) | |||
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' | |||
r = requests.get(path, cookies=cookies) | |||
r = requests.get(path, cookies=cookies, headers=self.headers) | |||
handle_http_response(r, logger, cookies, model_id) | |||
d = r.json() | |||
raise_on_error(d) | |||
@@ -275,7 +445,11 @@ class HubApi: | |||
if root is not None: | |||
path = path + f'&Root={root}' | |||
r = requests.get(path, cookies=cookies, headers=headers) | |||
r = requests.get( | |||
path, cookies=cookies, headers={ | |||
**headers, | |||
**self.headers | |||
}) | |||
handle_http_response(r, logger, cookies, model_id) | |||
d = r.json() | |||
@@ -290,11 +464,10 @@ class HubApi: | |||
return files | |||
def list_datasets(self): | |||
path = f'{self.dataset_endpoint}/api/v1/datasets' | |||
headers = None | |||
path = f'{self.endpoint}/api/v1/datasets' | |||
params = {} | |||
r = requests.get(path, params=params, headers=headers) | |||
r.raise_for_status() | |||
r = requests.get(path, params=params, headers=self.headers) | |||
raise_for_http_status(r) | |||
dataset_list = r.json()[API_RESPONSE_FIELD_DATA] | |||
return [x['Name'] for x in dataset_list] | |||
@@ -317,14 +490,14 @@ class HubApi: | |||
cache_dir): | |||
shutil.rmtree(cache_dir) | |||
os.makedirs(cache_dir, exist_ok=True) | |||
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}' | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' | |||
r = requests.get(datahub_url) | |||
resp = r.json() | |||
datahub_raise_on_error(datahub_url, resp) | |||
dataset_id = resp['Data']['Id'] | |||
dataset_type = resp['Data']['Type'] | |||
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||
r = requests.get(datahub_url) | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' | |||
r = requests.get(datahub_url, headers=self.headers) | |||
resp = r.json() | |||
datahub_raise_on_error(datahub_url, resp) | |||
file_list = resp['Data'] | |||
@@ -341,10 +514,10 @@ class HubApi: | |||
file_path = file_info['Path'] | |||
extension = os.path.splitext(file_path)[-1] | |||
if extension in dataset_meta_format: | |||
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
f'Revision={revision}&FilePath={file_path}' | |||
r = requests.get(datahub_url) | |||
r.raise_for_status() | |||
raise_for_http_status(r) | |||
local_path = os.path.join(cache_dir, file_path) | |||
if os.path.exists(local_path): | |||
logger.warning( | |||
@@ -365,7 +538,7 @@ class HubApi: | |||
namespace: str, | |||
revision: Optional[str] = DEFAULT_DATASET_REVISION): | |||
if file_name.endswith('.csv'): | |||
file_name = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ | |||
f'Revision={revision}&FilePath={file_name}' | |||
return file_name | |||
@@ -374,7 +547,7 @@ class HubApi: | |||
dataset_name: str, | |||
namespace: str, | |||
revision: Optional[str] = DEFAULT_DATASET_REVISION): | |||
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
f'ststoken?Revision={revision}' | |||
return self.datahub_remote_call(datahub_url) | |||
@@ -385,23 +558,39 @@ class HubApi: | |||
namespace: str, | |||
revision: Optional[str] = DEFAULT_DATASET_REVISION): | |||
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ | |||
f'ststoken?Revision={revision}' | |||
cookies = requests.utils.dict_from_cookiejar(cookies) | |||
r = requests.get(url=datahub_url, cookies=cookies) | |||
r = requests.get( | |||
url=datahub_url, cookies=cookies, headers=self.headers) | |||
resp = r.json() | |||
raise_on_error(resp) | |||
return resp['Data'] | |||
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, | |||
is_recursive, is_filter_dir, revision): | |||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ | |||
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' | |||
cookies = ModelScopeConfig.get_cookies() | |||
if cookies: | |||
cookies = requests.utils.dict_from_cookiejar(cookies) | |||
resp = requests.get(url=url, cookies=cookies) | |||
resp = resp.json() | |||
raise_on_error(resp) | |||
resp = resp['Data'] | |||
return resp | |||
def on_dataset_download(self, dataset_name: str, namespace: str) -> None: | |||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' | |||
r = requests.post(url) | |||
r.raise_for_status() | |||
r = requests.post(url, headers=self.headers) | |||
raise_for_http_status(r) | |||
@staticmethod | |||
def datahub_remote_call(url): | |||
r = requests.get(url) | |||
r = requests.get(url, headers={'user-agent': ModelScopeConfig.get_user_agent()}) | |||
resp = r.json() | |||
datahub_raise_on_error(url, resp) | |||
return resp['Data'] | |||
@@ -415,6 +604,7 @@ class ModelScopeConfig: | |||
COOKIES_FILE_NAME = 'cookies' | |||
GIT_TOKEN_FILE_NAME = 'git_token' | |||
USER_INFO_FILE_NAME = 'user' | |||
USER_SESSION_ID_FILE_NAME = 'session' | |||
@staticmethod | |||
def make_sure_credential_path_exist(): | |||
@@ -443,6 +633,23 @@ class ModelScopeConfig: | |||
return cookies | |||
return None | |||
@staticmethod | |||
def get_user_session_id(): | |||
session_path = os.path.join(ModelScopeConfig.path_credential, | |||
ModelScopeConfig.USER_SESSION_ID_FILE_NAME) | |||
session_id = '' | |||
if os.path.exists(session_path): | |||
with open(session_path, 'rb') as f: | |||
session_id = str(f.readline().strip(), encoding='utf-8') | |||
return session_id | |||
if session_id == '' or len(session_id) != 32: | |||
session_id = str(uuid.uuid4().hex) | |||
ModelScopeConfig.make_sure_credential_path_exist() | |||
with open(session_path, 'w+') as wf: | |||
wf.write(session_id) | |||
return session_id | |||
@staticmethod | |||
def save_token(token: str): | |||
ModelScopeConfig.make_sure_credential_path_exist() | |||
@@ -491,3 +698,32 @@ class ModelScopeConfig: | |||
except FileNotFoundError: | |||
pass | |||
return token | |||
@staticmethod | |||
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||
"""Formats a user-agent string with basic info about a request. | |||
Args: | |||
user_agent (`str`, `dict`, *optional*): | |||
The user agent info in the form of a dictionary or a single string. | |||
Returns: | |||
The formatted user-agent string. | |||
""" | |||
env = 'custom' | |||
if MODELSCOPE_ENVIRONMENT in os.environ: | |||
env = os.environ[MODELSCOPE_ENVIRONMENT] | |||
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s' % ( | |||
__version__, | |||
platform.python_version(), | |||
ModelScopeConfig.get_user_session_id(), | |||
platform.platform(), | |||
platform.processor(), | |||
env, | |||
) | |||
if isinstance(user_agent, dict): | |||
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||
elif isinstance(user_agent, str): | |||
ua += ';' + user_agent | |||
return ua |
@@ -16,6 +16,9 @@ API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' | |||
API_RESPONSE_FIELD_USERNAME = 'Username' | |||
API_RESPONSE_FIELD_EMAIL = 'Email' | |||
API_RESPONSE_FIELD_MESSAGE = 'Message' | |||
MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' | |||
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' | |||
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 | |||
class Licenses(object): | |||
@@ -0,0 +1,339 @@ | |||
import urllib | |||
from abc import ABC | |||
from http import HTTPStatus | |||
from typing import Optional | |||
import json | |||
import requests | |||
from attrs import asdict, define, field, validators | |||
from modelscope.hub.api import ModelScopeConfig | |||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
API_RESPONSE_FIELD_MESSAGE) | |||
from modelscope.hub.errors import (NotLoginException, NotSupportError, | |||
RequestError, handle_http_response, is_ok, | |||
raise_for_http_status) | |||
from modelscope.hub.utils.utils import get_endpoint | |||
from modelscope.utils.logger import get_logger | |||
# yapf: enable | |||
logger = get_logger() | |||
class Accelerator(object): | |||
CPU = 'cpu' | |||
GPU = 'gpu' | |||
class Vendor(object): | |||
EAS = 'eas' | |||
class EASRegion(object): | |||
beijing = 'cn-beijing' | |||
hangzhou = 'cn-hangzhou' | |||
class EASCpuInstanceType(object): | |||
"""EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||
""" | |||
tiny = 'ecs.c6.2xlarge' | |||
small = 'ecs.c6.4xlarge' | |||
medium = 'ecs.c6.6xlarge' | |||
large = 'ecs.c6.8xlarge' | |||
class EASGpuInstanceType(object): | |||
"""EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||
""" | |||
tiny = 'ecs.gn5-c28g1.7xlarge' | |||
small = 'ecs.gn5-c8g1.4xlarge' | |||
medium = 'ecs.gn6i-c24g1.12xlarge' | |||
large = 'ecs.gn6e-c12g1.3xlarge' | |||
def min_smaller_than_max(instance, attribute, value): | |||
if value > instance.max_replica: | |||
raise ValueError( | |||
"'min_replica' value: %s has to be smaller than 'max_replica' value: %s!" | |||
% (value, instance.max_replica)) | |||
@define | |||
class ServiceScalingConfig(object): | |||
"""Resource scaling config | |||
Currently we ignore max_replica | |||
Args: | |||
max_replica: maximum replica | |||
min_replica: minimum replica | |||
""" | |||
max_replica: int = field(default=1, validator=validators.ge(1)) | |||
min_replica: int = field( | |||
default=1, validator=[validators.ge(1), min_smaller_than_max]) | |||
@define | |||
class ServiceResourceConfig(object): | |||
"""Eas Resource request. | |||
Args: | |||
accelerator: the accelerator(cpu|gpu) | |||
instance_type: the instance type. | |||
scaling: The instance scaling config. | |||
""" | |||
instance_type: str | |||
scaling: ServiceScalingConfig | |||
accelerator: str = field( | |||
default=Accelerator.CPU, | |||
validator=validators.in_([Accelerator.CPU, Accelerator.GPU])) | |||
@define | |||
class ServiceProviderParameters(ABC): | |||
pass | |||
@define | |||
class EASDeployParameters(ServiceProviderParameters): | |||
"""Parameters for EAS Deployment. | |||
Args: | |||
resource_group: the resource group to deploy, current default. | |||
region: The eas instance region(eg: cn-hangzhou). | |||
access_key_id: The eas account access key id. | |||
access_key_secret: The eas account access key secret. | |||
vendor: must be 'eas' | |||
""" | |||
region: str | |||
access_key_id: str | |||
access_key_secret: str | |||
resource_group: Optional[str] = None | |||
vendor: str = field( | |||
default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||
@define | |||
class EASListParameters(ServiceProviderParameters): | |||
"""EAS instance list parameters. | |||
Args: | |||
resource_group: the resource group to deploy, current default. | |||
region: The eas instance region(eg: cn-hangzhou). | |||
access_key_id: The eas account access key id. | |||
access_key_secret: The eas account access key secret. | |||
vendor: must be 'eas' | |||
""" | |||
access_key_id: str | |||
access_key_secret: str | |||
region: str = None | |||
resource_group: str = None | |||
vendor: str = field( | |||
default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||
@define | |||
class DeployServiceParameters(object): | |||
"""Deploy service parameters | |||
Args: | |||
instance_name: the name of the service. | |||
model_id: the modelscope model_id | |||
revision: the modelscope model revision | |||
resource: the resource requirement. | |||
provider: the cloud service provider. | |||
""" | |||
instance_name: str | |||
model_id: str | |||
revision: str | |||
resource: ServiceResourceConfig | |||
provider: ServiceProviderParameters | |||
class AttrsToQueryString(ABC): | |||
"""Convert the attrs class to json string. | |||
Args: | |||
""" | |||
def to_query_str(self): | |||
self_dict = asdict( | |||
self.provider, filter=lambda attr, value: value is not None) | |||
json_str = json.dumps(self_dict) | |||
print(json_str) | |||
safe_str = urllib.parse.quote_plus(json_str) | |||
print(safe_str) | |||
query_param = 'provider=%s' % safe_str | |||
return query_param | |||
@define | |||
class ListServiceParameters(AttrsToQueryString): | |||
provider: ServiceProviderParameters | |||
skip: int = 0 | |||
limit: int = 100 | |||
@define | |||
class GetServiceParameters(AttrsToQueryString): | |||
provider: ServiceProviderParameters | |||
@define | |||
class DeleteServiceParameters(AttrsToQueryString): | |||
provider: ServiceProviderParameters | |||
class ServiceDeployer(object): | |||
def __init__(self, endpoint=None): | |||
self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} | |||
self.cookies = ModelScopeConfig.get_cookies() | |||
if self.cookies is None: | |||
raise NotLoginException( | |||
'Token does not exist, please login with HubApi first.') | |||
# deploy_model | |||
def create(self, model_id: str, revision: str, instance_name: str, | |||
resource: ServiceResourceConfig, | |||
provider: ServiceProviderParameters): | |||
"""Deploy model to cloud, current we only support PAI EAS, this is an async API , | |||
and the deployment could take a while to finish remotely. Please check deploy instance | |||
status separately via checking the status. | |||
Args: | |||
model_id (str): The deployed model id | |||
revision (str): The model revision | |||
instance_name (str): The deployed model instance name. | |||
resource (ServiceResourceConfig): The service resource information. | |||
provider (ServiceProviderParameters): The service provider parameter | |||
Raises: | |||
NotLoginException: To use this api, you need login first. | |||
NotSupportError: Not supported platform. | |||
RequestError: The server return error. | |||
Returns: | |||
ServiceInstanceInfo: The information of the deployed service instance. | |||
""" | |||
if provider.vendor != Vendor.EAS: | |||
raise NotSupportError( | |||
'Not support vendor: %s ,only support EAS current.' % | |||
(provider.vendor)) | |||
create_params = DeployServiceParameters( | |||
instance_name=instance_name, | |||
model_id=model_id, | |||
revision=revision, | |||
resource=resource, | |||
provider=provider) | |||
path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||
body = asdict(create_params) | |||
r = requests.post( | |||
path, json=body, cookies=self.cookies, headers=self.headers) | |||
handle_http_response(r, logger, self.cookies, 'create_service') | |||
if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||
if is_ok(r.json()): | |||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||
return data | |||
else: | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
raise_for_http_status(r) | |||
return None | |||
def get(self, instance_name: str, provider: ServiceProviderParameters): | |||
"""Query the specified instance information. | |||
Args: | |||
instance_name (str): The deployed instance name. | |||
provider (ServiceProviderParameters): The cloud provider information, for eas | |||
need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
Raises: | |||
NotLoginException: To use this api, you need login first. | |||
RequestError: The request is failed from server. | |||
Returns: | |||
Dict: The information of the requested service instance. | |||
""" | |||
params = GetServiceParameters(provider=provider) | |||
path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
self.endpoint, instance_name, params.to_query_str()) | |||
r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||
handle_http_response(r, logger, self.cookies, 'get_service') | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||
return data | |||
else: | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
raise_for_http_status(r) | |||
return None | |||
def delete(self, instance_name: str, provider: ServiceProviderParameters): | |||
"""Delete deployed model, this api send delete command and return, it will take | |||
some to delete, please check through the cloud console. | |||
Args: | |||
instance_name (str): The instance name you want to delete. | |||
provider (ServiceProviderParameters): The cloud provider information, for eas | |||
need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
Raises: | |||
NotLoginException: To call this api, you need login first. | |||
RequestError: The request is failed. | |||
Returns: | |||
Dict: The deleted instance information. | |||
""" | |||
params = DeleteServiceParameters(provider=provider) | |||
path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
self.endpoint, instance_name, params.to_query_str()) | |||
r = requests.delete(path, cookies=self.cookies, headers=self.headers) | |||
handle_http_response(r, logger, self.cookies, 'delete_service') | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||
return data | |||
else: | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
raise_for_http_status(r) | |||
return None | |||
def list(self, | |||
provider: ServiceProviderParameters, | |||
skip: int = 0, | |||
limit: int = 100): | |||
"""List deployed model instances. | |||
Args: | |||
provider (ServiceProviderParameters): The cloud service provider parameter, | |||
for eas, need access_key_id and access_key_secret. | |||
skip: start of the list, current not support. | |||
limit: maximum number of instances return, current not support | |||
Raises: | |||
NotLoginException: To use this api, you need login first. | |||
RequestError: The request is failed from server. | |||
Returns: | |||
List: List of instance information | |||
""" | |||
params = ListServiceParameters( | |||
provider=provider, skip=skip, limit=limit) | |||
path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||
params.to_query_str()) | |||
r = requests.get(path, cookies=self.cookies, headers=self.headers) | |||
handle_http_response(r, logger, self.cookies, 'list_service_instances') | |||
if r.status_code == HTTPStatus.OK: | |||
if is_ok(r.json()): | |||
data = r.json()[API_RESPONSE_FIELD_DATA] | |||
return data | |||
else: | |||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
else: | |||
raise_for_http_status(r) | |||
return None |
@@ -4,6 +4,18 @@ from http import HTTPStatus | |||
from requests.exceptions import HTTPError | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
class NotSupportError(Exception): | |||
pass | |||
class NoValidRevisionError(Exception): | |||
pass | |||
class NotExistError(Exception): | |||
pass | |||
@@ -45,15 +57,25 @@ def is_ok(rsp): | |||
return rsp['Code'] == HTTPStatus.OK and rsp['Success'] | |||
def handle_http_post_error(response, url, request_body): | |||
try: | |||
response.raise_for_status() | |||
except HTTPError as error: | |||
logger.error('Request %s with body: %s exception' % | |||
(url, request_body)) | |||
raise error | |||
def handle_http_response(response, logger, cookies, model_id): | |||
try: | |||
response.raise_for_status() | |||
except HTTPError: | |||
except HTTPError as error: | |||
if cookies is None: # code in [403] and | |||
logger.error( | |||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | |||
private. Please login first.') | |||
raise | |||
logger.error('Response details: %s' % response.content) | |||
raise error | |||
def raise_on_error(rsp): | |||
@@ -81,3 +103,33 @@ def datahub_raise_on_error(url, rsp): | |||
raise RequestError( | |||
f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" | |||
) | |||
def raise_for_http_status(rsp): | |||
""" | |||
Attempt to decode utf-8 first since some servers | |||
localize reason strings, for invalid utf-8, fall back | |||
to decoding with iso-8859-1. | |||
""" | |||
http_error_msg = '' | |||
if isinstance(rsp.reason, bytes): | |||
try: | |||
reason = rsp.reason.decode('utf-8') | |||
except UnicodeDecodeError: | |||
reason = rsp.reason.decode('iso-8859-1') | |||
else: | |||
reason = rsp.reason | |||
if 400 <= rsp.status_code < 500: | |||
http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code, | |||
reason, rsp.url) | |||
elif 500 <= rsp.status_code < 600: | |||
http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code, | |||
reason, rsp.url) | |||
if http_error_msg: | |||
req = rsp.request | |||
if req.method == 'POST': | |||
http_error_msg = u'%s, body: %s' % (http_error_msg, req.body) | |||
raise HTTPError(http_error_msg, response=rsp) |
@@ -2,29 +2,25 @@ | |||
import copy | |||
import os | |||
import sys | |||
import tempfile | |||
from functools import partial | |||
from http.cookiejar import CookieJar | |||
from pathlib import Path | |||
from typing import Dict, Optional, Union | |||
from uuid import uuid4 | |||
import requests | |||
from filelock import FileLock | |||
from tqdm import tqdm | |||
from modelscope import __version__ | |||
from modelscope.hub.api import HubApi, ModelScopeConfig | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
from modelscope.utils.logger import get_logger | |||
from .api import HubApi, ModelScopeConfig | |||
from .constants import FILE_HASH | |||
from .errors import FileDownloadError, NotExistError | |||
from .utils.caching import ModelFileSystemCache | |||
from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
get_endpoint, model_id_to_group_owner_name) | |||
SESSION_ID = uuid4().hex | |||
logger = get_logger() | |||
@@ -35,6 +31,7 @@ def model_file_download( | |||
cache_dir: Optional[str] = None, | |||
user_agent: Union[Dict, str, None] = None, | |||
local_files_only: Optional[bool] = False, | |||
cookies: Optional[CookieJar] = None, | |||
) -> Optional[str]: # pragma: no cover | |||
""" | |||
Download from a given URL and cache it if it's not already present in the | |||
@@ -105,54 +102,47 @@ def model_file_download( | |||
" online, set 'local_files_only' to False.") | |||
_api = HubApi() | |||
headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
cookies = ModelScopeConfig.get_cookies() | |||
branches, tags = _api.get_model_branches_and_tags( | |||
model_id, use_cookies=False if cookies is None else cookies) | |||
headers = { | |||
'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||
} | |||
if cookies is None: | |||
cookies = ModelScopeConfig.get_cookies() | |||
revision = _api.get_valid_revision( | |||
model_id, revision=revision, cookies=cookies) | |||
file_to_download_info = None | |||
is_commit_id = False | |||
if revision in branches or revision in tags: # The revision is version or tag, | |||
# we need to confirm the version is up to date | |||
# we need to get the file list to check if the lateast version is cached, if so return, otherwise download | |||
model_files = _api.get_model_files( | |||
model_id=model_id, | |||
revision=revision, | |||
recursive=True, | |||
use_cookies=False if cookies is None else cookies) | |||
for model_file in model_files: | |||
if model_file['Type'] == 'tree': | |||
continue | |||
if model_file['Path'] == file_path: | |||
if cache.exists(model_file): | |||
return cache.get_file_by_info(model_file) | |||
else: | |||
file_to_download_info = model_file | |||
break | |||
if file_to_download_info is None: | |||
raise NotExistError('The file path: %s not exist in: %s' % | |||
(file_path, model_id)) | |||
else: # the revision is commit id. | |||
cached_file_path = cache.get_file_by_path_and_commit_id( | |||
file_path, revision) | |||
if cached_file_path is not None: | |||
file_name = os.path.basename(cached_file_path) | |||
logger.info( | |||
f'File {file_name} already in cache, skip downloading!') | |||
return cached_file_path # the file is in cache. | |||
is_commit_id = True | |||
# we need to confirm the version is up-to-date | |||
# we need to get the file list to check if the latest version is cached, if so return, otherwise download | |||
model_files = _api.get_model_files( | |||
model_id=model_id, | |||
revision=revision, | |||
recursive=True, | |||
use_cookies=False if cookies is None else cookies) | |||
for model_file in model_files: | |||
if model_file['Type'] == 'tree': | |||
continue | |||
if model_file['Path'] == file_path: | |||
if cache.exists(model_file): | |||
logger.info( | |||
f'File {model_file["Name"]} already in cache, skip downloading!' | |||
) | |||
return cache.get_file_by_info(model_file) | |||
else: | |||
file_to_download_info = model_file | |||
break | |||
if file_to_download_info is None: | |||
raise NotExistError('The file path: %s not exist in: %s' % | |||
(file_path, model_id)) | |||
# we need to download again | |||
url_to_download = get_file_download_url(model_id, file_path, revision) | |||
file_to_download_info = { | |||
'Path': | |||
file_path, | |||
'Revision': | |||
revision if is_commit_id else file_to_download_info['Revision'], | |||
FILE_HASH: | |||
None if (is_commit_id or FILE_HASH not in file_to_download_info) else | |||
file_to_download_info[FILE_HASH] | |||
'Path': file_path, | |||
'Revision': file_to_download_info['Revision'], | |||
FILE_HASH: file_to_download_info[FILE_HASH] | |||
} | |||
temp_file_name = next(tempfile._get_candidate_names()) | |||
@@ -171,25 +161,6 @@ def model_file_download( | |||
os.path.join(temporary_cache_dir, temp_file_name)) | |||
def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: | |||
"""Formats a user-agent string with basic info about a request. | |||
Args: | |||
user_agent (`str`, `dict`, *optional*): | |||
The user agent info in the form of a dictionary or a single string. | |||
Returns: | |||
The formatted user-agent string. | |||
""" | |||
ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}' | |||
if isinstance(user_agent, dict): | |||
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) | |||
elif isinstance(user_agent, str): | |||
ua = user_agent | |||
return ua | |||
def get_file_download_url(model_id: str, file_path: str, revision: str): | |||
""" | |||
Format file download url according to `model_id`, `revision` and `file_path`. | |||
@@ -3,10 +3,9 @@ | |||
import os | |||
import subprocess | |||
from typing import List | |||
from xmlrpc.client import Boolean | |||
from modelscope.utils.logger import get_logger | |||
from .api import ModelScopeConfig | |||
from ..utils.constant import MASTER_MODEL_BRANCH | |||
from .errors import GitError | |||
logger = get_logger() | |||
@@ -131,6 +130,7 @@ class GitCommandWrapper(metaclass=Singleton): | |||
return response | |||
def add_user_info(self, repo_base_dir, repo_name): | |||
from modelscope.hub.api import ModelScopeConfig | |||
user_name, user_email = ModelScopeConfig.get_user_info() | |||
if user_name and user_email: | |||
# config user.name and user.email if exist | |||
@@ -138,8 +138,8 @@ class GitCommandWrapper(metaclass=Singleton): | |||
repo_base_dir, repo_name, user_name) | |||
response = self._run_git_command(*config_user_name_args.split(' ')) | |||
logger.debug(response.stdout.decode('utf8')) | |||
config_user_email_args = '-C %s/%s config user.name %s' % ( | |||
repo_base_dir, repo_name, user_name) | |||
config_user_email_args = '-C %s/%s config user.email %s' % ( | |||
repo_base_dir, repo_name, user_email) | |||
response = self._run_git_command( | |||
*config_user_email_args.split(' ')) | |||
logger.debug(response.stdout.decode('utf8')) | |||
@@ -177,6 +177,18 @@ class GitCommandWrapper(metaclass=Singleton): | |||
cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] | |||
return self._run_git_command(*cmds) | |||
def get_remote_branches(self, repo_dir: str): | |||
cmds = ['-C', '%s' % repo_dir, 'branch', '-r'] | |||
rsp = self._run_git_command(*cmds) | |||
info = [ | |||
line.strip() | |||
for line in rsp.stdout.decode('utf8').strip().split(os.linesep) | |||
] | |||
if len(info) == 1: | |||
return ['/'.join(info[0].split('/')[1:])] | |||
else: | |||
return ['/'.join(line.split('/')[1:]) for line in info[1:]] | |||
def pull(self, repo_dir: str): | |||
cmds = ['-C', repo_dir, 'pull'] | |||
return self._run_git_command(*cmds) | |||
@@ -216,3 +228,22 @@ class GitCommandWrapper(metaclass=Singleton): | |||
files.append(line.split(' ')[-1]) | |||
return files | |||
def tag(self, | |||
repo_dir: str, | |||
tag_name: str, | |||
message: str, | |||
ref: str = MASTER_MODEL_BRANCH): | |||
cmd_args = [ | |||
'-C', repo_dir, 'tag', tag_name, '-m', | |||
'"%s"' % message, ref | |||
] | |||
rsp = self._run_git_command(*cmd_args) | |||
logger.debug(rsp.stdout.decode('utf8')) | |||
return rsp | |||
def push_tag(self, repo_dir: str, tag_name): | |||
cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name] | |||
rsp = self._run_git_command(*cmd_args) | |||
logger.debug(rsp.stdout.decode('utf8')) | |||
return rsp |
@@ -5,9 +5,9 @@ from typing import Optional | |||
from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException | |||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, | |||
DEFAULT_MODEL_REVISION) | |||
DEFAULT_REPOSITORY_REVISION, | |||
MASTER_MODEL_BRANCH) | |||
from modelscope.utils.logger import get_logger | |||
from .api import ModelScopeConfig | |||
from .git import GitCommandWrapper | |||
from .utils.utils import get_endpoint | |||
@@ -21,7 +21,7 @@ class Repository: | |||
def __init__(self, | |||
model_dir: str, | |||
clone_from: str, | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||
auth_token: Optional[str] = None, | |||
git_path: Optional[str] = None): | |||
""" | |||
@@ -47,6 +47,7 @@ class Repository: | |||
err_msg = 'a non-default value of revision cannot be empty.' | |||
raise InvalidParameter(err_msg) | |||
from modelscope.hub.api import ModelScopeConfig | |||
if auth_token: | |||
self.auth_token = auth_token | |||
else: | |||
@@ -89,7 +90,8 @@ class Repository: | |||
def push(self, | |||
commit_message: str, | |||
branch: Optional[str] = DEFAULT_MODEL_REVISION, | |||
local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||
remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, | |||
force: bool = False): | |||
"""Push local files to remote, this method will do. | |||
git pull | |||
@@ -116,14 +118,48 @@ class Repository: | |||
url = self.git_wrapper.get_repo_remote_url(self.model_dir) | |||
self.git_wrapper.pull(self.model_dir) | |||
self.git_wrapper.add(self.model_dir, all_files=True) | |||
self.git_wrapper.commit(self.model_dir, commit_message) | |||
self.git_wrapper.push( | |||
repo_dir=self.model_dir, | |||
token=self.auth_token, | |||
url=url, | |||
local_branch=branch, | |||
remote_branch=branch) | |||
local_branch=local_branch, | |||
remote_branch=remote_branch) | |||
def tag(self, tag_name: str, message: str, ref: str = MASTER_MODEL_BRANCH): | |||
"""Create a new tag. | |||
Args: | |||
tag_name (str): The name of the tag | |||
message (str): The tag message. | |||
ref (str): The tag reference, can be commit id or branch. | |||
""" | |||
if tag_name is None or tag_name == '': | |||
msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.' | |||
raise InvalidParameter(msg) | |||
if message is None or message == '': | |||
msg = 'We use annotated tag, therefore message cannot None or empty.' | |||
self.git_wrapper.tag( | |||
repo_dir=self.model_dir, | |||
tag_name=tag_name, | |||
message=message, | |||
ref=ref) | |||
def tag_and_push(self, | |||
tag_name: str, | |||
message: str, | |||
ref: str = MASTER_MODEL_BRANCH): | |||
"""Create tag and push to remote | |||
Args: | |||
tag_name (str): The name of the tag | |||
message (str): The tag message. | |||
ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH. | |||
""" | |||
self.tag(tag_name, message, ref) | |||
self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name) | |||
class DatasetRepository: | |||
@@ -166,7 +202,7 @@ class DatasetRepository: | |||
err_msg = 'a non-default value of revision cannot be empty.' | |||
raise InvalidParameter(err_msg) | |||
self.revision = revision | |||
from modelscope.hub.api import ModelScopeConfig | |||
if auth_token: | |||
self.auth_token = auth_token | |||
else: | |||
@@ -2,16 +2,15 @@ | |||
import os | |||
import tempfile | |||
from http.cookiejar import CookieJar | |||
from pathlib import Path | |||
from typing import Dict, Optional, Union | |||
from modelscope.hub.api import HubApi, ModelScopeConfig | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
from modelscope.utils.logger import get_logger | |||
from .api import HubApi, ModelScopeConfig | |||
from .constants import FILE_HASH | |||
from .errors import NotExistError | |||
from .file_download import (get_file_download_url, http_get_file, | |||
http_user_agent) | |||
from .file_download import get_file_download_url, http_get_file | |||
from .utils.caching import ModelFileSystemCache | |||
from .utils.utils import (file_integrity_validation, get_cache_dir, | |||
model_id_to_group_owner_name) | |||
@@ -23,7 +22,8 @@ def snapshot_download(model_id: str, | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
cache_dir: Union[str, Path, None] = None, | |||
user_agent: Optional[Union[Dict, str]] = None, | |||
local_files_only: Optional[bool] = False) -> str: | |||
local_files_only: Optional[bool] = False, | |||
cookies: Optional[CookieJar] = None) -> str: | |||
"""Download all files of a repo. | |||
Downloads a whole snapshot of a repo's files at the specified revision. This | |||
is useful when you want all files from a repo, because you don't know which | |||
@@ -81,15 +81,15 @@ def snapshot_download(model_id: str, | |||
) # we can not confirm the cached file is for snapshot 'revision' | |||
else: | |||
# make headers | |||
headers = {'user-agent': http_user_agent(user_agent=user_agent, )} | |||
headers = { | |||
'user-agent': | |||
ModelScopeConfig.get_user_agent(user_agent=user_agent, ) | |||
} | |||
_api = HubApi() | |||
cookies = ModelScopeConfig.get_cookies() | |||
# get file list from model repo | |||
branches, tags = _api.get_model_branches_and_tags( | |||
model_id, use_cookies=False if cookies is None else cookies) | |||
if revision not in branches and revision not in tags: | |||
raise NotExistError('The specified branch or tag : %s not exist!' | |||
% revision) | |||
if cookies is None: | |||
cookies = ModelScopeConfig.get_cookies() | |||
revision = _api.get_valid_revision( | |||
model_id, revision=revision, cookies=cookies) | |||
snapshot_header = headers if 'CI_TEST' in os.environ else { | |||
**headers, | |||
@@ -110,7 +110,7 @@ def snapshot_download(model_id: str, | |||
for model_file in model_files: | |||
if model_file['Type'] == 'tree': | |||
continue | |||
# check model_file is exist in cache, if exist, skip download, otherwise download | |||
# check model_file is exist in cache, if existed, skip download, otherwise download | |||
if cache.exists(model_file): | |||
file_name = os.path.basename(model_file['Name']) | |||
logger.info( | |||
@@ -2,12 +2,12 @@ | |||
import hashlib | |||
import os | |||
from datetime import datetime | |||
from typing import Optional | |||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT, | |||
DEFAULT_MODELSCOPE_DOMAIN, | |||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, | |||
DEFAULT_MODELSCOPE_GROUP, | |||
MODEL_ID_SEPARATOR, | |||
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, | |||
MODELSCOPE_URL_SCHEME) | |||
from modelscope.hub.errors import FileIntegrityError | |||
from modelscope.utils.file_utils import get_default_cache_dir | |||
@@ -38,17 +38,24 @@ def get_cache_dir(model_id: Optional[str] = None): | |||
base_path, model_id + '/') | |||
def get_release_datetime(): | |||
if MODELSCOPE_SDK_DEBUG in os.environ: | |||
rt = int(round(datetime.now().timestamp())) | |||
else: | |||
from modelscope import version | |||
rt = int( | |||
round( | |||
datetime.strptime(version.__release_datetime__, | |||
'%Y-%m-%d %H:%M:%S').timestamp())) | |||
return rt | |||
def get_endpoint(): | |||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', | |||
DEFAULT_MODELSCOPE_DOMAIN) | |||
return MODELSCOPE_URL_SCHEME + modelscope_domain | |||
def get_dataset_hub_endpoint(): | |||
return os.environ.get('HUB_DATASET_ENDPOINT', | |||
DEFAULT_MODELSCOPE_DATA_ENDPOINT) | |||
def compute_hash(file_path): | |||
BUFFER_SIZE = 1024 * 64 # 64k buffer size | |||
sha256_hash = hashlib.sha256() | |||
@@ -9,11 +9,14 @@ class Models(object): | |||
Model name should only contain model info but not task info. | |||
""" | |||
# tinynas models | |||
tinynas_detection = 'tinynas-detection' | |||
tinynas_damoyolo = 'tinynas-damoyolo' | |||
# vision models | |||
detection = 'detection' | |||
realtime_object_detection = 'realtime-object-detection' | |||
realtime_video_object_detection = 'realtime-video-object-detection' | |||
scrfd = 'scrfd' | |||
classification_model = 'ClassificationModel' | |||
nafnet = 'nafnet' | |||
@@ -27,11 +30,13 @@ class Models(object): | |||
face_2d_keypoints = 'face-2d-keypoints' | |||
panoptic_segmentation = 'swinL-panoptic-segmentation' | |||
image_reid_person = 'passvitb' | |||
image_inpainting = 'FFTInpainting' | |||
video_summarization = 'pgl-video-summarization' | |||
swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
resnet50_bert = 'resnet50-bert' | |||
referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | |||
fer = 'fer' | |||
retinaface = 'retinaface' | |||
shop_segmentation = 'shop-segmentation' | |||
@@ -39,14 +44,18 @@ class Models(object): | |||
mtcnn = 'mtcnn' | |||
ulfd = 'ulfd' | |||
video_inpainting = 'video-inpainting' | |||
human_wholebody_keypoint = 'human-wholebody-keypoint' | |||
hand_static = 'hand-static' | |||
face_human_hand_detection = 'face-human-hand-detection' | |||
face_emotion = 'face-emotion' | |||
product_segmentation = 'product-segmentation' | |||
image_body_reshaping = 'image-body-reshaping' | |||
# EasyCV models | |||
yolox = 'YOLOX' | |||
segformer = 'Segformer' | |||
hand_2d_keypoints = 'HRNet-Hand2D-Keypoints' | |||
image_object_detection_auto = 'image-object-detection-auto' | |||
# nlp models | |||
bert = 'bert' | |||
@@ -58,18 +67,22 @@ class Models(object): | |||
space_dst = 'space-dst' | |||
space_intent = 'space-intent' | |||
space_modeling = 'space-modeling' | |||
star = 'star' | |||
star3 = 'star3' | |||
space_T_en = 'space-T-en' | |||
space_T_cn = 'space-T-cn' | |||
tcrf = 'transformer-crf' | |||
tcrf_wseg = 'transformer-crf-for-word-segmentation' | |||
transformer_softmax = 'transformer-softmax' | |||
lcrf = 'lstm-crf' | |||
lcrf_wseg = 'lstm-crf-for-word-segmentation' | |||
gcnncrf = 'gcnn-crf' | |||
bart = 'bart' | |||
gpt3 = 'gpt3' | |||
gpt_neo = 'gpt-neo' | |||
plug = 'plug' | |||
bert_for_ds = 'bert-for-document-segmentation' | |||
ponet = 'ponet' | |||
T5 = 'T5' | |||
bloom = 'bloom' | |||
# audio models | |||
sambert_hifigan = 'sambert-hifigan' | |||
@@ -88,6 +101,10 @@ class Models(object): | |||
team = 'team-multi-modal-similarity' | |||
video_clip = 'video-clip-multi-modal-embedding' | |||
# science models | |||
unifold = 'unifold' | |||
unifold_symmetry = 'unifold-symmetry' | |||
class TaskModels(object): | |||
# nlp task | |||
@@ -96,6 +113,7 @@ class TaskModels(object): | |||
information_extraction = 'information-extraction' | |||
fill_mask = 'fill-mask' | |||
feature_extraction = 'feature-extraction' | |||
text_generation = 'text-generation' | |||
class Heads(object): | |||
@@ -111,6 +129,8 @@ class Heads(object): | |||
token_classification = 'token-classification' | |||
# extraction | |||
information_extraction = 'information-extraction' | |||
# text gen | |||
text_generation = 'text-generation' | |||
class Pipelines(object): | |||
@@ -144,6 +164,7 @@ class Pipelines(object): | |||
salient_detection = 'u2net-salient-detection' | |||
image_classification = 'image-classification' | |||
face_detection = 'resnet-face-detection-scrfd10gkps' | |||
card_detection = 'resnet-card-detection-scrfd34gkps' | |||
ulfd_face_detection = 'manual-face-detection-ulfd' | |||
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' | |||
retina_face_detection = 'resnet50-face-detection-retinaface' | |||
@@ -160,6 +181,7 @@ class Pipelines(object): | |||
face_image_generation = 'gan-face-image-generation' | |||
product_retrieval_embedding = 'resnet50-product-retrieval-embedding' | |||
realtime_object_detection = 'cspnet_realtime-object-detection_yolox' | |||
realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo' | |||
face_recognition = 'ir101-face-recognition-cfglint' | |||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | |||
image2image_translation = 'image-to-image-translation' | |||
@@ -168,6 +190,7 @@ class Pipelines(object): | |||
ocr_recognition = 'convnextTiny-ocr-recognition' | |||
image_portrait_enhancement = 'gpen-image-portrait-enhancement' | |||
image_to_image_generation = 'image-to-image-generation' | |||
image_object_detection_auto = 'yolox_image-object-detection-auto' | |||
skin_retouching = 'unet-skin-retouching' | |||
tinynas_classification = 'tinynas-classification' | |||
tinynas_detection = 'tinynas-detection' | |||
@@ -178,21 +201,32 @@ class Pipelines(object): | |||
video_summarization = 'googlenet_pgl_video_summarization' | |||
image_semantic_segmentation = 'image-semantic-segmentation' | |||
image_reid_person = 'passvitb-image-reid-person' | |||
image_inpainting = 'fft-inpainting' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | |||
shop_segmentation = 'shop-segmentation' | |||
video_inpainting = 'video-inpainting' | |||
human_wholebody_keypoint = 'hrnetw48_human-wholebody-keypoint_image' | |||
pst_action_recognition = 'patchshift-action-recognition' | |||
hand_static = 'hand-static' | |||
face_human_hand_detection = 'face-human-hand-detection' | |||
face_emotion = 'face-emotion' | |||
product_segmentation = 'product-segmentation' | |||
image_body_reshaping = 'flow-based-body-reshaping' | |||
referring_video_object_segmentation = 'referring-video-object-segmentation' | |||
# nlp tasks | |||
automatic_post_editing = 'automatic-post-editing' | |||
translation_quality_estimation = 'translation-quality-estimation' | |||
domain_classification = 'domain-classification' | |||
sentence_similarity = 'sentence-similarity' | |||
word_segmentation = 'word-segmentation' | |||
multilingual_word_segmentation = 'multilingual-word-segmentation' | |||
word_segmentation_thai = 'word-segmentation-thai' | |||
part_of_speech = 'part-of-speech' | |||
named_entity_recognition = 'named-entity-recognition' | |||
named_entity_recognition_thai = 'named-entity-recognition-thai' | |||
named_entity_recognition_viet = 'named-entity-recognition-viet' | |||
text_generation = 'text-generation' | |||
text2text_generation = 'text2text-generation' | |||
sentiment_analysis = 'sentiment-analysis' | |||
@@ -208,14 +242,18 @@ class Pipelines(object): | |||
zero_shot_classification = 'zero-shot-classification' | |||
text_error_correction = 'text-error-correction' | |||
plug_generation = 'plug-generation' | |||
gpt3_generation = 'gpt3-generation' | |||
faq_question_answering = 'faq-question-answering' | |||
conversational_text_to_sql = 'conversational-text-to-sql' | |||
table_question_answering_pipeline = 'table-question-answering-pipeline' | |||
sentence_embedding = 'sentence-embedding' | |||
passage_ranking = 'passage-ranking' | |||
text_ranking = 'text-ranking' | |||
relation_extraction = 'relation-extraction' | |||
document_segmentation = 'document-segmentation' | |||
feature_extraction = 'feature-extraction' | |||
translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||
# audio tasks | |||
sambert_hifigan_tts = 'sambert-hifigan-tts' | |||
@@ -236,6 +274,10 @@ class Pipelines(object): | |||
text_to_image_synthesis = 'text-to-image-synthesis' | |||
video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
image_text_retrieval = 'image-text-retrieval' | |||
ofa_ocr_recognition = 'ofa-ocr-recognition' | |||
# science tasks | |||
protein_structure = 'unifold-protein-structure' | |||
class Trainers(object): | |||
@@ -253,12 +295,16 @@ class Trainers(object): | |||
# multi-modal trainers | |||
clip_multi_modal_embedding = 'clip-multi-modal-embedding' | |||
ofa = 'ofa' | |||
# cv trainers | |||
image_instance_segmentation = 'image-instance-segmentation' | |||
image_portrait_enhancement = 'image-portrait-enhancement' | |||
video_summarization = 'video-summarization' | |||
movie_scene_segmentation = 'movie-scene-segmentation' | |||
face_detection_scrfd = 'face-detection-scrfd' | |||
card_detection_scrfd = 'card-detection-scrfd' | |||
image_inpainting = 'image-inpainting' | |||
# nlp trainers | |||
bert_sentiment_analysis = 'bert-sentiment-analysis' | |||
@@ -266,10 +312,11 @@ class Trainers(object): | |||
dialog_intent_trainer = 'dialog-intent-trainer' | |||
nlp_base_trainer = 'nlp-base-trainer' | |||
nlp_veco_trainer = 'nlp-veco-trainer' | |||
nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' | |||
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | |||
# audio trainers | |||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
class Preprocessors(object): | |||
@@ -298,8 +345,12 @@ class Preprocessors(object): | |||
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
text_gen_tokenizer = 'text-gen-tokenizer' | |||
text2text_gen_preprocessor = 'text2text-gen-preprocessor' | |||
text_gen_jieba_tokenizer = 'text-gen-jieba-tokenizer' | |||
text2text_translate_preprocessor = 'text2text-translate-preprocessor' | |||
token_cls_tokenizer = 'token-cls-tokenizer' | |||
ner_tokenizer = 'ner-tokenizer' | |||
thai_ner_tokenizer = 'thai-ner-tokenizer' | |||
viet_ner_tokenizer = 'viet-ner-tokenizer' | |||
nli_tokenizer = 'nli-tokenizer' | |||
sen_cls_tokenizer = 'sen-cls-tokenizer' | |||
dialog_intent_preprocessor = 'dialog-intent-preprocessor' | |||
@@ -309,9 +360,10 @@ class Preprocessors(object): | |||
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
text_error_correction = 'text-error-correction' | |||
sentence_embedding = 'sentence-embedding' | |||
passage_ranking = 'passage-ranking' | |||
text_ranking = 'text-ranking' | |||
sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | |||
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
thai_wseg_tokenizer = 'thai-wseg-tokenizer' | |||
fill_mask = 'fill-mask' | |||
fill_mask_ponet = 'fill-mask-ponet' | |||
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||
@@ -320,6 +372,7 @@ class Preprocessors(object): | |||
re_tokenizer = 're-tokenizer' | |||
document_segmentation = 'document-segmentation' | |||
feature_extraction = 'feature-extraction' | |||
sentence_piece = 'sentence-piece' | |||
# audio preprocessor | |||
linear_aec_fbank = 'linear-aec-fbank' | |||
@@ -331,6 +384,9 @@ class Preprocessors(object): | |||
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | |||
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | |||
# science preprocessor | |||
unifold_preprocessor = 'unifold-preprocessor' | |||
class Metrics(object): | |||
""" Names for different metrics. | |||
@@ -340,6 +396,9 @@ class Metrics(object): | |||
accuracy = 'accuracy' | |||
audio_noise_metric = 'audio-noise-metric' | |||
# text gen | |||
BLEU = 'bleu' | |||
# metrics for image denoise task | |||
image_denoise_metric = 'image-denoise-metric' | |||
@@ -358,6 +417,10 @@ class Metrics(object): | |||
video_summarization_metric = 'video-summarization-metric' | |||
# metric for movie-scene-segmentation task | |||
movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | |||
# metric for inpainting task | |||
image_inpainting_metric = 'image-inpainting-metric' | |||
# metric for ocr | |||
NED = 'ned' | |||
class Optimizers(object): | |||
@@ -399,6 +462,9 @@ class Hooks(object): | |||
IterTimerHook = 'IterTimerHook' | |||
EvaluationHook = 'EvaluationHook' | |||
# Compression | |||
SparsityHook = 'SparsityHook' | |||
class LR_Schedulers(object): | |||
"""learning rate scheduler is defined here | |||
@@ -413,7 +479,10 @@ class Datasets(object): | |||
""" Names for different datasets. | |||
""" | |||
ClsDataset = 'ClsDataset' | |||
Face2dKeypointsDataset = 'Face2dKeypointsDataset' | |||
Face2dKeypointsDataset = 'FaceKeypointDataset' | |||
HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | |||
HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' | |||
SegDataset = 'SegDataset' | |||
DetDataset = 'DetDataset' | |||
DetImagesMixDataset = 'DetImagesMixDataset' | |||
PairedDataset = 'PairedDataset' |
@@ -17,6 +17,9 @@ if TYPE_CHECKING: | |||
from .token_classification_metric import TokenClassificationMetric | |||
from .video_summarization_metric import VideoSummarizationMetric | |||
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | |||
from .accuracy_metric import AccuracyMetric | |||
from .bleu_metric import BleuMetric | |||
from .image_inpainting_metric import ImageInpaintingMetric | |||
else: | |||
_import_structure = { | |||
@@ -34,6 +37,9 @@ else: | |||
'token_classification_metric': ['TokenClassificationMetric'], | |||
'video_summarization_metric': ['VideoSummarizationMetric'], | |||
'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | |||
'image_inpainting_metric': ['ImageInpaintingMetric'], | |||
'accuracy_metric': ['AccuracyMetric'], | |||
'bleu_metric': ['BleuMetric'], | |||
} | |||
import sys | |||
@@ -0,0 +1,46 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Dict | |||
import numpy as np | |||
from modelscope.metainfo import Metrics | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.registry import default_group | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
@METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) | |||
class AccuracyMetric(Metric): | |||
"""The metric computation class for classification classes. | |||
This metric class calculates accuracy for the whole input batches. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.preds = [] | |||
self.labels = [] | |||
def add(self, outputs: Dict, inputs: Dict): | |||
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||
ground_truths = inputs[label_name] | |||
eval_results = outputs[label_name] | |||
assert type(ground_truths) == type(eval_results) | |||
if isinstance(ground_truths, list): | |||
self.preds.extend(eval_results) | |||
self.labels.extend(ground_truths) | |||
elif isinstance(ground_truths, np.ndarray): | |||
self.preds.extend(eval_results.tolist()) | |||
self.labels.extend(ground_truths.tolist()) | |||
else: | |||
raise 'only support list or np.ndarray' | |||
def evaluate(self): | |||
assert len(self.preds) == len(self.labels) | |||
return { | |||
MetricKeys.ACCURACY: (np.asarray([ | |||
pred == ref for pred, ref in zip(self.preds, self.labels) | |||
])).mean().item() | |||
} |
@@ -35,6 +35,8 @@ class AudioNoiseMetric(Metric): | |||
total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr | |||
return { | |||
'total_loss': total_loss.item(), | |||
'avg_sisnr': avg_sisnr.item(), | |||
# model use opposite number of sisnr as a calculation shortcut. | |||
# revert it in evaluation result | |||
'avg_sisnr': -avg_sisnr.item(), | |||
MetricKeys.AVERAGE_LOSS: avg_loss.item() | |||
} |
@@ -10,8 +10,8 @@ class Metric(ABC): | |||
complex metrics for a specific task with or without other Metric subclasses. | |||
""" | |||
def __init__(self, trainer=None, *args, **kwargs): | |||
self.trainer = trainer | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
@abstractmethod | |||
def add(self, outputs: Dict, inputs: Dict): | |||
@@ -0,0 +1,42 @@ | |||
from itertools import zip_longest | |||
from typing import Dict | |||
import sacrebleu | |||
from modelscope.metainfo import Metrics | |||
from modelscope.utils.registry import default_group | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
EVAL_BLEU_ORDER = 4 | |||
@METRICS.register_module(group_key=default_group, module_name=Metrics.BLEU) | |||
class BleuMetric(Metric): | |||
"""The metric computation bleu for text generation classes. | |||
This metric class calculates accuracy for the whole input batches. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False) | |||
self.hyp_name = kwargs.get('hyp_name', 'hyp') | |||
self.ref_name = kwargs.get('ref_name', 'ref') | |||
self.refs = list() | |||
self.hyps = list() | |||
def add(self, outputs: Dict, inputs: Dict): | |||
self.refs.extend(inputs[self.ref_name]) | |||
self.hyps.extend(outputs[self.hyp_name]) | |||
def evaluate(self): | |||
if self.eval_tokenized_bleu: | |||
bleu = sacrebleu.corpus_bleu( | |||
self.hyps, list(zip_longest(*self.refs)), tokenize='none') | |||
else: | |||
bleu = sacrebleu.corpus_bleu(self.hyps, | |||
list(zip_longest(*self.refs))) | |||
return { | |||
MetricKeys.BLEU_4: bleu.score, | |||
} |
@@ -18,10 +18,12 @@ class MetricKeys(object): | |||
SSIM = 'ssim' | |||
AVERAGE_LOSS = 'avg_loss' | |||
FScore = 'fscore' | |||
FID = 'fid' | |||
BLEU_1 = 'bleu-1' | |||
BLEU_4 = 'bleu-4' | |||
ROUGE_1 = 'rouge-1' | |||
ROUGE_L = 'rouge-l' | |||
NED = 'ned' # ocr metric | |||
task_default_metrics = { | |||
@@ -31,6 +33,7 @@ task_default_metrics = { | |||
Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||
Tasks.token_classification: [Metrics.token_cls_metric], | |||
Tasks.text_generation: [Metrics.text_gen_metric], | |||
Tasks.text_classification: [Metrics.seq_cls_metric], | |||
Tasks.image_denoising: [Metrics.image_denoise_metric], | |||
Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | |||
Tasks.image_portrait_enhancement: | |||
@@ -39,6 +42,7 @@ task_default_metrics = { | |||
Tasks.image_captioning: [Metrics.text_gen_metric], | |||
Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||
Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | |||
Tasks.image_inpainting: [Metrics.image_inpainting_metric], | |||
} | |||
@@ -0,0 +1 @@ | |||
__author__ = 'tylin' |
@@ -0,0 +1,57 @@ | |||
# Filename: ciderD.py | |||
# | |||
# Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric | |||
# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) | |||
# | |||
# Creation Date: Sun Feb 8 14:16:54 2015 | |||
# | |||
# Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu> | |||
from __future__ import absolute_import, division, print_function | |||
from .ciderD_scorer import CiderScorer | |||
class CiderD: | |||
""" | |||
Main Class to compute the CIDEr metric | |||
""" | |||
def __init__(self, n=4, sigma=6.0, df='corpus'): | |||
# set cider to sum over 1 to 4-grams | |||
self._n = n | |||
# set the standard deviation parameter for gaussian penalty | |||
self._sigma = sigma | |||
# set which where to compute document frequencies from | |||
self._df = df | |||
self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) | |||
def compute_score(self, gts, res): | |||
""" | |||
Main function to compute CIDEr score | |||
:param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence> | |||
ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence> | |||
:return: cider (float) : computed CIDEr score for the corpus | |||
""" # noqa | |||
# clear all the previous hypos and refs | |||
tmp_cider_scorer = self.cider_scorer.copy_empty() | |||
tmp_cider_scorer.clear() | |||
for res_id in res: | |||
hypo = res_id['caption'] | |||
ref = gts[res_id['image_id']] | |||
# Sanity check. | |||
assert (type(hypo) is list) | |||
assert (len(hypo) == 1) | |||
assert (type(ref) is list) | |||
assert (len(ref) > 0) | |||
tmp_cider_scorer += (hypo[0], ref) | |||
(score, scores) = tmp_cider_scorer.compute_score() | |||
return score, scores | |||
def method(self): | |||
return 'CIDEr-D' |
@@ -0,0 +1,233 @@ | |||
#!/usr/bin/env python | |||
# Tsung-Yi Lin <tl483@cornell.edu> | |||
# Ramakrishna Vedantam <vrama91@vt.edu> | |||
from __future__ import absolute_import, division, print_function | |||
import copy | |||
import math | |||
import os | |||
import pdb | |||
from collections import defaultdict | |||
import numpy as np | |||
import six | |||
from six.moves import cPickle | |||
def precook(s, n=4, out=False): | |||
""" | |||
Takes a string as input and returns an object that can be given to | |||
either cook_refs or cook_test. This is optional: cook_refs and cook_test | |||
can take string arguments as well. | |||
:param s: string : sentence to be converted into ngrams | |||
:param n: int : number of ngrams for which representation is calculated | |||
:return: term frequency vector for occuring ngrams | |||
""" | |||
words = s.split() | |||
counts = defaultdict(int) | |||
for k in range(1, n + 1): | |||
for i in range(len(words) - k + 1): | |||
ngram = tuple(words[i:i + k]) | |||
counts[ngram] += 1 | |||
return counts | |||
def cook_refs(refs, n=4): # lhuang: oracle will call with "average" | |||
'''Takes a list of reference sentences for a single segment | |||
and returns an object that encapsulates everything that BLEU | |||
needs to know about them. | |||
:param refs: list of string : reference sentences for some image | |||
:param n: int : number of ngrams for which (ngram) representation is calculated | |||
:return: result (list of dict) | |||
''' | |||
return [precook(ref, n) for ref in refs] | |||
def cook_test(test, n=4): | |||
'''Takes a test sentence and returns an object that | |||
encapsulates everything that BLEU needs to know about it. | |||
:param test: list of string : hypothesis sentence for some image | |||
:param n: int : number of ngrams for which (ngram) representation is calculated | |||
:return: result (dict) | |||
''' | |||
return precook(test, n, True) | |||
class CiderScorer(object): | |||
"""CIDEr scorer. | |||
""" | |||
def copy(self): | |||
''' copy the refs.''' | |||
new = CiderScorer(n=self.n) | |||
new.ctest = copy.copy(self.ctest) | |||
new.crefs = copy.copy(self.crefs) | |||
return new | |||
def copy_empty(self): | |||
new = CiderScorer(df_mode='corpus', n=self.n, sigma=self.sigma) | |||
new.df_mode = self.df_mode | |||
new.ref_len = self.ref_len | |||
new.document_frequency = self.document_frequency | |||
return new | |||
def __init__(self, df_mode='corpus', test=None, refs=None, n=4, sigma=6.0): | |||
''' singular instance ''' | |||
self.n = n | |||
self.sigma = sigma | |||
self.crefs = [] | |||
self.ctest = [] | |||
self.df_mode = df_mode | |||
self.ref_len = None | |||
if self.df_mode != 'corpus': | |||
pkl_file = cPickle.load( | |||
open(df_mode, 'rb'), | |||
**(dict(encoding='latin1') if six.PY3 else {})) | |||
self.ref_len = np.log(float(pkl_file['ref_len'])) | |||
self.document_frequency = pkl_file['document_frequency'] | |||
else: | |||
self.document_frequency = None | |||
self.cook_append(test, refs) | |||
def clear(self): | |||
self.crefs = [] | |||
self.ctest = [] | |||
def cook_append(self, test, refs): | |||
'''called by constructor and __iadd__ to avoid creating new instances.''' | |||
if refs is not None: | |||
self.crefs.append(cook_refs(refs)) | |||
if test is not None: | |||
self.ctest.append(cook_test(test)) # N.B.: -1 | |||
else: | |||
self.ctest.append( | |||
None) # lens of crefs and ctest have to match | |||
def size(self): | |||
assert len(self.crefs) == len( | |||
self.ctest), 'refs/test mismatch! %d<>%d' % (len( | |||
self.crefs), len(self.ctest)) | |||
return len(self.crefs) | |||
def __iadd__(self, other): | |||
'''add an instance (e.g., from another sentence).''' | |||
if type(other) is tuple: | |||
# avoid creating new CiderScorer instances | |||
self.cook_append(other[0], other[1]) | |||
else: | |||
self.ctest.extend(other.ctest) | |||
self.crefs.extend(other.crefs) | |||
return self | |||
def compute_doc_freq(self): | |||
""" | |||
Compute term frequency for reference data. | |||
This will be used to compute idf (inverse document frequency later) | |||
The term frequency is stored in the object | |||
:return: None | |||
""" | |||
for refs in self.crefs: | |||
# refs, k ref captions of one image | |||
for ngram in set([ | |||
ngram for ref in refs for (ngram, count) in ref.items() | |||
]): # noqa | |||
self.document_frequency[ngram] += 1 | |||
def compute_cider(self): | |||
def counts2vec(cnts): | |||
""" | |||
Function maps counts of ngram to vector of tfidf weights. | |||
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. | |||
The n-th entry of array denotes length of n-grams. | |||
:param cnts: | |||
:return: vec (array of dict), norm (array of float), length (int) | |||
""" | |||
vec = [defaultdict(float) for _ in range(self.n)] | |||
length = 0 | |||
norm = [0.0 for _ in range(self.n)] | |||
for (ngram, term_freq) in cnts.items(): | |||
# give word count 1 if it doesn't appear in reference corpus | |||
df = np.log(max(1.0, self.document_frequency[ngram])) | |||
# ngram index | |||
n = len(ngram) - 1 | |||
# tf (term_freq) * idf (precomputed idf) for n-grams | |||
vec[n][ngram] = float(term_freq) * (self.ref_len - df) | |||
# compute norm for the vector. the norm will be used for computing similarity | |||
norm[n] += pow(vec[n][ngram], 2) | |||
if n == 1: | |||
length += term_freq | |||
norm = [np.sqrt(n) for n in norm] | |||
return vec, norm, length | |||
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): | |||
''' | |||
Compute the cosine similarity of two vectors. | |||
:param vec_hyp: array of dictionary for vector corresponding to hypothesis | |||
:param vec_ref: array of dictionary for vector corresponding to reference | |||
:param norm_hyp: array of float for vector corresponding to hypothesis | |||
:param norm_ref: array of float for vector corresponding to reference | |||
:param length_hyp: int containing length of hypothesis | |||
:param length_ref: int containing length of reference | |||
:return: array of score for each n-grams cosine similarity | |||
''' | |||
delta = float(length_hyp - length_ref) | |||
# measure consine similarity | |||
val = np.array([0.0 for _ in range(self.n)]) | |||
for n in range(self.n): | |||
# ngram | |||
for (ngram, count) in vec_hyp[n].items(): | |||
# vrama91 : added clipping | |||
val[n] += min(vec_hyp[n][ngram], | |||
vec_ref[n][ngram]) * vec_ref[n][ngram] | |||
if (norm_hyp[n] != 0) and (norm_ref[n] != 0): | |||
val[n] /= (norm_hyp[n] * norm_ref[n]) | |||
assert (not math.isnan(val[n])) | |||
# vrama91: added a length based gaussian penalty | |||
val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2)) | |||
return val | |||
# compute log reference length | |||
if self.df_mode == 'corpus': | |||
self.ref_len = np.log(float(len(self.crefs))) | |||
# elif self.df_mode == "coco-val-df": | |||
# if coco option selected, use length of coco-val set | |||
# self.ref_len = np.log(float(40504)) | |||
scores = [] | |||
for test, refs in zip(self.ctest, self.crefs): | |||
# compute vector for test captions | |||
vec, norm, length = counts2vec(test) | |||
# compute vector for ref captions | |||
score = np.array([0.0 for _ in range(self.n)]) | |||
for ref in refs: | |||
vec_ref, norm_ref, length_ref = counts2vec(ref) | |||
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) | |||
# change by vrama91 - mean of ngram scores, instead of sum | |||
score_avg = np.mean(score) | |||
# divide by number of references | |||
score_avg /= len(refs) | |||
# multiply score by 10 | |||
score_avg *= 10.0 | |||
# append score of an image to the score list | |||
scores.append(score_avg) | |||
return scores | |||
def compute_score(self, option=None, verbose=0): | |||
# compute idf | |||
if self.df_mode == 'corpus': | |||
self.document_frequency = defaultdict(float) | |||
self.compute_doc_freq() | |||
# assert to check document frequency | |||
assert (len(self.ctest) >= max(self.document_frequency.values())) | |||
# import json for now and write the corresponding files | |||
# compute cider score | |||
score = self.compute_cider() | |||
# debug | |||
# print score | |||
return np.mean(np.array(score)), np.array(score) |
@@ -1,12 +1,16 @@ | |||
# ------------------------------------------------------------------------ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# ------------------------------------------------------------------------ | |||
# modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py | |||
# ------------------------------------------------------------------------ | |||
from typing import Dict | |||
import cv2 | |||
import numpy as np | |||
from skimage.metrics import peak_signal_noise_ratio, structural_similarity | |||
import torch | |||
from modelscope.metainfo import Metrics | |||
from modelscope.utils.registry import default_group | |||
from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
torch_nested_numpify) | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
@@ -20,26 +24,249 @@ class ImageDenoiseMetric(Metric): | |||
label_name = 'target' | |||
def __init__(self): | |||
super(ImageDenoiseMetric, self).__init__() | |||
self.preds = [] | |||
self.labels = [] | |||
def add(self, outputs: Dict, inputs: Dict): | |||
ground_truths = outputs[ImageDenoiseMetric.label_name] | |||
eval_results = outputs[ImageDenoiseMetric.pred_name] | |||
self.preds.append( | |||
torch_nested_numpify(torch_nested_detach(eval_results))) | |||
self.labels.append( | |||
torch_nested_numpify(torch_nested_detach(ground_truths))) | |||
self.preds.append(eval_results) | |||
self.labels.append(ground_truths) | |||
def evaluate(self): | |||
psnr_list, ssim_list = [], [] | |||
for (pred, label) in zip(self.preds, self.labels): | |||
psnr_list.append( | |||
peak_signal_noise_ratio(label[0], pred[0], data_range=255)) | |||
ssim_list.append( | |||
structural_similarity( | |||
label[0], pred[0], multichannel=True, data_range=255)) | |||
psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0)) | |||
ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0)) | |||
return { | |||
MetricKeys.PSNR: np.mean(psnr_list), | |||
MetricKeys.SSIM: np.mean(ssim_list) | |||
} | |||
def reorder_image(img, input_order='HWC'): | |||
"""Reorder images to 'HWC' order. | |||
If the input_order is (h, w), return (h, w, 1); | |||
If the input_order is (c, h, w), return (h, w, c); | |||
If the input_order is (h, w, c), return as it is. | |||
Args: | |||
img (ndarray): Input image. | |||
input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||
If the input image shape is (h, w), input_order will not have | |||
effects. Default: 'HWC'. | |||
Returns: | |||
ndarray: reordered image. | |||
""" | |||
if input_order not in ['HWC', 'CHW']: | |||
raise ValueError( | |||
f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" | |||
) | |||
if len(img.shape) == 2: | |||
img = img[..., None] | |||
if input_order == 'CHW': | |||
img = img.transpose(1, 2, 0) | |||
return img | |||
def calculate_psnr(img1, img2, crop_border, input_order='HWC'): | |||
"""Calculate PSNR (Peak Signal-to-Noise Ratio). | |||
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio | |||
Args: | |||
img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. | |||
img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. | |||
crop_border (int): Cropped pixels in each edge of an image. These | |||
pixels are not involved in the PSNR calculation. | |||
input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||
Default: 'HWC'. | |||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||
Returns: | |||
float: psnr result. | |||
""" | |||
assert img1.shape == img2.shape, ( | |||
f'Image shapes are differnet: {img1.shape}, {img2.shape}.') | |||
if input_order not in ['HWC', 'CHW']: | |||
raise ValueError( | |||
f'Wrong input_order {input_order}. Supported input_orders are ' | |||
'"HWC" and "CHW"') | |||
if type(img1) == torch.Tensor: | |||
if len(img1.shape) == 4: | |||
img1 = img1.squeeze(0) | |||
img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) | |||
if type(img2) == torch.Tensor: | |||
if len(img2.shape) == 4: | |||
img2 = img2.squeeze(0) | |||
img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) | |||
img1 = reorder_image(img1, input_order=input_order) | |||
img2 = reorder_image(img2, input_order=input_order) | |||
img1 = img1.astype(np.float64) | |||
img2 = img2.astype(np.float64) | |||
if crop_border != 0: | |||
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
def _psnr(img1, img2): | |||
mse = np.mean((img1 - img2)**2) | |||
if mse == 0: | |||
return float('inf') | |||
max_value = 1. if img1.max() <= 1 else 255. | |||
return 20. * np.log10(max_value / np.sqrt(mse)) | |||
return _psnr(img1, img2) | |||
def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True): | |||
"""Calculate SSIM (structural similarity). | |||
Ref: | |||
Image quality assessment: From error visibility to structural similarity | |||
The results are the same as that of the official released MATLAB code in | |||
https://ece.uwaterloo.ca/~z70wang/research/ssim/. | |||
For three-channel images, SSIM is calculated for each channel and then | |||
averaged. | |||
Args: | |||
img1 (ndarray): Images with range [0, 255]. | |||
img2 (ndarray): Images with range [0, 255]. | |||
crop_border (int): Cropped pixels in each edge of an image. These | |||
pixels are not involved in the SSIM calculation. | |||
input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||
Default: 'HWC'. | |||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||
Returns: | |||
float: ssim result. | |||
""" | |||
assert img1.shape == img2.shape, ( | |||
f'Image shapes are differnet: {img1.shape}, {img2.shape}.') | |||
if input_order not in ['HWC', 'CHW']: | |||
raise ValueError( | |||
f'Wrong input_order {input_order}. Supported input_orders are ' | |||
'"HWC" and "CHW"') | |||
if type(img1) == torch.Tensor: | |||
if len(img1.shape) == 4: | |||
img1 = img1.squeeze(0) | |||
img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) | |||
if type(img2) == torch.Tensor: | |||
if len(img2.shape) == 4: | |||
img2 = img2.squeeze(0) | |||
img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) | |||
img1 = reorder_image(img1, input_order=input_order) | |||
img2 = reorder_image(img2, input_order=input_order) | |||
img1 = img1.astype(np.float64) | |||
img2 = img2.astype(np.float64) | |||
if crop_border != 0: | |||
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
def _cal_ssim(img1, img2): | |||
ssims = [] | |||
max_value = 1 if img1.max() <= 1 else 255 | |||
with torch.no_grad(): | |||
final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim( | |||
img1, img2, max_value) | |||
ssims.append(final_ssim) | |||
return np.array(ssims).mean() | |||
return _cal_ssim(img1, img2) | |||
def _ssim(img, img2, max_value): | |||
"""Calculate SSIM (structural similarity) for one channel images. | |||
It is called by func:`calculate_ssim`. | |||
Args: | |||
img (ndarray): Images with range [0, 255] with order 'HWC'. | |||
img2 (ndarray): Images with range [0, 255] with order 'HWC'. | |||
Returns: | |||
float: SSIM result. | |||
""" | |||
c1 = (0.01 * max_value)**2 | |||
c2 = (0.03 * max_value)**2 | |||
img = img.astype(np.float64) | |||
img2 = img2.astype(np.float64) | |||
kernel = cv2.getGaussianKernel(11, 1.5) | |||
window = np.outer(kernel, kernel.transpose()) | |||
mu1 = cv2.filter2D(img, -1, window)[5:-5, | |||
5:-5] # valid mode for window size 11 | |||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] | |||
mu1_sq = mu1**2 | |||
mu2_sq = mu2**2 | |||
mu1_mu2 = mu1 * mu2 | |||
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq | |||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq | |||
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 | |||
tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) | |||
tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) | |||
ssim_map = tmp1 / tmp2 | |||
return ssim_map.mean() | |||
def _3d_gaussian_calculator(img, conv3d): | |||
out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) | |||
return out | |||
def _generate_3d_gaussian_kernel(): | |||
kernel = cv2.getGaussianKernel(11, 1.5) | |||
window = np.outer(kernel, kernel.transpose()) | |||
kernel_3 = cv2.getGaussianKernel(11, 1.5) | |||
kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) | |||
conv3d = torch.nn.Conv3d( | |||
1, | |||
1, (11, 11, 11), | |||
stride=1, | |||
padding=(5, 5, 5), | |||
bias=False, | |||
padding_mode='replicate') | |||
conv3d.weight.requires_grad = False | |||
conv3d.weight[0, 0, :, :, :] = kernel | |||
return conv3d | |||
def _ssim_3d(img1, img2, max_value): | |||
assert len(img1.shape) == 3 and len(img2.shape) == 3 | |||
"""Calculate SSIM (structural similarity) for one channel images. | |||
It is called by func:`calculate_ssim`. | |||
Args: | |||
img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. | |||
img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. | |||
Returns: | |||
float: ssim result. | |||
""" | |||
C1 = (0.01 * max_value)**2 | |||
C2 = (0.03 * max_value)**2 | |||
img1 = img1.astype(np.float64) | |||
img2 = img2.astype(np.float64) | |||
kernel = _generate_3d_gaussian_kernel().cuda() | |||
img1 = torch.tensor(img1).float().cuda() | |||
img2 = torch.tensor(img2).float().cuda() | |||
mu1 = _3d_gaussian_calculator(img1, kernel) | |||
mu2 = _3d_gaussian_calculator(img2, kernel) | |||
mu1_sq = mu1**2 | |||
mu2_sq = mu2**2 | |||
mu1_mu2 = mu1 * mu2 | |||
sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq | |||
sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq | |||
sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2 | |||
tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2) | |||
tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) | |||
ssim_map = tmp1 / tmp2 | |||
return float(ssim_map.mean()) |
@@ -0,0 +1,210 @@ | |||
""" | |||
Part of the implementation is borrowed and modified from LaMa, publicly available at | |||
https://github.com/saic-mdal/lama | |||
""" | |||
from typing import Dict | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
from scipy import linalg | |||
from modelscope.metainfo import Metrics | |||
from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3 | |||
from modelscope.utils.registry import default_group | |||
from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
torch_nested_numpify) | |||
from .base import Metric | |||
from .builder import METRICS, MetricKeys | |||
def fid_calculate_activation_statistics(act): | |||
mu = np.mean(act, axis=0) | |||
sigma = np.cov(act, rowvar=False) | |||
return mu, sigma | |||
def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): | |||
mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) | |||
mu2, sigma2 = fid_calculate_activation_statistics(activations_target) | |||
diff = mu1 - mu2 | |||
# Product might be almost singular | |||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |||
if not np.isfinite(covmean).all(): | |||
offset = np.eye(sigma1.shape[0]) * eps | |||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |||
# Numerical error might give slight imaginary component | |||
if np.iscomplexobj(covmean): | |||
# if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): | |||
m = np.max(np.abs(covmean.imag)) | |||
raise ValueError('Imaginary component {}'.format(m)) | |||
covmean = covmean.real | |||
tr_covmean = np.trace(covmean) | |||
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) | |||
- 2 * tr_covmean) | |||
class FIDScore(torch.nn.Module): | |||
def __init__(self, dims=2048, eps=1e-6): | |||
super().__init__() | |||
if getattr(FIDScore, '_MODEL', None) is None: | |||
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] | |||
FIDScore._MODEL = InceptionV3([block_idx]).eval() | |||
self.model = FIDScore._MODEL | |||
self.eps = eps | |||
self.reset() | |||
def forward(self, pred_batch, target_batch, mask=None): | |||
activations_pred = self._get_activations(pred_batch) | |||
activations_target = self._get_activations(target_batch) | |||
self.activations_pred.append(activations_pred.detach().cpu()) | |||
self.activations_target.append(activations_target.detach().cpu()) | |||
def get_value(self): | |||
activations_pred, activations_target = (self.activations_pred, | |||
self.activations_target) | |||
activations_pred = torch.cat(activations_pred).cpu().numpy() | |||
activations_target = torch.cat(activations_target).cpu().numpy() | |||
total_distance = calculate_frechet_distance( | |||
activations_pred, activations_target, eps=self.eps) | |||
self.reset() | |||
return total_distance | |||
def reset(self): | |||
self.activations_pred = [] | |||
self.activations_target = [] | |||
def _get_activations(self, batch): | |||
activations = self.model(batch)[0] | |||
if activations.shape[2] != 1 or activations.shape[3] != 1: | |||
assert False, \ | |||
'We should not have got here, because Inception always scales inputs to 299x299' | |||
activations = activations.squeeze(-1).squeeze(-1) | |||
return activations | |||
class SSIM(torch.nn.Module): | |||
"""SSIM. Modified from: | |||
https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py | |||
""" | |||
def __init__(self, window_size=11, size_average=True): | |||
super().__init__() | |||
self.window_size = window_size | |||
self.size_average = size_average | |||
self.channel = 1 | |||
self.register_buffer('window', | |||
self._create_window(window_size, self.channel)) | |||
def forward(self, img1, img2): | |||
assert len(img1.shape) == 4 | |||
channel = img1.size()[1] | |||
if channel == self.channel and self.window.data.type( | |||
) == img1.data.type(): | |||
window = self.window | |||
else: | |||
window = self._create_window(self.window_size, channel) | |||
window = window.type_as(img1) | |||
self.window = window | |||
self.channel = channel | |||
return self._ssim(img1, img2, window, self.window_size, channel, | |||
self.size_average) | |||
def _gaussian(self, window_size, sigma): | |||
gauss = torch.Tensor([ | |||
np.exp(-(x - (window_size // 2))**2 / float(2 * sigma**2)) | |||
for x in range(window_size) | |||
]) | |||
return gauss / gauss.sum() | |||
def _create_window(self, window_size, channel): | |||
_1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) | |||
_2D_window = _1D_window.mm( | |||
_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |||
return _2D_window.expand(channel, 1, window_size, | |||
window_size).contiguous() | |||
def _ssim(self, | |||
img1, | |||
img2, | |||
window, | |||
window_size, | |||
channel, | |||
size_average=True): | |||
mu1 = F.conv2d( | |||
img1, window, padding=(window_size // 2), groups=channel) | |||
mu2 = F.conv2d( | |||
img2, window, padding=(window_size // 2), groups=channel) | |||
mu1_sq = mu1.pow(2) | |||
mu2_sq = mu2.pow(2) | |||
mu1_mu2 = mu1 * mu2 | |||
sigma1_sq = F.conv2d( | |||
img1 * img1, window, padding=(window_size // 2), | |||
groups=channel) - mu1_sq | |||
sigma2_sq = F.conv2d( | |||
img2 * img2, window, padding=(window_size // 2), | |||
groups=channel) - mu2_sq | |||
sigma12 = F.conv2d( | |||
img1 * img2, window, padding=(window_size // 2), | |||
groups=channel) - mu1_mu2 | |||
C1 = 0.01**2 | |||
C2 = 0.03**2 | |||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ | |||
((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |||
if size_average: | |||
return ssim_map.mean() | |||
return ssim_map.mean(1).mean(1).mean(1) | |||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |||
missing_keys, unexpected_keys, error_msgs): | |||
return | |||
@METRICS.register_module( | |||
group_key=default_group, module_name=Metrics.image_inpainting_metric) | |||
class ImageInpaintingMetric(Metric): | |||
"""The metric computation class for image inpainting classes. | |||
""" | |||
def __init__(self): | |||
self.preds = [] | |||
self.targets = [] | |||
self.SSIM = SSIM(window_size=11, size_average=False).eval() | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
self.FID = FIDScore().to(device) | |||
def add(self, outputs: Dict, inputs: Dict): | |||
pred = outputs['inpainted'] | |||
target = inputs['image'] | |||
self.preds.append(torch_nested_detach(pred)) | |||
self.targets.append(torch_nested_detach(target)) | |||
def evaluate(self): | |||
ssim_list = [] | |||
for (pred, target) in zip(self.preds, self.targets): | |||
ssim_list.append(self.SSIM(pred, target)) | |||
self.FID(pred, target) | |||
ssim_list = torch_nested_numpify(ssim_list) | |||
fid = self.FID.get_value() | |||
return {MetricKeys.SSIM: np.mean(ssim_list), MetricKeys.FID: fid} |
@@ -1,5 +1,8 @@ | |||
# Part of the implementation is borrowed and modified from BasicSR, publicly available at | |||
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py | |||
from typing import Dict | |||
import cv2 | |||
import numpy as np | |||
from modelscope.metainfo import Metrics | |||
@@ -35,6 +38,7 @@ class ImagePortraitEnhancementMetric(Metric): | |||
def add(self, outputs: Dict, inputs: Dict): | |||
ground_truths = outputs['target'] | |||
eval_results = outputs['pred'] | |||
self.preds.extend(eval_results) | |||
self.targets.extend(ground_truths) | |||
@@ -34,17 +34,24 @@ class TokenClassificationMetric(Metric): | |||
self.labels.append( | |||
torch_nested_numpify(torch_nested_detach(ground_truths))) | |||
def __init__(self, return_entity_level_metrics=False, *args, **kwargs): | |||
def __init__(self, | |||
return_entity_level_metrics=False, | |||
label2id=None, | |||
*args, | |||
**kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.return_entity_level_metrics = return_entity_level_metrics | |||
self.preds = [] | |||
self.labels = [] | |||
self.label2id = label2id | |||
def evaluate(self): | |||
self.id2label = { | |||
id: label | |||
for label, id in self.trainer.label2id.items() | |||
} | |||
label2id = self.label2id | |||
if label2id is None: | |||
assert hasattr(self, 'trainer') | |||
label2id = self.trainer.label2id | |||
self.id2label = {id: label for label, id in label2id.items()} | |||
self.preds = np.concatenate(self.preds, axis=0) | |||
self.labels = np.concatenate(self.labels, axis=0) | |||
predictions = np.argmax(self.preds, axis=-1) | |||
@@ -1,3 +1,6 @@ | |||
# Part of the implementation is borrowed and modified from PGL-SUM, | |||
# publicly available at https://github.com/e-apostolidis/PGL-SUM | |||
from typing import Dict | |||
import numpy as np | |||
@@ -1,3 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Any, Dict | |||
@@ -1,15 +1,14 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Dict | |||
import torch | |||
from typing import Dict, Optional | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.audio.audio_utils import update_conf | |||
from modelscope.utils.constant import Tasks | |||
from .fsmn_sele_v2 import FSMNSeleNetV2 | |||
@@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||
MODEL_TXT = 'model.txt' | |||
SC_CONFIG = 'sound_connect.conf' | |||
SC_CONF_ITEM_KWS_MODEL = '${kws_model}' | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
def __init__(self, | |||
model_dir: str, | |||
training: Optional[bool] = False, | |||
*args, | |||
**kwargs): | |||
"""initialize the dfsmn model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
model_bin_file = os.path.join(model_dir, | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
self._model = None | |||
if os.path.exists(model_bin_file): | |||
kwargs.pop('device') | |||
self._model = FSMNSeleNetV2(*args, **kwargs) | |||
checkpoint = torch.load(model_bin_file) | |||
self._model.load_state_dict(checkpoint, strict=False) | |||
self._sc = None | |||
if os.path.exists(model_txt_file): | |||
with open(sc_config_file) as f: | |||
lines = f.readlines() | |||
with open(sc_config_file, 'w') as f: | |||
for line in lines: | |||
if self.SC_CONF_ITEM_KWS_MODEL in line: | |||
line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, | |||
model_txt_file) | |||
f.write(line) | |||
import py_sound_connect | |||
self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
self.size_in = self._sc.bytesPerBlockIn() | |||
self.size_out = self._sc.bytesPerBlockOut() | |||
if self._model is None and self._sc is None: | |||
raise Exception( | |||
f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' | |||
) | |||
if training: | |||
self.model = FSMNSeleNetV2(*args, **kwargs) | |||
else: | |||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
self._sc = None | |||
if os.path.exists(model_txt_file): | |||
conf_dict = dict(mode=56542, kws_model=model_txt_file) | |||
update_conf(sc_config_file, sc_config_file, conf_dict) | |||
import py_sound_connect | |||
self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
self.size_in = self._sc.bytesPerBlockIn() | |||
self.size_out = self._sc.bytesPerBlockOut() | |||
else: | |||
raise Exception( | |||
f'Invalid model directory! Failed to load model file: {model_txt_file}.' | |||
) | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
... | |||
return self.model.forward(input) | |||
def forward_decode(self, data: bytes): | |||
result = {'pcm': self._sc.process(data, self.size_out)} | |||
@@ -1,3 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Any, Dict | |||
@@ -2,6 +2,7 @@ | |||
import os | |||
import pickle as pkl | |||
from threading import Lock | |||
import json | |||
import numpy as np | |||
@@ -27,6 +28,7 @@ class Voice: | |||
self.__am_config = AttrDict(**am_config) | |||
self.__voc_config = AttrDict(**voc_config) | |||
self.__model_loaded = False | |||
self.__lock = Lock() | |||
if 'am' not in self.__am_config: | |||
raise TtsModelConfigurationException( | |||
'modelscope error: am configuration invalid') | |||
@@ -71,34 +73,35 @@ class Voice: | |||
self.__generator.remove_weight_norm() | |||
def __am_forward(self, symbol_seq): | |||
with torch.no_grad(): | |||
inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||
symbol_seq) | |||
inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||
self.__device) | |||
inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||
self.__device) | |||
inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to( | |||
self.__device) | |||
inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||
self.__device) | |||
inputs_ling = torch.stack( | |||
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||
dim=-1).unsqueeze(0) | |||
inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||
self.__device).unsqueeze(0) | |||
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||
self.__device).unsqueeze(0) | |||
inputs_len = torch.zeros(1).to(self.__device).long( | |||
) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||
inputs_spk[:, :-1], inputs_len) | |||
postnet_outputs = res['postnet_outputs'] | |||
LR_length_rounded = res['LR_length_rounded'] | |||
valid_length = int(LR_length_rounded[0].item()) | |||
postnet_outputs = postnet_outputs[ | |||
0, :valid_length, :].cpu().numpy() | |||
return postnet_outputs | |||
with self.__lock: | |||
with torch.no_grad(): | |||
inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( | |||
symbol_seq) | |||
inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( | |||
self.__device) | |||
inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( | |||
self.__device) | |||
inputs_syllable = torch.from_numpy( | |||
inputs_feat_lst[2]).long().to(self.__device) | |||
inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( | |||
self.__device) | |||
inputs_ling = torch.stack( | |||
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws], | |||
dim=-1).unsqueeze(0) | |||
inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( | |||
self.__device).unsqueeze(0) | |||
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( | |||
self.__device).unsqueeze(0) | |||
inputs_len = torch.zeros(1).to(self.__device).long( | |||
) + inputs_emo.size(1) - 1 # minus 1 for "~" | |||
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], | |||
inputs_spk[:, :-1], inputs_len) | |||
postnet_outputs = res['postnet_outputs'] | |||
LR_length_rounded = res['LR_length_rounded'] | |||
valid_length = int(LR_length_rounded[0].item()) | |||
postnet_outputs = postnet_outputs[ | |||
0, :valid_length, :].cpu().numpy() | |||
return postnet_outputs | |||
def __vocoder_forward(self, melspec): | |||
dim0 = list(melspec.shape)[-1] | |||
@@ -118,14 +121,15 @@ class Voice: | |||
return audio | |||
def forward(self, symbol_seq): | |||
if not self.__model_loaded: | |||
torch.manual_seed(self.__am_config.seed) | |||
if torch.cuda.is_available(): | |||
with self.__lock: | |||
if not self.__model_loaded: | |||
torch.manual_seed(self.__am_config.seed) | |||
self.__device = torch.device('cuda') | |||
else: | |||
self.__device = torch.device('cpu') | |||
self.__load_am() | |||
self.__load_vocoder() | |||
self.__model_loaded = True | |||
if torch.cuda.is_available(): | |||
torch.manual_seed(self.__am_config.seed) | |||
self.__device = torch.device('cuda') | |||
else: | |||
self.__device = torch.device('cpu') | |||
self.__load_am() | |||
self.__load_vocoder() | |||
self.__model_loaded = True | |||
return self.__vocoder_forward(self.__am_forward(symbol_seq)) |
@@ -5,11 +5,11 @@ from abc import ABC, abstractmethod | |||
from typing import Any, Callable, Dict, List, Optional, Union | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models.builder import build_model | |||
from modelscope.utils.checkpoint import save_pretrained | |||
from modelscope.models.builder import MODELS, build_model | |||
from modelscope.utils.checkpoint import save_checkpoint, save_pretrained | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
from modelscope.utils.device import device_placement, verify_device | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile, Tasks | |||
from modelscope.utils.device import verify_device | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@@ -66,7 +66,6 @@ class Model(ABC): | |||
revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
cfg_dict: Config = None, | |||
device: str = None, | |||
*model_args, | |||
**kwargs): | |||
""" Instantiate a model from local directory or remote model repo. Note | |||
that when loading from remote, the model revision can be specified. | |||
@@ -90,11 +89,11 @@ class Model(ABC): | |||
cfg = Config.from_file( | |||
osp.join(local_model_dir, ModelFile.CONFIGURATION)) | |||
task_name = cfg.task | |||
if 'task' in kwargs: | |||
task_name = kwargs.pop('task') | |||
model_cfg = cfg.model | |||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | |||
model_cfg.type = model_cfg.model_type | |||
model_cfg.model_dir = local_model_dir | |||
for k, v in kwargs.items(): | |||
model_cfg[k] = v | |||
@@ -109,15 +108,19 @@ class Model(ABC): | |||
# dynamically add pipeline info to model for pipeline inference | |||
if hasattr(cfg, 'pipeline'): | |||
model.pipeline = cfg.pipeline | |||
if not hasattr(model, 'cfg'): | |||
model.cfg = cfg | |||
return model | |||
def save_pretrained(self, | |||
target_folder: Union[str, os.PathLike], | |||
save_checkpoint_names: Union[str, List[str]] = None, | |||
save_function: Callable = None, | |||
save_function: Callable = save_checkpoint, | |||
config: Optional[dict] = None, | |||
**kwargs): | |||
"""save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded | |||
"""save the pretrained model, its configuration and other related files to a directory, | |||
so that it can be re-loaded | |||
Args: | |||
target_folder (Union[str, os.PathLike]): | |||
@@ -133,5 +136,10 @@ class Model(ABC): | |||
The config for the configuration.json, might not be identical with model.config | |||
""" | |||
if config is None and hasattr(self, 'cfg'): | |||
config = self.cfg | |||
assert config is not None, 'Cannot save the model because the model config is empty.' | |||
if isinstance(config, Config): | |||
config = config.to_dict() | |||
save_pretrained(self, target_folder, save_checkpoint_names, | |||
save_function, config, **kwargs) |
@@ -1,12 +1,20 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.utils.config import ConfigDict | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule | |||
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg | |||
MODELS = Registry('models') | |||
BACKBONES = Registry('backbones') | |||
BACKBONES = MODELS | |||
HEADS = Registry('heads') | |||
modules = LazyImportModule.AST_INDEX[INDEX_KEY] | |||
for module_index in list(modules.keys()): | |||
if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': | |||
modules[(MODELS.name.upper(), module_index[1], | |||
module_index[2])] = modules[module_index] | |||
def build_model(cfg: ConfigDict, | |||
task_name: str = None, | |||
@@ -23,30 +31,27 @@ def build_model(cfg: ConfigDict, | |||
cfg, MODELS, group_key=task_name, default_args=default_args) | |||
def build_backbone(cfg: ConfigDict, | |||
field: str = None, | |||
default_args: dict = None): | |||
def build_backbone(cfg: ConfigDict, default_args: dict = None): | |||
""" build backbone given backbone config dict | |||
Args: | |||
cfg (:obj:`ConfigDict`): config dict for backbone object. | |||
field (str, optional): field, such as CV, NLP's backbone | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
return build_from_cfg( | |||
cfg, BACKBONES, group_key=field, default_args=default_args) | |||
cfg, BACKBONES, group_key=Tasks.backbone, default_args=default_args) | |||
def build_head(cfg: ConfigDict, | |||
group_key: str = None, | |||
task_name: str = None, | |||
default_args: dict = None): | |||
""" build head given config dict | |||
Args: | |||
cfg (:obj:`ConfigDict`): config dict for head object. | |||
task_name (str, optional): task name, refer to | |||
:obj:`Tasks` for more details | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
if group_key is None: | |||
group_key = cfg[TYPE_NAME] | |||
return build_from_cfg( | |||
cfg, HEADS, group_key=group_key, default_args=default_args) | |||
cfg, HEADS, group_key=task_name, default_args=default_args) |
@@ -4,14 +4,16 @@ | |||
from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
body_3d_keypoints, cartoon, cmdssl_video_embedding, | |||
crowd_counting, face_2d_keypoints, face_detection, | |||
face_generation, image_classification, image_color_enhance, | |||
image_colorization, image_denoise, image_instance_segmentation, | |||
face_generation, human_wholebody_keypoint, image_classification, | |||
image_color_enhance, image_colorization, image_denoise, | |||
image_inpainting, image_instance_segmentation, | |||
image_panoptic_segmentation, image_portrait_enhancement, | |||
image_reid_person, image_semantic_segmentation, | |||
image_to_image_generation, image_to_image_translation, | |||
movie_scene_segmentation, object_detection, | |||
product_retrieval_embedding, realtime_object_detection, | |||
salient_detection, shop_segmentation, super_resolution, | |||
referring_video_object_segmentation, salient_detection, | |||
shop_segmentation, super_resolution, | |||
video_single_object_tracking, video_summarization, virual_tryon) | |||
# yapf: enable |
@@ -4,6 +4,7 @@ import os | |||
import os.path as osp | |||
import shutil | |||
import subprocess | |||
import uuid | |||
import cv2 | |||
import numpy as np | |||
@@ -84,7 +85,9 @@ class ActionDetONNX(Model): | |||
def forward_video(self, video_name, scale): | |||
min_size, max_size = self._get_sizes(scale) | |||
tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4]) | |||
tmp_dir = osp.join( | |||
self.tmp_dir, | |||
str(uuid.uuid1()) + '_' + osp.basename(video_name)[:-4]) | |||
if osp.exists(tmp_dir): | |||
shutil.rmtree(tmp_dir) | |||
os.makedirs(tmp_dir) | |||
@@ -110,6 +113,7 @@ class ActionDetONNX(Model): | |||
len(frame_names) * self.temporal_stride, | |||
self.temporal_stride)) | |||
batch_imgs = [self.parse_frames(names) for names in frame_names] | |||
shutil.rmtree(tmp_dir) | |||
N, _, T, H, W = batch_imgs[0].shape | |||
scale_min = min_size / min(H, W) | |||
@@ -128,7 +132,6 @@ class ActionDetONNX(Model): | |||
'timestamp': t, | |||
'actions': res | |||
} for t, res in zip(timestamp, results)] | |||
shutil.rmtree(tmp_dir) | |||
return results | |||
def forward(self, video_name): | |||
@@ -1,3 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Any, Dict, Optional, Union | |||
@@ -1,10 +1,10 @@ | |||
# ------------------------------------------------------------------------------ | |||
# Copyright (c) Microsoft | |||
# Licensed under the MIT License. | |||
# Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||
# Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||
# https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||
# ------------------------------------------------------------------------------ | |||
""" | |||
Copyright (c) Microsoft | |||
Licensed under the MIT License. | |||
Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||
Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||
https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||
""" | |||
import functools | |||
import logging | |||
@@ -8,12 +8,14 @@ if TYPE_CHECKING: | |||
from .mtcnn import MtcnnFaceDetector | |||
from .retinaface import RetinaFaceDetection | |||
from .ulfd_slim import UlfdFaceDetector | |||
from .scrfd import ScrfdDetect | |||
else: | |||
_import_structure = { | |||
'ulfd_slim': ['UlfdFaceDetector'], | |||
'retinaface': ['RetinaFaceDetection'], | |||
'mtcnn': ['MtcnnFaceDetector'], | |||
'mogface': ['MogFaceDetector'] | |||
'mogface': ['MogFaceDetector'], | |||
'scrfd': ['ScrfdDetect'] | |||
} | |||
import sys | |||
@@ -1,189 +0,0 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||
""" | |||
import numpy as np | |||
from mmdet.datasets.builder import PIPELINES | |||
from numpy import random | |||
@PIPELINES.register_module() | |||
class RandomSquareCrop(object): | |||
"""Random crop the image & bboxes, the cropped patches have minimum IoU | |||
requirement with original image & bboxes, the IoU threshold is randomly | |||
selected from min_ious. | |||
Args: | |||
min_ious (tuple): minimum IoU threshold for all intersections with | |||
bounding boxes | |||
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||
where a >= min_crop_size). | |||
Note: | |||
The keys for bboxes, labels and masks should be paired. That is, \ | |||
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||
""" | |||
def __init__(self, | |||
crop_ratio_range=None, | |||
crop_choice=None, | |||
bbox_clip_border=True): | |||
self.crop_ratio_range = crop_ratio_range | |||
self.crop_choice = crop_choice | |||
self.bbox_clip_border = bbox_clip_border | |||
assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||
if self.crop_ratio_range is not None: | |||
self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||
self.bbox2label = { | |||
'gt_bboxes': 'gt_labels', | |||
'gt_bboxes_ignore': 'gt_labels_ignore' | |||
} | |||
self.bbox2mask = { | |||
'gt_bboxes': 'gt_masks', | |||
'gt_bboxes_ignore': 'gt_masks_ignore' | |||
} | |||
def __call__(self, results): | |||
"""Call function to crop images and bounding boxes with minimum IoU | |||
constraint. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Result dict with images and bounding boxes cropped, \ | |||
'img_shape' key is updated. | |||
""" | |||
if 'img_fields' in results: | |||
assert results['img_fields'] == ['img'], \ | |||
'Only single img_fields is allowed' | |||
img = results['img'] | |||
assert 'bbox_fields' in results | |||
assert 'gt_bboxes' in results | |||
boxes = results['gt_bboxes'] | |||
h, w, c = img.shape | |||
scale_retry = 0 | |||
if self.crop_ratio_range is not None: | |||
max_scale = self.crop_ratio_max | |||
else: | |||
max_scale = np.amax(self.crop_choice) | |||
while True: | |||
scale_retry += 1 | |||
if scale_retry == 1 or max_scale > 1.0: | |||
if self.crop_ratio_range is not None: | |||
scale = np.random.uniform(self.crop_ratio_min, | |||
self.crop_ratio_max) | |||
elif self.crop_choice is not None: | |||
scale = np.random.choice(self.crop_choice) | |||
else: | |||
scale = scale * 1.2 | |||
for i in range(250): | |||
short_side = min(w, h) | |||
cw = int(scale * short_side) | |||
ch = cw | |||
# TODO +1 | |||
if w == cw: | |||
left = 0 | |||
elif w > cw: | |||
left = random.randint(0, w - cw) | |||
else: | |||
left = random.randint(w - cw, 0) | |||
if h == ch: | |||
top = 0 | |||
elif h > ch: | |||
top = random.randint(0, h - ch) | |||
else: | |||
top = random.randint(h - ch, 0) | |||
patch = np.array( | |||
(int(left), int(top), int(left + cw), int(top + ch)), | |||
dtype=np.int) | |||
# center of boxes should inside the crop img | |||
# only adjust boxes and instance masks when the gt is not empty | |||
# adjust boxes | |||
def is_center_of_bboxes_in_patch(boxes, patch): | |||
# TODO >= | |||
center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||
mask = \ | |||
((center[:, 0] > patch[0]) | |||
* (center[:, 1] > patch[1]) | |||
* (center[:, 0] < patch[2]) | |||
* (center[:, 1] < patch[3])) | |||
return mask | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
if not mask.any(): | |||
continue | |||
for key in results.get('bbox_fields', []): | |||
boxes = results[key].copy() | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
boxes = boxes[mask] | |||
if self.bbox_clip_border: | |||
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||
boxes -= np.tile(patch[:2], 2) | |||
results[key] = boxes | |||
# labels | |||
label_key = self.bbox2label.get(key) | |||
if label_key in results: | |||
results[label_key] = results[label_key][mask] | |||
# keypoints field | |||
if key == 'gt_bboxes': | |||
for kps_key in results.get('keypoints_fields', []): | |||
keypointss = results[kps_key].copy() | |||
keypointss = keypointss[mask, :, :] | |||
if self.bbox_clip_border: | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
max=patch[2:]) | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
min=patch[:2]) | |||
keypointss[:, :, 0] -= patch[0] | |||
keypointss[:, :, 1] -= patch[1] | |||
results[kps_key] = keypointss | |||
# mask fields | |||
mask_key = self.bbox2mask.get(key) | |||
if mask_key in results: | |||
results[mask_key] = results[mask_key][mask.nonzero() | |||
[0]].crop(patch) | |||
# adjust the img no matter whether the gt is empty before crop | |||
rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||
patch_from = patch.copy() | |||
patch_from[0] = max(0, patch_from[0]) | |||
patch_from[1] = max(0, patch_from[1]) | |||
patch_from[2] = min(img.shape[1], patch_from[2]) | |||
patch_from[3] = min(img.shape[0], patch_from[3]) | |||
patch_to = patch.copy() | |||
patch_to[0] = max(0, patch_to[0] * -1) | |||
patch_to[1] = max(0, patch_to[1] * -1) | |||
patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||
patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||
rimg[patch_to[1]:patch_to[3], | |||
patch_to[0]:patch_to[2], :] = img[ | |||
patch_from[1]:patch_from[3], | |||
patch_from[0]:patch_from[2], :] | |||
img = rimg | |||
results['img'] = img | |||
results['img_shape'] = img.shape | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(min_ious={self.min_iou}, ' | |||
repr_str += f'crop_size={self.crop_size})' | |||
return repr_str |
@@ -1,3 +1,5 @@ | |||
# The implementation is based on MogFace, available at | |||
# https://github.com/damo-cv/MogFace | |||
import os | |||
import cv2 | |||
@@ -0,0 +1,2 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .scrfd_detect import ScrfdDetect |
@@ -6,7 +6,7 @@ import numpy as np | |||
import torch | |||
def bbox2result(bboxes, labels, num_classes, kps=None): | |||
def bbox2result(bboxes, labels, num_classes, kps=None, num_kps=5): | |||
"""Convert detection results to a list of numpy arrays. | |||
Args: | |||
@@ -17,7 +17,7 @@ def bbox2result(bboxes, labels, num_classes, kps=None): | |||
Returns: | |||
list(ndarray): bbox results of each class | |||
""" | |||
bbox_len = 5 if kps is None else 5 + 10 # if has kps, add 10 kps into bbox | |||
bbox_len = 5 if kps is None else 5 + num_kps * 2 # if has kps, add num_kps*2 into bbox | |||
if bboxes.shape[0] == 0: | |||
return [ | |||
np.zeros((0, bbox_len), dtype=np.float32) |
@@ -17,6 +17,7 @@ def multiclass_nms(multi_bboxes, | |||
Args: | |||
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |||
multi_kps (Tensor): shape (n, #class*num_kps*2) or (n, num_kps*2) | |||
multi_scores (Tensor): shape (n, #class), where the last column | |||
contains scores of the background class, but this will be ignored. | |||
score_thr (float): bbox threshold, bboxes with scores lower than it | |||
@@ -36,16 +37,18 @@ def multiclass_nms(multi_bboxes, | |||
num_classes = multi_scores.size(1) - 1 | |||
# exclude background category | |||
kps = None | |||
if multi_kps is not None: | |||
num_kps = int((multi_kps.shape[1] / num_classes) / 2) | |||
if multi_bboxes.shape[1] > 4: | |||
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) | |||
if multi_kps is not None: | |||
kps = multi_kps.view(multi_scores.size(0), -1, 10) | |||
kps = multi_kps.view(multi_scores.size(0), -1, num_kps * 2) | |||
else: | |||
bboxes = multi_bboxes[:, None].expand( | |||
multi_scores.size(0), num_classes, 4) | |||
if multi_kps is not None: | |||
kps = multi_kps[:, None].expand( | |||
multi_scores.size(0), num_classes, 10) | |||
multi_scores.size(0), num_classes, num_kps * 2) | |||
scores = multi_scores[:, :-1] | |||
if score_factors is not None: | |||
@@ -56,7 +59,7 @@ def multiclass_nms(multi_bboxes, | |||
bboxes = bboxes.reshape(-1, 4) | |||
if kps is not None: | |||
kps = kps.reshape(-1, 10) | |||
kps = kps.reshape(-1, num_kps * 2) | |||
scores = scores.reshape(-1) | |||
labels = labels.reshape(-1) | |||
@@ -2,6 +2,12 @@ | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines | |||
""" | |||
from .auto_augment import RotateV2 | |||
from .formating import DefaultFormatBundleV2 | |||
from .loading import LoadAnnotationsV2 | |||
from .transforms import RandomSquareCrop | |||
__all__ = ['RandomSquareCrop'] | |||
__all__ = [ | |||
'RandomSquareCrop', 'LoadAnnotationsV2', 'RotateV2', | |||
'DefaultFormatBundleV2' | |||
] |
@@ -0,0 +1,271 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/auto_augment.py | |||
""" | |||
import copy | |||
import cv2 | |||
import mmcv | |||
import numpy as np | |||
from mmdet.datasets.builder import PIPELINES | |||
_MAX_LEVEL = 10 | |||
def level_to_value(level, max_value): | |||
"""Map from level to values based on max_value.""" | |||
return (level / _MAX_LEVEL) * max_value | |||
def random_negative(value, random_negative_prob): | |||
"""Randomly negate value based on random_negative_prob.""" | |||
return -value if np.random.rand() < random_negative_prob else value | |||
def bbox2fields(): | |||
"""The key correspondence from bboxes to labels, masks and | |||
segmentations.""" | |||
bbox2label = { | |||
'gt_bboxes': 'gt_labels', | |||
'gt_bboxes_ignore': 'gt_labels_ignore' | |||
} | |||
bbox2mask = { | |||
'gt_bboxes': 'gt_masks', | |||
'gt_bboxes_ignore': 'gt_masks_ignore' | |||
} | |||
bbox2seg = { | |||
'gt_bboxes': 'gt_semantic_seg', | |||
} | |||
return bbox2label, bbox2mask, bbox2seg | |||
@PIPELINES.register_module() | |||
class RotateV2(object): | |||
"""Apply Rotate Transformation to image (and its corresponding bbox, mask, | |||
segmentation). | |||
Args: | |||
level (int | float): The level should be in range (0,_MAX_LEVEL]. | |||
scale (int | float): Isotropic scale factor. Same in | |||
``mmcv.imrotate``. | |||
center (int | float | tuple[float]): Center point (w, h) of the | |||
rotation in the source image. If None, the center of the | |||
image will be used. Same in ``mmcv.imrotate``. | |||
img_fill_val (int | float | tuple): The fill value for image border. | |||
If float, the same value will be used for all the three | |||
channels of image. If tuple, the should be 3 elements (e.g. | |||
equals the number of channels for image). | |||
seg_ignore_label (int): The fill value used for segmentation map. | |||
Note this value must equals ``ignore_label`` in ``semantic_head`` | |||
of the corresponding config. Default 255. | |||
prob (float): The probability for perform transformation and | |||
should be in range 0 to 1. | |||
max_rotate_angle (int | float): The maximum angles for rotate | |||
transformation. | |||
random_negative_prob (float): The probability that turns the | |||
offset negative. | |||
""" | |||
def __init__(self, | |||
level, | |||
scale=1, | |||
center=None, | |||
img_fill_val=128, | |||
seg_ignore_label=255, | |||
prob=0.5, | |||
max_rotate_angle=30, | |||
random_negative_prob=0.5): | |||
assert isinstance(level, (int, float)), \ | |||
f'The level must be type int or float. got {type(level)}.' | |||
assert 0 <= level <= _MAX_LEVEL, \ | |||
f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.' | |||
assert isinstance(scale, (int, float)), \ | |||
f'The scale must be type int or float. got type {type(scale)}.' | |||
if isinstance(center, (int, float)): | |||
center = (center, center) | |||
elif isinstance(center, tuple): | |||
assert len(center) == 2, 'center with type tuple must have '\ | |||
f'2 elements. got {len(center)} elements.' | |||
else: | |||
assert center is None, 'center must be None or type int, '\ | |||
f'float or tuple, got type {type(center)}.' | |||
if isinstance(img_fill_val, (float, int)): | |||
img_fill_val = tuple([float(img_fill_val)] * 3) | |||
elif isinstance(img_fill_val, tuple): | |||
assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\ | |||
f'have 3 elements. got {len(img_fill_val)}.' | |||
img_fill_val = tuple([float(val) for val in img_fill_val]) | |||
else: | |||
raise ValueError( | |||
'img_fill_val must be float or tuple with 3 elements.') | |||
assert np.all([0 <= val <= 255 for val in img_fill_val]), \ | |||
'all elements of img_fill_val should between range [0,255]. '\ | |||
f'got {img_fill_val}.' | |||
assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\ | |||
f'got {prob}.' | |||
assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\ | |||
f'should be type int or float. got type {type(max_rotate_angle)}.' | |||
self.level = level | |||
self.scale = scale | |||
# Rotation angle in degrees. Positive values mean | |||
# clockwise rotation. | |||
self.angle = level_to_value(level, max_rotate_angle) | |||
self.center = center | |||
self.img_fill_val = img_fill_val | |||
self.seg_ignore_label = seg_ignore_label | |||
self.prob = prob | |||
self.max_rotate_angle = max_rotate_angle | |||
self.random_negative_prob = random_negative_prob | |||
def _rotate_img(self, results, angle, center=None, scale=1.0): | |||
"""Rotate the image. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
angle (float): Rotation angle in degrees, positive values | |||
mean clockwise rotation. Same in ``mmcv.imrotate``. | |||
center (tuple[float], optional): Center point (w, h) of the | |||
rotation. Same in ``mmcv.imrotate``. | |||
scale (int | float): Isotropic scale factor. Same in | |||
``mmcv.imrotate``. | |||
""" | |||
for key in results.get('img_fields', ['img']): | |||
img = results[key].copy() | |||
img_rotated = mmcv.imrotate( | |||
img, angle, center, scale, border_value=self.img_fill_val) | |||
results[key] = img_rotated.astype(img.dtype) | |||
results['img_shape'] = results[key].shape | |||
def _rotate_bboxes(self, results, rotate_matrix): | |||
"""Rotate the bboxes.""" | |||
h, w, c = results['img_shape'] | |||
for key in results.get('bbox_fields', []): | |||
min_x, min_y, max_x, max_y = np.split( | |||
results[key], results[key].shape[-1], axis=-1) | |||
coordinates = np.stack([[min_x, min_y], [max_x, min_y], | |||
[min_x, max_y], | |||
[max_x, max_y]]) # [4, 2, nb_bbox, 1] | |||
# pad 1 to convert from format [x, y] to homogeneous | |||
# coordinates format [x, y, 1] | |||
coordinates = np.concatenate( | |||
(coordinates, | |||
np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)), | |||
axis=1) # [4, 3, nb_bbox, 1] | |||
coordinates = coordinates.transpose( | |||
(2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] | |||
rotated_coords = np.matmul(rotate_matrix, | |||
coordinates) # [nb_bbox, 4, 2, 1] | |||
rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] | |||
min_x, min_y = np.min( | |||
rotated_coords[:, :, 0], axis=1), np.min( | |||
rotated_coords[:, :, 1], axis=1) | |||
max_x, max_y = np.max( | |||
rotated_coords[:, :, 0], axis=1), np.max( | |||
rotated_coords[:, :, 1], axis=1) | |||
results[key] = np.stack([min_x, min_y, max_x, max_y], | |||
axis=-1).astype(results[key].dtype) | |||
def _rotate_keypoints90(self, results, angle): | |||
"""Rotate the keypoints, only valid when angle in [-90,90,-180,180]""" | |||
if angle not in [-90, 90, 180, -180 | |||
] or self.scale != 1 or self.center is not None: | |||
return | |||
for key in results.get('keypoints_fields', []): | |||
k = results[key] | |||
if angle == 90: | |||
w, h, c = results['img'].shape | |||
new = np.stack([h - k[..., 1], k[..., 0], k[..., 2]], axis=-1) | |||
elif angle == -90: | |||
w, h, c = results['img'].shape | |||
new = np.stack([k[..., 1], w - k[..., 0], k[..., 2]], axis=-1) | |||
else: | |||
h, w, c = results['img'].shape | |||
new = np.stack([w - k[..., 0], h - k[..., 1], k[..., 2]], | |||
axis=-1) | |||
# a kps is invalid if thrid value is -1 | |||
kps_invalid = new[..., -1][:, -1] == -1 | |||
new[kps_invalid] = np.zeros(new.shape[1:]) - 1 | |||
results[key] = new | |||
def _rotate_masks(self, | |||
results, | |||
angle, | |||
center=None, | |||
scale=1.0, | |||
fill_val=0): | |||
"""Rotate the masks.""" | |||
h, w, c = results['img_shape'] | |||
for key in results.get('mask_fields', []): | |||
masks = results[key] | |||
results[key] = masks.rotate((h, w), angle, center, scale, fill_val) | |||
def _rotate_seg(self, | |||
results, | |||
angle, | |||
center=None, | |||
scale=1.0, | |||
fill_val=255): | |||
"""Rotate the segmentation map.""" | |||
for key in results.get('seg_fields', []): | |||
seg = results[key].copy() | |||
results[key] = mmcv.imrotate( | |||
seg, angle, center, scale, | |||
border_value=fill_val).astype(seg.dtype) | |||
def _filter_invalid(self, results, min_bbox_size=0): | |||
"""Filter bboxes and corresponding masks too small after rotate | |||
augmentation.""" | |||
bbox2label, bbox2mask, _ = bbox2fields() | |||
for key in results.get('bbox_fields', []): | |||
bbox_w = results[key][:, 2] - results[key][:, 0] | |||
bbox_h = results[key][:, 3] - results[key][:, 1] | |||
valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size) | |||
valid_inds = np.nonzero(valid_inds)[0] | |||
results[key] = results[key][valid_inds] | |||
# label fields. e.g. gt_labels and gt_labels_ignore | |||
label_key = bbox2label.get(key) | |||
if label_key in results: | |||
results[label_key] = results[label_key][valid_inds] | |||
# mask fields, e.g. gt_masks and gt_masks_ignore | |||
mask_key = bbox2mask.get(key) | |||
if mask_key in results: | |||
results[mask_key] = results[mask_key][valid_inds] | |||
def __call__(self, results): | |||
"""Call function to rotate images, bounding boxes, masks and semantic | |||
segmentation maps. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Rotated results. | |||
""" | |||
if np.random.rand() > self.prob: | |||
return results | |||
h, w = results['img'].shape[:2] | |||
center = self.center | |||
if center is None: | |||
center = ((w - 1) * 0.5, (h - 1) * 0.5) | |||
angle = random_negative(self.angle, self.random_negative_prob) | |||
self._rotate_img(results, angle, center, self.scale) | |||
rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) | |||
self._rotate_bboxes(results, rotate_matrix) | |||
self._rotate_keypoints90(results, angle) | |||
self._rotate_masks(results, angle, center, self.scale, fill_val=0) | |||
self._rotate_seg( | |||
results, angle, center, self.scale, fill_val=self.seg_ignore_label) | |||
self._filter_invalid(results) | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(level={self.level}, ' | |||
repr_str += f'scale={self.scale}, ' | |||
repr_str += f'center={self.center}, ' | |||
repr_str += f'img_fill_val={self.img_fill_val}, ' | |||
repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' | |||
repr_str += f'prob={self.prob}, ' | |||
repr_str += f'max_rotate_angle={self.max_rotate_angle}, ' | |||
repr_str += f'random_negative_prob={self.random_negative_prob})' | |||
return repr_str |
@@ -0,0 +1,113 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/formating.py | |||
""" | |||
import numpy as np | |||
import torch | |||
from mmcv.parallel import DataContainer as DC | |||
from mmdet.datasets.builder import PIPELINES | |||
def to_tensor(data): | |||
"""Convert objects of various python types to :obj:`torch.Tensor`. | |||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |||
:class:`Sequence`, :class:`int` and :class:`float`. | |||
Args: | |||
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to | |||
be converted. | |||
""" | |||
if isinstance(data, torch.Tensor): | |||
return data | |||
elif isinstance(data, np.ndarray): | |||
return torch.from_numpy(data) | |||
elif isinstance(data, Sequence) and not mmcv.is_str(data): | |||
return torch.tensor(data) | |||
elif isinstance(data, int): | |||
return torch.LongTensor([data]) | |||
elif isinstance(data, float): | |||
return torch.FloatTensor([data]) | |||
else: | |||
raise TypeError(f'type {type(data)} cannot be converted to tensor.') | |||
@PIPELINES.register_module() | |||
class DefaultFormatBundleV2(object): | |||
"""Default formatting bundle. | |||
It simplifies the pipeline of formatting common fields, including "img", | |||
"proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". | |||
These fields are formatted as follows. | |||
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) | |||
- proposals: (1)to tensor, (2)to DataContainer | |||
- gt_bboxes: (1)to tensor, (2)to DataContainer | |||
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer | |||
- gt_labels: (1)to tensor, (2)to DataContainer | |||
- gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) | |||
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ | |||
(3)to DataContainer (stack=True) | |||
""" | |||
def __call__(self, results): | |||
"""Call function to transform and format common fields in results. | |||
Args: | |||
results (dict): Result dict contains the data to convert. | |||
Returns: | |||
dict: The result dict contains the data that is formatted with \ | |||
default bundle. | |||
""" | |||
if 'img' in results: | |||
img = results['img'] | |||
# add default meta keys | |||
results = self._add_default_meta_keys(results) | |||
if len(img.shape) < 3: | |||
img = np.expand_dims(img, -1) | |||
img = np.ascontiguousarray(img.transpose(2, 0, 1)) | |||
results['img'] = DC(to_tensor(img), stack=True) | |||
for key in [ | |||
'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_keypointss', | |||
'gt_labels' | |||
]: | |||
if key not in results: | |||
continue | |||
results[key] = DC(to_tensor(results[key])) | |||
if 'gt_masks' in results: | |||
results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) | |||
if 'gt_semantic_seg' in results: | |||
results['gt_semantic_seg'] = DC( | |||
to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) | |||
return results | |||
def _add_default_meta_keys(self, results): | |||
"""Add default meta keys. | |||
We set default meta keys including `pad_shape`, `scale_factor` and | |||
`img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and | |||
`Pad` are implemented during the whole pipeline. | |||
Args: | |||
results (dict): Result dict contains the data to convert. | |||
Returns: | |||
results (dict): Updated result dict contains the data to convert. | |||
""" | |||
img = results['img'] | |||
results.setdefault('pad_shape', img.shape) | |||
results.setdefault('scale_factor', 1.0) | |||
num_channels = 1 if len(img.shape) < 3 else img.shape[2] | |||
results.setdefault( | |||
'img_norm_cfg', | |||
dict( | |||
mean=np.zeros(num_channels, dtype=np.float32), | |||
std=np.ones(num_channels, dtype=np.float32), | |||
to_rgb=False)) | |||
return results | |||
def __repr__(self): | |||
return self.__class__.__name__ |
@@ -0,0 +1,225 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/loading.py | |||
""" | |||
import os.path as osp | |||
import numpy as np | |||
import pycocotools.mask as maskUtils | |||
from mmdet.core import BitmapMasks, PolygonMasks | |||
from mmdet.datasets.builder import PIPELINES | |||
@PIPELINES.register_module() | |||
class LoadAnnotationsV2(object): | |||
"""Load mutiple types of annotations. | |||
Args: | |||
with_bbox (bool): Whether to parse and load the bbox annotation. | |||
Default: True. | |||
with_label (bool): Whether to parse and load the label annotation. | |||
Default: True. | |||
with_keypoints (bool): Whether to parse and load the keypoints annotation. | |||
Default: False. | |||
with_mask (bool): Whether to parse and load the mask annotation. | |||
Default: False. | |||
with_seg (bool): Whether to parse and load the semantic segmentation | |||
annotation. Default: False. | |||
poly2mask (bool): Whether to convert the instance masks from polygons | |||
to bitmaps. Default: True. | |||
file_client_args (dict): Arguments to instantiate a FileClient. | |||
See :class:`mmcv.fileio.FileClient` for details. | |||
Defaults to ``dict(backend='disk')``. | |||
""" | |||
def __init__(self, | |||
with_bbox=True, | |||
with_label=True, | |||
with_keypoints=False, | |||
with_mask=False, | |||
with_seg=False, | |||
poly2mask=True, | |||
file_client_args=dict(backend='disk')): | |||
self.with_bbox = with_bbox | |||
self.with_label = with_label | |||
self.with_keypoints = with_keypoints | |||
self.with_mask = with_mask | |||
self.with_seg = with_seg | |||
self.poly2mask = poly2mask | |||
self.file_client_args = file_client_args.copy() | |||
self.file_client = None | |||
def _load_bboxes(self, results): | |||
"""Private function to load bounding box annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded bounding box annotations. | |||
""" | |||
ann_info = results['ann_info'] | |||
results['gt_bboxes'] = ann_info['bboxes'].copy() | |||
gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) | |||
if gt_bboxes_ignore is not None: | |||
results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() | |||
results['bbox_fields'].append('gt_bboxes_ignore') | |||
results['bbox_fields'].append('gt_bboxes') | |||
return results | |||
def _load_keypoints(self, results): | |||
"""Private function to load bounding box annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded bounding box annotations. | |||
""" | |||
ann_info = results['ann_info'] | |||
results['gt_keypointss'] = ann_info['keypointss'].copy() | |||
results['keypoints_fields'] = ['gt_keypointss'] | |||
return results | |||
def _load_labels(self, results): | |||
"""Private function to load label annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded label annotations. | |||
""" | |||
results['gt_labels'] = results['ann_info']['labels'].copy() | |||
return results | |||
def _poly2mask(self, mask_ann, img_h, img_w): | |||
"""Private function to convert masks represented with polygon to | |||
bitmaps. | |||
Args: | |||
mask_ann (list | dict): Polygon mask annotation input. | |||
img_h (int): The height of output mask. | |||
img_w (int): The width of output mask. | |||
Returns: | |||
numpy.ndarray: The decode bitmap mask of shape (img_h, img_w). | |||
""" | |||
if isinstance(mask_ann, list): | |||
# polygon -- a single object might consist of multiple parts | |||
# we merge all parts into one mask rle code | |||
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||
rle = maskUtils.merge(rles) | |||
elif isinstance(mask_ann['counts'], list): | |||
# uncompressed RLE | |||
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) | |||
else: | |||
# rle | |||
rle = mask_ann | |||
mask = maskUtils.decode(rle) | |||
return mask | |||
def process_polygons(self, polygons): | |||
"""Convert polygons to list of ndarray and filter invalid polygons. | |||
Args: | |||
polygons (list[list]): Polygons of one instance. | |||
Returns: | |||
list[numpy.ndarray]: Processed polygons. | |||
""" | |||
polygons = [np.array(p) for p in polygons] | |||
valid_polygons = [] | |||
for polygon in polygons: | |||
if len(polygon) % 2 == 0 and len(polygon) >= 6: | |||
valid_polygons.append(polygon) | |||
return valid_polygons | |||
def _load_masks(self, results): | |||
"""Private function to load mask annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded mask annotations. | |||
If ``self.poly2mask`` is set ``True``, `gt_mask` will contain | |||
:obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used. | |||
""" | |||
h, w = results['img_info']['height'], results['img_info']['width'] | |||
gt_masks = results['ann_info']['masks'] | |||
if self.poly2mask: | |||
gt_masks = BitmapMasks( | |||
[self._poly2mask(mask, h, w) for mask in gt_masks], h, w) | |||
else: | |||
gt_masks = PolygonMasks( | |||
[self.process_polygons(polygons) for polygons in gt_masks], h, | |||
w) | |||
results['gt_masks'] = gt_masks | |||
results['mask_fields'].append('gt_masks') | |||
return results | |||
def _load_semantic_seg(self, results): | |||
"""Private function to load semantic segmentation annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`dataset`. | |||
Returns: | |||
dict: The dict contains loaded semantic segmentation annotations. | |||
""" | |||
import mmcv | |||
if self.file_client is None: | |||
self.file_client = mmcv.FileClient(**self.file_client_args) | |||
filename = osp.join(results['seg_prefix'], | |||
results['ann_info']['seg_map']) | |||
img_bytes = self.file_client.get(filename) | |||
results['gt_semantic_seg'] = mmcv.imfrombytes( | |||
img_bytes, flag='unchanged').squeeze() | |||
results['seg_fields'].append('gt_semantic_seg') | |||
return results | |||
def __call__(self, results): | |||
"""Call function to load multiple types annotations. | |||
Args: | |||
results (dict): Result dict from :obj:`mmdet.CustomDataset`. | |||
Returns: | |||
dict: The dict contains loaded bounding box, label, mask and | |||
semantic segmentation annotations. | |||
""" | |||
if self.with_bbox: | |||
results = self._load_bboxes(results) | |||
if results is None: | |||
return None | |||
if self.with_label: | |||
results = self._load_labels(results) | |||
if self.with_keypoints: | |||
results = self._load_keypoints(results) | |||
if self.with_mask: | |||
results = self._load_masks(results) | |||
if self.with_seg: | |||
results = self._load_semantic_seg(results) | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(with_bbox={self.with_bbox}, ' | |||
repr_str += f'with_label={self.with_label}, ' | |||
repr_str += f'with_keypoints={self.with_keypoints}, ' | |||
repr_str += f'with_mask={self.with_mask}, ' | |||
repr_str += f'with_seg={self.with_seg})' | |||
repr_str += f'poly2mask={self.poly2mask})' | |||
repr_str += f'poly2mask={self.file_client_args})' | |||
return repr_str |
@@ -0,0 +1,737 @@ | |||
""" | |||
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at | |||
https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py | |||
""" | |||
import mmcv | |||
import numpy as np | |||
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps | |||
from mmdet.datasets.builder import PIPELINES | |||
from numpy import random | |||
@PIPELINES.register_module() | |||
class ResizeV2(object): | |||
"""Resize images & bbox & mask &kps. | |||
This transform resizes the input image to some scale. Bboxes and masks are | |||
then resized with the same scale factor. If the input dict contains the key | |||
"scale", then the scale in the input dict is used, otherwise the specified | |||
scale in the init method is used. If the input dict contains the key | |||
"scale_factor" (if MultiScaleFlipAug does not give img_scale but | |||
scale_factor), the actual scale will be computed by image shape and | |||
scale_factor. | |||
`img_scale` can either be a tuple (single-scale) or a list of tuple | |||
(multi-scale). There are 3 multiscale modes: | |||
- ``ratio_range is not None``: randomly sample a ratio from the ratio \ | |||
range and multiply it with the image scale. | |||
- ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ | |||
sample a scale from the multiscale range. | |||
- ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ | |||
sample a scale from multiple scales. | |||
Args: | |||
img_scale (tuple or list[tuple]): Images scales for resizing. | |||
multiscale_mode (str): Either "range" or "value". | |||
ratio_range (tuple[float]): (min_ratio, max_ratio) | |||
keep_ratio (bool): Whether to keep the aspect ratio when resizing the | |||
image. | |||
bbox_clip_border (bool, optional): Whether clip the objects outside | |||
the border of the image. Defaults to True. | |||
backend (str): Image resize backend, choices are 'cv2' and 'pillow'. | |||
These two backends generates slightly different results. Defaults | |||
to 'cv2'. | |||
override (bool, optional): Whether to override `scale` and | |||
`scale_factor` so as to call resize twice. Default False. If True, | |||
after the first resizing, the existed `scale` and `scale_factor` | |||
will be ignored so the second resizing can be allowed. | |||
This option is a work-around for multiple times of resize in DETR. | |||
Defaults to False. | |||
""" | |||
def __init__(self, | |||
img_scale=None, | |||
multiscale_mode='range', | |||
ratio_range=None, | |||
keep_ratio=True, | |||
bbox_clip_border=True, | |||
backend='cv2', | |||
override=False): | |||
if img_scale is None: | |||
self.img_scale = None | |||
else: | |||
if isinstance(img_scale, list): | |||
self.img_scale = img_scale | |||
else: | |||
self.img_scale = [img_scale] | |||
assert mmcv.is_list_of(self.img_scale, tuple) | |||
if ratio_range is not None: | |||
# mode 1: given a scale and a range of image ratio | |||
assert len(self.img_scale) == 1 | |||
else: | |||
# mode 2: given multiple scales or a range of scales | |||
assert multiscale_mode in ['value', 'range'] | |||
self.backend = backend | |||
self.multiscale_mode = multiscale_mode | |||
self.ratio_range = ratio_range | |||
self.keep_ratio = keep_ratio | |||
# TODO: refactor the override option in Resize | |||
self.override = override | |||
self.bbox_clip_border = bbox_clip_border | |||
@staticmethod | |||
def random_select(img_scales): | |||
"""Randomly select an img_scale from given candidates. | |||
Args: | |||
img_scales (list[tuple]): Images scales for selection. | |||
Returns: | |||
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ | |||
where ``img_scale`` is the selected image scale and \ | |||
``scale_idx`` is the selected index in the given candidates. | |||
""" | |||
assert mmcv.is_list_of(img_scales, tuple) | |||
scale_idx = np.random.randint(len(img_scales)) | |||
img_scale = img_scales[scale_idx] | |||
return img_scale, scale_idx | |||
@staticmethod | |||
def random_sample(img_scales): | |||
"""Randomly sample an img_scale when ``multiscale_mode=='range'``. | |||
Args: | |||
img_scales (list[tuple]): Images scale range for sampling. | |||
There must be two tuples in img_scales, which specify the lower | |||
and uper bound of image scales. | |||
Returns: | |||
(tuple, None): Returns a tuple ``(img_scale, None)``, where \ | |||
``img_scale`` is sampled scale and None is just a placeholder \ | |||
to be consistent with :func:`random_select`. | |||
""" | |||
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 | |||
img_scale_long = [max(s) for s in img_scales] | |||
img_scale_short = [min(s) for s in img_scales] | |||
long_edge = np.random.randint( | |||
min(img_scale_long), | |||
max(img_scale_long) + 1) | |||
short_edge = np.random.randint( | |||
min(img_scale_short), | |||
max(img_scale_short) + 1) | |||
img_scale = (long_edge, short_edge) | |||
return img_scale, None | |||
@staticmethod | |||
def random_sample_ratio(img_scale, ratio_range): | |||
"""Randomly sample an img_scale when ``ratio_range`` is specified. | |||
A ratio will be randomly sampled from the range specified by | |||
``ratio_range``. Then it would be multiplied with ``img_scale`` to | |||
generate sampled scale. | |||
Args: | |||
img_scale (tuple): Images scale base to multiply with ratio. | |||
ratio_range (tuple[float]): The minimum and maximum ratio to scale | |||
the ``img_scale``. | |||
Returns: | |||
(tuple, None): Returns a tuple ``(scale, None)``, where \ | |||
``scale`` is sampled ratio multiplied with ``img_scale`` and \ | |||
None is just a placeholder to be consistent with \ | |||
:func:`random_select`. | |||
""" | |||
assert isinstance(img_scale, tuple) and len(img_scale) == 2 | |||
min_ratio, max_ratio = ratio_range | |||
assert min_ratio <= max_ratio | |||
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio | |||
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) | |||
return scale, None | |||
def _random_scale(self, results): | |||
"""Randomly sample an img_scale according to ``ratio_range`` and | |||
``multiscale_mode``. | |||
If ``ratio_range`` is specified, a ratio will be sampled and be | |||
multiplied with ``img_scale``. | |||
If multiple scales are specified by ``img_scale``, a scale will be | |||
sampled according to ``multiscale_mode``. | |||
Otherwise, single scale will be used. | |||
Args: | |||
results (dict): Result dict from :obj:`dataset`. | |||
Returns: | |||
dict: Two new keys 'scale` and 'scale_idx` are added into \ | |||
``results``, which would be used by subsequent pipelines. | |||
""" | |||
if self.ratio_range is not None: | |||
scale, scale_idx = self.random_sample_ratio( | |||
self.img_scale[0], self.ratio_range) | |||
elif len(self.img_scale) == 1: | |||
scale, scale_idx = self.img_scale[0], 0 | |||
elif self.multiscale_mode == 'range': | |||
scale, scale_idx = self.random_sample(self.img_scale) | |||
elif self.multiscale_mode == 'value': | |||
scale, scale_idx = self.random_select(self.img_scale) | |||
else: | |||
raise NotImplementedError | |||
results['scale'] = scale | |||
results['scale_idx'] = scale_idx | |||
def _resize_img(self, results): | |||
"""Resize images with ``results['scale']``.""" | |||
for key in results.get('img_fields', ['img']): | |||
if self.keep_ratio: | |||
img, scale_factor = mmcv.imrescale( | |||
results[key], | |||
results['scale'], | |||
return_scale=True, | |||
backend=self.backend) | |||
# the w_scale and h_scale has minor difference | |||
# a real fix should be done in the mmcv.imrescale in the future | |||
new_h, new_w = img.shape[:2] | |||
h, w = results[key].shape[:2] | |||
w_scale = new_w / w | |||
h_scale = new_h / h | |||
else: | |||
img, w_scale, h_scale = mmcv.imresize( | |||
results[key], | |||
results['scale'], | |||
return_scale=True, | |||
backend=self.backend) | |||
results[key] = img | |||
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], | |||
dtype=np.float32) | |||
results['img_shape'] = img.shape | |||
# in case that there is no padding | |||
results['pad_shape'] = img.shape | |||
results['scale_factor'] = scale_factor | |||
results['keep_ratio'] = self.keep_ratio | |||
def _resize_bboxes(self, results): | |||
"""Resize bounding boxes with ``results['scale_factor']``.""" | |||
for key in results.get('bbox_fields', []): | |||
bboxes = results[key] * results['scale_factor'] | |||
if self.bbox_clip_border: | |||
img_shape = results['img_shape'] | |||
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) | |||
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) | |||
results[key] = bboxes | |||
def _resize_keypoints(self, results): | |||
"""Resize keypoints with ``results['scale_factor']``.""" | |||
for key in results.get('keypoints_fields', []): | |||
keypointss = results[key].copy() | |||
factors = results['scale_factor'] | |||
assert factors[0] == factors[2] | |||
assert factors[1] == factors[3] | |||
keypointss[:, :, 0] *= factors[0] | |||
keypointss[:, :, 1] *= factors[1] | |||
if self.bbox_clip_border: | |||
img_shape = results['img_shape'] | |||
keypointss[:, :, 0] = np.clip(keypointss[:, :, 0], 0, | |||
img_shape[1]) | |||
keypointss[:, :, 1] = np.clip(keypointss[:, :, 1], 0, | |||
img_shape[0]) | |||
results[key] = keypointss | |||
def _resize_masks(self, results): | |||
"""Resize masks with ``results['scale']``""" | |||
for key in results.get('mask_fields', []): | |||
if results[key] is None: | |||
continue | |||
if self.keep_ratio: | |||
results[key] = results[key].rescale(results['scale']) | |||
else: | |||
results[key] = results[key].resize(results['img_shape'][:2]) | |||
def _resize_seg(self, results): | |||
"""Resize semantic segmentation map with ``results['scale']``.""" | |||
for key in results.get('seg_fields', []): | |||
if self.keep_ratio: | |||
gt_seg = mmcv.imrescale( | |||
results[key], | |||
results['scale'], | |||
interpolation='nearest', | |||
backend=self.backend) | |||
else: | |||
gt_seg = mmcv.imresize( | |||
results[key], | |||
results['scale'], | |||
interpolation='nearest', | |||
backend=self.backend) | |||
results['gt_semantic_seg'] = gt_seg | |||
def __call__(self, results): | |||
"""Call function to resize images, bounding boxes, masks, semantic | |||
segmentation map. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \ | |||
'keep_ratio' keys are added into result dict. | |||
""" | |||
if 'scale' not in results: | |||
if 'scale_factor' in results: | |||
img_shape = results['img'].shape[:2] | |||
scale_factor = results['scale_factor'] | |||
assert isinstance(scale_factor, float) | |||
results['scale'] = tuple( | |||
[int(x * scale_factor) for x in img_shape][::-1]) | |||
else: | |||
self._random_scale(results) | |||
else: | |||
if not self.override: | |||
assert 'scale_factor' not in results, ( | |||
'scale and scale_factor cannot be both set.') | |||
else: | |||
results.pop('scale') | |||
if 'scale_factor' in results: | |||
results.pop('scale_factor') | |||
self._random_scale(results) | |||
self._resize_img(results) | |||
self._resize_bboxes(results) | |||
self._resize_keypoints(results) | |||
self._resize_masks(results) | |||
self._resize_seg(results) | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(img_scale={self.img_scale}, ' | |||
repr_str += f'multiscale_mode={self.multiscale_mode}, ' | |||
repr_str += f'ratio_range={self.ratio_range}, ' | |||
repr_str += f'keep_ratio={self.keep_ratio})' | |||
repr_str += f'bbox_clip_border={self.bbox_clip_border})' | |||
return repr_str | |||
@PIPELINES.register_module() | |||
class RandomFlipV2(object): | |||
"""Flip the image & bbox & mask & kps. | |||
If the input dict contains the key "flip", then the flag will be used, | |||
otherwise it will be randomly decided by a ratio specified in the init | |||
method. | |||
When random flip is enabled, ``flip_ratio``/``direction`` can either be a | |||
float/string or tuple of float/string. There are 3 flip modes: | |||
- ``flip_ratio`` is float, ``direction`` is string: the image will be | |||
``direction``ly flipped with probability of ``flip_ratio`` . | |||
E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, | |||
then image will be horizontally flipped with probability of 0.5. | |||
- ``flip_ratio`` is float, ``direction`` is list of string: the image wil | |||
be ``direction[i]``ly flipped with probability of | |||
``flip_ratio/len(direction)``. | |||
E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, | |||
then image will be horizontally flipped with probability of 0.25, | |||
vertically with probability of 0.25. | |||
- ``flip_ratio`` is list of float, ``direction`` is list of string: | |||
given ``len(flip_ratio) == len(direction)``, the image wil | |||
be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. | |||
E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', | |||
'vertical']``, then image will be horizontally flipped with probability | |||
of 0.3, vertically with probability of 0.5 | |||
Args: | |||
flip_ratio (float | list[float], optional): The flipping probability. | |||
Default: None. | |||
direction(str | list[str], optional): The flipping direction. Options | |||
are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. | |||
If input is a list, the length must equal ``flip_ratio``. Each | |||
element in ``flip_ratio`` indicates the flip probability of | |||
corresponding direction. | |||
""" | |||
def __init__(self, flip_ratio=None, direction='horizontal'): | |||
if isinstance(flip_ratio, list): | |||
assert mmcv.is_list_of(flip_ratio, float) | |||
assert 0 <= sum(flip_ratio) <= 1 | |||
elif isinstance(flip_ratio, float): | |||
assert 0 <= flip_ratio <= 1 | |||
elif flip_ratio is None: | |||
pass | |||
else: | |||
raise ValueError('flip_ratios must be None, float, ' | |||
'or list of float') | |||
self.flip_ratio = flip_ratio | |||
valid_directions = ['horizontal', 'vertical', 'diagonal'] | |||
if isinstance(direction, str): | |||
assert direction in valid_directions | |||
elif isinstance(direction, list): | |||
assert mmcv.is_list_of(direction, str) | |||
assert set(direction).issubset(set(valid_directions)) | |||
else: | |||
raise ValueError('direction must be either str or list of str') | |||
self.direction = direction | |||
if isinstance(flip_ratio, list): | |||
assert len(self.flip_ratio) == len(self.direction) | |||
self.count = 0 | |||
def bbox_flip(self, bboxes, img_shape, direction): | |||
"""Flip bboxes horizontally. | |||
Args: | |||
bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) | |||
img_shape (tuple[int]): Image shape (height, width) | |||
direction (str): Flip direction. Options are 'horizontal', | |||
'vertical'. | |||
Returns: | |||
numpy.ndarray: Flipped bounding boxes. | |||
""" | |||
assert bboxes.shape[-1] % 4 == 0 | |||
flipped = bboxes.copy() | |||
if direction == 'horizontal': | |||
w = img_shape[1] | |||
flipped[..., 0::4] = w - bboxes[..., 2::4] | |||
flipped[..., 2::4] = w - bboxes[..., 0::4] | |||
elif direction == 'vertical': | |||
h = img_shape[0] | |||
flipped[..., 1::4] = h - bboxes[..., 3::4] | |||
flipped[..., 3::4] = h - bboxes[..., 1::4] | |||
elif direction == 'diagonal': | |||
w = img_shape[1] | |||
h = img_shape[0] | |||
flipped[..., 0::4] = w - bboxes[..., 2::4] | |||
flipped[..., 1::4] = h - bboxes[..., 3::4] | |||
flipped[..., 2::4] = w - bboxes[..., 0::4] | |||
flipped[..., 3::4] = h - bboxes[..., 1::4] | |||
else: | |||
raise ValueError(f"Invalid flipping direction '{direction}'") | |||
return flipped | |||
def keypoints_flip(self, keypointss, img_shape, direction): | |||
"""Flip keypoints horizontally.""" | |||
assert direction == 'horizontal' | |||
assert keypointss.shape[-1] == 3 | |||
num_kps = keypointss.shape[1] | |||
assert num_kps in [4, 5], f'Only Support num_kps=4 or 5, got:{num_kps}' | |||
assert keypointss.ndim == 3 | |||
flipped = keypointss.copy() | |||
if num_kps == 5: | |||
flip_order = [1, 0, 2, 4, 3] | |||
elif num_kps == 4: | |||
flip_order = [3, 2, 1, 0] | |||
for idx, a in enumerate(flip_order): | |||
flipped[:, idx, :] = keypointss[:, a, :] | |||
w = img_shape[1] | |||
flipped[..., 0] = w - flipped[..., 0] | |||
return flipped | |||
def __call__(self, results): | |||
"""Call function to flip bounding boxes, masks, semantic segmentation | |||
maps. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Flipped results, 'flip', 'flip_direction' keys are added \ | |||
into result dict. | |||
""" | |||
if 'flip' not in results: | |||
if isinstance(self.direction, list): | |||
# None means non-flip | |||
direction_list = self.direction + [None] | |||
else: | |||
# None means non-flip | |||
direction_list = [self.direction, None] | |||
if isinstance(self.flip_ratio, list): | |||
non_flip_ratio = 1 - sum(self.flip_ratio) | |||
flip_ratio_list = self.flip_ratio + [non_flip_ratio] | |||
else: | |||
non_flip_ratio = 1 - self.flip_ratio | |||
# exclude non-flip | |||
single_ratio = self.flip_ratio / (len(direction_list) - 1) | |||
flip_ratio_list = [single_ratio] * (len(direction_list) | |||
- 1) + [non_flip_ratio] | |||
cur_dir = np.random.choice(direction_list, p=flip_ratio_list) | |||
results['flip'] = cur_dir is not None | |||
if 'flip_direction' not in results: | |||
results['flip_direction'] = cur_dir | |||
if results['flip']: | |||
# flip image | |||
for key in results.get('img_fields', ['img']): | |||
results[key] = mmcv.imflip( | |||
results[key], direction=results['flip_direction']) | |||
# flip bboxes | |||
for key in results.get('bbox_fields', []): | |||
results[key] = self.bbox_flip(results[key], | |||
results['img_shape'], | |||
results['flip_direction']) | |||
# flip kps | |||
for key in results.get('keypoints_fields', []): | |||
results[key] = self.keypoints_flip(results[key], | |||
results['img_shape'], | |||
results['flip_direction']) | |||
# flip masks | |||
for key in results.get('mask_fields', []): | |||
results[key] = results[key].flip(results['flip_direction']) | |||
# flip segs | |||
for key in results.get('seg_fields', []): | |||
results[key] = mmcv.imflip( | |||
results[key], direction=results['flip_direction']) | |||
return results | |||
def __repr__(self): | |||
return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' | |||
@PIPELINES.register_module() | |||
class RandomSquareCrop(object): | |||
"""Random crop the image & bboxes, the cropped patches have minimum IoU | |||
requirement with original image & bboxes, the IoU threshold is randomly | |||
selected from min_ious. | |||
Args: | |||
min_ious (tuple): minimum IoU threshold for all intersections with | |||
bounding boxes | |||
min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, | |||
where a >= min_crop_size). | |||
Note: | |||
The keys for bboxes, labels and masks should be paired. That is, \ | |||
`gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ | |||
`gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. | |||
""" | |||
def __init__(self, | |||
crop_ratio_range=None, | |||
crop_choice=None, | |||
bbox_clip_border=True, | |||
big_face_ratio=0, | |||
big_face_crop_choice=None): | |||
self.crop_ratio_range = crop_ratio_range | |||
self.crop_choice = crop_choice | |||
self.big_face_crop_choice = big_face_crop_choice | |||
self.bbox_clip_border = bbox_clip_border | |||
assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) | |||
if self.crop_ratio_range is not None: | |||
self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range | |||
self.bbox2label = { | |||
'gt_bboxes': 'gt_labels', | |||
'gt_bboxes_ignore': 'gt_labels_ignore' | |||
} | |||
self.bbox2mask = { | |||
'gt_bboxes': 'gt_masks', | |||
'gt_bboxes_ignore': 'gt_masks_ignore' | |||
} | |||
assert big_face_ratio >= 0 and big_face_ratio <= 1.0 | |||
self.big_face_ratio = big_face_ratio | |||
def __call__(self, results): | |||
"""Call function to crop images and bounding boxes with minimum IoU | |||
constraint. | |||
Args: | |||
results (dict): Result dict from loading pipeline. | |||
Returns: | |||
dict: Result dict with images and bounding boxes cropped, \ | |||
'img_shape' key is updated. | |||
""" | |||
if 'img_fields' in results: | |||
assert results['img_fields'] == ['img'], \ | |||
'Only single img_fields is allowed' | |||
img = results['img'] | |||
assert 'bbox_fields' in results | |||
assert 'gt_bboxes' in results | |||
# try augment big face images | |||
find_bigface = False | |||
if np.random.random() < self.big_face_ratio: | |||
min_size = 100 # h and w | |||
expand_ratio = 0.3 # expand ratio of croped face alongwith both w and h | |||
bbox = results['gt_bboxes'].copy() | |||
lmks = results['gt_keypointss'].copy() | |||
label = results['gt_labels'].copy() | |||
# filter small faces | |||
size_mask = ((bbox[:, 2] - bbox[:, 0]) > min_size) * ( | |||
(bbox[:, 3] - bbox[:, 1]) > min_size) | |||
bbox = bbox[size_mask] | |||
lmks = lmks[size_mask] | |||
label = label[size_mask] | |||
# randomly choose a face that has no overlap with others | |||
if len(bbox) > 0: | |||
overlaps = bbox_overlaps(bbox, bbox) | |||
overlaps -= np.eye(overlaps.shape[0]) | |||
iou_mask = np.sum(overlaps, axis=1) == 0 | |||
bbox = bbox[iou_mask] | |||
lmks = lmks[iou_mask] | |||
label = label[iou_mask] | |||
if len(bbox) > 0: | |||
choice = np.random.randint(len(bbox)) | |||
bbox = bbox[choice] | |||
lmks = lmks[choice] | |||
label = [label[choice]] | |||
w = bbox[2] - bbox[0] | |||
h = bbox[3] - bbox[1] | |||
x1 = bbox[0] - w * expand_ratio | |||
x2 = bbox[2] + w * expand_ratio | |||
y1 = bbox[1] - h * expand_ratio | |||
y2 = bbox[3] + h * expand_ratio | |||
x1, x2 = np.clip([x1, x2], 0, img.shape[1]) | |||
y1, y2 = np.clip([y1, y2], 0, img.shape[0]) | |||
bbox -= np.tile([x1, y1], 2) | |||
lmks -= (x1, y1, 0) | |||
find_bigface = True | |||
img = img[int(y1):int(y2), int(x1):int(x2), :] | |||
results['gt_bboxes'] = np.expand_dims(bbox, axis=0) | |||
results['gt_keypointss'] = np.expand_dims(lmks, axis=0) | |||
results['gt_labels'] = np.array(label) | |||
results['img'] = img | |||
boxes = results['gt_bboxes'] | |||
h, w, c = img.shape | |||
if self.crop_ratio_range is not None: | |||
max_scale = self.crop_ratio_max | |||
else: | |||
max_scale = np.amax(self.crop_choice) | |||
scale_retry = 0 | |||
while True: | |||
scale_retry += 1 | |||
if scale_retry == 1 or max_scale > 1.0: | |||
if self.crop_ratio_range is not None: | |||
scale = np.random.uniform(self.crop_ratio_min, | |||
self.crop_ratio_max) | |||
elif self.crop_choice is not None: | |||
scale = np.random.choice(self.crop_choice) | |||
else: | |||
scale = scale * 1.2 | |||
if find_bigface: | |||
# select a scale from big_face_crop_choice if in big_face mode | |||
scale = np.random.choice(self.big_face_crop_choice) | |||
for i in range(250): | |||
long_side = max(w, h) | |||
cw = int(scale * long_side) | |||
ch = cw | |||
# TODO +1 | |||
if w == cw: | |||
left = 0 | |||
elif w > cw: | |||
left = random.randint(0, w - cw) | |||
else: | |||
left = random.randint(w - cw, 0) | |||
if h == ch: | |||
top = 0 | |||
elif h > ch: | |||
top = random.randint(0, h - ch) | |||
else: | |||
top = random.randint(h - ch, 0) | |||
patch = np.array( | |||
(int(left), int(top), int(left + cw), int(top + ch)), | |||
dtype=np.int32) | |||
# center of boxes should inside the crop img | |||
# only adjust boxes and instance masks when the gt is not empty | |||
# adjust boxes | |||
def is_center_of_bboxes_in_patch(boxes, patch): | |||
# TODO >= | |||
center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |||
mask = \ | |||
((center[:, 0] > patch[0]) | |||
* (center[:, 1] > patch[1]) | |||
* (center[:, 0] < patch[2]) | |||
* (center[:, 1] < patch[3])) | |||
return mask | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
if not mask.any(): | |||
continue | |||
for key in results.get('bbox_fields', []): | |||
boxes = results[key].copy() | |||
mask = is_center_of_bboxes_in_patch(boxes, patch) | |||
boxes = boxes[mask] | |||
if self.bbox_clip_border: | |||
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) | |||
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) | |||
boxes -= np.tile(patch[:2], 2) | |||
results[key] = boxes | |||
# labels | |||
label_key = self.bbox2label.get(key) | |||
if label_key in results: | |||
results[label_key] = results[label_key][mask] | |||
# keypoints field | |||
if key == 'gt_bboxes': | |||
for kps_key in results.get('keypoints_fields', []): | |||
keypointss = results[kps_key].copy() | |||
keypointss = keypointss[mask, :, :] | |||
if self.bbox_clip_border: | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
max=patch[2:]) | |||
keypointss[:, :, : | |||
2] = keypointss[:, :, :2].clip( | |||
min=patch[:2]) | |||
keypointss[:, :, 0] -= patch[0] | |||
keypointss[:, :, 1] -= patch[1] | |||
results[kps_key] = keypointss | |||
# mask fields | |||
mask_key = self.bbox2mask.get(key) | |||
if mask_key in results: | |||
results[mask_key] = results[mask_key][mask.nonzero() | |||
[0]].crop(patch) | |||
# adjust the img no matter whether the gt is empty before crop | |||
rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 | |||
patch_from = patch.copy() | |||
patch_from[0] = max(0, patch_from[0]) | |||
patch_from[1] = max(0, patch_from[1]) | |||
patch_from[2] = min(img.shape[1], patch_from[2]) | |||
patch_from[3] = min(img.shape[0], patch_from[3]) | |||
patch_to = patch.copy() | |||
patch_to[0] = max(0, patch_to[0] * -1) | |||
patch_to[1] = max(0, patch_to[1] * -1) | |||
patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) | |||
patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) | |||
rimg[patch_to[1]:patch_to[3], | |||
patch_to[0]:patch_to[2], :] = img[ | |||
patch_from[1]:patch_from[3], | |||
patch_from[0]:patch_from[2], :] | |||
img = rimg | |||
results['img'] = img | |||
results['img_shape'] = img.shape | |||
return results | |||
def __repr__(self): | |||
repr_str = self.__class__.__name__ | |||
repr_str += f'(min_ious={self.min_iou}, ' | |||
repr_str += f'crop_size={self.crop_size})' | |||
return repr_str |
@@ -13,7 +13,7 @@ class RetinaFaceDataset(CustomDataset): | |||
CLASSES = ('FG', ) | |||
def __init__(self, min_size=None, **kwargs): | |||
self.NK = 5 | |||
self.NK = kwargs.pop('num_kps', 5) | |||
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | |||
self.min_size = min_size | |||
self.gt_path = kwargs.get('gt_path') | |||
@@ -33,7 +33,8 @@ class RetinaFaceDataset(CustomDataset): | |||
if len(values) > 4: | |||
if len(values) > 5: | |||
kps = np.array( | |||
values[4:19], dtype=np.float32).reshape((self.NK, 3)) | |||
values[4:4 + self.NK * 3], dtype=np.float32).reshape( | |||
(self.NK, 3)) | |||
for li in range(kps.shape[0]): | |||
if (kps[li, :] == -1).all(): | |||
kps[li][2] = 0.0 # weight = 0, ignore |
@@ -103,6 +103,7 @@ class SCRFDHead(AnchorHead): | |||
scale_mode=1, | |||
dw_conv=False, | |||
use_kps=False, | |||
num_kps=5, | |||
loss_kps=dict( | |||
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), | |||
**kwargs): | |||
@@ -116,7 +117,7 @@ class SCRFDHead(AnchorHead): | |||
self.scale_mode = scale_mode | |||
self.use_dfl = True | |||
self.dw_conv = dw_conv | |||
self.NK = 5 | |||
self.NK = num_kps | |||
self.extra_flops = 0.0 | |||
if loss_dfl is None or not loss_dfl: | |||
self.use_dfl = False | |||
@@ -323,8 +324,8 @@ class SCRFDHead(AnchorHead): | |||
batch_size, -1, self.cls_out_channels).sigmoid() | |||
bbox_pred = bbox_pred.permute(0, 2, 3, | |||
1).reshape(batch_size, -1, 4) | |||
kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 10) | |||
kps_pred = kps_pred.permute(0, 2, 3, | |||
1).reshape(batch_size, -1, self.NK * 2) | |||
return cls_score, bbox_pred, kps_pred | |||
def forward_train(self, | |||
@@ -788,7 +789,7 @@ class SCRFDHead(AnchorHead): | |||
if self.use_dfl: | |||
kps_pred = self.integral(kps_pred) * stride[0] | |||
else: | |||
kps_pred = kps_pred.reshape((-1, 10)) * stride[0] | |||
kps_pred = kps_pred.reshape((-1, self.NK * 2)) * stride[0] | |||
nms_pre = cfg.get('nms_pre', -1) | |||
if nms_pre > 0 and scores.shape[0] > nms_pre: | |||
@@ -815,7 +816,7 @@ class SCRFDHead(AnchorHead): | |||
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) | |||
if mlvl_kps is not None: | |||
scale_factor2 = torch.tensor( | |||
[scale_factor[0], scale_factor[1]] * 5) | |||
[scale_factor[0], scale_factor[1]] * self.NK) | |||
mlvl_kps /= scale_factor2.to(mlvl_kps.device) | |||
mlvl_scores = torch.cat(mlvl_scores) |
@@ -54,7 +54,13 @@ class SCRFD(SingleStageDetector): | |||
gt_bboxes_ignore) | |||
return losses | |||
def simple_test(self, img, img_metas, rescale=False): | |||
def simple_test(self, | |||
img, | |||
img_metas, | |||
rescale=False, | |||
repeat_head=1, | |||
output_kps_var=0, | |||
output_results=1): | |||
"""Test function without test time augmentation. | |||
Args: | |||
@@ -62,6 +68,9 @@ class SCRFD(SingleStageDetector): | |||
img_metas (list[dict]): List of image information. | |||
rescale (bool, optional): Whether to rescale the results. | |||
Defaults to False. | |||
repeat_head (int): repeat inference times in head | |||
output_kps_var (int): whether output kps var to calculate quality | |||
output_results (int): 0: nothing 1: bbox 2: both bbox and kps | |||
Returns: | |||
list[list[np.ndarray]]: BBox results of each image and classes. | |||
@@ -69,40 +78,71 @@ class SCRFD(SingleStageDetector): | |||
corresponds to each class. | |||
""" | |||
x = self.extract_feat(img) | |||
outs = self.bbox_head(x) | |||
if torch.onnx.is_in_onnx_export(): | |||
print('single_stage.py in-onnx-export') | |||
print(outs.__class__) | |||
cls_score, bbox_pred, kps_pred = outs | |||
for c in cls_score: | |||
print(c.shape) | |||
for c in bbox_pred: | |||
print(c.shape) | |||
if self.bbox_head.use_kps: | |||
for c in kps_pred: | |||
assert repeat_head >= 1 | |||
kps_out0 = [] | |||
kps_out1 = [] | |||
kps_out2 = [] | |||
for i in range(repeat_head): | |||
outs = self.bbox_head(x) | |||
kps_out0 += [outs[2][0].detach().cpu().numpy()] | |||
kps_out1 += [outs[2][1].detach().cpu().numpy()] | |||
kps_out2 += [outs[2][2].detach().cpu().numpy()] | |||
if output_kps_var: | |||
var0 = np.var(np.vstack(kps_out0), axis=0).mean() | |||
var1 = np.var(np.vstack(kps_out1), axis=0).mean() | |||
var2 = np.var(np.vstack(kps_out2), axis=0).mean() | |||
var = np.mean([var0, var1, var2]) | |||
else: | |||
var = None | |||
if output_results > 0: | |||
if torch.onnx.is_in_onnx_export(): | |||
print('single_stage.py in-onnx-export') | |||
print(outs.__class__) | |||
cls_score, bbox_pred, kps_pred = outs | |||
for c in cls_score: | |||
print(c.shape) | |||
for c in bbox_pred: | |||
print(c.shape) | |||
return (cls_score, bbox_pred, kps_pred) | |||
else: | |||
return (cls_score, bbox_pred) | |||
bbox_list = self.bbox_head.get_bboxes( | |||
*outs, img_metas, rescale=rescale) | |||
if self.bbox_head.use_kps: | |||
for c in kps_pred: | |||
print(c.shape) | |||
return (cls_score, bbox_pred, kps_pred) | |||
else: | |||
return (cls_score, bbox_pred) | |||
bbox_list = self.bbox_head.get_bboxes( | |||
*outs, img_metas, rescale=rescale) | |||
# return kps if use_kps | |||
if len(bbox_list[0]) == 2: | |||
bbox_results = [ | |||
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) | |||
for det_bboxes, det_labels in bbox_list | |||
] | |||
elif len(bbox_list[0]) == 3: | |||
bbox_results = [ | |||
bbox2result( | |||
det_bboxes, | |||
det_labels, | |||
self.bbox_head.num_classes, | |||
kps=det_kps) | |||
for det_bboxes, det_labels, det_kps in bbox_list | |||
] | |||
return bbox_results | |||
# return kps if use_kps | |||
if len(bbox_list[0]) == 2: | |||
bbox_results = [ | |||
bbox2result(det_bboxes, det_labels, | |||
self.bbox_head.num_classes) | |||
for det_bboxes, det_labels in bbox_list | |||
] | |||
elif len(bbox_list[0]) == 3: | |||
if output_results == 2: | |||
bbox_results = [ | |||
bbox2result( | |||
det_bboxes, | |||
det_labels, | |||
self.bbox_head.num_classes, | |||
kps=det_kps, | |||
num_kps=self.bbox_head.NK) | |||
for det_bboxes, det_labels, det_kps in bbox_list | |||
] | |||
elif output_results == 1: | |||
bbox_results = [ | |||
bbox2result(det_bboxes, det_labels, | |||
self.bbox_head.num_classes) | |||
for det_bboxes, det_labels, _ in bbox_list | |||
] | |||
else: | |||
bbox_results = None | |||
if var is not None: | |||
return bbox_results, var | |||
else: | |||
return bbox_results | |||
def feature_test(self, img): | |||
x = self.extract_feat(img) |
@@ -0,0 +1,71 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
from copy import deepcopy | |||
from typing import Any, Dict | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
__all__ = ['ScrfdDetect'] | |||
@MODELS.register_module(Tasks.face_detection, module_name=Models.scrfd) | |||
class ScrfdDetect(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the face detection model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
from mmcv import Config | |||
from mmcv.parallel import MMDataParallel | |||
from mmcv.runner import load_checkpoint | |||
from mmdet.models import build_detector | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead | |||
from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD | |||
cfg = Config.fromfile(osp.join(model_dir, 'mmcv_scrfd.py')) | |||
ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) | |||
cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) | |||
detector = build_detector(cfg.model) | |||
logger.info(f'loading model from {ckpt_path}') | |||
device = torch.device( | |||
f'cuda:{0}' if torch.cuda.is_available() else 'cpu') | |||
load_checkpoint(detector, ckpt_path, map_location=device) | |||
detector = MMDataParallel(detector, device_ids=[0]) | |||
detector.eval() | |||
self.detector = detector | |||
logger.info('load model done') | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
result = self.detector( | |||
return_loss=False, | |||
rescale=True, | |||
img=[input['img'][0].unsqueeze(0)], | |||
img_metas=[[dict(input['img_metas'][0].data)]], | |||
output_results=2) | |||
assert result is not None | |||
result = result[0][0] | |||
bboxes = result[:, :4].tolist() | |||
kpss = result[:, 5:].tolist() | |||
scores = result[:, 4].tolist() | |||
return { | |||
OutputKeys.SCORES: scores, | |||
OutputKeys.BOXES: bboxes, | |||
OutputKeys.KEYPOINTS: kpss | |||
} | |||
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
return input |