diff --git a/.dev_scripts/build_docs.sh b/.dev_scripts/build_docs.sh new file mode 100644 index 00000000..dc76e6f4 --- /dev/null +++ b/.dev_scripts/build_docs.sh @@ -0,0 +1,8 @@ +pip install -r requirements/docs.txt +cd docs +rm -rf build + +# update api rst +#rm -rf source/api/ +#sphinx-apidoc --module-first -o source/api/ ../modelscope/ +make html diff --git a/.dev_scripts/build_image.sh b/.dev_scripts/build_image.sh new file mode 100644 index 00000000..e6403aed --- /dev/null +++ b/.dev_scripts/build_image.sh @@ -0,0 +1,169 @@ +#!/bin/bash +# default values. +BASE_CPU_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04 +BASE_GPU_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.3.0-cudnn8-devel +MODELSCOPE_REPO_ADDRESS=reg.docker.alibaba-inc.com/modelscope/modelscope +python_version=3.7.13 +torch_version=1.11.0 +cudatoolkit_version=11.3 +tensorflow_version=1.15.5 +modelscope_version=None +is_ci_test=False +is_dsw=False +is_cpu=False +run_ci_test=False +function usage(){ + echo "usage: build.sh " + echo " --python=python_version set python version, default: $python_version" + echo " --torch=torch_version set pytorch version, fefault: $torch_version" + echo " --cudatoolkit=cudatoolkit_version set cudatoolkit version used for pytorch, default: $cudatoolkit_version" + echo " --tensorflow=tensorflow_version set tensorflow version, default: $tensorflow_version" + echo " --modelscope=modelscope_version set modelscope version, default: $modelscope_version" + echo " --test option for run test before push image, only push on ci test pass" + echo " --cpu option for build cpu version" + echo " --dsw option for build dsw version" + echo " --ci option for build ci version" + echo " --push option for push image to remote repo" +} +for i in "$@"; do + case $i in + --python=*) + python_version="${i#*=}" + shift + ;; + --torch=*) + torch_version="${i#*=}" + shift # pytorch version + ;; + --tensorflow=*) + tensorflow_version="${i#*=}" + shift # tensorflow version + ;; + --cudatoolkit=*) + cudatoolkit_version="${i#*=}" + shift # cudatoolkit for pytorch + ;; + --modelscope=*) + modelscope_version="${i#*=}" + shift # cudatoolkit for pytorch + ;; + --test) + run_ci_test=True + shift # will run ci test + ;; + --cpu) + is_cpu=True + shift # is cpu image + ;; + --ci) + is_ci_test=True + shift # is ci, will not install modelscope + ;; + --dsw) + is_dsw=True + shift # is dsw, will set dsw cache location + ;; + --push) + is_push=True + shift # is dsw, will set dsw cache location + ;; + --help) + usage + exit 0 + ;; + -*|--*) + echo "Unknown option $i" + usage + exit 1 + ;; + *) + ;; + esac +done + +if [ "$modelscope_version" == "None" ]; then + echo "ModelScope version must specify!" + exit 1 +fi +if [ "$is_cpu" == "True" ]; then + export BASE_IMAGE=$BASE_CPU_IMAGE + base_tag=ubuntu20.04 + export USE_GPU=False +else + export BASE_IMAGE=$BASE_GPU_IMAGE + base_tag=ubuntu20.04-cuda11.3.0 + export USE_GPU=True +fi +if [[ $python_version == 3.7* ]]; then + base_tag=$base_tag-py37 +elif [[ $python_version == z* ]]; then + base_tag=$base_tag-py38 +elif [[ $python_version == z* ]]; then + base_tag=$base_tag-py39 +else + echo "Unsupport python version: $python_version" + exit 1 +fi + +target_image_tag=$base_tag-torch$torch_version-tf$tensorflow_version +if [ "$is_ci_test" == "True" ]; then + target_image_tag=$target_image_tag-$modelscope_version-ci +else + target_image_tag=$target_image_tag-$modelscope_version-test +fi +export IMAGE_TO_BUILD=$MODELSCOPE_REPO_ADDRESS:$target_image_tag +export PYTHON_VERSION=$python_version +export TORCH_VERSION=$torch_version +export CUDATOOLKIT_VERSION=$cudatoolkit_version +export TENSORFLOW_VERSION=$tensorflow_version +echo -e "Building image with:\npython$python_version\npytorch$torch_version\ntensorflow:$tensorflow_version\ncudatoolkit:$cudatoolkit_version\ncpu:$is_cpu\nis_ci:$is_ci_test\nis_dsw:$is_dsw\n" +docker_file_content=`cat docker/Dockerfile.ubuntu` +if [ "$is_ci_test" != "True" ]; then + echo "Building ModelScope lib, will install ModelScope lib to image" + docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir modelscope==$modelscope_version -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html" +fi +echo "$is_dsw" +if [ "$is_dsw" == "False" ]; then + echo "Not DSW image" +else + echo "Building dsw image well need set ModelScope lib cache location." + docker_file_content="${docker_file_content} \nENV MODELSCOPE_CACHE=/mnt/workspace/.cache/modelscope" +fi +printf "$docker_file_content" > Dockerfile +docker build -t $IMAGE_TO_BUILD \ + --build-arg USE_GPU \ + --build-arg BASE_IMAGE \ + --build-arg PYTHON_VERSION \ + --build-arg TORCH_VERSION \ + --build-arg CUDATOOLKIT_VERSION \ + --build-arg TENSORFLOW_VERSION \ + -f Dockerfile . + +if [ $? -ne 0 ]; then + echo "Running docker build command error, please check the log!" + exit -1 +fi +if [ "$run_ci_test" == "True" ]; then + echo "Running ci case." + export MODELSCOPE_CACHE=/home/mulin.lyh/model_scope_cache + export MODELSCOPE_HOME_CACHE=/home/mulin.lyh/ci_case_home # for credential + export IMAGE_NAME=$MODELSCOPE_REPO_ADDRESS + export IMAGE_VERSION=$target_image_tag + export MODELSCOPE_DOMAIN=www.modelscope.cn + export HUB_DATASET_ENDPOINT=http://www.modelscope.cn + export CI_TEST=True + export TEST_LEVEL=1 + if [ "$is_ci_test" != "True" ]; then + echo "Testing for dsw image or MaaS-lib image" + export CI_COMMAND="python tests/run.py" + fi + bash .dev_scripts/dockerci.sh + if [ $? -ne 0 ]; then + echo "Running unittest failed, please check the log!" + exit -1 + fi +fi +if [ "$is_push" == "True" ]; then + echo "Pushing image: $IMAGE_TO_BUILD" + docker push $IMAGE_TO_BUILD +fi diff --git a/.dev_scripts/ci_container_test.sh b/.dev_scripts/ci_container_test.sh new file mode 100644 index 00000000..4fd2778f --- /dev/null +++ b/.dev_scripts/ci_container_test.sh @@ -0,0 +1,36 @@ +echo "Testing envs" +printenv +echo "ENV END" +if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then + awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + pip install -r requirements/tests.txt + + git config --global --add safe.directory /Maas-lib + git config --global user.email tmp + git config --global user.name tmp.com + + # linter test + # use internal project for pre-commit due to the network problem + if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then + pre-commit run -c .pre-commit-config_local.yaml --all-files + if [ $? -ne 0 ]; then + echo "linter test failed, please run 'pre-commit run --all-files' to check" + exit -1 + fi + fi + # test with install + python setup.py install +else + echo "Running case in release image, run case directly!" +fi +if [ $# -eq 0 ]; then + ci_command="python tests/run.py --subprocess" +else + ci_command="$@" +fi +echo "Running case with command: $ci_command" +$ci_command diff --git a/.dev_scripts/dockerci.sh b/.dev_scripts/dockerci.sh new file mode 100644 index 00000000..07ea947a --- /dev/null +++ b/.dev_scripts/dockerci.sh @@ -0,0 +1,80 @@ +#!/bin/bash +MODELSCOPE_CACHE_DIR_IN_CONTAINER=/modelscope_cache +CODE_DIR=$PWD +CODE_DIR_IN_CONTAINER=/Maas-lib +echo "$USER" +gpus='7 6 5 4 3 2 1 0' +cpu_sets='0-7 8-15 16-23 24-30 31-37 38-44 45-51 52-58' +cpu_sets_arr=($cpu_sets) +is_get_file_lock=false +# export RUN_CASE_COMMAND='python tests/run.py --run_config tests/run_config.yaml' +CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh $RUN_CASE_BASE_COMMAND} +echo "ci command: $CI_COMMAND" +for gpu in $gpus +do + exec {lock_fd}>"/tmp/gpu$gpu" || exit 1 + flock -n "$lock_fd" || { echo "WARN: gpu $gpu is in use!" >&2; continue; } + echo "get gpu lock $gpu" + CONTAINER_NAME="modelscope-ci-$gpu" + let is_get_file_lock=true + + # pull image if there are update + docker pull ${IMAGE_NAME}:${IMAGE_VERSION} + if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then + docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ + --cpuset-cpus=${cpu_sets_arr[$gpu]} \ + --gpus="device=$gpu" \ + -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ + -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ + -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ + -v /home/admin/pre-commit:/home/admin/pre-commit \ + -e CI_TEST=True \ + -e TEST_LEVEL=$TEST_LEVEL \ + -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ + -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ + -e MODELSCOPE_SDK_DEBUG=True \ + -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ + -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ + -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ + -e TEST_LEVEL=$TEST_LEVEL \ + -e MODELSCOPE_ENVIRONMENT='ci' \ + -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ + -e MODEL_TAG_URL=$MODEL_TAG_URL \ + --workdir=$CODE_DIR_IN_CONTAINER \ + --net host \ + ${IMAGE_NAME}:${IMAGE_VERSION} \ + $CI_COMMAND + else + docker run --rm --name $CONTAINER_NAME --shm-size=16gb \ + --cpuset-cpus=${cpu_sets_arr[$gpu]} \ + --gpus="device=$gpu" \ + -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \ + -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ + -v $MODELSCOPE_HOME_CACHE/$gpu:/root \ + -v /home/admin/pre-commit:/home/admin/pre-commit \ + -e CI_TEST=True \ + -e TEST_LEVEL=$TEST_LEVEL \ + -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \ + -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \ + -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \ + -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \ + -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \ + -e TEST_LEVEL=$TEST_LEVEL \ + -e MODELSCOPE_ENVIRONMENT='ci' \ + -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ + -e MODEL_TAG_URL=$MODEL_TAG_URL \ + --workdir=$CODE_DIR_IN_CONTAINER \ + --net host \ + ${IMAGE_NAME}:${IMAGE_VERSION} \ + $CI_COMMAND + fi + if [ $? -ne 0 ]; then + echo "Running test case failed, please check the log!" + exit -1 + fi + break +done +if [ "$is_get_file_lock" = false ] ; then + echo 'No free GPU!' + exit 1 +fi diff --git a/.dev_scripts/linter.sh b/.dev_scripts/linter.sh new file mode 100644 index 00000000..6468e42b --- /dev/null +++ b/.dev_scripts/linter.sh @@ -0,0 +1,3 @@ +yapf -r -i modelscope/ configs/ tests/ setup.py +isort -rc modelscope/ configs/ tests/ setup.py +flake8 modelscope/ configs/ tests/ setup.py diff --git a/.dev_scripts/run_docker.sh b/.dev_scripts/run_docker.sh new file mode 100644 index 00000000..8999458a --- /dev/null +++ b/.dev_scripts/run_docker.sh @@ -0,0 +1,7 @@ +#sudo docker run --name zwm_maas -v /home/wenmeng.zwm/workspace:/home/wenmeng.zwm/workspace --net host -ti reg.docker.alibaba-inc.com/pai-dlc/tensorflow-training:2.3-gpu-py36-cu101-ubuntu18.04 bash +#sudo docker run --name zwm_maas_pytorch -v /home/wenmeng.zwm/workspace:/home/wenmeng.zwm/workspace --net host -ti reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 bash +CONTAINER_NAME=modelscope-dev +IMAGE_NAME=registry.cn-shanghai.aliyuncs.com/modelscope/modelscope +IMAGE_VERSION=v0.1.1-16-g62856fa-devel +MOUNT_DIR=/home/wenmeng.zwm/workspace +sudo docker run --name $CONTAINER_NAME -v $MOUNT_DIR:$MOUNT_DIR --net host -ti ${IMAGE_NAME}:${IMAGE_VERSION} bash diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..4198ecc0 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,11 @@ +.gitignore +tests +data +.dev_scripts +.dockerignore +.git +.gitattributes +.pre-commit-config.yaml +.pre-commit-config_local.yaml +.readthedocs.yaml +Dockfile diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..1a3015ec --- /dev/null +++ b/.gitattributes @@ -0,0 +1,9 @@ +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +*.JPEG filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.avi filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/citest.yaml b/.github/workflows/citest.yaml new file mode 100644 index 00000000..00c6bbbf --- /dev/null +++ b/.github/workflows/citest.yaml @@ -0,0 +1,64 @@ +name: citest + +on: + push: + branches: + - master + - "release/**" + paths-ignore: + - "setup.*" + - "requirements.txt" + - "requirements/**" + - "docs/**" + - "tools/**" + - ".dev_scripts/**" + - "README.md" + - "README_zh-CN.md" + - "NOTICE" + - ".github/workflows/lint.yaml" + - ".github/workflows/publish.yaml" + + pull_request: + paths-ignore: + - "setup.*" + - "requirements.txt" + - "requirements/**" + - "docs/**" + - "tools/**" + - ".dev_scripts/**" + - "README.md" + - "README_zh-CN.md" + - "NOTICE" + - ".github/workflows/lint.yaml" + - ".github/workflows/publish.yaml" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + unittest: + # The type of runner that the job will run on + runs-on: [modelscope-self-hosted] + steps: + - name: ResetFileMode + shell: bash + run: | + # reset filemode to allow action runner to delete files + # generated by root in docker + set -e + source ~/.bashrc + sudo chown -R $USER:$USER $ACTION_RUNNER_DIR + + - name: Checkout + uses: actions/checkout@v2 + with: + lfs: 'true' + - name: Checkout LFS objects + run: git lfs checkout + - name: Run unittest + shell: bash + run: | + set -e + source /mnt/modelscope/ci_env.sh + bash .dev_scripts/dockerci.sh diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 00000000..dc4b5487 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,22 @@ +name: Lint test + +on: [push, pull_request] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + - name: Install pre-commit hook + run: | + pip install pre-commit + - name: Linting + run: pre-commit run --all-files diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..de086eea --- /dev/null +++ b/.gitignore @@ -0,0 +1,128 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +/package +/temp +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +.vscode +.idea + +# custom +*.pkl +*.pkl.json +*.log.json +*.whl +*.tar.gz +*.swp +*.log +*.tar.gz +source.sh +tensorboard.sh +.DS_Store +replace.sh +result.png + +# Pytorch +*.pth +*.pt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..48fe7547 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,37 @@ +repos: + - repo: https://gitlab.com/pycqa/flake8.git + rev: 4.0.0 + hooks: + - id: flake8 + exclude: thirdparty/|examples/ + - repo: https://github.com/PyCQA/isort.git + rev: 4.3.21 + hooks: + - id: isort + exclude: examples + - repo: https://github.com/pre-commit/mirrors-yapf.git + rev: v0.30.0 + hooks: + - id: yapf + exclude: thirdparty/|examples/ + - repo: https://github.com/pre-commit/pre-commit-hooks.git + rev: v3.1.0 + hooks: + - id: trailing-whitespace + exclude: thirdparty/ + - id: check-yaml + exclude: thirdparty/ + - id: end-of-file-fixer + exclude: thirdparty/ + - id: requirements-txt-fixer + exclude: thirdparty/ + - id: double-quote-string-fixer + exclude: thirdparty/ + - id: check-merge-conflict + exclude: thirdparty/ + - id: fix-encoding-pragma + exclude: thirdparty/ + args: ["--remove"] + - id: mixed-line-ending + exclude: thirdparty/ + args: ["--fix=lf"] diff --git a/.pre-commit-config_local.yaml b/.pre-commit-config_local.yaml new file mode 100644 index 00000000..0b2e2f39 --- /dev/null +++ b/.pre-commit-config_local.yaml @@ -0,0 +1,37 @@ +repos: + - repo: /home/admin/pre-commit/flake8 + rev: 4.0.0 + hooks: + - id: flake8 + exclude: thirdparty/|examples/ + - repo: /home/admin/pre-commit/isort + rev: 4.3.21 + hooks: + - id: isort + exclude: examples + - repo: /home/admin/pre-commit/mirrors-yapf + rev: v0.30.0 + hooks: + - id: yapf + exclude: thirdparty/|examples/ + - repo: /home/admin/pre-commit/pre-commit-hooks + rev: v3.1.0 + hooks: + - id: trailing-whitespace + exclude: thirdparty/ + - id: check-yaml + exclude: thirdparty/ + - id: end-of-file-fixer + exclude: thirdparty/ + - id: requirements-txt-fixer + exclude: thirdparty/ + - id: double-quote-string-fixer + exclude: thirdparty/ + - id: check-merge-conflict + exclude: thirdparty/ + - id: fix-encoding-pragma + exclude: thirdparty/ + args: ["--remove"] + - id: mixed-line-ending + exclude: thirdparty/ + args: ["--fix=lf"] diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..f7b9c7ea --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,28 @@ +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-20.04 + tools: + python: "3.7" + # You can also specify other tool versions: + # nodejs: "16" + # rust: "1.55" + # golang: "1.17" + jobs: + post_checkout: + - echo "dummy" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/source/conf.py + +# If using Sphinx, optionally build your docs in additional formats such as PDF +# formats: +formats: all + +python: + install: + - requirements: requirements/docs.txt + - requirements: requirements/readthedocs.txt + - requirements: requirements/framework.txt diff --git a/LICENSE b/LICENSE index 137069b8..80a72b64 100644 --- a/LICENSE +++ b/LICENSE @@ -71,3 +71,207 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. +======= +Copyright 2022-2023 Alibaba ModelScope. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020-2022 Alibaba ModelScope. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..665d7e90 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include modelscope/configs *.py diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..96532199 --- /dev/null +++ b/Makefile @@ -0,0 +1,25 @@ +WHL_BUILD_DIR :=package +DOC_BUILD_DIR :=docs/build/ + +# default rule +default: whl docs + +.PHONY: docs +docs: + bash .dev_scripts/build_docs.sh + +.PHONY: linter +linter: + bash .dev_scripts/linter.sh + +.PHONY: test +test: + bash .dev_scripts/citest.sh + +.PHONY: whl +whl: + python setup.py sdist bdist_wheel + +.PHONY: clean +clean: + rm -rf $(WHL_BUILD_DIR) $(DOC_BUILD_DIR) diff --git a/Makefile.docker b/Makefile.docker new file mode 100644 index 00000000..97400318 --- /dev/null +++ b/Makefile.docker @@ -0,0 +1,67 @@ +DOCKER_REGISTRY = registry.cn-shanghai.aliyuncs.com +DOCKER_ORG = modelscope +DOCKER_IMAGE = modelscope +DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) + +# CUDA_VERSION = 11.3 +# CUDNN_VERSION = 8 +BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 +# BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 +BASE_DEVEL = pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel + + +MODELSCOPE_VERSION = $(shell git describe --tags --always) + +# Can be either official / dev +BUILD_TYPE = dev +BUILD_PROGRESS = auto +BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) + +EXTRA_DOCKER_BUILD_FLAGS ?= --network=host +# DOCKER_BUILD = DOCKER_BUILDKIT=1 \ +# docker build \ +# --progress=$(BUILD_PROGRESS) \ +# $(EXTRA_DOCKER_BUILD_FLAGS) \ +# --target $(BUILD_TYPE) \ +# -t $(DOCKER_FULL_NAME):$(DOCKER_TAG) \ +# $(BUILD_ARGS) \ +# -f docker/pytorch.dockerfile . +DOCKER_BUILD = DOCKER_BUILDKIT=1 \ + docker build \ + $(EXTRA_DOCKER_BUILD_FLAGS) \ + -t $(DOCKER_FULL_NAME):$(DOCKER_TAG) \ + $(BUILD_ARGS) \ + -f docker/pytorch.dockerfile . +DOCKER_PUSH = docker push $(DOCKER_FULL_NAME):$(DOCKER_TAG) + +.PHONY: all +all: devel-image + +.PHONY: devel-image +devel-image: BASE_IMAGE := $(BASE_DEVEL) +devel-image: DOCKER_TAG := $(MODELSCOPE_VERSION)-devel +devel-image: + $(DOCKER_BUILD) + +.PHONY: devel-push +devel-push: BASE_IMAGE := $(BASE_DEVEL) +devel-push: DOCKER_TAG := $(MODELSCOPE_VERSION)-devel +devel-push: + $(DOCKER_PUSH) + +.PHONY: runtime-image +runtime-image: BASE_IMAGE := $(BASE_RUNTIME) +runtime-image: DOCKER_TAG := $(MODELSCOPE_VERSION)-runtime +runtime-image: + $(DOCKER_BUILD) + docker tag $(DOCKER_FULL_NAME):$(DOCKER_TAG) $(DOCKER_FULL_NAME):latest + +.PHONY: runtime-push +runtime-push: BASE_IMAGE := $(BASE_RUNTIME) +runtime-push: DOCKER_TAG := $(MODELSCOPE_VERSION)-runtime +runtime-push: + $(DOCKER_PUSH) + +.PHONY: clean +clean: + -docker rmi -f $(shell docker images -q $(DOCKER_FULL_NAME)) diff --git a/README.md b/README.md index 4ae93e1d..33cb61c7 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,29 @@ # ModelScope +======= +# Introduction + +[ModelScope]( https://www.modelscope.cn) is a “Model-as-a-Service” (MaaS) platform that seeks to bring together most advanced machine learning models from the AI community, and to streamline the process of leveraging AI models in real applications. The core ModelScope library enables developers to perform inference, training and evaluation, through rich layers of API designs that facilitate a unified experience across state-of-the-art models from different AI domains. + +The Python library offers the layered-APIs necessary for model contributors to integrate models from CV, NLP, Speech, Multi-Modality, as well as Scientific-computation, into the ModelScope ecosystem. Implementations for all these different models are encapsulated within the library in a way that allows easy and unified access. With such integration, model inference, finetuning, and evaluations can be done with only a few lines of codes. In the meantime, flexibilities are provided so that different components in the model applications can be customized as well, where necessary. + +Apart from harboring implementations of various models, ModelScope library also enables the necessary interactions with ModelScope backend services, particularly with the Model-Hub and Dataset-Hub. Such interactions facilitate management of various entities (models and datasets) to be performed seamlessly under-the-hood, including entity lookup, version control, cache management, and many others. + +# Installation + +Please refer to [installation](https://modelscope.cn/docs/%E7%8E%AF%E5%A2%83%E5%AE%89%E8%A3%85). + +# Get Started + +You can refer to [quick_start](https://modelscope.cn/docs/%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B) for quick start. + +We also provide other documentations including: +* [Introduction to tasks](https://modelscope.cn/docs/%E4%BB%BB%E5%8A%A1%E7%9A%84%E4%BB%8B%E7%BB%8D) +* [Use pipeline for model inference](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E6%8E%A8%E7%90%86Pipeline) +* [Finetune example](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AE%AD%E7%BB%83Train) +* [Preprocessing of data](https://modelscope.cn/docs/%E6%95%B0%E6%8D%AE%E7%9A%84%E9%A2%84%E5%A4%84%E7%90%86) +* [Evaluation metrics](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AF%84%E4%BC%B0) + +# License + +This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). diff --git a/configs/README.md b/configs/README.md new file mode 100644 index 00000000..9c042744 --- /dev/null +++ b/configs/README.md @@ -0,0 +1 @@ +Each model should be associated with a configuration.json file hosted on modelscope model-hub, together with the model binaries. This folder serves the purpose of hosting example configuration, for reference. diff --git a/configs/cv/configuration.json b/configs/cv/configuration.json new file mode 100644 index 00000000..ae07fa10 --- /dev/null +++ b/configs/cv/configuration.json @@ -0,0 +1,178 @@ +{ + "framework": "pytorch", + + "task": "image_classification", + + "model": { + "type": "classification", + "pretrained": null, + "backbone": { + "type": "ResNet", + "depth": 50, + "out_indices": [ + 4 + ], + "norm_cfg": { + "type": "BN" + } + }, + "head": { + "type": "ClsHead", + "with_avg_pool": true, + "in_channels": 2048, + "loss_config": { + "type": "CrossEntropyLossWithLabelSmooth", + "label_smooth": 0 + }, + "num_classes": 1000 + } + }, + + "dataset": { + "train": { + "type": "ClsDataset", + "data_source": { + "list_file": "data/imagenet_raw/meta/train_labeled.txt", + "root": "data/imagenet_raw/train/", + "type": "ClsSourceImageList" + } + }, + "val": { + "type": "ClsDataset", + "data_source": { + "list_file": "data/imagenet_raw/meta/val_labeled.txt", + "root": "data/imagenet_raw/validation/", + "type": "ClsSourceImageList" + } + }, + "test": {} + }, + + + "preprocessor":{ + "train": [ + { + "type": "RandomResizedCrop", + "size": 224 + }, + { + "type": "RandomHorizontalFlip" + }, + { + "type": "ToTensor" + }, + { + "type": "Normalize", + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ] + }, + { + "type": "Collect", + "keys": [ + "img", + "gt_labels" + ] + } + ], + "val": [ + { + "type": "Resize", + "size": 256 + }, + { + "type": "CenterCrop", + "size": 224 + }, + { + "type": "ToTensor" + }, + { + "type": "Normalize", + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ] + }, + { + "type": "Collect", + "keys": [ + "img", + "gt_labels" + ] + } + ] + }, + + "train": { + "work_dir": "./work_dir", + "dataloader": { + "batch_size_per_gpu": 2, + "workers_per_gpu": 1 + }, + "optimizer": { + "type": "SGD", + "lr": 0.01, + "options": { + "grad_clip": { + "max_norm": 2.0 + } + } + }, + "lr_scheduler": { + "type": "StepLR", + "step_size": 2, + "options": { + "warmup": { + "type": "LinearWarmup", + "warmup_iters": 2 + + } + } + }, + "hooks": + [ + { + "type": "CheckpointHook", + "interval": 2 + }, + { + "type": "TextLoggerHook", + "interval": 1 + }, + { + "type": "IterTimerHook" + }, + { + "type": "EvaluationHook", + "interval": 1 + } + ] + }, + + "evaluation": { + "dataloader": { + "batch_size_per_gpu": 2, + "workers_per_gpu": 1, + "shuffle": false + }, + "metrics": ["accuracy", "precision", "recall"] + }, + "pipeline": { + "type": "dummy" + } + +} diff --git a/configs/examples/configuration.json b/configs/examples/configuration.json new file mode 100644 index 00000000..551c7a50 --- /dev/null +++ b/configs/examples/configuration.json @@ -0,0 +1,7 @@ +{ + "a": 1, + "b" : { + "c": [1,2,3], + "d" : "dd" + } +} diff --git a/configs/examples/configuration.py b/configs/examples/configuration.py new file mode 100644 index 00000000..beafb8ee --- /dev/null +++ b/configs/examples/configuration.py @@ -0,0 +1,2 @@ +a = 1 +b = dict(c=[1, 2, 3], d='dd') diff --git a/configs/examples/configuration.yaml b/configs/examples/configuration.yaml new file mode 100644 index 00000000..d69dfed3 --- /dev/null +++ b/configs/examples/configuration.yaml @@ -0,0 +1,4 @@ +a: 1 +b: + c: [1,2,3] + d: dd diff --git a/configs/examples/plain_args.yaml b/configs/examples/plain_args.yaml new file mode 100644 index 00000000..0698b089 --- /dev/null +++ b/configs/examples/plain_args.yaml @@ -0,0 +1,5 @@ +model_dir: path/to/model +lr: 0.01 +optimizer: Adam +weight_decay: 1e-6 +save_checkpoint_epochs: 20 diff --git a/configs/examples/train.json b/configs/examples/train.json new file mode 100644 index 00000000..fbfde923 --- /dev/null +++ b/configs/examples/train.json @@ -0,0 +1,131 @@ +{ + "framework": "pytorch", + + "task": "image_classification", + + "model": { + "type": "Resnet50ForImageClassification", + "pretrained": null, + "backbone": { + "type": "ResNet", + "depth": 50, + "out_indices": [ + 4 + ], + "norm_cfg": { + "type": "BN" + } + }, + "head": { + "type": "ClsHead", + "with_avg_pool": true, + "in_channels": 2048, + "loss_config": { + "type": "CrossEntropyLossWithLabelSmooth", + "label_smooth": 0 + }, + "num_classes": 1000 + } + }, + + "dataset": { + "train": { + "type": "ClsDataset", + "data_source": { + "list_file": "data/imagenet_raw/meta/train_labeled.txt", + "root": "data/imagenet_raw/train/", + "type": "ClsSourceImageList" + } + }, + "val": { + "type": "ClsDataset", + "data_source": { + "list_file": "data/imagenet_raw/meta/val_labeled.txt", + "root": "data/imagenet_raw/validation/", + "type": "ClsSourceImageList" + } + } + }, + + + "preprocessor":{ + "train": [ + { + "type": "RandomResizedCrop", + "size": 224 + }, + { + "type": "RandomHorizontalFlip" + }, + { + "type": "ToTensor" + }, + { + "type": "Normalize", + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ] + }, + { + "type": "Collect", + "keys": [ + "img", + "gt_labels" + ] + } + ], + "val": [ + { + "type": "Resize", + "size": 256 + }, + { + "type": "CenterCrop", + "size": 224 + }, + { + "type": "ToTensor" + }, + { + "type": "Normalize", + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ] + }, + { + "type": "Collect", + "keys": [ + "img", + "gt_labels" + ] + } + ] + }, + + "train": { + "batch_size": 32, + "learning_rate": 0.00001, + "lr_scheduler_type": "cosine", + "num_epochs": 20 + }, + + "evaluation": { + "batch_size": 32, + "metrics": ["accuracy", "precision", "recall"] + } + +} diff --git a/configs/nlp/sbert_sentence_similarity.json b/configs/nlp/sbert_sentence_similarity.json new file mode 100644 index 00000000..9320e0d7 --- /dev/null +++ b/configs/nlp/sbert_sentence_similarity.json @@ -0,0 +1,87 @@ +{ + "framework": "pytorch", + "task": "sentence-similarity", + "preprocessor": { + "type": "sen-sim-tokenizer", + "first_sequence": "sentence1", + "second_sequence": "sentence2" + }, + "model": { + "type": "text-classification", + "backbone": { + "type": "structbert", + "prefix": "encoder", + "attention_probs_dropout_prob": 0.1, + "easynlp_version": "0.0.3", + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 21128 + }, + "head": { + "type": "text-classification", + "hidden_dropout_prob": 0.1, + "hidden_size": 768 + } + }, + "pipeline": { + "type": "sentence-similarity" + }, + "train": { + "work_dir": "/tmp", + "dataloader": { + "batch_size_per_gpu": 2, + "workers_per_gpu": 1 + }, + "optimizer": { + "type": "SGD", + "lr": 0.01, + "options": { + "grad_clip": { + "max_norm": 2.0 + } + } + }, + "lr_scheduler": { + "type": "StepLR", + "step_size": 2, + "options": { + "warmup": { + "type": "LinearWarmup", + "warmup_iters": 2 + } + } + }, + "hooks": [{ + "type": "CheckpointHook", + "interval": 1 + }, { + "type": "TextLoggerHook", + "interval": 1 + }, { + "type": "IterTimerHook" + }, { + "type": "EvaluationHook", + "interval": 1 + }] + }, + "evaluation": { + "dataloader": { + "batch_size_per_gpu": 2, + "workers_per_gpu": 1, + "shuffle": false + } + } + } diff --git a/configs/nlp/sequence_classification_trainer.yaml b/configs/nlp/sequence_classification_trainer.yaml new file mode 100644 index 00000000..0dd16b91 --- /dev/null +++ b/configs/nlp/sequence_classification_trainer.yaml @@ -0,0 +1,62 @@ +# In current version, many arguments are not used in pipelines, so, +# a tag `[being used]` will indicate which argument is being used +version: v0.1 +framework: pytorch +task: text-classification + +model: + path: bert-base-sst2 + backbone: + type: bert + prefix: bert + attention_probs_dropout_prob: 0.1 + bos_token_id: 0 + eos_token_id: 2 + hidden_act: elu + hidden_dropout_prob: 0.1 + hidden_size: 768 + initializer_range: 0.02 + intermediate_size: 3072 + layer_norm_eps: 1e-05 + max_position_embeddings: 514 + model_type: roberta + num_attention_heads: 12 + num_hidden_layers: 12 + pad_token_id: 1 + type_vocab_size: 1 + vocab_size: 50265 + num_classes: 5 + + +col_index: &col_indexs + text_col: 0 + label_col: 1 + +dataset: + train: + <<: *col_indexs + file: ~ + valid: + <<: *col_indexs + file: glue/sst2 # [being used] + test: + <<: *col_indexs + file: ~ + +preprocessor: + type: Tokenize + tokenizer_name: /workspace/bert-base-sst2 + +train: + batch_size: 256 + learning_rate: 0.00001 + lr_scheduler_type: cosine + num_steps: 100000 + +evaluation: # [being used] + model_path: .cache/easynlp/ + max_sequence_length: 128 + batch_size: 32 + metrics: + - accuracy + - f1 diff --git a/data/test/audios/1ch_nihaomiya.wav b/data/test/audios/1ch_nihaomiya.wav new file mode 100644 index 00000000..4618d412 --- /dev/null +++ b/data/test/audios/1ch_nihaomiya.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f7f5a0a4efca1e83463cb44460c66b56fb7cd673eb6da37924637bc05ef758d +size 1440044 diff --git a/data/test/audios/3ch_nihaomiya.wav b/data/test/audios/3ch_nihaomiya.wav new file mode 100644 index 00000000..57d9f061 --- /dev/null +++ b/data/test/audios/3ch_nihaomiya.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ad1a268c614076614a2ae6528abc29cc85ae35826d172079d7d9b26a0299559 +size 4325096 diff --git a/data/test/audios/asr_example.wav b/data/test/audios/asr_example.wav new file mode 100644 index 00000000..5c61b555 --- /dev/null +++ b/data/test/audios/asr_example.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87bde7feb3b40d75dec27e5824dd1077911f867e3f125c4bf603ec0af954d4db +size 77864 diff --git a/data/test/audios/asr_example_8K.wav b/data/test/audios/asr_example_8K.wav new file mode 100644 index 00000000..956aad27 --- /dev/null +++ b/data/test/audios/asr_example_8K.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f +size 38954 diff --git a/data/test/audios/asr_example_cn_dialect.wav b/data/test/audios/asr_example_cn_dialect.wav new file mode 100644 index 00000000..e18fb05d --- /dev/null +++ b/data/test/audios/asr_example_cn_dialect.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4 +size 93676 diff --git a/data/test/audios/asr_example_cn_en.wav b/data/test/audios/asr_example_cn_en.wav new file mode 100644 index 00000000..8baf3193 --- /dev/null +++ b/data/test/audios/asr_example_cn_en.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb +size 472012 diff --git a/data/test/audios/asr_example_en.wav b/data/test/audios/asr_example_en.wav new file mode 100644 index 00000000..fa996eec --- /dev/null +++ b/data/test/audios/asr_example_en.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b +size 71244 diff --git a/data/test/audios/asr_example_es.wav b/data/test/audios/asr_example_es.wav new file mode 100644 index 00000000..95b22dc3 --- /dev/null +++ b/data/test/audios/asr_example_es.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61 +size 181964 diff --git a/data/test/audios/asr_example_id.wav b/data/test/audios/asr_example_id.wav new file mode 100644 index 00000000..54c30614 --- /dev/null +++ b/data/test/audios/asr_example_id.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425 +size 112078 diff --git a/data/test/audios/asr_example_ja.wav b/data/test/audios/asr_example_ja.wav new file mode 100644 index 00000000..e953fee2 --- /dev/null +++ b/data/test/audios/asr_example_ja.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281 +size 200556 diff --git a/data/test/audios/asr_example_ko.wav b/data/test/audios/asr_example_ko.wav new file mode 100644 index 00000000..0dad1be3 --- /dev/null +++ b/data/test/audios/asr_example_ko.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b +size 162636 diff --git a/data/test/audios/asr_example_ru.wav b/data/test/audios/asr_example_ru.wav new file mode 100644 index 00000000..b0cb8f2f --- /dev/null +++ b/data/test/audios/asr_example_ru.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61 +size 160204 diff --git a/data/test/audios/farend_speech.wav b/data/test/audios/farend_speech.wav new file mode 100644 index 00000000..4e96d842 --- /dev/null +++ b/data/test/audios/farend_speech.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3637ee0628d0953f77d5a32327980af542c43230c4127d2a72b4df1ea2ffb0be +size 320042 diff --git a/data/test/audios/kws_bofangyinyue.wav b/data/test/audios/kws_bofangyinyue.wav new file mode 100644 index 00000000..c8bf69b7 --- /dev/null +++ b/data/test/audios/kws_bofangyinyue.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a72a7b8d1e8be6ebaa09aeee0d71472569bc62cc4872ecfdbd1651bb3d03eaba +size 69110 diff --git a/data/test/audios/kws_xiaoyunxiaoyun.wav b/data/test/audios/kws_xiaoyunxiaoyun.wav new file mode 100644 index 00000000..8afe6b7c --- /dev/null +++ b/data/test/audios/kws_xiaoyunxiaoyun.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6b1671bcfa872278c99490cd1acb08297b8df4dc78f268e4b6a582b4364e4a1 +size 297684 diff --git a/data/test/audios/nearend_mic.wav b/data/test/audios/nearend_mic.wav new file mode 100644 index 00000000..e055c2e0 --- /dev/null +++ b/data/test/audios/nearend_mic.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc116af609a66f431f94df6b385ff2aa362f8a2d437c2279f5401e47f9178469 +size 320042 diff --git a/data/test/audios/noise_2ch.wav b/data/test/audios/noise_2ch.wav new file mode 100644 index 00000000..c754e39a --- /dev/null +++ b/data/test/audios/noise_2ch.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d +size 2327764 diff --git a/data/test/audios/speech_with_noise.wav b/data/test/audios/speech_with_noise.wav new file mode 100644 index 00000000..d57488c9 --- /dev/null +++ b/data/test/audios/speech_with_noise.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9354345a6297f4522e690d337546aa9a686a7e61eefcd935478a2141b924db8f +size 76770 diff --git a/data/test/audios/wake_word_with_label_xyxy.wav b/data/test/audios/wake_word_with_label_xyxy.wav new file mode 100644 index 00000000..b7999777 --- /dev/null +++ b/data/test/audios/wake_word_with_label_xyxy.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150 +size 68524 diff --git a/data/test/images/Solvay_conference_1927.png b/data/test/images/Solvay_conference_1927.png new file mode 100755 index 00000000..0c97101d --- /dev/null +++ b/data/test/images/Solvay_conference_1927.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa8ab905e8374a0f94b4bfbfc81da14e762c71eaf64bae85bdd03b07cdf884c2 +size 859206 diff --git a/data/test/images/auto_demo.jpg b/data/test/images/auto_demo.jpg new file mode 100644 index 00000000..30393e53 --- /dev/null +++ b/data/test/images/auto_demo.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76bf84536edbaf192a8a699efc62ba2b06056bac12c426ecfcc2e003d91fbd32 +size 53219 diff --git a/data/test/images/bird.JPEG b/data/test/images/bird.JPEG new file mode 100755 index 00000000..897eb3c8 --- /dev/null +++ b/data/test/images/bird.JPEG @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19fb781a44aec9349a8b73850e53b7eb9b0623d54ebd0cd8577c13bf463b5004 +size 74237 diff --git a/data/test/images/card_detection.jpg b/data/test/images/card_detection.jpg new file mode 100644 index 00000000..86728c2c --- /dev/null +++ b/data/test/images/card_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecbc9d0827cfb92e93e7d75868b1724142685dc20d3b32023c3c657a7b688a9c +size 254845 diff --git a/data/test/images/coco_cn/train.json b/data/test/images/coco_cn/train.json new file mode 100644 index 00000000..634706bd --- /dev/null +++ b/data/test/images/coco_cn/train.json @@ -0,0 +1,51 @@ +[ + { + "image": "train/COCO_train2014_000000496606.jpg", + "caption": [ + "一只黄色的小狗趴在长椅上" + ] + }, + { + "image": "val/COCO_val2014_000000043734.jpg", + "caption": [ + "两只黑色的狗从水里翻着水花游过来" + ] + }, + { + "image": "train/COCO_train2014_000000404748.jpg", + "caption": [ + "两只长颈鹿站在岩石旁的草地上" + ] + }, + { + "image": "val/COCO_val2014_000000574392.jpg", + "caption": [ + "一个木制的公园长椅在森林里。" + ] + }, + { + "image": "train/COCO_train2014_000000563734.jpg", + "caption": [ + "许多公交车排成队在广场上停着。" + ] + }, + { + "image": "train/COCO_train2014_000000197406.jpg", + "caption": [ + "一个男人和一只长颈鹿站在沙滩上" + ] + }, + { + "image": "val/COCO_val2014_000000473869.jpg", + "caption": [ + "一个微笑的男人在厨房里做饭。" + ] + }, + { + "image": "train/COCO_train2014_000000021183.jpg", + "caption": [ + "一个年龄比较大,坐在街道旁座椅上的男人手里握着一个装着写有标语的板子的手推车", + "一个年老的男人坐在街道上的长椅上,手搭在面前放着告示牌的小推车上" + ] + } +] diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000021183.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000021183.jpg new file mode 100644 index 00000000..6d684e76 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000021183.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eff26436ee5ca4146a5c7218c8a1814a324574e92114736792dcc768ac1e566f +size 134292 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000177625.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000177625.jpg new file mode 100644 index 00000000..57d9c322 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000177625.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07fb36fb94301aa067c1c7f9ca4c8c04d6d7282b4a5494e392c54928d242a56b +size 149178 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000197406.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000197406.jpg new file mode 100644 index 00000000..fbc2aeea --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000197406.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33473d2a21e669196271e28eca437696625e4a5e11eb6efc5b57e7961f15cf0d +size 68914 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000275612.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000275612.jpg new file mode 100644 index 00000000..0b36fd13 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000275612.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c1f6dc406b0e08b43668e73f9700e63420eb4e384a53c539062e89315b64ad6 +size 84248 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000404748.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000404748.jpg new file mode 100644 index 00000000..43633b77 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000404748.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed36ab05878caee478d6532777c862af11a4c62182ba989dfb3bf32e41277c65 +size 239503 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000493952.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000493952.jpg new file mode 100644 index 00000000..f8f4a2e9 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000493952.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95815c59443288b019e496d0c81cf8e734b347e8a31d996a9f1463eb506f3717 +size 177175 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000496606.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000496606.jpg new file mode 100644 index 00000000..292457aa --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000496606.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7482e789876cbdd18e1e5f0487d2a10f40be1cf4ce696d8e203da80418ec580b +size 195821 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000563734.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000563734.jpg new file mode 100644 index 00000000..e1083b01 --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000563734.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df3cee336d965ca249b5e4acd9618d0e2d0e267267222408b6565bb331a5fb23 +size 198775 diff --git a/data/test/images/coco_cn/train/COCO_train2014_000000573854.jpg b/data/test/images/coco_cn/train/COCO_train2014_000000573854.jpg new file mode 100644 index 00000000..bfdaeb4d --- /dev/null +++ b/data/test/images/coco_cn/train/COCO_train2014_000000573854.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c03f9d1eb963d6b385e22f3a26d202bea7637d3effd347af53435f5ad9434d72 +size 179422 diff --git a/data/test/images/coco_cn/val.json b/data/test/images/coco_cn/val.json new file mode 100644 index 00000000..cab5adf0 --- /dev/null +++ b/data/test/images/coco_cn/val.json @@ -0,0 +1,52 @@ +[ + { + "image": "train/COCO_train2014_000000573854.jpg", + "caption": [ + "机场跑道的喷气式飞机正准备起飞。" + ] + }, + { + "image": "val/COCO_val2014_000000412975.jpg", + "caption": [ + "一个女孩走下台阶。" + ] + }, + { + "image": "val/COCO_val2014_000000341725.jpg", + "caption": [ + "窗台上蓝色的花瓶里有一束粉色的郁金香。" + ] + }, + { + "image": "val/COCO_val2014_000000163020.jpg", + "caption": [ + "一只海鸥在水面上飞翔。" + ] + }, + { + "image": "train/COCO_train2014_000000177625.jpg", + "caption": [ + "一男一女在聚会上玩电子游戏,男人的脚边趴着一只狗" + ] + }, + { + "image": "train/COCO_train2014_000000275612.jpg", + "caption": [ + "厕所中的一个马桶", + "浴室里,高档的马桶与各式洗浴用品一应俱全。" + ] + }, + { + "image": "train/COCO_train2014_000000493952.jpg", + "caption": [ + "一辆黑色轿车停在一栋大楼前。" + ] + }, + { + "image": "val/COCO_val2014_000000044723.jpg", + "caption": [ + "阴天下一张伦敦塔的照片。", + "一座大楼的顶端悬挂着钟表。" + ] + } +] diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000043734.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000043734.jpg new file mode 100644 index 00000000..b9293cce --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000043734.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c26dc7c54a1202744d50bc2186ea2a49865879a3a3a174099c4e9ecc1199a16a +size 93126 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000044723.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000044723.jpg new file mode 100644 index 00000000..afaf372d --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000044723.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24cd7ad56cf00d57a7b2d182957a8ad6b44d5eb55dfe3bc69ad5a292151d482e +size 122140 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000163020.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000163020.jpg new file mode 100644 index 00000000..de16ebc5 --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000163020.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee9c78c8c141d1bb3cd064f2a003f0786a19c0b2cc54e0cfa2ee2459daf7bebe +size 63796 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000341725.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000341725.jpg new file mode 100644 index 00000000..85f5d815 --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000341725.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2c08174e59610f65797f50e9eea968eec0ba092c5aca69574e70a6e98862da7 +size 92038 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000412975.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000412975.jpg new file mode 100644 index 00000000..5ba16dbd --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000412975.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a324c2f213442f8ab2fcc5f16f59d2d31ec08993b27b13a623b3a32dd4c408ac +size 182587 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000473869.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000473869.jpg new file mode 100644 index 00000000..e6ac2baa --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000473869.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2f1193dc4c0cd50e0a233810fde1875f7211e936c35d9a3754bf71c2c8da84e +size 109371 diff --git a/data/test/images/coco_cn/val/COCO_val2014_000000574392.jpg b/data/test/images/coco_cn/val/COCO_val2014_000000574392.jpg new file mode 100644 index 00000000..b62feea9 --- /dev/null +++ b/data/test/images/coco_cn/val/COCO_val2014_000000574392.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74f3626546a174ca28da8ce35eeea6d62d230da5ff74fd73d37211557c35d83e +size 377231 diff --git a/data/test/images/crowd_counting.jpg b/data/test/images/crowd_counting.jpg new file mode 100644 index 00000000..0468fe5b --- /dev/null +++ b/data/test/images/crowd_counting.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03c9b0ae20b5000b083e8211e2c119176b88db0ea4f48e29b86dcf2f901e382b +size 130079 diff --git a/data/test/images/dogs.jpg b/data/test/images/dogs.jpg new file mode 100644 index 00000000..450a969d --- /dev/null +++ b/data/test/images/dogs.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78094cc48fbcfd9b6d321fe13619ecc72b65e006fc1b4c4458409ade9979486d +size 129862 diff --git a/data/test/images/face_detection.png b/data/test/images/face_detection.png new file mode 100644 index 00000000..3b572877 --- /dev/null +++ b/data/test/images/face_detection.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa3963d1c54e6d3d46e9a59872a99ed955d4050092f5cfe5f591e03d740b7042 +size 653006 diff --git a/data/test/images/face_detection2.jpeg b/data/test/images/face_detection2.jpeg new file mode 100644 index 00000000..7f6025fa --- /dev/null +++ b/data/test/images/face_detection2.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d510ab26ddc58ffea882c8ef850c1f9bd4444772f2bce7ebea3e76944536c3ae +size 48909 diff --git a/data/test/images/face_emotion.jpg b/data/test/images/face_emotion.jpg new file mode 100644 index 00000000..54f22280 --- /dev/null +++ b/data/test/images/face_emotion.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:712b5525e37080d33f62d6657609dbef20e843ccc04ee5c788ea11aa7c08545e +size 123341 diff --git a/data/test/images/face_enhancement/gt/000000.jpg b/data/test/images/face_enhancement/gt/000000.jpg new file mode 100644 index 00000000..13c18e3b --- /dev/null +++ b/data/test/images/face_enhancement/gt/000000.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cd14710143ba1a912e3ef574d0bf71c7e40bf9897522cba07ecae2567343064 +size 850603 diff --git a/data/test/images/face_enhancement/gt/000001.jpg b/data/test/images/face_enhancement/gt/000001.jpg new file mode 100644 index 00000000..d0b7afc0 --- /dev/null +++ b/data/test/images/face_enhancement/gt/000001.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7f166ecb3a6913dbd05a1eb271399cbaa731d1074ac03184c13ae245ca66819 +size 800380 diff --git a/data/test/images/face_enhancement/lq/000000.png b/data/test/images/face_enhancement/lq/000000.png new file mode 100644 index 00000000..8503d219 --- /dev/null +++ b/data/test/images/face_enhancement/lq/000000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e95d11661485fc0e6f326398f953459dcb3e65b7f4a6c892611266067cf8fe3a +size 245773 diff --git a/data/test/images/face_enhancement/lq/000001.png b/data/test/images/face_enhancement/lq/000001.png new file mode 100644 index 00000000..9afb2a0e --- /dev/null +++ b/data/test/images/face_enhancement/lq/000001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03972400b20b3e6f1d056b359d9c9f12952653a67a73b36018504ce9ee9edf9d +size 254261 diff --git a/data/test/images/face_human_hand_detection.jpg b/data/test/images/face_human_hand_detection.jpg new file mode 100644 index 00000000..f94bb547 --- /dev/null +++ b/data/test/images/face_human_hand_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fddc7be8381eb244cd692601f1c1e6cf3484b44bb4e73df0bc7de29352eb487 +size 23889 diff --git a/data/test/images/face_recognition_1.png b/data/test/images/face_recognition_1.png new file mode 100644 index 00000000..eefe2138 --- /dev/null +++ b/data/test/images/face_recognition_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48e541daeb2692907efef47018e41abb5ae6bcd88eb5ff58290d7fe5dc8b2a13 +size 462584 diff --git a/data/test/images/face_recognition_2.png b/data/test/images/face_recognition_2.png new file mode 100644 index 00000000..1292d8cb --- /dev/null +++ b/data/test/images/face_recognition_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9565b43d9f65361b9bad6553b327c2c6f02fd063a4c8dc0f461e88ea461989d +size 357166 diff --git a/data/test/images/facial_expression_recognition.jpg b/data/test/images/facial_expression_recognition.jpg new file mode 100644 index 00000000..a943fa72 --- /dev/null +++ b/data/test/images/facial_expression_recognition.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdb1cef5a5fd5f938a856311011c4820ddc45946a470b9929c61e59b6a065633 +size 161535 diff --git a/data/test/images/generative_multimodal.jpg b/data/test/images/generative_multimodal.jpg new file mode 100644 index 00000000..b7b32939 --- /dev/null +++ b/data/test/images/generative_multimodal.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24b78db10990c809380508b962decb53cb16db582135cb3c7d56c48f71d5ceb8 +size 39683 diff --git a/data/test/images/hand_keypoints.jpg b/data/test/images/hand_keypoints.jpg new file mode 100644 index 00000000..cb445c26 --- /dev/null +++ b/data/test/images/hand_keypoints.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c05d58edee7398de37b8e479410676d6b97cfde69cc003e8356a348067e71988 +size 7750 diff --git a/data/test/images/hand_static.jpg b/data/test/images/hand_static.jpg new file mode 100644 index 00000000..43ae28b1 --- /dev/null +++ b/data/test/images/hand_static.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94b8e281d77ee6d3ea2a8a0c9408ecdbd29fe75f33ea5399b6ea00070ba77bd6 +size 13090 diff --git a/data/test/images/image-text-retrieval.jpg b/data/test/images/image-text-retrieval.jpg new file mode 100644 index 00000000..2d20374a --- /dev/null +++ b/data/test/images/image-text-retrieval.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b012c7e966f6550874ccb85ef9602d483aa89b8623dff9ffcdb0faab8f2ca9ab +size 218143 diff --git a/data/test/images/image_body_reshaping.jpg b/data/test/images/image_body_reshaping.jpg new file mode 100644 index 00000000..d78acb8f --- /dev/null +++ b/data/test/images/image_body_reshaping.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2c1119e3d521cf2e583b1e85fc9c9afd1d44954b433135039a98050a730932d +size 1127557 diff --git a/data/test/images/image_captioning.png b/data/test/images/image_captioning.png new file mode 100644 index 00000000..de3f1918 --- /dev/null +++ b/data/test/images/image_captioning.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af83a94899a6d23339c3ecc5c4c58c57c835af57b531a2f4c50461184f820141 +size 603621 diff --git a/data/test/images/image_classification.png b/data/test/images/image_classification.png new file mode 100644 index 00000000..3d1a2f8c --- /dev/null +++ b/data/test/images/image_classification.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bdb9627c3a40897e84ee186b2a959f272790571644224e1d2efca443f867e12 +size 202823 diff --git a/data/test/images/image_color_enhance.png b/data/test/images/image_color_enhance.png new file mode 100644 index 00000000..ffb4d188 --- /dev/null +++ b/data/test/images/image_color_enhance.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b +size 950676 diff --git a/data/test/images/image_color_enhance/gt/1.png b/data/test/images/image_color_enhance/gt/1.png new file mode 100644 index 00000000..ffb4d188 --- /dev/null +++ b/data/test/images/image_color_enhance/gt/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b +size 950676 diff --git a/data/test/images/image_color_enhance/gt/2.png b/data/test/images/image_color_enhance/gt/2.png new file mode 100644 index 00000000..a84f2543 --- /dev/null +++ b/data/test/images/image_color_enhance/gt/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a4a8a60501976b2c5e753814a346519ef6faff052b53359cf44b4e597e62aaf +size 902214 diff --git a/data/test/images/image_color_enhance/gt/3.png b/data/test/images/image_color_enhance/gt/3.png new file mode 100644 index 00000000..dc04f4bc --- /dev/null +++ b/data/test/images/image_color_enhance/gt/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:326c5e3907926a4af6fec382050026d505d78aab8c5f2e0ecc85ac863abbb94c +size 856195 diff --git a/data/test/images/image_color_enhance/gt/4.png b/data/test/images/image_color_enhance/gt/4.png new file mode 100644 index 00000000..4e888582 --- /dev/null +++ b/data/test/images/image_color_enhance/gt/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:455f364c008be76a392085e7590b9050a628853a9df1e608a40c75a15bc41c5f +size 951993 diff --git a/data/test/images/image_color_enhance/lq/1.png b/data/test/images/image_color_enhance/lq/1.png new file mode 100644 index 00000000..a9641037 --- /dev/null +++ b/data/test/images/image_color_enhance/lq/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f806b26557317f856e7583fb128713579df3354016b368ef32791b283e3be051 +size 932493 diff --git a/data/test/images/image_color_enhance/lq/2.png b/data/test/images/image_color_enhance/lq/2.png new file mode 100644 index 00000000..79176bd1 --- /dev/null +++ b/data/test/images/image_color_enhance/lq/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ec66811ec4f1ec8735b7f0eb897100f80939ba5dc150028fa91bfcd15b5164c +size 896481 diff --git a/data/test/images/image_color_enhance/lq/3.png b/data/test/images/image_color_enhance/lq/3.png new file mode 100644 index 00000000..93f52409 --- /dev/null +++ b/data/test/images/image_color_enhance/lq/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9517b185b0cffc0c830270fd52551e145054daa00c704ed4132589b24ab46e9 +size 828266 diff --git a/data/test/images/image_color_enhance/lq/4.png b/data/test/images/image_color_enhance/lq/4.png new file mode 100644 index 00000000..6a1f659a --- /dev/null +++ b/data/test/images/image_color_enhance/lq/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a233195949ed1c3db9c9a182baf3d8f014620d28bab823aa4d4cc203e602bc6 +size 927552 diff --git a/data/test/images/image_detection.jpg b/data/test/images/image_detection.jpg new file mode 100644 index 00000000..37447ce3 --- /dev/null +++ b/data/test/images/image_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0218020651b6cdcc0051563f75750c8200d34fc49bf34cc053cd59c1f13cad03 +size 128624 diff --git a/data/test/images/image_inpainting/image_inpainting.png b/data/test/images/image_inpainting/image_inpainting.png new file mode 100644 index 00000000..e141012d --- /dev/null +++ b/data/test/images/image_inpainting/image_inpainting.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46db348eae61448f1668ce282caec21375e96c3268d53da44aa67ec32cbf4fa5 +size 2747938 diff --git a/data/test/images/image_inpainting/image_inpainting_mask.png b/data/test/images/image_inpainting/image_inpainting_mask.png new file mode 100644 index 00000000..e30f67e7 --- /dev/null +++ b/data/test/images/image_inpainting/image_inpainting_mask.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:709c1828ed2d56badf2f19a40194da9a5e5e6db2fb73ef55d047407f49bc7a15 +size 27616 diff --git a/data/test/images/image_instance_segmentation.jpg b/data/test/images/image_instance_segmentation.jpg new file mode 100644 index 00000000..f390fc90 --- /dev/null +++ b/data/test/images/image_instance_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e9ab135da7eacabdeeeee11ba4b7bcdd1bfac128cf92a9de9c79f984060ae1e +size 259865 diff --git a/data/test/images/image_matting.png b/data/test/images/image_matting.png new file mode 100644 index 00000000..de3f1918 --- /dev/null +++ b/data/test/images/image_matting.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af83a94899a6d23339c3ecc5c4c58c57c835af57b531a2f4c50461184f820141 +size 603621 diff --git a/data/test/images/image_mplug_vqa.jpg b/data/test/images/image_mplug_vqa.jpg new file mode 100644 index 00000000..57919471 --- /dev/null +++ b/data/test/images/image_mplug_vqa.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b37b706885849037b5fa7fa44a3b78a6375f768d95ce46bfcb8e7329d038a692 +size 181725 diff --git a/data/test/images/image_ocr_recognition.jpg b/data/test/images/image_ocr_recognition.jpg new file mode 100644 index 00000000..b41287cd --- /dev/null +++ b/data/test/images/image_ocr_recognition.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:772b19f76c98044e39330853928624f10e085106a4292b4dd19f865531080747 +size 959 diff --git a/data/test/images/image_panoptic_segmentation.jpg b/data/test/images/image_panoptic_segmentation.jpg new file mode 100644 index 00000000..2a8d826b --- /dev/null +++ b/data/test/images/image_panoptic_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a +size 245864 diff --git a/data/test/images/image_reid_person.jpg b/data/test/images/image_reid_person.jpg new file mode 100644 index 00000000..078468ec --- /dev/null +++ b/data/test/images/image_reid_person.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c9a7e42edc7065c16972ff56267aad63f5233e36aa5a699b84939f5bad73276 +size 2451 diff --git a/data/test/images/image_salient_detection.jpg b/data/test/images/image_salient_detection.jpg new file mode 100644 index 00000000..9c0632d3 --- /dev/null +++ b/data/test/images/image_salient_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70ea0c06f9cfe3882253f7175221d47e394ab9c469076ab220e880b17dbcdd02 +size 48552 diff --git a/data/test/images/image_segmentation.jpg b/data/test/images/image_segmentation.jpg new file mode 100644 index 00000000..a9c0875c --- /dev/null +++ b/data/test/images/image_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af6fa61274e497ecc170de5adc4b8e7ac89eba2bc22a6aa119b08ec7adbe9459 +size 146140 diff --git a/data/test/images/image_semantic_segmentation.jpg b/data/test/images/image_semantic_segmentation.jpg new file mode 100644 index 00000000..2a8d826b --- /dev/null +++ b/data/test/images/image_semantic_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a +size 245864 diff --git a/data/test/images/image_wolf.jpeg b/data/test/images/image_wolf.jpeg new file mode 100644 index 00000000..32d0c567 --- /dev/null +++ b/data/test/images/image_wolf.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbe3c719d25c2c90349c3c280e74f46f315a490443655ceba8b8a203af0f7259 +size 171378 diff --git a/data/test/images/img2img_input.jpg b/data/test/images/img2img_input.jpg new file mode 100644 index 00000000..2da79e75 --- /dev/null +++ b/data/test/images/img2img_input.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e4cbf844cd16a892a7d2f2764b1537c346675d3b0145016d6836441ba907366 +size 9195 diff --git a/data/test/images/img2img_input_mask.png b/data/test/images/img2img_input_mask.png new file mode 100644 index 00000000..131fc37a --- /dev/null +++ b/data/test/images/img2img_input_mask.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33b3d3076e191fa92511bf69fa76e1222b3b3be0049e711c948a1218b587510c +size 4805 diff --git a/data/test/images/img2img_input_masked_img.png b/data/test/images/img2img_input_masked_img.png new file mode 100644 index 00000000..7f7c256b --- /dev/null +++ b/data/test/images/img2img_input_masked_img.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99c2b02a927b86ff194287ea4c5a05349dd800cff2b523212d1dad378c252feb +size 103334 diff --git a/data/test/images/img2img_style.jpg b/data/test/images/img2img_style.jpg new file mode 100644 index 00000000..1b361f11 --- /dev/null +++ b/data/test/images/img2img_style.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef06465535002fd565f3e50d16772bdcb8e47f474fb7d7c318510fff49ab1090 +size 212790 diff --git a/data/test/images/keypoints_detect/000000438304.jpg b/data/test/images/keypoints_detect/000000438304.jpg new file mode 100644 index 00000000..5d03c471 --- /dev/null +++ b/data/test/images/keypoints_detect/000000438304.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64ab6a5556b022cbd398d98cd5bb243a4ee6e4ea6e3285f433eb78b76b53fd4e +size 269177 diff --git a/data/test/images/keypoints_detect/000000438862.jpg b/data/test/images/keypoints_detect/000000438862.jpg new file mode 100644 index 00000000..47946a91 --- /dev/null +++ b/data/test/images/keypoints_detect/000000438862.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3689831ed23f734ebab9405f48ffbfbbefb778e9de3101a9d56e421ea45288cf +size 248595 diff --git a/data/test/images/keypoints_detect/000000439522.jpg b/data/test/images/keypoints_detect/000000439522.jpg new file mode 100644 index 00000000..32b59e7a --- /dev/null +++ b/data/test/images/keypoints_detect/000000439522.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:663545f71af556370c7cba7fd8010a665d00c0b477075562a3d7669c6d853ad3 +size 107685 diff --git a/data/test/images/keypoints_detect/000000440336.jpg b/data/test/images/keypoints_detect/000000440336.jpg new file mode 100644 index 00000000..b61d7c8d --- /dev/null +++ b/data/test/images/keypoints_detect/000000440336.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5c2df473a26427ae57950acec86d1e4d3a49cdf1a18d427cd1a354465408f00 +size 102909 diff --git a/data/test/images/keypoints_detect/000000442836.jpg b/data/test/images/keypoints_detect/000000442836.jpg new file mode 100644 index 00000000..9642df68 --- /dev/null +++ b/data/test/images/keypoints_detect/000000442836.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44b225eaff012bd016fcfe8a3dbeace93fd418164f40e4b5f5b9f0d76f39097b +size 308635 diff --git a/data/test/images/keypoints_detect/000000447088.jpg b/data/test/images/keypoints_detect/000000447088.jpg new file mode 100644 index 00000000..8d4f1752 --- /dev/null +++ b/data/test/images/keypoints_detect/000000447088.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:510da487b16303646cf4b500cae0a4168cba2feb3dd706c007a3f5c64400501c +size 148413 diff --git a/data/test/images/keypoints_detect/000000447917.jpg b/data/test/images/keypoints_detect/000000447917.jpg new file mode 100644 index 00000000..542c7b3a --- /dev/null +++ b/data/test/images/keypoints_detect/000000447917.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbaa52b9ecc59b899500db9200ce65b17aa8b87172c8c70de585fa27c80e7ad1 +size 238442 diff --git a/data/test/images/keypoints_detect/000000448263.jpg b/data/test/images/keypoints_detect/000000448263.jpg new file mode 100644 index 00000000..474563e2 --- /dev/null +++ b/data/test/images/keypoints_detect/000000448263.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72fcff7fd4da5ede2d3c1a31449769b0595685f7250597f05cd176c4c80ced03 +size 37753 diff --git a/data/test/images/keypoints_detect/img_test_wholebody.jpg b/data/test/images/keypoints_detect/img_test_wholebody.jpg new file mode 100644 index 00000000..40a9f3f8 --- /dev/null +++ b/data/test/images/keypoints_detect/img_test_wholebody.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dec0fbb931cb609bf481e56b89cd2fbbab79839f22832c3bbe69a8fae2769cdd +size 167407 diff --git a/data/test/images/keypoints_detect/test_img_face_2d_keypoints.png b/data/test/images/keypoints_detect/test_img_face_2d_keypoints.png new file mode 100644 index 00000000..00311c33 --- /dev/null +++ b/data/test/images/keypoints_detect/test_img_face_2d_keypoints.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:331ead75033fa2f01f6be72a2f8e34d581fcb593308067815d4bb136bb13b766 +size 54390 diff --git a/data/test/images/marilyn_monroe_4.jpg b/data/test/images/marilyn_monroe_4.jpg new file mode 100644 index 00000000..cdcf22b0 --- /dev/null +++ b/data/test/images/marilyn_monroe_4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b425fb89442e4c6c32c71c17c1c1afef8a2c5bc9ec9529b5a0fc21c53e1a02b +size 39248 diff --git a/data/test/images/mog_face_detection.jpg b/data/test/images/mog_face_detection.jpg new file mode 100644 index 00000000..c95881fe --- /dev/null +++ b/data/test/images/mog_face_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 +size 87228 diff --git a/data/test/images/mtcnn_face_detection.jpg b/data/test/images/mtcnn_face_detection.jpg new file mode 100644 index 00000000..c95881fe --- /dev/null +++ b/data/test/images/mtcnn_face_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 +size 87228 diff --git a/data/test/images/multimodal_similarity.jpg b/data/test/images/multimodal_similarity.jpg new file mode 100644 index 00000000..70a2b844 --- /dev/null +++ b/data/test/images/multimodal_similarity.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f24abbba43782d733dedbb0b4f416635af50263862e5632963ac9263e430555 +size 88542 diff --git a/data/test/images/noisy-demo-0.png b/data/test/images/noisy-demo-0.png new file mode 100644 index 00000000..e3321ecb --- /dev/null +++ b/data/test/images/noisy-demo-0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:403034182fa320130dae0d75b92e85e0850771378e674d65455c403a4958e29c +size 170716 diff --git a/data/test/images/noisy-demo-1.png b/data/test/images/noisy-demo-1.png new file mode 100644 index 00000000..f79c51cd --- /dev/null +++ b/data/test/images/noisy-demo-1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebd5dacad9b75ef80f87eb785d7818421dadb63257da0e91e123766c5913f855 +size 149971 diff --git a/data/test/images/ocr_detection.jpg b/data/test/images/ocr_detection.jpg new file mode 100644 index 00000000..c347810e --- /dev/null +++ b/data/test/images/ocr_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c8435db5583400be5d11a2c17910c96133b462c8a99ccaf0e19f4aac34e0a94 +size 141149 diff --git a/data/test/images/ocr_recognition.jpg b/data/test/images/ocr_recognition.jpg new file mode 100644 index 00000000..069ac03d --- /dev/null +++ b/data/test/images/ocr_recognition.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d68cfcaa7cc7b8276877c2dfa022deebe82076bc178ece1bfe7fd5423cd5b99 +size 60009 diff --git a/data/test/images/ocr_recognition_document.png b/data/test/images/ocr_recognition_document.png new file mode 100644 index 00000000..d74018bb --- /dev/null +++ b/data/test/images/ocr_recognition_document.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29f2ad929c852f6456367054d13e113078cf06b763fe54d73fd324f789331aa3 +size 61611 diff --git a/data/test/images/product_embed_bag.jpg b/data/test/images/product_embed_bag.jpg new file mode 100644 index 00000000..8427c028 --- /dev/null +++ b/data/test/images/product_embed_bag.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08691a9373aa6d05b236a4ba788f3eccdea4c37aa77b30fc94b02ec3e1f18210 +size 367017 diff --git a/data/test/images/product_segmentation.jpg b/data/test/images/product_segmentation.jpg new file mode 100644 index 00000000..c188a69e --- /dev/null +++ b/data/test/images/product_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a16038f7809127eb3e03cbae049592d193707e095309daca78f7d108d67fe4ec +size 108357 diff --git a/data/test/images/retina_face_detection.jpg b/data/test/images/retina_face_detection.jpg new file mode 100644 index 00000000..c95881fe --- /dev/null +++ b/data/test/images/retina_face_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 +size 87228 diff --git a/data/test/images/shop_segmentation.jpg b/data/test/images/shop_segmentation.jpg new file mode 100644 index 00000000..ec02881d --- /dev/null +++ b/data/test/images/shop_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5ecc371c8b0ca09d0e11df89bc549000937eafc451929586426fe657ade25a0 +size 238607 diff --git a/data/test/images/skin_retouching.png b/data/test/images/skin_retouching.png new file mode 100644 index 00000000..a0b8df2a --- /dev/null +++ b/data/test/images/skin_retouching.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fcd36e0ada8a506bb09d3e0f3594e2be978194ea4123e066331c0bcb7fc79bc +size 683425 diff --git a/data/test/images/style_transfer_content.jpg b/data/test/images/style_transfer_content.jpg new file mode 100644 index 00000000..5602662d --- /dev/null +++ b/data/test/images/style_transfer_content.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f33a6ad9fcd7367cec2e81b8b0e4234d4f5f7d1be284d48085a25bb6d03782d7 +size 72130 diff --git a/data/test/images/style_transfer_style.jpg b/data/test/images/style_transfer_style.jpg new file mode 100644 index 00000000..820b093f --- /dev/null +++ b/data/test/images/style_transfer_style.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1af09b2c18a6674b7d88849cb87564dd77e1ce04d1517bb085449b614cc0c8d8 +size 376101 diff --git a/data/test/images/text_driven_segmentation.jpg b/data/test/images/text_driven_segmentation.jpg new file mode 100644 index 00000000..e3320b1f --- /dev/null +++ b/data/test/images/text_driven_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c7d2f279e3b317f1d0de18410a0585e122166fa2464c17b88a0c813f6c58bd4 +size 67861 diff --git a/data/test/images/ulfd_face_detection.jpg b/data/test/images/ulfd_face_detection.jpg new file mode 100644 index 00000000..c95881fe --- /dev/null +++ b/data/test/images/ulfd_face_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9 +size 87228 diff --git a/data/test/images/virtual_tryon_cloth.jpg b/data/test/images/virtual_tryon_cloth.jpg new file mode 100644 index 00000000..baa4d3aa --- /dev/null +++ b/data/test/images/virtual_tryon_cloth.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ce0d25b3392f140bf35fba9c6711fdcfc2efde536600aa48dace35462e81adf +size 8825 diff --git a/data/test/images/virtual_tryon_model.jpg b/data/test/images/virtual_tryon_model.jpg new file mode 100644 index 00000000..2862a8be --- /dev/null +++ b/data/test/images/virtual_tryon_model.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb76a61306d3d311d440c5c695958909166e04fb34c827d74d766ba830945d6f +size 5034 diff --git a/data/test/images/virtual_tryon_pose.jpg b/data/test/images/virtual_tryon_pose.jpg new file mode 100644 index 00000000..41804706 --- /dev/null +++ b/data/test/images/virtual_tryon_pose.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ab9baf18074b6b5655ee546794789395757486d6e2180c2627aad47b819e505 +size 11778 diff --git a/data/test/images/visual_grounding.png b/data/test/images/visual_grounding.png new file mode 100644 index 00000000..a37791ec --- /dev/null +++ b/data/test/images/visual_grounding.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b89734b9c9d89342e58fbe406d3b9bdc8e07447cb170a4ae2743000471fc969 +size 23069 diff --git a/data/test/images/visual_question_answering.png b/data/test/images/visual_question_answering.png new file mode 100644 index 00000000..e39d34a0 --- /dev/null +++ b/data/test/images/visual_question_answering.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d53e9fbdd129b234dcbec9b9fe6a15a0e05820e802a873f95955574267bbd2ff +size 121141 diff --git a/data/test/regression/fill_mask_bert_zh.bin b/data/test/regression/fill_mask_bert_zh.bin new file mode 100644 index 00000000..17c28b81 --- /dev/null +++ b/data/test/regression/fill_mask_bert_zh.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:541183383bb06aa3ca2c44a68cd51c1be5e3e984a1dee2c58092b9552660f3ce +size 61883 diff --git a/data/test/regression/fill_mask_sbert_en.bin b/data/test/regression/fill_mask_sbert_en.bin new file mode 100644 index 00000000..09aaf300 --- /dev/null +++ b/data/test/regression/fill_mask_sbert_en.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f0afcd9d2aa5ac9569114203bd9db4f1a520c903a88fd4854370cdde0e7eab7 +size 119940 diff --git a/data/test/regression/fill_mask_sbert_zh.bin b/data/test/regression/fill_mask_sbert_zh.bin new file mode 100644 index 00000000..62581a26 --- /dev/null +++ b/data/test/regression/fill_mask_sbert_zh.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4eae921001139d7e3c06331c9ef2213f8fc1c23512acd95751559866fb770e96 +size 121855 diff --git a/data/test/regression/fill_mask_veco_en.bin b/data/test/regression/fill_mask_veco_en.bin new file mode 100644 index 00000000..4d2dba7d --- /dev/null +++ b/data/test/regression/fill_mask_veco_en.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f97d34d7450d17d0a93647129ab10d16b1f6e70c34a73b6f7687b79519ee4f71 +size 121563 diff --git a/data/test/regression/fill_mask_veco_zh.bin b/data/test/regression/fill_mask_veco_zh.bin new file mode 100644 index 00000000..a6eb5621 --- /dev/null +++ b/data/test/regression/fill_mask_veco_zh.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8355f27a3235209f206b5e75f4400353e5989e94cf4d71270b42ded8821d536 +size 121563 diff --git a/data/test/regression/sbert-base-tnews.bin b/data/test/regression/sbert-base-tnews.bin new file mode 100644 index 00000000..d2c63ab0 --- /dev/null +++ b/data/test/regression/sbert-base-tnews.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:344ef971bdf310b76c6571d1f4994ab6abc5edc659654d71a4f75b14a30960c2 +size 152926 diff --git a/data/test/regression/sbert_nli.bin b/data/test/regression/sbert_nli.bin new file mode 100644 index 00000000..52e31692 --- /dev/null +++ b/data/test/regression/sbert_nli.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0aeb07b6c9b40a0cfa7492e839431764e9bece93c906833a07c05e83520a399 +size 63161 diff --git a/data/test/regression/sbert_sen_sim.bin b/data/test/regression/sbert_sen_sim.bin new file mode 100644 index 00000000..1c8efb81 --- /dev/null +++ b/data/test/regression/sbert_sen_sim.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7aa5c7a2565ccf0d2eea4baf8adbd0e020dbe36a7159b31156c53141cc9b2df2 +size 63165 diff --git a/data/test/regression/sbert_ws_en.bin b/data/test/regression/sbert_ws_en.bin new file mode 100644 index 00000000..3ad45356 --- /dev/null +++ b/data/test/regression/sbert_ws_en.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc6de82a8485fbfa008f6c2d5411cd07ba03e4a780bcb4e67efc6fba3c6ce92f +size 63597 diff --git a/data/test/regression/sbert_ws_zh.bin b/data/test/regression/sbert_ws_zh.bin new file mode 100644 index 00000000..a85d787f --- /dev/null +++ b/data/test/regression/sbert_ws_zh.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d98ac11a4e9e2744a7402a5cc912da991a41938bbc5dd60f15ee5c6b3196030 +size 63349 diff --git a/data/test/regression/sbert_zero_shot.bin b/data/test/regression/sbert_zero_shot.bin new file mode 100644 index 00000000..04171523 --- /dev/null +++ b/data/test/regression/sbert_zero_shot.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01f9b9bf6f8bbf9bb377d4cb6f399b2e5e065381f5b7332343e0db7b4fae72a5 +size 62519 diff --git a/data/test/videos/Walking.54138969.mp4 b/data/test/videos/Walking.54138969.mp4 new file mode 100644 index 00000000..d4355290 --- /dev/null +++ b/data/test/videos/Walking.54138969.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7663f9a32ea57086bf66c4b9e9ebe0fd418986c67716c7be02ca917e72ddc0ba +size 8155895 diff --git a/data/test/videos/action_detection_test_video.mp4 b/data/test/videos/action_detection_test_video.mp4 new file mode 100644 index 00000000..e2ea1d80 --- /dev/null +++ b/data/test/videos/action_detection_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b7c3bc7c82ea5fee9d83130041df01046d89143ff77058b04577455ff6fdc92 +size 3191059 diff --git a/data/test/videos/action_recognition_test_video.mp4 b/data/test/videos/action_recognition_test_video.mp4 new file mode 100644 index 00000000..9197b770 --- /dev/null +++ b/data/test/videos/action_recognition_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24dc4237b1197321ee8486bb983fa01fd47e2b4afdb3c2df24229e5f2bd20119 +size 1475924 diff --git a/data/test/videos/dog.avi b/data/test/videos/dog.avi new file mode 100644 index 00000000..afcda087 --- /dev/null +++ b/data/test/videos/dog.avi @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:469090fb217a34a2c096cfd42c251da69dca9fcd1a3c1faae7d29183c1816c14 +size 12834294 diff --git a/data/test/videos/live_category_test_video.mp4 b/data/test/videos/live_category_test_video.mp4 new file mode 100644 index 00000000..30529812 --- /dev/null +++ b/data/test/videos/live_category_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c09750586d1693b6c521d98907c3290d78635a2fb33c76db0132cd2b8ef90f0 +size 1019267 diff --git a/data/test/videos/mask_dir/mask_00000_00320.png b/data/test/videos/mask_dir/mask_00000_00320.png new file mode 100644 index 00000000..2eae71a1 --- /dev/null +++ b/data/test/videos/mask_dir/mask_00000_00320.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b158f6029d9763d7f84042f7c5835f398c688fdbb6b3f4fe6431101d4118c66c +size 2766 diff --git a/data/test/videos/mask_dir/mask_00321_00633.png b/data/test/videos/mask_dir/mask_00321_00633.png new file mode 100644 index 00000000..89633eb6 --- /dev/null +++ b/data/test/videos/mask_dir/mask_00321_00633.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0dcf46b93077e2229ab69cd6ddb80e2689546c575ee538bb2033fee1124ef3e3 +size 2761 diff --git a/data/test/videos/movie_scene_segmentation_test_video.mp4 b/data/test/videos/movie_scene_segmentation_test_video.mp4 new file mode 100644 index 00000000..21ea3cb1 --- /dev/null +++ b/data/test/videos/movie_scene_segmentation_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03002807dc2aa180c3ae104e764c7a4d6c421d186a5d552f97d338467ae6c443 +size 12722029 diff --git a/data/test/videos/multi_modal_test_video_9770.mp4 b/data/test/videos/multi_modal_test_video_9770.mp4 new file mode 100644 index 00000000..45245b52 --- /dev/null +++ b/data/test/videos/multi_modal_test_video_9770.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33e21c16d5388684b61d7251b9d4e418f8146c3ba3fa400ebd8d913058687cfc +size 431888 diff --git a/data/test/videos/referring_video_object_segmentation_test_video.mp4 b/data/test/videos/referring_video_object_segmentation_test_video.mp4 new file mode 100644 index 00000000..529595a5 --- /dev/null +++ b/data/test/videos/referring_video_object_segmentation_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a49c9bc74a60860c360a4bf4509fe9db915279aaabd953f354f2c38e9be1e6cb +size 2924691 diff --git a/data/test/videos/test_realtime_vod.mp4 b/data/test/videos/test_realtime_vod.mp4 new file mode 100644 index 00000000..a0e44852 --- /dev/null +++ b/data/test/videos/test_realtime_vod.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f58df1d25590c158ae0a04b3999bd44b610cdaddb17d78afd84c34b3f00d4e87 +size 4068783 diff --git a/data/test/videos/video_category_test_video.mp4 b/data/test/videos/video_category_test_video.mp4 new file mode 100644 index 00000000..195af371 --- /dev/null +++ b/data/test/videos/video_category_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cfc935328ecace53338050a6789250e08b9d17a52efa2339b0e133edc1fae9d4 +size 3943349 diff --git a/data/test/videos/video_inpainting_test.mp4 b/data/test/videos/video_inpainting_test.mp4 new file mode 100644 index 00000000..61f96fac --- /dev/null +++ b/data/test/videos/video_inpainting_test.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c9870df5a86acaaec67063183dace795479cd0f05296f13058995f475149c56 +size 2957783 diff --git a/docker/.dockerignore b/docker/.dockerignore new file mode 100644 index 00000000..14284cb6 --- /dev/null +++ b/docker/.dockerignore @@ -0,0 +1,4 @@ +*.sh +*.md +*.dockerfile +*.zip diff --git a/docker/Dockerfile.ubuntu b/docker/Dockerfile.ubuntu new file mode 100644 index 00000000..6dafbc3e --- /dev/null +++ b/docker/Dockerfile.ubuntu @@ -0,0 +1,85 @@ +ARG BASE_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.3.0-cudnn8-devel +FROM $BASE_IMAGE +ARG DEBIAN_FRONTEND=noninteractive +ENV TZ=Asia/Shanghai +ENV CONDA_DIR /opt/conda +ENV PATH="${CONDA_DIR}/bin:${PATH}" +ENV arch=x86_64 +SHELL ["/bin/bash", "-c"] +COPY docker/rcfiles /tmp/resources +RUN apt-get update && apt-get install -y --reinstall ca-certificates && \ + cp /tmp/resources/ubuntu20.04_sources.tuna /etc/apt/sources.list && \ + apt-get update && \ + apt-get install -y locales wget git vim ffmpeg libsm6 tzdata language-pack-zh-hans ttf-wqy-microhei ttf-wqy-zenhei xfonts-wqy libxext6 build-essential ninja-build && \ + wget https://packagecloud.io/github/git-lfs/packages/debian/bullseye/git-lfs_3.2.0_amd64.deb/download -O ./git-lfs_3.2.0_amd64.deb && \ + dpkg -i ./git-lfs_3.2.0_amd64.deb && \ + rm -f ./git-lfs_3.2.0_amd64.deb && \ + locale-gen zh_CN && \ + locale-gen zh_CN.utf8 && \ + update-locale LANG=zh_CN.UTF-8 LC_ALL=zh_CN.UTF-8 LANGUAGE=zh_CN.UTF-8 && \ + ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ + dpkg-reconfigure --frontend noninteractive tzdata && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ENV LANG=zh_CN.UTF-8 LANGUAGE=zh_CN.UTF-8 LC_ALL=zh_CN.UTF-8 + +#install and config python +ARG PYTHON_VERSION=3.7.13 +RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${arch}.sh -O ./miniconda.sh && \ + /bin/bash miniconda.sh -b -p /opt/conda && \ + rm -f miniconda.sh && \ + ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ + echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ + cp /tmp/resources/conda.tuna ~/.condarc && \ + source /root/.bashrc && \ + conda install --yes python==${PYTHON_VERSION} && \ + pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ + pip config set install.trusted-host pypi.tuna.tsinghua.edu.cn + +ARG USE_GPU=True + +# install pytorch +ARG TORCH_VERSION=1.12.0 +ARG CUDATOOLKIT_VERSION=11.3 +RUN if [ "$USE_GPU" = "True" ] ; then \ + pip install --no-cache-dir torch==$TORCH_VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113; \ + else \ + pip install --no-cache-dir torch==$TORCH_VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu; \ + fi + +# install tensorflow +ARG TENSORFLOW_VERSION=1.15.5 +RUN if [ "$USE_GPU" = "True" ] ; then \ + pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \ + else \ + pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \ + fi + +RUN if [ "$USE_GPU" = "True" ] ; then \ + CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="5.0 5.2 6.0 6.1 7.0 7.5 8.0 8.6" MMCV_WITH_OPS=1 MAX_JOBS=8 FORCE_CUDA=1 pip install --no-cache-dir mmcv-full && pip cache purge; \ + else \ + MMCV_WITH_OPS=1 MAX_JOBS=8 pip install --no-cache-dir mmcv-full && pip cache purge; \ + fi + +# install modelscope +COPY requirements /var/modelscope +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -r /var/modelscope/framework.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ + pip install --no-cache-dir -r /var/modelscope/audio.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ + pip install --no-cache-dir -r /var/modelscope/cv.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ + pip install --no-cache-dir -r /var/modelscope/multi-modal.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ + pip install --no-cache-dir -r /var/modelscope/nlp.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \ + pip cache purge + +# default shell bash +ENV SHELL=/bin/bash + +# install special package +RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq fasttext https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/xtcocotools-1.12-cp37-cp37m-linux_x86_64.whl + +RUN if [ "$USE_GPU" = "True" ] ; then \ + pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \ + else \ + pip install --no-cache-dir dgl dglgo -f https://data.dgl.ai/wheels/repo.html; \ + fi diff --git a/docker/pytorch.dockerfile b/docker/pytorch.dockerfile new file mode 100644 index 00000000..a1fe5b15 --- /dev/null +++ b/docker/pytorch.dockerfile @@ -0,0 +1,54 @@ +# syntax = docker/dockerfile:experimental +# +# NOTE: To build this you will need a docker version > 18.06 with +# experimental enabled and DOCKER_BUILDKIT=1 +# +# If you do not use buildkit you are not going to have a good time +# +# For reference: +# https://docs.docker.com/develop/develop-images/build_enhancements/ + +# ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 +# FROM ${BASE_IMAGE} as dev-base + +# FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base +FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel +# FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime +# config pip source +RUN mkdir /root/.pip +COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf +COPY docker/rcfiles/sources.list.aliyun /etc/apt/sources.list + +# Install essential Ubuntu packages +RUN apt-get update &&\ + apt-get install -y software-properties-common \ + build-essential \ + git \ + wget \ + vim \ + curl \ + zip \ + zlib1g-dev \ + unzip \ + pkg-config \ + libsndfile1 + +# install modelscope and its python env +WORKDIR /opt/modelscope +COPY . . +RUN pip install -r requirements.txt +# RUN --mount=type=cache,target=/opt/ccache \ +# python setup.py install + +# opencv-python-headless conflict with opencv-python installed +RUN python setup.py install \ + && pip uninstall -y opencv-python-headless + +# prepare modelscope libs +COPY docker/scripts/install_libs.sh /tmp/ +RUN bash /tmp/install_libs.sh && \ + rm -rf /tmp/install_libs.sh + +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/modelscope/lib64 + +WORKDIR /workspace diff --git a/docker/rcfiles/conda.tuna b/docker/rcfiles/conda.tuna new file mode 100644 index 00000000..ce8a2908 --- /dev/null +++ b/docker/rcfiles/conda.tuna @@ -0,0 +1,15 @@ +channels: + - defaults +show_channel_urls: true +default_channels: + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2 +custom_channels: + conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud + msys2: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud + bioconda: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud + menpo: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud + pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud + pytorch-lts: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud + simpleitk: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud diff --git a/docker/rcfiles/pip.conf.tsinghua b/docker/rcfiles/pip.conf.tsinghua new file mode 100644 index 00000000..4242075a --- /dev/null +++ b/docker/rcfiles/pip.conf.tsinghua @@ -0,0 +1,2 @@ +[global] +index-url=https://pypi.tuna.tsinghua.edu.cn/simple diff --git a/docker/rcfiles/sources.list.aliyun b/docker/rcfiles/sources.list.aliyun new file mode 100644 index 00000000..120bb1f1 --- /dev/null +++ b/docker/rcfiles/sources.list.aliyun @@ -0,0 +1,25 @@ +deb http://mirrors.aliyun.com/ubuntu/ bionic main restricted +# deb-src http://mirrors.aliyun.com/ubuntu/ bionic main restricted + +deb http://mirrors.aliyun.com/ubuntu/ bionic-updates main restricted +# deb-src http://mirrors.aliyun.com/ubuntu/ bionic-updates main restricted + +deb http://mirrors.aliyun.com/ubuntu/ bionic universe +# deb-src http://mirrors.aliyun.com/ubuntu/ bionic universe +deb http://mirrors.aliyun.com/ubuntu/ bionic-updates universe +# deb-src http://mirrors.aliyun.com/ubuntu/ bionic-updates universe + +deb http://mirrors.aliyun.com/ubuntu/ bionic multiverse +# deb-src http://mirrors.aliyun.com/ubuntu/ bionic multiverse +deb http://mirrors.aliyun.com/ubuntu/ bionic-updates multiverse +# deb-src http://mirrors.aliyun.com/ubuntu/ bionic-updates multiverse + +deb http://mirrors.aliyun.com/ubuntu/ bionic-backports main restricted universe multiverse +# deb-src http://mirrors.aliyun.com/ubuntu/ bionic-backports main restricted universe multiverse + +deb http://mirrors.aliyun.com/ubuntu bionic-security main restricted +# deb-src http://mirrors.aliyun.com/ubuntu bionic-security main restricted +deb http://mirrors.aliyun.com/ubuntu bionic-security universe +# deb-src http://mirrors.aliyun.com/ubuntu bionic-security universe +deb http://mirrors.aliyun.com/ubuntu bionic-security multiverse +# deb-src http://mirrors.aliyun.com/ubuntu bionic-security multiverse diff --git a/docker/rcfiles/ubuntu20.04_sources.tuna b/docker/rcfiles/ubuntu20.04_sources.tuna new file mode 100644 index 00000000..a247bbfa --- /dev/null +++ b/docker/rcfiles/ubuntu20.04_sources.tuna @@ -0,0 +1,13 @@ +# 默认注释了源码镜像以提高 apt update 速度,如有需要可自行取消注释 +deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse +# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse +deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse +# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse +deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse +# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse +deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse +# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse + +# 预发布软件源,不建议启用 +# deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-proposed main restricted universe multiverse +# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-proposed main restricted universe multiverse diff --git a/docker/rcfiles/user.vimrc b/docker/rcfiles/user.vimrc new file mode 100644 index 00000000..590aca43 --- /dev/null +++ b/docker/rcfiles/user.vimrc @@ -0,0 +1,10 @@ +set nocompatible +set encoding=utf-8 +set hlsearch +set smartindent +set ruler +set number +set ts=2 +set sw=2 +set expandtab +autocmd FileType make setlocal noexpandtab diff --git a/docker/scripts/install_libs.sh b/docker/scripts/install_libs.sh new file mode 100644 index 00000000..dea0dc19 --- /dev/null +++ b/docker/scripts/install_libs.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +set -eo pipefail + +ModelScopeLib=/usr/local/modelscope/lib64 + +if [ ! -d /usr/local/modelscope ]; then + mkdir -p $ModelScopeLib +fi + +# audio libs +wget "http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/libs/audio/libmitaec_pyio.so" -O ${ModelScopeLib}/libmitaec_pyio.so diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..d0c3cbf1 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..a051c6be --- /dev/null +++ b/docs/README.md @@ -0,0 +1,37 @@ +## maintain docs +1. build docs + ```shell + # in root directory: + make docs + ``` + +2. doc string format + + We adopt the google style docstring format as the standard, please refer to the following documents. + 1. Google Python style guide docstring [link](http://google.github.io/styleguide/pyguide.html#381-docstrings) + 2. Google docstring example [link](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) + 3. sample:torch.nn.modules.conv [link](https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv1d) + 4. load function as an example: + + ```python + def load(file, file_format=None, **kwargs): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml". + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('oss://path/of/your/file') # file is storaged in petrel + + Returns: + The content from the file. + """ + ``` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..3d64bb3a --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/api/modelscope.fileio.format.rst b/docs/source/api/modelscope.fileio.format.rst new file mode 100644 index 00000000..2c7b11de --- /dev/null +++ b/docs/source/api/modelscope.fileio.format.rst @@ -0,0 +1,34 @@ +modelscope.fileio.format package +================================ + +.. automodule:: modelscope.fileio.format + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.fileio.format.base module +------------------------------------ + +.. automodule:: modelscope.fileio.format.base + :members: + :undoc-members: + :show-inheritance: + +modelscope.fileio.format.json module +------------------------------------ + +.. automodule:: modelscope.fileio.format.json + :members: + :undoc-members: + :show-inheritance: + +modelscope.fileio.format.yaml module +------------------------------------ + +.. automodule:: modelscope.fileio.format.yaml + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.fileio.rst b/docs/source/api/modelscope.fileio.rst new file mode 100644 index 00000000..3f4ae1ca --- /dev/null +++ b/docs/source/api/modelscope.fileio.rst @@ -0,0 +1,34 @@ +modelscope.fileio package +========================= + +.. automodule:: modelscope.fileio + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.fileio.format + +Submodules +---------- + +modelscope.fileio.file module +----------------------------- + +.. automodule:: modelscope.fileio.file + :members: + :undoc-members: + :show-inheritance: + +modelscope.fileio.io module +--------------------------- + +.. automodule:: modelscope.fileio.io + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.hub.rst b/docs/source/api/modelscope.hub.rst new file mode 100644 index 00000000..47d210c2 --- /dev/null +++ b/docs/source/api/modelscope.hub.rst @@ -0,0 +1,50 @@ +modelscope.hub package +========================= + +.. automodule:: modelscope.hub + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.hub.utils + +Submodules +---------- + +modelscope.hub.api module +----------------------------- + +.. automodule:: modelscope.hub.api + :members: + :undoc-members: + :show-inheritance: + +modelscope.hub.git module +--------------------------- + +.. automodule:: modelscope.hub.git + :members: + :undoc-members: + :show-inheritance: + +modelscope.hub.file_download module +--------------------------- + +.. automodule:: modelscope.hub.file_download + :members: + :undoc-members: + :show-inheritance: + +modelscope.hub.snapshot_download module +--------------------------- + +.. automodule:: modelscope.hub.snapshot_download + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.hub.utils.rst b/docs/source/api/modelscope.hub.utils.rst new file mode 100644 index 00000000..74d8ae96 --- /dev/null +++ b/docs/source/api/modelscope.hub.utils.rst @@ -0,0 +1,26 @@ +modelscope.hub.utils package +=============================== + +.. automodule:: modelscope.hub.utils + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.hub.utils.caching module +------------------------------------------------------- + +.. automodule:: modelscope.hub.utils.caching + :members: + :undoc-members: + :show-inheritance: + +modelscope.pipelines.cv.image\_matting\_pipeline module +------------------------------------------------------- + +.. automodule:: modelscope.hub.utils.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.models.cv.cartoon.facelib.LK.rst b/docs/source/api/modelscope.models.cv.cartoon.facelib.LK.rst new file mode 100644 index 00000000..848c7d67 --- /dev/null +++ b/docs/source/api/modelscope.models.cv.cartoon.facelib.LK.rst @@ -0,0 +1,18 @@ +modelscope.models.cv.cartoon.facelib.LK package +=============================================== + +.. automodule:: modelscope.models.cv.cartoon.facelib.LK + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.models.cv.cartoon.facelib.LK.lk module +------------------------------------------------- + +.. automodule:: modelscope.models.cv.cartoon.facelib.LK.lk + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.models.cv.cartoon.facelib.rst b/docs/source/api/modelscope.models.cv.cartoon.facelib.rst new file mode 100644 index 00000000..a81536b0 --- /dev/null +++ b/docs/source/api/modelscope.models.cv.cartoon.facelib.rst @@ -0,0 +1,50 @@ +modelscope.models.cv.cartoon.facelib package +============================================ + +.. automodule:: modelscope.models.cv.cartoon.facelib + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.models.cv.cartoon.facelib.LK + +Submodules +---------- + +modelscope.models.cv.cartoon.facelib.config module +-------------------------------------------------- + +.. automodule:: modelscope.models.cv.cartoon.facelib.config + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.cv.cartoon.facelib.face\_detector module +---------------------------------------------------------- + +.. automodule:: modelscope.models.cv.cartoon.facelib.face_detector + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.cv.cartoon.facelib.face\_landmark module +---------------------------------------------------------- + +.. automodule:: modelscope.models.cv.cartoon.facelib.face_landmark + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.cv.cartoon.facelib.facer module +------------------------------------------------- + +.. automodule:: modelscope.models.cv.cartoon.facelib.facer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.models.cv.cartoon.mtcnn_pytorch.rst b/docs/source/api/modelscope.models.cv.cartoon.mtcnn_pytorch.rst new file mode 100644 index 00000000..b5845af7 --- /dev/null +++ b/docs/source/api/modelscope.models.cv.cartoon.mtcnn_pytorch.rst @@ -0,0 +1,15 @@ +modelscope.models.cv.cartoon.mtcnn\_pytorch package +=================================================== + +.. automodule:: modelscope.models.cv.cartoon.mtcnn_pytorch + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.models.cv.cartoon.mtcnn_pytorch.src diff --git a/docs/source/api/modelscope.models.cv.cartoon.mtcnn_pytorch.src.rst b/docs/source/api/modelscope.models.cv.cartoon.mtcnn_pytorch.src.rst new file mode 100644 index 00000000..715cc292 --- /dev/null +++ b/docs/source/api/modelscope.models.cv.cartoon.mtcnn_pytorch.src.rst @@ -0,0 +1,26 @@ +modelscope.models.cv.cartoon.mtcnn\_pytorch.src package +======================================================= + +.. automodule:: modelscope.models.cv.cartoon.mtcnn_pytorch.src + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.models.cv.cartoon.mtcnn\_pytorch.src.align\_trans module +------------------------------------------------------------------- + +.. automodule:: modelscope.models.cv.cartoon.mtcnn_pytorch.src.align_trans + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.cv.cartoon.mtcnn\_pytorch.src.matlab\_cp2tform module +----------------------------------------------------------------------- + +.. automodule:: modelscope.models.cv.cartoon.mtcnn_pytorch.src.matlab_cp2tform + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.models.cv.cartoon.rst b/docs/source/api/modelscope.models.cv.cartoon.rst new file mode 100644 index 00000000..5a262e03 --- /dev/null +++ b/docs/source/api/modelscope.models.cv.cartoon.rst @@ -0,0 +1,27 @@ +modelscope.models.cv.cartoon package +==================================== + +.. automodule:: modelscope.models.cv.cartoon + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.models.cv.cartoon.facelib + modelscope.models.cv.cartoon.mtcnn_pytorch + +Submodules +---------- + +modelscope.models.cv.cartoon.utils module +----------------------------------------- + +.. automodule:: modelscope.models.cv.cartoon.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.models.cv.rst b/docs/source/api/modelscope.models.cv.rst new file mode 100644 index 00000000..47ce3916 --- /dev/null +++ b/docs/source/api/modelscope.models.cv.rst @@ -0,0 +1,15 @@ +modelscope.models.cv package +============================ + +.. automodule:: modelscope.models.cv + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.models.cv.cartoon diff --git a/docs/source/api/modelscope.models.nlp.rst b/docs/source/api/modelscope.models.nlp.rst new file mode 100644 index 00000000..6cc411d4 --- /dev/null +++ b/docs/source/api/modelscope.models.nlp.rst @@ -0,0 +1,90 @@ +modelscope.models.nlp package +============================= + +.. automodule:: modelscope.models.nlp + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.models.nlp.bert\_for\_sequence\_classification module +------------------------------------------------------------ + +.. automodule:: modelscope.models.nlp.bert_for_sequence_classification + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.nlp.palm\_for\_text\_generation module +---------------------------------------------------- + +.. automodule:: modelscope.models.nlp.palm_for_text_generation + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.nlp.csanmt\_for\_translation module +---------------------------------------------------- + +.. automodule:: modelscope.models.nlp.palm_for_text_generation + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.nlp.masked\_language module +---------------------------------------------------- + +.. automodule:: modelscope.models.nlp.masked_language + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.nlp.sbert\_for\_nil module +---------------------------------------------------- + +.. automodule:: modelscope.models.nlp.sbert_for_nil + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.nlp.sbert\_for\_sentence\_similarity module +---------------------------------------------------- + +.. automodule:: modelscope.models.nlp.sbert_for_sentence_similarity + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.nlp.sbert\_for\_sentiment\_classification module +---------------------------------------------------- + +.. automodule:: modelscope.models.nlp.sbert_for_sentiment_classification + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.nlp.sbert\_for\_sequence\_classification module +---------------------------------------------------- + +.. automodule:: modelscope.models.nlp.sbert_for_sequence_classification + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.nlp.sbert\_for\_token\_classification module +---------------------------------------------------- + +.. automodule:: modelscope.models.nlp.sbert_for_token_classification + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.nlp.sbert\_for\_zero\_shot\_classification module +---------------------------------------------------- + +.. automodule:: modelscope.models.nlp.sbert_for_zero_shot_classification + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.models.rst b/docs/source/api/modelscope.models.rst new file mode 100644 index 00000000..2eaa1a6b --- /dev/null +++ b/docs/source/api/modelscope.models.rst @@ -0,0 +1,37 @@ +modelscope.models package +========================= + +.. automodule:: modelscope.models + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.models.cv + modelscope.models.nlp + modelscope.models.multi_modal + modelscope.models.audio + +Submodules +---------- + +modelscope.models.base module +----------------------------- + +.. automodule:: modelscope.models.base + :members: + :undoc-members: + :show-inheritance: + +modelscope.models.builder module +-------------------------------- + +.. automodule:: modelscope.models.builder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.msdatasets.rst b/docs/source/api/modelscope.msdatasets.rst new file mode 100644 index 00000000..53b858a8 --- /dev/null +++ b/docs/source/api/modelscope.msdatasets.rst @@ -0,0 +1,18 @@ +modelscope.msdatasets package +============================= + +.. automodule:: modelscope.msdatasets + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.msdatasets.ms\_dataset module +---------------------------------------- + +.. automodule:: modelscope.msdatasets.ms_dataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.pipelines.audio.rst b/docs/source/api/modelscope.pipelines.audio.rst new file mode 100644 index 00000000..f162893f --- /dev/null +++ b/docs/source/api/modelscope.pipelines.audio.rst @@ -0,0 +1,7 @@ +modelscope.pipelines.audio package +================================== + +.. automodule:: modelscope.pipelines.audio + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.pipelines.cv.rst b/docs/source/api/modelscope.pipelines.cv.rst new file mode 100644 index 00000000..3f2da3f4 --- /dev/null +++ b/docs/source/api/modelscope.pipelines.cv.rst @@ -0,0 +1,26 @@ +modelscope.pipelines.cv package +=============================== + +.. automodule:: modelscope.pipelines.cv + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.pipelines.cv.image\_cartoon\_pipeline module +------------------------------------------------------- + +.. automodule:: modelscope.pipelines.cv.image_cartoon_pipeline + :members: + :undoc-members: + :show-inheritance: + +modelscope.pipelines.cv.image\_matting\_pipeline module +------------------------------------------------------- + +.. automodule:: modelscope.pipelines.cv.image_matting_pipeline + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.pipelines.multi_modal.rst b/docs/source/api/modelscope.pipelines.multi_modal.rst new file mode 100644 index 00000000..4bc3982f --- /dev/null +++ b/docs/source/api/modelscope.pipelines.multi_modal.rst @@ -0,0 +1,42 @@ +modelscope.pipelines.multi\_modal package +========================================= + +.. automodule:: modelscope.pipelines.multi_modal + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.pipelines.multi\_modal.image\_captioning\_pipeline module +---------------------------------------------------------- + +.. automodule:: modelscope.pipelines.multi_modal.image_captioning_pipeline + :members: + :undoc-members: + :show-inheritance: + +modelscope.pipelines.multi\_modal.multi\_modal\_embedding\_pipeline module +---------------------------------------------------------- + +.. automodule:: modelscope.pipelines.multi_modal.multi_modal_embedding_pipeline + :members: + :undoc-members: + :show-inheritance: + +modelscope.pipelines.multi\_modal.text\_to\_image\_synthesis\_pipeline module +---------------------------------------------------------- + +.. automodule:: modelscope.pipelines.multi_modal.text_to_image_synthesis_pipeline + :members: + :undoc-members: + :show-inheritance: + +modelscope.pipelines.multi\_modal.visual\_question\_answering\_pipeline module +---------------------------------------------------------- + +.. automodule:: modelscope.pipelines.multi_modal.visual_question_answering_pipeline + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.pipelines.nlp.rst b/docs/source/api/modelscope.pipelines.nlp.rst new file mode 100644 index 00000000..836d914f --- /dev/null +++ b/docs/source/api/modelscope.pipelines.nlp.rst @@ -0,0 +1,26 @@ +modelscope.pipelines.nlp package +================================ + +.. automodule:: modelscope.pipelines.nlp + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.pipelines.nlp.sequence\_classification\_pipeline module +------------------------------------------------------------------ + +.. automodule:: modelscope.pipelines.nlp.sequence_classification_pipeline + :members: + :undoc-members: + :show-inheritance: + +modelscope.pipelines.nlp.text\_generation\_pipeline module +---------------------------------------------------------- + +.. automodule:: modelscope.pipelines.nlp.text_generation_pipeline + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.pipelines.rst b/docs/source/api/modelscope.pipelines.rst new file mode 100644 index 00000000..e56a9a87 --- /dev/null +++ b/docs/source/api/modelscope.pipelines.rst @@ -0,0 +1,53 @@ +modelscope.pipelines package +============================ + +.. automodule:: modelscope.pipelines + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.pipelines.cv + modelscope.pipelines.nlp + modelscope.pipelines.multi_modal + modelscope.pipelines.audio + +Submodules +---------- + +modelscope.pipelines.builder module +----------------------------------- + +.. automodule:: modelscope.pipelines.builder + :members: + :undoc-members: + :show-inheritance: + +modelscope.pipelines.base module +----------------------------------- + +.. automodule:: modelscope.pipelines.base + :members: + :undoc-members: + :show-inheritance: + +modelscope.outputs module +----------------------------------- + +.. automodule:: modelscope.outputs + :members: + :undoc-members: + :show-inheritance: + +modelscope.pipelines.util module +-------------------------------- + +.. automodule:: modelscope.pipelines.util + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.preprocessors.rst b/docs/source/api/modelscope.preprocessors.rst new file mode 100644 index 00000000..b555198d --- /dev/null +++ b/docs/source/api/modelscope.preprocessors.rst @@ -0,0 +1,50 @@ +modelscope.preprocessors package +================================ + +.. automodule:: modelscope.preprocessors + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.preprocessors.base module +------------------------------------ + +.. automodule:: modelscope.preprocessors.base + :members: + :undoc-members: + :show-inheritance: + +modelscope.preprocessors.builder module +--------------------------------------- + +.. automodule:: modelscope.preprocessors.builder + :members: + :undoc-members: + :show-inheritance: + +modelscope.preprocessors.common module +-------------------------------------- + +.. automodule:: modelscope.preprocessors.common + :members: + :undoc-members: + :show-inheritance: + +modelscope.preprocessors.image module +------------------------------------- + +.. automodule:: modelscope.preprocessors.image + :members: + :undoc-members: + :show-inheritance: + +modelscope.preprocessors.nlp module +----------------------------------- + +.. automodule:: modelscope.preprocessors.nlp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.rst b/docs/source/api/modelscope.rst new file mode 100644 index 00000000..d38654a4 --- /dev/null +++ b/docs/source/api/modelscope.rst @@ -0,0 +1,33 @@ +modelscope package +================== + +.. automodule:: modelscope + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.fileio + modelscope.models + modelscope.pipelines + modelscope.preprocessors + modelscope.msdatasets + modelscope.trainers + modelscope.utils + modelscope.hub + +Submodules +---------- + +modelscope.version module +------------------------- + +.. automodule:: modelscope.version + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.trainers.nlp.rst b/docs/source/api/modelscope.trainers.nlp.rst new file mode 100644 index 00000000..4bc2f875 --- /dev/null +++ b/docs/source/api/modelscope.trainers.nlp.rst @@ -0,0 +1,18 @@ +modelscope.trainers.nlp package +=============================== + +.. automodule:: modelscope.trainers.nlp + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.trainers.nlp.sequence\_classification\_trainer module +---------------------------------------------------------------- + +.. automodule:: modelscope.trainers.nlp.sequence_classification_trainer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.trainers.rst b/docs/source/api/modelscope.trainers.rst new file mode 100644 index 00000000..aac4fb99 --- /dev/null +++ b/docs/source/api/modelscope.trainers.rst @@ -0,0 +1,34 @@ +modelscope.trainers package +=========================== + +.. automodule:: modelscope.trainers + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + modelscope.trainers.nlp + +Submodules +---------- + +modelscope.trainers.base module +------------------------------- + +.. automodule:: modelscope.trainers.base + :members: + :undoc-members: + :show-inheritance: + +modelscope.trainers.builder module +---------------------------------- + +.. automodule:: modelscope.trainers.builder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/modelscope.utils.rst b/docs/source/api/modelscope.utils.rst new file mode 100644 index 00000000..3d705cfb --- /dev/null +++ b/docs/source/api/modelscope.utils.rst @@ -0,0 +1,58 @@ +modelscope.utils package +======================== + +.. automodule:: modelscope.utils + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +modelscope.utils.config module +------------------------------ + +.. automodule:: modelscope.utils.config + :members: + :undoc-members: + :show-inheritance: + +modelscope.utils.constant module +-------------------------------- + +.. automodule:: modelscope.utils.constant + :members: + :undoc-members: + :show-inheritance: + +modelscope.utils.hub module +--------------------------- + +.. automodule:: modelscope.utils.hub + :members: + :undoc-members: + :show-inheritance: + +modelscope.utils.logger module +------------------------------ + +.. automodule:: modelscope.utils.logger + :members: + :undoc-members: + :show-inheritance: + +modelscope.utils.registry module +-------------------------------- + +.. automodule:: modelscope.utils.registry + :members: + :undoc-members: + :show-inheritance: + +modelscope.utils.type\_assert module +------------------------------------ + +.. automodule:: modelscope.utils.type_assert + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/change_log.md b/docs/source/change_log.md new file mode 100644 index 00000000..1081c148 --- /dev/null +++ b/docs/source/change_log.md @@ -0,0 +1,70 @@ +## v 0.2.2 (05/07/2022) +Second internal release. + +### Highlights + +### Algorithms +#### CV +* add cv-person-image-cartoon pipeline +* add action recognition pipeline +* add ocr detection pipeline +* add animal recognition model +* add cmdssl video embedding extraction pipeline + +#### NLP +* add speech AEC pipeline +* add palm2.0 +* add space model +* add MPLUG model +* add dialog_intent, dialog_modeling, dialog state tracking pipleline +* add maskedlm model and fill_mask pipeline +* add nli pipeline +* add sentence similarity pipeline +* add sentiment_classification pipeline +* add text generation pipeline +* add translation pipeline +* add chinese word segmentation pipeline +* add zero-shot classification + +#### Audio +* add tts pipeline +* add kws kwsbp pipline +* add linear aec pipeline +* add ans pipeline + +#### Multi-Modal +* add image captioning pipeline +* add multi-modal feature extraction pipeline +* add text to image synthesis pipeline +* add VQA pipeline + +### Framework +* add msdataset interface +* add hub interface and cache support +* support multiple models in single pipeline +* add default model configuration for each pipeline +* remove task field image and video, using cv instead +* dockerfile support +* multi-level tests support +* sphinx-docs use book theme +* formalize the output of pipeline and make pipeline reusable +* pipeline refactor and standardize module_name +* self-host repo support + +### Bug Fix +* support kwargs in pipeline +* fix errors in task name definition + +## v 0.1.0 (20/05/2022) + +First internal release for pipeline inference + +* provide basic modules including fileio, logging +* config file parser +* module registry and build, which support group management +* add modules including preprocessor, model and pipeline +* image loading and nlp tokenize support in preprocessor +* add two pipeline: image-matting pipeline and text-classification pipeline +* add task constants according to PRD +* citest support +* makefile and scripts which support packaging whl, build docs, unittest diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 00000000..39e0d881 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,104 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +import sphinx_book_theme + +sys.path.insert(0, os.path.abspath('../../')) +# -- Project information ----------------------------------------------------- + +project = 'modelscope' +copyright = '2022-2023, Alibaba ModelScope' +author = 'modelscope Authors' +version_file = '../../modelscope/version.py' + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +# The full version, including alpha/beta/rc tags +version = get_version() +release = version + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'myst_parser', + 'sphinx_markdown_tables', + 'sphinx_copybutton', +] + +autodoc_mock_imports = [ + 'matplotlib', 'pycocotools', 'terminaltables', 'mmcv.ops' +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = { + '.rst': 'restructuredtext', + '.md': 'markdown', +} + +# The master toctree document. +master_doc = 'index' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['build', 'Thumbs.db', '.DS_Store'] + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_book_theme' +html_theme_path = [sphinx_book_theme.get_html_theme_path()] +html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] +# html_css_files = ['css/readthedocs.css'] + +# -- Options for HTMLHelp output --------------------------------------------- +# Output file base name for HTML help builder. +htmlhelp_basename = 'modelscope_doc' + +# -- Extension configuration ------------------------------------------------- +# Ignore >>> when copying code +copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_is_regexp = True + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = {'https://docs.python.org/': None} + +autodoc_default_options = { + 'member-order': 'bysource', + 'special-members': '__init__', +} diff --git a/docs/source/develop.md b/docs/source/develop.md new file mode 100644 index 00000000..62801353 --- /dev/null +++ b/docs/source/develop.md @@ -0,0 +1,167 @@ +# Develop + +## 1. Code Style +We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style. + +We use the following toolsseed isortseed isortseed isort for linting and formatting: +- [flake8](http://flake8.pycqa.org/en/latest/): linter +- [yapf](https://github.com/google/yapf): formatter +- [isort](https://github.com/timothycrosley/isort): sort imports + +Style configurations of yapf and isort can be found in [setup.cfg](../../setup.cfg). +We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `seed-isort-config`, `isort`, `trailing whitespaces`, +fixes `end-of-files`, sorts `requirments.txt` automatically on every commit. +The config for a pre-commit hook is stored in [.pre-commit-config](../../.pre-commit-config.yaml). +After you clone the repository, you will need to install initialize pre-commit hook. +```bash +pip install -r requirements/tests.txt +``` +From the repository folder +```bash +pre-commit install +``` + +After this on every commit check code linters and formatter will be enforced. + +If you want to use pre-commit to check all the files, you can run +```bash +pre-commit run --all-files +``` + +If you only want to format and lint your code, you can run +```bash +make linter +``` + +## 2. Test + +### 2.1 Test level + +There are mainly three test levels: + +* level 0: tests for basic interface and function of framework, such as `tests/trainers/test_trainer_base.py` +* level 1: important functional test which test end2end workflow, such as `tests/pipelines/test_image_matting.py` +* level 2: scenario tests for all the implemented modules such as model, pipeline in different algorithm filed. + +Default test level is 0, which will only run those cases of level 0, you can set test level +via environment variable `TEST_LEVEL`. + + +```bash +# run all tests +TEST_LEVEL=2 make test + +# run important functional tests +TEST_LEVEL=1 make test + +# run core UT and basic functional tests +make test +``` + +When writing test cases, you should assign a test level for your test case using +following code. If left default, the test level will be 0, it will run in each +test stage. + +File test_module.py +```python +from modelscope.utils.test_utils import test_level + +class ImageCartoonTest(unittest.TestCase): + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_by_direct_model_download(self): + pass +``` + +### 2.2 Run tests + +1. Run your own single test case to test your self-implemented function. You can run your +test file directly, if it fails to run, pls check if variable `TEST_LEVEL` +exists in the environment and unset it. +```bash +python tests/path/to/your_test.py +``` + +2. Remember to run core tests in local environment before start a codereview, by default it will +only run test cases with level 0. +```bash +make tests +``` + +3. After you start a code review, ci tests will be triggered which will run test cases with level 1 + +4. Daily regression tests will run all cases at 0 am each day using master branch. + +### 2.3 Test data storage + +As we need a lot of data for testing, including images, videos, models. We use git lfs +to store those large files. + +1. install git-lfs(version>=2.5.0) +for mac +```bash +brew install git-lfs +git lfs install +``` + +for centos, please download rpm from git-lfs github release [website](https://github.com/git-lfs/git-lfs/releases/tag/v3.2.0) +```bash +wget http://101374-public.oss-cn-hangzhou-zmf.aliyuncs.com/git-lfs-3.2.0-1.el7.x86_64.rpm +sudo rpm -ivh git-lfs-3.2.0-1.el7.x86_64.rpm +git lfs install +``` + +for ubuntu +```bash +curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash +sudo apt-get install git-lfs +git lfs install +``` + +2. track your data type using git lfs, for example, to track png files +```bash +git lfs track "*.png" +``` + +3. add your test files to `data/test/` folder, you can make directories if you need. +```bash +git add data/test/test.png +``` + +4. commit your test data to remote branch +```bash +git commit -m "xxx" +``` + +To pull data from remote repo, just as the same way you pull git files. +```bash +git pull origin branch_name +``` + + + + +## Development and Code Review +1. Get the latest master code and checkout a new branch for local development. + ```shell + git pull origin master --rebase + git checout -b dev/my-dev-branch + ``` + note: replace "dev/my-dev-branch" with a meaningful branch name. We recommend using a new dev branch for every change. +2. Make your local changes. +3. Commit your local changes. + ```shell + git add . + git commit -m "[to #42322933] my commit message" + ``` + note: you may replace [to #42322933] with your own aone issue id (if any). +4. Push your change: + ```shell + git push --set-upstream origin dev/my-dev-branch + ``` + Note that you may push multiple times to the same branch with 'git push' commands later. +5. Create a pull request on github to merge your code into master. + +## Build pip package +```bash +make whl +``` diff --git a/docs/source/faq.md b/docs/source/faq.md new file mode 100644 index 00000000..f4881c5e --- /dev/null +++ b/docs/source/faq.md @@ -0,0 +1,48 @@ +# 常见问题 + + + +### 1. macOS环境pip方式安装tokenizers报错 + +对于tokenizers库, pypi上缺乏针对`macOS`环境预编译包,需要搭建源码编译环境后才能正确安装,步骤如下: + +1. 安装rust + ```shell + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + pip install setuptools_rust + + ``` + +2. 更新rust环境变量 + + ```shell + source $HOME/.cargo/env + ``` +3. 安装tokenziers + ```shell + pip install tokenziers + ``` +reference: [https://huggingface.co/docs/tokenizers/installation#installation-from-sources](https://huggingface.co/docs/tokenizers/installation#installation-from-sources) + +### 2. pip 安装包冲突 + +> ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. + +由于依赖库之间的版本不兼容,可能会存在版本冲突的情况,大部分情况下不影响正常运行。 + +### 3. 安装pytorch出现版本错误 + +> ERROR: Ignored the following versions that require a different python version: 1.1.0 Requires-Python >=3.8; 1.1.0rc1 Requires-Python >=3.8; 1.1.1 Requires-Python >=3.8 +> ERROR: Could not find a version that satisfies the requirement torch==1.8.1+cu111 (from versions: 1.0.0, 1.0.1, 1.0.1.post2, 1.1.0, 1.2.0, 1.3.0, 1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0) +> ERROR: No matching distribution found for torch==1.8.1+cu111 + +安装时使用如下命令: + +```shell +pip install -f https://download.pytorch.org/whl/torch_stable.html -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt +``` +### 4. zsh: no matches found: modelscope-0.2.2-py3-none-any.whl[all] +mac终端的zsh 对于[]需要做转义,执行如下命令 +```shell +pip install modelscope\[all\] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 00000000..aba54341 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,51 @@ +.. modelscope documentation file, + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +ModelScope DOCUMENTATION +======================================= + +ModelScope doc + +.. toctree:: + :maxdepth: 2 + :caption: USER GUIDE + + quick_start.md + develop.md + faq.md + +.. toctree:: + :maxdepth: 2 + :caption: Tutorials + + tutorials/index + + + +.. toctree:: + :maxdepth: 2 + :caption: Changelog + + change_log.md + +.. toctree:: +.. :maxdepth: 10 +.. :caption: API Doc + +.. api/modelscope.preprocessors +.. api/modelscope.models +.. api/modelscope.pipelines +.. api/modelscope.fileio +.. api/modelscope.utils +.. api/modelscope.hub +.. api/modelscope.msdatasets +.. api/modelscope.tools +.. api/modelscope.trainers + + +Indices and tables +================== +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md new file mode 100644 index 00000000..7cefa048 --- /dev/null +++ b/docs/source/quick_start.md @@ -0,0 +1,118 @@ +ModelScope Library目前支持tensorflow,pytorch深度学习框架进行模型训练、推理, 在Python 3.7+, Pytorch 1.8+, Tensorflow1.15,Tensorflow 2.x上测试可运行。 + +**注: **`**语音相关**`**的功能仅支持 python3.7,tensorflow1.15的**`**linux**`**环境使用。 其他功能可以在windows、mac上安装使用。** + +## python环境配置 + +首先,参考[文档](https://docs.anaconda.com/anaconda/install/) 安装配置Anaconda环境。 +安装完成后,执行如下命令为modelscope library创建对应的python环境。 + +```shell +conda create -n modelscope python=3.7 +conda activate modelscope +``` + +## 安装深度学习框架 + +- 安装pytorch[参考链接](https://pytorch.org/get-started/locally/)。 + +```shell +pip3 install torch torchvision torchaudio +``` + +- 安装Tensorflow[参考链接](https://www.tensorflow.org/install/pip)。 + +```shell +pip install --upgrade tensorflow +``` + +## ModelScope library 安装 + +注: 如果在安装过程中遇到错误,请前往[常见问题](faq.md)查找解决方案。 + +### pip安装 +执行如下命令可以安装所有领域依赖: +```shell +pip install "modelscope[cv,nlp,audio,multi-modal]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +如仅需体验`语音功能`,请执行如下命令: +```shell +pip install "modelscope[audio]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +如仅需体验CV功能,可执行如下命令安装依赖: +```shell +pip install "modelscope[cv]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +如仅需体验NLP功能,可执行如下命令安装依赖: +```shell +pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +如仅需体验多模态功能,可执行如下命令安装依赖: +```shell +pip install "modelscope[multi-modal]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` +**注**: + +1. `**语音相关**`**的功能仅支持 python3.7,tensorflow1.15的**`**linux**`**环境使用。 其他功能可以在windows、mac上安装使用。** + +2. 语音领域中一部分模型使用了三方库SoundFile进行wav文件处理,**在Linux系统上用户需要手动安装SoundFile的底层依赖库libsndfile**,在Windows和MacOS上会自动安装不需要用户操作。详细信息可参考[SoundFile官网](https://github.com/bastibe/python-soundfile#installation)。以Ubuntu系统为>例,用户需要执行如下命令: + + ```shell + sudo apt-get update + sudo apt-get install libsndfile1 + ``` + +3. **CV功能使用需要安装mmcv-full, 请参考mmcv**[**安装手册**](https://github.com/open-mmlab/mmcv#installation)**进行安装** + +### 使用源码安装 + +适合本地开发调试使用,修改源码后可以直接执行。 +ModelScope的源码可以直接clone到本地: + +```shell +git clone git@github.com:modelscope/modelscope.git +cd modelscope +git fetch origin master +git checkout master + +``` + + +安装依赖 +如需安装所有依赖,请执行如下命令 +```shell +pip install -e ".[cv,nlp,audio,multi-modal]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + + + +如需体验`语音功能`,请单独执行如下命令: +```shell +pip install -e ".[audio]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +如仅需体验CV功能,可执行如下命令安装依赖: +```shell +pip install -e ".[cv]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` +如仅需体验NLP功能,可执行如下命令安装依赖: +```shell +pip install -e ".[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +如仅需体验多模态功能,可执行如下命令安装依赖: +```shell +pip install -e ".[multi-modal]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +``` + +### 安装验证 + +安装成功后,可以执行如下命令进行验证安装是否正确: + +```shell +python -c "from modelscope.pipelines import pipeline;print(pipeline('word-segmentation')('今天天气不错,适合 出去游玩'))" +``` diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst new file mode 100644 index 00000000..9d8528c2 --- /dev/null +++ b/docs/source/tutorials/index.rst @@ -0,0 +1,6 @@ +.. toctree:: + :maxdepth: 2 + :caption: Tutorials + + pipeline.md + trainer.md diff --git a/docs/source/tutorials/pipeline.md b/docs/source/tutorials/pipeline.md new file mode 100644 index 00000000..ebdc06f3 --- /dev/null +++ b/docs/source/tutorials/pipeline.md @@ -0,0 +1,61 @@ +# Pipeline使用教程 +本文简单介绍如何使用`pipeline`函数加载模型进行推理。`pipeline`函数支持按照任务类型、模型名称从模型仓库拉取模型进行进行推理,包含以下几个方面: +* 使用pipeline()函数进行推理 +* 指定特定预处理、特定模型进行推理 +* 不同场景推理任务示例 +## 环境准备 +详细步骤可以参考 [快速开始](../quick_start.md) +## Pipeline基本用法 +下面以中文分词任务为例,说明pipeline函数的基本用法 + +1. pipeline函数支持指定特定任务名称,加载任务默认模型,创建对应pipeline对象 + 执行如下python代码 + ```python + from modelscope.pipelines import pipeline + word_segmentation = pipeline('word-segmentation') + ``` + +2. 输入文本 + ``` python + input = '今天天气不错,适合出去游玩' + print(word_segmentation(input)) + {'output': '今天 天气 不错 , 适合 出去 游玩'} + ``` + +3. 输入多条样本 + +pipeline对象也支持传入多个样本列表输入,返回对应输出列表,每个元素对应输入样本的返回结果 + + ```python + inputs = ['今天天气不错,适合出去游玩','这本书很好,建议你看看'] + print(word_segmentation(inputs)) + [{'output': '今天 天气 不错 , 适合 出去 游玩'}, {'output': '这 本 书 很 好 , 建议 你 看看'}] + ``` +## 指定预处理、模型进行推理 +pipeline函数支持传入实例化的预处理对象、模型对象,从而支持用户在推理过程中定制化预处理、模型。 + +1. 首先,创建预处理方法和模型 +```python +from modelscope.models import Model +from modelscope.preprocessors import TokenClassificationPreprocessor +model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') +tokenizer = TokenClassificationPreprocessor(model.model_dir) +``` + +2. 使用tokenizer和模型对象创建pipeline +```python +from modelscope.pipelines import pipeline +word_seg = pipeline('word-segmentation', model=model, preprocessor=tokenizer) +input = '今天天气不错,适合出去游玩' +print(word_seg(input)) +{'output': '今天 天气 不错 , 适合 出去 游玩'} +``` +## 不同场景任务推理示例 +下面以一个图像任务:人像抠图('image-matting')为例,进一步说明pipeline的用法 +```python +import cv2 +from modelscope.pipelines import pipeline +img_matting = pipeline('image-matting') +result = img_matting('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matting.png') +cv2.imwrite('result.png', result['output_png']) +``` diff --git a/docs/source/tutorials/trainer.md b/docs/source/tutorials/trainer.md new file mode 100644 index 00000000..1dfdb9cf --- /dev/null +++ b/docs/source/tutorials/trainer.md @@ -0,0 +1,54 @@ +# Trainer使用教程 +Modelscope提供了众多预训练模型,你可以使用其中任意一个,利用公开数据集或者私有数据集针对特定任务进行模型训练,在本篇文章中将介绍如何使用Modelscope的`Trainer`模块进行Finetuning和评估。 + +## 环境准备 +详细步骤可以参考 [快速开始](../quick_start.md) + +### 准备数据集 + +在开始Finetuning前,需要准备一个数据集用以训练和评估,详细可以参考数据集使用教程。 + +```python +from datasets import Dataset +train_dataset = MsDataset.load'afqmc_small', namespace='modelscope', split='train') +eval_dataset = MsDataset.load('afqmc_small', namespace='modelscope', split='validation') +``` +### 训练 +ModelScope把所有训练相关的配置信息全部放到了模型仓库下的`configuration.json`中,因此我们只需要创建Trainer,加载配置文件,传入数据集即可完成训练。 + +首先,通过工厂方法创建Trainer, 需要传入模型仓库路径, 训练数据集对象,评估数据集对象,训练目录 +```python +kwargs = dict( + model='damo/nlp_structbert_sentiment-classification_chinese-base', + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir='work_dir') + +trainer = build_trainer(default_args=kwargs) +``` + +启动训练。 +```python +trainer.train() +``` + +如果需要调整训练参数,可以在模型仓库页面下载`configuration.json`文件到本地,修改参数后,指定配置文件路径,创建trainer +```python +kwargs = dict( + model='damo/nlp_structbert_sentiment-classification_chinese-base', + train_dataset=train_dataset, + eval_dataset=eval_dataset, + cfg_file='你的配置文件路径' + work_dir='work_dir') + +trainer = build_trainer(default_args=kwargs) +trainer.train() +``` + + +### 评估 +训练过程中会定期使用验证集进行评估测试, Trainer模块也支持指定特定轮次保存的checkpoint路径,进行单次评估。 +```python +eval_results = trainer.evaluate('work_dir/epoch_10.pth') +print(eval_results) +``` diff --git a/modelscope/__init__.py b/modelscope/__init__.py new file mode 100644 index 00000000..81fdf505 --- /dev/null +++ b/modelscope/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .version import __release_datetime__, __version__ + +__all__ = ['__version__', '__release_datetime__'] diff --git a/modelscope/exporters/__init__.py b/modelscope/exporters/__init__.py new file mode 100644 index 00000000..a597114f --- /dev/null +++ b/modelscope/exporters/__init__.py @@ -0,0 +1,4 @@ +from .base import Exporter +from .builder import build_exporter +from .nlp import SbertForSequenceClassificationExporter +from .torch_model_exporter import TorchModelExporter diff --git a/modelscope/exporters/base.py b/modelscope/exporters/base.py new file mode 100644 index 00000000..c8b7900e --- /dev/null +++ b/modelscope/exporters/base.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from abc import ABC, abstractmethod + +from modelscope.models import Model +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from .builder import build_exporter + + +class Exporter(ABC): + """Exporter base class to output model to onnx, torch_script, graphdef, etc. + """ + + def __init__(self): + self.model = None + + @classmethod + def from_model(cls, model: Model, **kwargs): + """Build the Exporter instance. + + Args: + model: A Model instance. it will be used to generate the intermediate format file, + and the configuration.json in its model_dir field will be used to create the exporter instance. + kwargs: Extra kwargs used to create the Exporter instance. + + Returns: + The Exporter instance + """ + cfg = Config.from_file( + os.path.join(model.model_dir, ModelFile.CONFIGURATION)) + task_name = cfg.task + model_cfg = cfg.model + if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): + model_cfg.type = model_cfg.model_type + export_cfg = ConfigDict({'type': model_cfg.type}) + if hasattr(cfg, 'export'): + export_cfg.update(cfg.export) + exporter = build_exporter(export_cfg, task_name, kwargs) + exporter.model = model + return exporter + + @abstractmethod + def export_onnx(self, outputs: str, opset=11, **kwargs): + """Export the model as onnx format files. + + In some cases, several files may be generated, + So please return a dict which contains the generated name with the file path. + + Args: + opset: The version of the ONNX operator set to use. + outputs: The output dir. + kwargs: In this default implementation, + kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). + + Returns: + A dict contains the model name with the model file path. + """ + pass diff --git a/modelscope/exporters/builder.py b/modelscope/exporters/builder.py new file mode 100644 index 00000000..90699c12 --- /dev/null +++ b/modelscope/exporters/builder.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.config import ConfigDict +from modelscope.utils.registry import Registry, build_from_cfg + +EXPORTERS = Registry('exporters') + + +def build_exporter(cfg: ConfigDict, + task_name: str = None, + default_args: dict = None): + """ build exporter by the given model config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for exporter object. + task_name (str, optional): task name, refer to + :obj:`Tasks` for more details + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, EXPORTERS, group_key=task_name, default_args=default_args) diff --git a/modelscope/exporters/nlp/__init__.py b/modelscope/exporters/nlp/__init__.py new file mode 100644 index 00000000..fdfd2711 --- /dev/null +++ b/modelscope/exporters/nlp/__init__.py @@ -0,0 +1,2 @@ +from .sbert_for_sequence_classification_exporter import \ + SbertForSequenceClassificationExporter diff --git a/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py new file mode 100644 index 00000000..7cee331b --- /dev/null +++ b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py @@ -0,0 +1,86 @@ +import os +from collections import OrderedDict +from typing import Any, Dict, Mapping, Tuple + +from torch.utils.data.dataloader import default_collate + +from modelscope.exporters.builder import EXPORTERS +from modelscope.exporters.torch_model_exporter import TorchModelExporter +from modelscope.metainfo import Models +from modelscope.preprocessors import Preprocessor, build_preprocessor +from modelscope.utils.config import Config +from modelscope.utils.constant import ModeKeys, Tasks + + +@EXPORTERS.register_module( + Tasks.sentence_similarity, module_name=Models.structbert) +@EXPORTERS.register_module( + Tasks.sentiment_classification, module_name=Models.structbert) +@EXPORTERS.register_module(Tasks.nli, module_name=Models.structbert) +@EXPORTERS.register_module( + Tasks.zero_shot_classification, module_name=Models.structbert) +class SbertForSequenceClassificationExporter(TorchModelExporter): + + def generate_dummy_inputs(self, + shape: Tuple = None, + pair: bool = False, + **kwargs) -> Dict[str, Any]: + """Generate dummy inputs for model exportation to onnx or other formats by tracing. + + Args: + shape: A tuple of input shape which should have at most two dimensions. + shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. + shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. + pair(bool, `optional`): Whether to generate sentence pairs or single sentences. + + Returns: + Dummy inputs. + """ + + cfg = Config.from_file( + os.path.join(self.model.model_dir, 'configuration.json')) + field_name = Tasks.find_field_by_task(cfg.task) + if 'type' not in cfg.preprocessor and 'val' in cfg.preprocessor: + cfg = cfg.preprocessor.val + else: + cfg = cfg.preprocessor + + batch_size = 1 + sequence_length = {} + if shape is not None: + if len(shape) == 1: + batch_size = shape[0] + elif len(shape) == 2: + batch_size, max_length = shape + sequence_length = {'sequence_length': max_length} + + cfg.update({ + 'model_dir': self.model.model_dir, + 'mode': ModeKeys.TRAIN, + **sequence_length + }) + preprocessor: Preprocessor = build_preprocessor(cfg, field_name) + if pair: + first_sequence = preprocessor.tokenizer.unk_token + second_sequence = preprocessor.tokenizer.unk_token + else: + first_sequence = preprocessor.tokenizer.unk_token + second_sequence = None + + batched = [] + for _ in range(batch_size): + batched.append(preprocessor((first_sequence, second_sequence))) + return default_collate(batched) + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + dynamic_axis = {0: 'batch', 1: 'sequence'} + return OrderedDict([ + ('input_ids', dynamic_axis), + ('attention_mask', dynamic_axis), + ('token_type_ids', dynamic_axis), + ]) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict({'logits': {0: 'batch'}}) diff --git a/modelscope/exporters/torch_model_exporter.py b/modelscope/exporters/torch_model_exporter.py new file mode 100644 index 00000000..1d332591 --- /dev/null +++ b/modelscope/exporters/torch_model_exporter.py @@ -0,0 +1,345 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from contextlib import contextmanager +from itertools import chain +from typing import Any, Dict, Mapping + +import torch +from torch import nn +from torch.onnx import export as onnx_export + +from modelscope.models import TorchModel +from modelscope.outputs import ModelOutputBase +from modelscope.pipelines.base import collate_fn +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.regress_test_utils import (compare_arguments_nested, + numpify_tensor_nested) +from .base import Exporter + +logger = get_logger(__name__) + + +class TorchModelExporter(Exporter): + """The torch base class of exporter. + + This class provides the default implementations for exporting onnx and torch script. + Each specific model may implement its own exporter by overriding the export_onnx/export_torch_script, + and to provide implementations for generate_dummy_inputs/inputs/outputs methods. + """ + + def export_onnx(self, output_dir: str, opset=13, **kwargs): + """Export the model as onnx format files. + + In some cases, several files may be generated, + So please return a dict which contains the generated name with the file path. + + Args: + opset: The version of the ONNX operator set to use. + output_dir: The output dir. + kwargs: + model: A model instance which will replace the exporting of self.model. + In this default implementation, + you can pass the arguments needed by _torch_export_onnx, other unrecognized args + will be carried to generate_dummy_inputs as extra arguments (such as input shape). + + Returns: + A dict containing the model key - model file path pairs. + """ + model = self.model if 'model' not in kwargs else kwargs.pop('model') + if not isinstance(model, nn.Module) and hasattr(model, 'model'): + model = model.model + onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE) + self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs) + return {'model': onnx_file} + + def export_torch_script(self, output_dir: str, **kwargs): + """Export the model as torch script files. + + In some cases, several files may be generated, + So please return a dict which contains the generated name with the file path. + + Args: + output_dir: The output dir. + kwargs: + model: A model instance which will replace the exporting of self.model. + In this default implementation, + you can pass the arguments needed by _torch_export_torch_script, other unrecognized args + will be carried to generate_dummy_inputs as extra arguments (like input shape). + + Returns: + A dict contains the model name with the model file path. + """ + model = self.model if 'model' not in kwargs else kwargs.pop('model') + if not isinstance(model, nn.Module) and hasattr(model, 'model'): + model = model.model + ts_file = os.path.join(output_dir, ModelFile.TS_MODEL_FILE) + # generate ts by tracing + self._torch_export_torch_script(model, ts_file, **kwargs) + return {'model': ts_file} + + def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: + """Generate dummy inputs for model exportation to onnx or other formats by tracing. + + Returns: + Dummy inputs. + """ + return None + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + """Return an ordered dict contains the model's input arguments name with their dynamic axis. + + About the information of dynamic axis please check the dynamic_axes argument of torch.onnx.export function + """ + return None + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + """Return an ordered dict contains the model's output arguments name with their dynamic axis. + + About the information of dynamic axis please check the dynamic_axes argument of torch.onnx.export function + """ + return None + + @staticmethod + def _decide_input_format(model, args): + import inspect + + def _signature(model) -> inspect.Signature: + should_be_callable = getattr(model, 'forward', model) + if callable(should_be_callable): + return inspect.signature(should_be_callable) + raise ValueError('model has no forward method and is not callable') + + try: + sig = _signature(model) + except ValueError as e: + logger.warn('%s, skipping _decide_input_format' % e) + return args + try: + ordered_list_keys = list(sig.parameters.keys()) + if ordered_list_keys[0] == 'self': + ordered_list_keys = ordered_list_keys[1:] + args_dict: Dict = {} + if isinstance(args, list): + args_list = args + elif isinstance(args, tuple): + args_list = list(args) + else: + args_list = [args] + if isinstance(args_list[-1], Mapping): + args_dict = args_list[-1] + args_list = args_list[:-1] + n_nonkeyword = len(args_list) + for optional_arg in ordered_list_keys[n_nonkeyword:]: + if optional_arg in args_dict: + args_list.append(args_dict[optional_arg]) + # Check if this arg has a default value + else: + param = sig.parameters[optional_arg] + if param.default != param.empty: + args_list.append(param.default) + args = args_list if isinstance(args, list) else tuple(args_list) + # Cases of models with no input args + except IndexError: + logger.warn('No input args, skipping _decide_input_format') + except Exception as e: + logger.warn('Skipping _decide_input_format\n {}'.format(e.args[0])) + + return args + + def _torch_export_onnx(self, + model: nn.Module, + output: str, + opset: int = 13, + device: str = 'cpu', + validation: bool = True, + rtol: float = None, + atol: float = None, + **kwargs): + """Export the model to an onnx format file. + + Args: + model: A torch.nn.Module instance to export. + output: The output file. + opset: The version of the ONNX operator set to use. + device: The device used to forward. + validation: Whether validate the export file. + rtol: The rtol used to regress the outputs. + atol: The atol used to regress the outputs. + kwargs: + dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). + inputs: An inputs structure which will replace the calling of self.inputs. + outputs: An outputs structure which will replace the calling of self.outputs. + """ + + dummy_inputs = self.generate_dummy_inputs( + **kwargs) if 'dummy_inputs' not in kwargs else kwargs.pop( + 'dummy_inputs') + inputs = self.inputs if 'inputs' not in kwargs else kwargs.pop( + 'inputs') + outputs = self.outputs if 'outputs' not in kwargs else kwargs.pop( + 'outputs') + if dummy_inputs is None or inputs is None or outputs is None: + raise NotImplementedError( + 'Model property dummy_inputs,inputs,outputs must be set.') + + with torch.no_grad(): + model.eval() + device = torch.device(device) + model.to(device) + dummy_inputs = collate_fn(dummy_inputs, device) + + if isinstance(dummy_inputs, Mapping): + dummy_inputs = dict(dummy_inputs) + onnx_outputs = list(outputs.keys()) + + with replace_call(): + onnx_export( + model, + (dummy_inputs, ), + f=output, + input_names=list(inputs.keys()), + output_names=onnx_outputs, + dynamic_axes={ + name: axes + for name, axes in chain(inputs.items(), + outputs.items()) + }, + do_constant_folding=True, + opset_version=opset, + ) + + if validation: + try: + import onnx + import onnxruntime as ort + except ImportError: + logger.warn( + 'Cannot validate the exported onnx file, because ' + 'the installation of onnx or onnxruntime cannot be found') + return + onnx_model = onnx.load(output) + onnx.checker.check_model(onnx_model) + ort_session = ort.InferenceSession(output) + with torch.no_grad(): + model.eval() + outputs_origin = model.forward( + *self._decide_input_format(model, dummy_inputs)) + if isinstance(outputs_origin, (Mapping, ModelOutputBase)): + outputs_origin = list( + numpify_tensor_nested(outputs_origin).values()) + elif isinstance(outputs_origin, (tuple, list)): + outputs_origin = list(numpify_tensor_nested(outputs_origin)) + outputs = ort_session.run( + onnx_outputs, + numpify_tensor_nested(dummy_inputs), + ) + outputs = numpify_tensor_nested(outputs) + if isinstance(outputs, dict): + outputs = list(outputs.values()) + elif isinstance(outputs, tuple): + outputs = list(outputs) + + tols = {} + if rtol is not None: + tols['rtol'] = rtol + if atol is not None: + tols['atol'] = atol + if not compare_arguments_nested('Onnx model output match failed', + outputs, outputs_origin, **tols): + raise RuntimeError( + 'export onnx failed because of validation error.') + + def _torch_export_torch_script(self, + model: nn.Module, + output: str, + device: str = 'cpu', + validation: bool = True, + rtol: float = None, + atol: float = None, + strict: bool = True, + **kwargs): + """Export the model to a torch script file. + + Args: + model: A torch.nn.Module instance to export. + output: The output file. + device: The device used to forward. + validation: Whether validate the export file. + rtol: The rtol used to regress the outputs. + atol: The atol used to regress the outputs. + strict: strict mode in torch script tracing. + kwargs: + dummy_inputs: A dummy inputs which will replace the calling of self.generate_dummy_inputs(). + """ + + model.eval() + dummy_param = 'dummy_inputs' not in kwargs + dummy_inputs = self.generate_dummy_inputs( + **kwargs) if dummy_param else kwargs.pop('dummy_inputs') + if dummy_inputs is None: + raise NotImplementedError( + 'Model property dummy_inputs must be set.') + dummy_inputs = collate_fn(dummy_inputs, device) + if isinstance(dummy_inputs, Mapping): + dummy_inputs_filter = [] + for _input in self._decide_input_format(model, dummy_inputs): + if _input is not None: + dummy_inputs_filter.append(_input) + else: + break + + if len(dummy_inputs) != len(dummy_inputs_filter): + logger.warn( + f'Dummy inputs is not continuous in the forward method, ' + f'origin length: {len(dummy_inputs)}, ' + f'the length after filtering: {len(dummy_inputs_filter)}') + dummy_inputs = dummy_inputs_filter + + with torch.no_grad(): + model.eval() + with replace_call(): + traced_model = torch.jit.trace( + model, tuple(dummy_inputs), strict=strict) + torch.jit.save(traced_model, output) + + if validation: + ts_model = torch.jit.load(output) + with torch.no_grad(): + model.eval() + ts_model.eval() + outputs = ts_model.forward(*dummy_inputs) + outputs = numpify_tensor_nested(outputs) + outputs_origin = model.forward(*dummy_inputs) + outputs_origin = numpify_tensor_nested(outputs_origin) + if isinstance(outputs, dict): + outputs = list(outputs.values()) + if isinstance(outputs_origin, dict): + outputs_origin = list(outputs_origin.values()) + tols = {} + if rtol is not None: + tols['rtol'] = rtol + if atol is not None: + tols['atol'] = atol + if not compare_arguments_nested( + 'Torch script model output match failed', outputs, + outputs_origin, **tols): + raise RuntimeError( + 'export torch script failed because of validation error.') + + +@contextmanager +def replace_call(): + """This function is used to recover the original call method. + + The Model class of modelscope overrides the call method. When exporting to onnx or torchscript, torch will + prepare the parameters as the prototype of forward method, and trace the call method, this causes + problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it + back after the tracing was done. + """ + TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl + yield + TorchModel.__call__ = TorchModel.call_origin + del TorchModel.call_origin diff --git a/modelscope/fileio/__init__.py b/modelscope/fileio/__init__.py new file mode 100644 index 00000000..385cd02c --- /dev/null +++ b/modelscope/fileio/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .file import File, LocalStorage +from .io import dump, dumps, load diff --git a/modelscope/fileio/file.py b/modelscope/fileio/file.py new file mode 100644 index 00000000..3fff80c8 --- /dev/null +++ b/modelscope/fileio/file.py @@ -0,0 +1,326 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import contextlib +import os +import tempfile +from abc import ABCMeta, abstractmethod +from pathlib import Path +from typing import Generator, Union + +import requests + + +class Storage(metaclass=ABCMeta): + """Abstract class of storage. + + All backends need to implement two apis: ``read()`` and ``read_text()``. + ``read()`` reads the file as a byte stream and ``read_text()`` reads + the file as texts. + """ + + @abstractmethod + def read(self, filepath: str): + pass + + @abstractmethod + def read_text(self, filepath: str): + pass + + @abstractmethod + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + pass + + @abstractmethod + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + pass + + +class LocalStorage(Storage): + """Local hard disk storage""" + + def read(self, filepath: Union[str, Path]) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + with open(filepath, 'rb') as f: + content = f.read() + return content + + def read_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + with open(filepath, 'r', encoding=encoding) as f: + value_buf = f.read() + return value_buf + + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``write`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + dirname = os.path.dirname(filepath) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname) + with open(filepath, 'wb') as f: + f.write(obj) + + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``write_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + dirname = os.path.dirname(filepath) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname) + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + @contextlib.contextmanager + def as_local_path( + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing.""" + yield filepath + + +class HTTPStorage(Storage): + """HTTP and HTTPS storage.""" + + def read(self, url): + # TODO @wenmeng.zwm add progress bar if file is too large + r = requests.get(url) + r.raise_for_status() + return r.content + + def read_text(self, url): + r = requests.get(url) + r.raise_for_status() + return r.text + + @contextlib.contextmanager + def as_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> storage = HTTPStorage() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with storage.get_local_path('http://path/to/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.read(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def write(self, obj: bytes, url: Union[str, Path]) -> None: + raise NotImplementedError('write is not supported by HTTP Storage') + + def write_text(self, + obj: str, + url: Union[str, Path], + encoding: str = 'utf-8') -> None: + raise NotImplementedError( + 'write_text is not supported by HTTP Storage') + + +class OSSStorage(Storage): + """OSS storage.""" + + def __init__(self, oss_config_file=None): + # read from config file or env var + raise NotImplementedError( + 'OSSStorage.__init__ to be implemented in the future') + + def read(self, filepath): + raise NotImplementedError( + 'OSSStorage.read to be implemented in the future') + + def read_text(self, filepath, encoding='utf-8'): + raise NotImplementedError( + 'OSSStorage.read_text to be implemented in the future') + + @contextlib.contextmanager + def as_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> storage = OSSStorage() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with storage.get_local_path('http://path/to/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.read(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + raise NotImplementedError( + 'OSSStorage.write to be implemented in the future') + + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + raise NotImplementedError( + 'OSSStorage.write_text to be implemented in the future') + + +G_STORAGES = {} + + +class File(object): + _prefix_to_storage: dict = { + 'oss': OSSStorage, + 'http': HTTPStorage, + 'https': HTTPStorage, + 'local': LocalStorage, + } + + @staticmethod + def _get_storage(uri): + assert isinstance(uri, + str), f'uri should be str type, but got {type(uri)}' + + if '://' not in uri: + # local path + storage_type = 'local' + else: + prefix, _ = uri.split('://') + storage_type = prefix + + assert storage_type in File._prefix_to_storage, \ + f'Unsupported uri {uri}, valid prefixs: '\ + f'{list(File._prefix_to_storage.keys())}' + + if storage_type not in G_STORAGES: + G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]() + + return G_STORAGES[storage_type] + + @staticmethod + def read(uri: str) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + storage = File._get_storage(uri) + return storage.read(uri) + + @staticmethod + def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + storage = File._get_storage(uri) + return storage.read_text(uri) + + @staticmethod + def write(obj: bytes, uri: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``write`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + storage = File._get_storage(uri) + return storage.write(obj, uri) + + @staticmethod + def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``write_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + storage = File._get_storage(uri) + return storage.write_text(obj, uri) + + @contextlib.contextmanager + def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing.""" + storage = File._get_storage(uri) + with storage.as_local_path(uri) as local_path: + yield local_path diff --git a/modelscope/fileio/format/__init__.py b/modelscope/fileio/format/__init__.py new file mode 100644 index 00000000..68518266 --- /dev/null +++ b/modelscope/fileio/format/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .base import FormatHandler +from .json import JsonHandler +from .yaml import YamlHandler diff --git a/modelscope/fileio/format/base.py b/modelscope/fileio/format/base.py new file mode 100644 index 00000000..6303c3b3 --- /dev/null +++ b/modelscope/fileio/format/base.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from abc import ABCMeta, abstractmethod + + +class FormatHandler(metaclass=ABCMeta): + # if `text_format` is True, file + # should use text mode otherwise binary mode + text_mode = True + + @abstractmethod + def load(self, file, **kwargs): + pass + + @abstractmethod + def dump(self, obj, file, **kwargs): + pass + + @abstractmethod + def dumps(self, obj, **kwargs): + pass diff --git a/modelscope/fileio/format/json.py b/modelscope/fileio/format/json.py new file mode 100644 index 00000000..9979c023 --- /dev/null +++ b/modelscope/fileio/format/json.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np + +from .base import FormatHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f'{type(obj)} is unsupported for json dump') + + +class JsonHandler(FormatHandler): + """Use jsonplus, serialization of Python types to JSON that "just works".""" + + def load(self, file): + import jsonplus + return jsonplus.loads(file.read()) + + def dump(self, obj, file, **kwargs): + file.write(self.dumps(obj, **kwargs)) + + def dumps(self, obj, **kwargs): + import jsonplus + kwargs.setdefault('default', set_default) + return jsonplus.dumps(obj, **kwargs) diff --git a/modelscope/fileio/format/yaml.py b/modelscope/fileio/format/yaml.py new file mode 100644 index 00000000..783af7f3 --- /dev/null +++ b/modelscope/fileio/format/yaml.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import yaml + +try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader, Dumper # type: ignore + +from .base import FormatHandler # isort:skip + + +class YamlHandler(FormatHandler): + + def load(self, file, **kwargs): + kwargs.setdefault('Loader', Loader) + return yaml.load(file, **kwargs) + + def dump(self, obj, file, **kwargs): + kwargs.setdefault('Dumper', Dumper) + yaml.dump(obj, file, **kwargs) + + def dumps(self, obj, **kwargs): + kwargs.setdefault('Dumper', Dumper) + return yaml.dump(obj, **kwargs) diff --git a/modelscope/fileio/io.py b/modelscope/fileio/io.py new file mode 100644 index 00000000..1b23997a --- /dev/null +++ b/modelscope/fileio/io.py @@ -0,0 +1,127 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +# Copyright (c) OpenMMLab. All rights reserved. +from io import BytesIO, StringIO +from pathlib import Path + +from .file import File +from .format import JsonHandler, YamlHandler + +format_handlers = { + 'json': JsonHandler(), + 'yaml': YamlHandler(), + 'yml': YamlHandler(), +} + + +def load(file, file_format=None, **kwargs): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml". + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('oss://path/of/your/file') # file is storaged in petrel + + Returns: + The content from the file. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None and isinstance(file, str): + file_format = file.split('.')[-1] + if file_format not in format_handlers: + raise TypeError(f'Unsupported format: {file_format}') + + handler = format_handlers[file_format] + if isinstance(file, str): + if handler.text_mode: + with StringIO(File.read_text(file)) as f: + obj = handler.load(f, **kwargs) + else: + with BytesIO(File.read(file)) as f: + obj = handler.load(f, **kwargs) + elif hasattr(file, 'read'): + obj = handler.load(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def dump(obj, file=None, file_format=None, **kwargs): + """Dump data to json/yaml strings or files. + + This method provides a unified api for dumping data as strings or to files. + + Args: + obj (any): The python object to be dumped. + file (str or :obj:`Path` or file-like object, optional): If not + specified, then the object is dumped to a str, otherwise to a file + specified by the filename or file-like object. + file_format (str, optional): Same as :func:`load`. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 'oss://path/of/your/file') # oss + + Returns: + bool: True for success, False otherwise. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None: + if isinstance(file, str): + file_format = file.split('.')[-1] + elif file is None: + raise ValueError( + 'file_format must be specified since file is None') + if file_format not in format_handlers: + raise TypeError(f'Unsupported format: {file_format}') + + handler = format_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif isinstance(file, str): + if handler.text_mode: + with StringIO() as f: + handler.dump(obj, f, **kwargs) + File.write_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump(obj, f, **kwargs) + File.write(f.getvalue(), file) + elif hasattr(file, 'write'): + handler.dump(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') + + +def dumps(obj, format, **kwargs): + """Dump data to json/yaml strings or files. + + This method provides a unified api for dumping data as strings or to files. + + Args: + obj (any): The python object to be dumped. + format (str, optional): Same as file_format :func:`load`. + + Examples: + >>> dumps('hello world', 'json') # json + >>> dumps('hello world', 'yaml') # yaml + + Returns: + bool: True for success, False otherwise. + """ + if format not in format_handlers: + raise TypeError(f'Unsupported format: {format}') + + handler = format_handlers[format] + return handler.dumps(obj, **kwargs) diff --git a/modelscope/hub/__init__.py b/modelscope/hub/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py new file mode 100644 index 00000000..f2ff822d --- /dev/null +++ b/modelscope/hub/api.py @@ -0,0 +1,800 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +# yapf: disable +import datetime +import os +import pickle +import platform +import shutil +import tempfile +import uuid +from collections import defaultdict +from http import HTTPStatus +from http.cookiejar import CookieJar +from os.path import expanduser +from typing import Dict, List, Optional, Tuple, Union + +import requests + +from modelscope import __version__ +from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, + API_RESPONSE_FIELD_EMAIL, + API_RESPONSE_FIELD_GIT_ACCESS_TOKEN, + API_RESPONSE_FIELD_MESSAGE, + API_RESPONSE_FIELD_USERNAME, + DEFAULT_CREDENTIALS_PATH, + MODELSCOPE_ENVIRONMENT, + MODELSCOPE_USERNAME, ONE_YEAR_SECONDS, + Licenses, ModelVisibility) +from modelscope.hub.errors import (InvalidParameter, NotExistError, + NotLoginException, NoValidRevisionError, + RequestError, datahub_raise_on_error, + handle_http_post_error, + handle_http_response, is_ok, + raise_for_http_status, raise_on_error) +from modelscope.hub.git import GitCommandWrapper +from modelscope.hub.repository import Repository +from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, + DEFAULT_MODEL_REVISION, + DEFAULT_REPOSITORY_REVISION, + MASTER_MODEL_BRANCH, DatasetFormations, + DatasetMetaFormats, DownloadChannel, + DownloadMode, ModelFile) +from modelscope.utils.logger import get_logger +from .utils.utils import (get_endpoint, get_release_datetime, + model_id_to_group_owner_name) + +logger = get_logger() + + +class HubApi: + + def __init__(self, endpoint=None): + self.endpoint = endpoint if endpoint is not None else get_endpoint() + self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} + + def login( + self, + access_token: str, + ) -> tuple(): + """ + Login with username and password + + Args: + access_token(`str`): user access token on modelscope. + Returns: + cookies: to authenticate yourself to ModelScope open-api + gitlab token: to access private repos + + + You only have to login once within 30 days. + + """ + path = f'{self.endpoint}/api/v1/login' + r = requests.post( + path, json={'AccessToken': access_token}, headers=self.headers) + raise_for_http_status(r) + d = r.json() + raise_on_error(d) + + token = d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_GIT_ACCESS_TOKEN] + cookies = r.cookies + + # save token and cookie + ModelScopeConfig.save_token(token) + ModelScopeConfig.save_cookies(cookies) + ModelScopeConfig.save_user_info( + d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_USERNAME], + d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_EMAIL]) + + return d[API_RESPONSE_FIELD_DATA][ + API_RESPONSE_FIELD_GIT_ACCESS_TOKEN], cookies + + def create_model( + self, + model_id: str, + visibility: str, + license: str, + chinese_name: Optional[str] = None, + ) -> str: + """ + Create model repo at ModelScopeHub + + Args: + model_id:(`str`): The model id + visibility(`int`): visibility of the model(1-private, 5-public), default public. + license(`str`): license of the model, default none. + chinese_name(`str`, *optional*): chinese name of the model + Returns: + name of the model created + + + model_id = {owner}/{name} + + """ + if model_id is None: + raise InvalidParameter('model_id is required!') + cookies = ModelScopeConfig.get_cookies() + if cookies is None: + raise ValueError('Token does not exist, please login first.') + + path = f'{self.endpoint}/api/v1/models' + owner_or_group, name = model_id_to_group_owner_name(model_id) + body = { + 'Path': owner_or_group, + 'Name': name, + 'ChineseName': chinese_name, + 'Visibility': visibility, # server check + 'License': license + } + r = requests.post( + path, json=body, cookies=cookies, headers=self.headers) + handle_http_post_error(r, path, body) + raise_on_error(r.json()) + model_repo_url = f'{get_endpoint()}/{model_id}' + return model_repo_url + + def delete_model(self, model_id): + """_summary_ + + Args: + model_id (str): The model id. + + model_id = {owner}/{name} + + """ + cookies = ModelScopeConfig.get_cookies() + if cookies is None: + raise ValueError('Token does not exist, please login first.') + path = f'{self.endpoint}/api/v1/models/{model_id}' + + r = requests.delete(path, cookies=cookies, headers=self.headers) + raise_for_http_status(r) + raise_on_error(r.json()) + + def get_model_url(self, model_id): + return f'{self.endpoint}/api/v1/models/{model_id}.git' + + def get_model( + self, + model_id: str, + revision: str = DEFAULT_MODEL_REVISION, + ) -> str: + """ + Get model information at modelscope_hub + + Args: + model_id(`str`): The model id. + revision(`str`): revision of model + Returns: + The model detail information. + Raises: + NotExistError: If the model is not exist, will throw NotExistError + + model_id = {owner}/{name} + + """ + cookies = ModelScopeConfig.get_cookies() + owner_or_group, name = model_id_to_group_owner_name(model_id) + if revision: + path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' + else: + path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}' + + r = requests.get(path, cookies=cookies, headers=self.headers) + handle_http_response(r, logger, cookies, model_id) + if r.status_code == HTTPStatus.OK: + if is_ok(r.json()): + return r.json()[API_RESPONSE_FIELD_DATA] + else: + raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + + def push_model(self, + model_id: str, + model_dir: str, + visibility: int = ModelVisibility.PUBLIC, + license: str = Licenses.APACHE_V2, + chinese_name: Optional[str] = None, + commit_message: Optional[str] = 'upload model', + revision: Optional[str] = DEFAULT_REPOSITORY_REVISION): + """ + Upload model from a given directory to given repository. A valid model directory + must contain a configuration.json file. + + This function upload the files in given directory to given repository. If the + given repository is not exists in remote, it will automatically create it with + given visibility, license and chinese_name parameters. If the revision is also + not exists in remote repository, it will create a new branch for it. + + This function must be called before calling HubApi's login with a valid token + which can be obtained from ModelScope's website. + + Args: + model_id (`str`): + The model id to be uploaded, caller must have write permission for it. + model_dir(`str`): + The Absolute Path of the finetune result. + visibility(`int`, defaults to `0`): + Visibility of the new created model(1-private, 5-public). If the model is + not exists in ModelScope, this function will create a new model with this + visibility and this parameter is required. You can ignore this parameter + if you make sure the model's existence. + license(`str`, defaults to `None`): + License of the new created model(see License). If the model is not exists + in ModelScope, this function will create a new model with this license + and this parameter is required. You can ignore this parameter if you + make sure the model's existence. + chinese_name(`str`, *optional*, defaults to `None`): + chinese name of the new created model. + commit_message(`str`, *optional*, defaults to `None`): + commit message of the push request. + revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION): + which branch to push. If the branch is not exists, It will create a new + branch and push to it. + """ + if model_id is None: + raise InvalidParameter('model_id cannot be empty!') + if model_dir is None: + raise InvalidParameter('model_dir cannot be empty!') + if not os.path.exists(model_dir) or os.path.isfile(model_dir): + raise InvalidParameter('model_dir must be a valid directory.') + cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) + if not os.path.exists(cfg_file): + raise ValueError(f'{model_dir} must contain a configuration.json.') + cookies = ModelScopeConfig.get_cookies() + if cookies is None: + raise NotLoginException('Must login before upload!') + files_to_save = os.listdir(model_dir) + try: + self.get_model(model_id=model_id) + except Exception: + if visibility is None or license is None: + raise InvalidParameter( + 'visibility and license cannot be empty if want to create new repo' + ) + logger.info('Create new model %s' % model_id) + self.create_model( + model_id=model_id, + visibility=visibility, + license=license, + chinese_name=chinese_name) + tmp_dir = tempfile.mkdtemp() + git_wrapper = GitCommandWrapper() + try: + repo = Repository(model_dir=tmp_dir, clone_from=model_id) + branches = git_wrapper.get_remote_branches(tmp_dir) + if revision not in branches: + logger.info('Create new branch %s' % revision) + git_wrapper.new_branch(tmp_dir, revision) + git_wrapper.checkout(tmp_dir, revision) + files_in_repo = os.listdir(tmp_dir) + for f in files_in_repo: + if f[0] != '.': + src = os.path.join(tmp_dir, f) + if os.path.isfile(src): + os.remove(src) + else: + shutil.rmtree(src, ignore_errors=True) + for f in files_to_save: + if f[0] != '.': + src = os.path.join(model_dir, f) + if os.path.isdir(src): + shutil.copytree(src, os.path.join(tmp_dir, f)) + else: + shutil.copy(src, tmp_dir) + if not commit_message: + date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') + commit_message = '[automsg] push model %s to hub at %s' % ( + model_id, date) + repo.push(commit_message=commit_message, local_branch=revision, remote_branch=revision) + except Exception: + raise + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + def list_models(self, + owner_or_group: str, + page_number=1, + page_size=10) -> dict: + """List models in owner or group. + + Args: + owner_or_group(`str`): owner or group. + page_number(`int`): The page number, default: 1 + page_size(`int`): The page size, default: 10 + Returns: + dict: {"models": "list of models", "TotalCount": total_number_of_models_in_owner_or_group} + """ + cookies = ModelScopeConfig.get_cookies() + path = f'{self.endpoint}/api/v1/models/' + r = requests.put( + path, + data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % + (owner_or_group, page_number, page_size), + cookies=cookies, + headers=self.headers) + handle_http_response(r, logger, cookies, 'list_model') + if r.status_code == HTTPStatus.OK: + if is_ok(r.json()): + data = r.json()[API_RESPONSE_FIELD_DATA] + return data + else: + raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + return None + + def _check_cookie(self, + use_cookies: Union[bool, + CookieJar] = False) -> CookieJar: + cookies = None + if isinstance(use_cookies, CookieJar): + cookies = use_cookies + elif use_cookies: + cookies = ModelScopeConfig.get_cookies() + if cookies is None: + raise ValueError('Token does not exist, please login first.') + return cookies + + def list_model_revisions( + self, + model_id: str, + cutoff_timestamp: int = None, + use_cookies: Union[bool, CookieJar] = False) -> List[str]: + """Get model branch and tags. + + Args: + model_id (str): The model id + cutoff_timestamp (int): Tags created before the cutoff will be included. + The timestamp is represented by the seconds elasped from the epoch time. + use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will + will load cookie from local. Defaults to False. + Returns: + Tuple[List[str], List[str]]: Return list of branch name and tags + """ + cookies = self._check_cookie(use_cookies) + if cutoff_timestamp is None: + cutoff_timestamp = get_release_datetime() + path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp + r = requests.get(path, cookies=cookies, headers=self.headers) + handle_http_response(r, logger, cookies, model_id) + d = r.json() + raise_on_error(d) + info = d[API_RESPONSE_FIELD_DATA] + # tags returned from backend are guaranteed to be ordered by create-time + tags = [x['Revision'] for x in info['RevisionMap']['Tags'] + ] if info['RevisionMap']['Tags'] else [] + return tags + + def get_valid_revision(self, model_id: str, revision=None, cookies: Optional[CookieJar] = None): + release_timestamp = get_release_datetime() + current_timestamp = int(round(datetime.datetime.now().timestamp())) + # for active development in library codes (non-release-branches), release_timestamp + # is set to be a far-away-time-in-the-future, to ensure that we shall + # get the master-HEAD version from model repo by default (when no revision is provided) + if release_timestamp > current_timestamp + ONE_YEAR_SECONDS: + branches, tags = self.get_model_branches_and_tags( + model_id, use_cookies=False if cookies is None else cookies) + if revision is None: + revision = MASTER_MODEL_BRANCH + logger.info('Model revision not specified, use default: %s in development mode' % revision) + if revision not in branches and revision not in tags: + raise NotExistError('The model: %s has no branch or tag : %s .' % revision) + logger.info('Development mode use revision: %s' % revision) + else: + if revision is None: # user not specified revision, use latest revision before release time + revisions = self.list_model_revisions( + model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) + if len(revisions) == 0: + raise NoValidRevisionError('The model: %s has no valid revision!' % model_id) + # tags (revisions) returned from backend are guaranteed to be ordered by create-time + # we shall obtain the latest revision created earlier than release version of this branch + revision = revisions[0] + logger.info('Model revision not specified, use the latest revision: %s' % revision) + else: + # use user-specified revision + revisions = self.list_model_revisions( + model_id, cutoff_timestamp=current_timestamp, use_cookies=False if cookies is None else cookies) + if revision not in revisions: + raise NotExistError( + 'The model: %s has no revision: %s !' % (model_id, revision)) + logger.info('Use user-specified model revision: %s' % revision) + return revision + + def get_model_branches_and_tags( + self, + model_id: str, + use_cookies: Union[bool, CookieJar] = False, + ) -> Tuple[List[str], List[str]]: + """Get model branch and tags. + + Args: + model_id (str): The model id + use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will + will load cookie from local. Defaults to False. + Returns: + Tuple[List[str], List[str]]: Return list of branch name and tags + """ + cookies = self._check_cookie(use_cookies) + + path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' + r = requests.get(path, cookies=cookies, headers=self.headers) + handle_http_response(r, logger, cookies, model_id) + d = r.json() + raise_on_error(d) + info = d[API_RESPONSE_FIELD_DATA] + branches = [x['Revision'] for x in info['RevisionMap']['Branches'] + ] if info['RevisionMap']['Branches'] else [] + tags = [x['Revision'] for x in info['RevisionMap']['Tags'] + ] if info['RevisionMap']['Tags'] else [] + return branches, tags + + def get_model_files(self, + model_id: str, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + root: Optional[str] = None, + recursive: Optional[str] = False, + use_cookies: Union[bool, CookieJar] = False, + headers: Optional[dict] = {}) -> List[dict]: + """List the models files. + + Args: + model_id (str): The model id + revision (Optional[str], optional): The branch or tag name. + root (Optional[str], optional): The root path. Defaults to None. + recursive (Optional[str], optional): Is recursive list files. Defaults to False. + use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, + will load cookie from local. Defaults to False. + headers: request headers + + Raises: + ValueError: If user_cookies is True, but no local cookie. + + Returns: + List[dict]: Model file list. + """ + if revision: + path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % ( + self.endpoint, model_id, revision, recursive) + else: + path = '%s/api/v1/models/%s/repo/files?Recursive=%s' % ( + self.endpoint, model_id, recursive) + cookies = self._check_cookie(use_cookies) + if root is not None: + path = path + f'&Root={root}' + + r = requests.get( + path, cookies=cookies, headers={ + **headers, + **self.headers + }) + + handle_http_response(r, logger, cookies, model_id) + d = r.json() + raise_on_error(d) + + files = [] + for file in d[API_RESPONSE_FIELD_DATA]['Files']: + if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes': + continue + + files.append(file) + return files + + def list_datasets(self): + path = f'{self.endpoint}/api/v1/datasets' + params = {} + r = requests.get(path, params=params, headers=self.headers) + raise_for_http_status(r) + dataset_list = r.json()[API_RESPONSE_FIELD_DATA] + return [x['Name'] for x in dataset_list] + + def fetch_dataset_scripts( + self, + dataset_name: str, + namespace: str, + download_mode: Optional[DownloadMode], + revision: Optional[str] = DEFAULT_DATASET_REVISION): + if namespace is None: + raise ValueError( + f'Dataset from Hubs.modelscope should have a valid "namespace", but get {namespace}' + ) + revision = revision or DEFAULT_DATASET_REVISION + cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, namespace, + dataset_name, revision) + download_mode = DownloadMode(download_mode + or DownloadMode.REUSE_DATASET_IF_EXISTS) + if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists( + cache_dir): + shutil.rmtree(cache_dir) + os.makedirs(cache_dir, exist_ok=True) + datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' + cookies = ModelScopeConfig.get_cookies() + r = requests.get(datahub_url, cookies=cookies) + resp = r.json() + datahub_raise_on_error(datahub_url, resp) + dataset_id = resp['Data']['Id'] + dataset_type = resp['Data']['Type'] + datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' + r = requests.get(datahub_url, cookies=cookies, headers=self.headers) + resp = r.json() + datahub_raise_on_error(datahub_url, resp) + file_list = resp['Data'] + if file_list is None: + raise NotExistError( + f'The modelscope dataset [dataset_name = {dataset_name}, namespace = {namespace}, ' + f'version = {revision}] dose not exist') + + file_list = file_list['Files'] + local_paths = defaultdict(list) + dataset_formation = DatasetFormations(dataset_type) + dataset_meta_format = DatasetMetaFormats[dataset_formation] + for file_info in file_list: + file_path = file_info['Path'] + extension = os.path.splitext(file_path)[-1] + if extension in dataset_meta_format: + datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ + f'Revision={revision}&FilePath={file_path}' + r = requests.get(datahub_url, cookies=cookies) + raise_for_http_status(r) + local_path = os.path.join(cache_dir, file_path) + if os.path.exists(local_path): + logger.warning( + f"Reusing dataset {dataset_name}'s python file ({local_path})" + ) + local_paths[extension].append(local_path) + continue + with open(local_path, 'wb') as f: + f.write(r.content) + local_paths[extension].append(local_path) + + return local_paths, dataset_formation, cache_dir + + def get_dataset_file_url( + self, + file_name: str, + dataset_name: str, + namespace: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION): + if file_name.endswith('.csv'): + file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ + f'Revision={revision}&FilePath={file_name}' + return file_name + + def get_dataset_access_config( + self, + dataset_name: str, + namespace: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION): + datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ + f'ststoken?Revision={revision}' + return self.datahub_remote_call(datahub_url) + + def get_dataset_access_config_session( + self, + cookies: CookieJar, + dataset_name: str, + namespace: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION): + + datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ + f'ststoken?Revision={revision}' + + r = requests.get(url=datahub_url, cookies=cookies, headers=self.headers) + resp = r.json() + raise_on_error(resp) + return resp['Data'] + + def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, + is_recursive, is_filter_dir, revision): + url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ + f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' + + cookies = ModelScopeConfig.get_cookies() + resp = requests.get(url=url, cookies=cookies) + resp = resp.json() + raise_on_error(resp) + resp = resp['Data'] + return resp + + def on_dataset_download(self, dataset_name: str, namespace: str) -> None: + url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' + cookies = ModelScopeConfig.get_cookies() + r = requests.post(url, cookies=cookies, headers=self.headers) + raise_for_http_status(r) + + def delete_oss_dataset_object(self, object_name: str, dataset_name: str, + namespace: str, revision: str) -> str: + if not object_name or not dataset_name or not namespace or not revision: + raise ValueError('Args cannot be empty!') + + url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}' + + cookies = self.check_local_cookies(use_cookies=True) + resp = requests.delete(url=url, cookies=cookies) + resp = resp.json() + raise_on_error(resp) + resp = resp['Message'] + return resp + + def delete_oss_dataset_dir(self, object_name: str, dataset_name: str, + namespace: str, revision: str) -> str: + if not object_name or not dataset_name or not namespace or not revision: + raise ValueError('Args cannot be empty!') + + url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \ + f'&Revision={revision}' + + cookies = self.check_local_cookies(use_cookies=True) + resp = requests.delete(url=url, cookies=cookies) + resp = resp.json() + raise_on_error(resp) + resp = resp['Message'] + return resp + + @staticmethod + def datahub_remote_call(url): + cookies = ModelScopeConfig.get_cookies() + r = requests.get(url, cookies=cookies, headers={'user-agent': ModelScopeConfig.get_user_agent()}) + resp = r.json() + datahub_raise_on_error(url, resp) + return resp['Data'] + + def check_local_cookies(self, use_cookies) -> CookieJar: + return self._check_cookie(use_cookies=use_cookies) + + def dataset_download_uv(self, dataset_name: str, namespace: str): + if not dataset_name or not namespace: + raise ValueError('dataset_name or namespace cannot be empty!') + + # get channel and user_name + channel = DownloadChannel.LOCAL.value + user_name = '' + if MODELSCOPE_ENVIRONMENT in os.environ: + channel = os.environ[MODELSCOPE_ENVIRONMENT] + if MODELSCOPE_USERNAME in os.environ: + user_name = os.environ[MODELSCOPE_USERNAME] + + url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/{channel}?user={user_name}' + cookies = ModelScopeConfig.get_cookies() + r = requests.post(url, cookies=cookies, headers=self.headers) + resp = r.json() + raise_on_error(resp) + return resp['Message'] + + +class ModelScopeConfig: + path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) + COOKIES_FILE_NAME = 'cookies' + GIT_TOKEN_FILE_NAME = 'git_token' + USER_INFO_FILE_NAME = 'user' + USER_SESSION_ID_FILE_NAME = 'session' + + @staticmethod + def make_sure_credential_path_exist(): + os.makedirs(ModelScopeConfig.path_credential, exist_ok=True) + + @staticmethod + def save_cookies(cookies: CookieJar): + ModelScopeConfig.make_sure_credential_path_exist() + with open( + os.path.join(ModelScopeConfig.path_credential, + ModelScopeConfig.COOKIES_FILE_NAME), 'wb+') as f: + pickle.dump(cookies, f) + + @staticmethod + def get_cookies(): + cookies_path = os.path.join(ModelScopeConfig.path_credential, + ModelScopeConfig.COOKIES_FILE_NAME) + if os.path.exists(cookies_path): + with open(cookies_path, 'rb') as f: + cookies = pickle.load(f) + for cookie in cookies: + if cookie.is_expired(): + logger.warn( + 'Authentication has expired, please re-login') + return None + return cookies + return None + + @staticmethod + def get_user_session_id(): + session_path = os.path.join(ModelScopeConfig.path_credential, + ModelScopeConfig.USER_SESSION_ID_FILE_NAME) + session_id = '' + if os.path.exists(session_path): + with open(session_path, 'rb') as f: + session_id = str(f.readline().strip(), encoding='utf-8') + return session_id + if session_id == '' or len(session_id) != 32: + session_id = str(uuid.uuid4().hex) + ModelScopeConfig.make_sure_credential_path_exist() + with open(session_path, 'w+') as wf: + wf.write(session_id) + + return session_id + + @staticmethod + def save_token(token: str): + ModelScopeConfig.make_sure_credential_path_exist() + with open( + os.path.join(ModelScopeConfig.path_credential, + ModelScopeConfig.GIT_TOKEN_FILE_NAME), 'w+') as f: + f.write(token) + + @staticmethod + def save_user_info(user_name: str, user_email: str): + ModelScopeConfig.make_sure_credential_path_exist() + with open( + os.path.join(ModelScopeConfig.path_credential, + ModelScopeConfig.USER_INFO_FILE_NAME), 'w+') as f: + f.write('%s:%s' % (user_name, user_email)) + + @staticmethod + def get_user_info() -> Tuple[str, str]: + try: + with open( + os.path.join(ModelScopeConfig.path_credential, + ModelScopeConfig.USER_INFO_FILE_NAME), + 'r') as f: + info = f.read() + return info.split(':')[0], info.split(':')[1] + except FileNotFoundError: + pass + return None, None + + @staticmethod + def get_token() -> Optional[str]: + """ + Get token or None if not existent. + + Returns: + `str` or `None`: The token, `None` if it doesn't exist. + + """ + token = None + try: + with open( + os.path.join(ModelScopeConfig.path_credential, + ModelScopeConfig.GIT_TOKEN_FILE_NAME), + 'r') as f: + token = f.read() + except FileNotFoundError: + pass + return token + + @staticmethod + def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: + """Formats a user-agent string with basic info about a request. + + Args: + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. + + Returns: + The formatted user-agent string. + """ + env = 'custom' + if MODELSCOPE_ENVIRONMENT in os.environ: + env = os.environ[MODELSCOPE_ENVIRONMENT] + user_name = 'unknown' + if MODELSCOPE_USERNAME in os.environ: + user_name = os.environ[MODELSCOPE_USERNAME] + + ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % ( + __version__, + platform.python_version(), + ModelScopeConfig.get_user_session_id(), + platform.platform(), + platform.processor(), + env, + user_name, + ) + if isinstance(user_agent, dict): + ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += ';' + user_agent + return ua diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py new file mode 100644 index 00000000..373a0cf4 --- /dev/null +++ b/modelscope/hub/constants.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from pathlib import Path + +MODELSCOPE_URL_SCHEME = 'http://' +DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn' +DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DOMAIN + +DEFAULT_MODELSCOPE_GROUP = 'damo' +MODEL_ID_SEPARATOR = '/' +FILE_HASH = 'Sha256' +LOGGER_NAME = 'ModelScopeHub' +DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials') +API_RESPONSE_FIELD_DATA = 'Data' +API_RESPONSE_FIELD_GIT_ACCESS_TOKEN = 'AccessToken' +API_RESPONSE_FIELD_USERNAME = 'Username' +API_RESPONSE_FIELD_EMAIL = 'Email' +API_RESPONSE_FIELD_MESSAGE = 'Message' +MODELSCOPE_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' +MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' +MODELSCOPE_USERNAME = 'MODELSCOPE_USERNAME' +ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 + + +class Licenses(object): + APACHE_V2 = 'Apache License 2.0' + GPL_V2 = 'GPL-2.0' + GPL_V3 = 'GPL-3.0' + LGPL_V2_1 = 'LGPL-2.1' + LGPL_V3 = 'LGPL-3.0' + AFL_V3 = 'AFL-3.0' + ECL_V2 = 'ECL-2.0' + MIT = 'MIT' + + +class ModelVisibility(object): + PRIVATE = 1 + INTERNAL = 3 + PUBLIC = 5 diff --git a/modelscope/hub/deploy.py b/modelscope/hub/deploy.py new file mode 100644 index 00000000..8cacde82 --- /dev/null +++ b/modelscope/hub/deploy.py @@ -0,0 +1,339 @@ +import urllib +from abc import ABC +from http import HTTPStatus +from typing import Optional + +import json +import requests +from attrs import asdict, define, field, validators + +from modelscope.hub.api import ModelScopeConfig +from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, + API_RESPONSE_FIELD_MESSAGE) +from modelscope.hub.errors import (NotLoginException, NotSupportError, + RequestError, handle_http_response, is_ok, + raise_for_http_status) +from modelscope.hub.utils.utils import get_endpoint +from modelscope.utils.logger import get_logger + +# yapf: enable + +logger = get_logger() + + +class Accelerator(object): + CPU = 'cpu' + GPU = 'gpu' + + +class Vendor(object): + EAS = 'eas' + + +class EASRegion(object): + beijing = 'cn-beijing' + hangzhou = 'cn-hangzhou' + + +class EASCpuInstanceType(object): + """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) + """ + tiny = 'ecs.c6.2xlarge' + small = 'ecs.c6.4xlarge' + medium = 'ecs.c6.6xlarge' + large = 'ecs.c6.8xlarge' + + +class EASGpuInstanceType(object): + """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) + """ + tiny = 'ecs.gn5-c28g1.7xlarge' + small = 'ecs.gn5-c8g1.4xlarge' + medium = 'ecs.gn6i-c24g1.12xlarge' + large = 'ecs.gn6e-c12g1.3xlarge' + + +def min_smaller_than_max(instance, attribute, value): + if value > instance.max_replica: + raise ValueError( + "'min_replica' value: %s has to be smaller than 'max_replica' value: %s!" + % (value, instance.max_replica)) + + +@define +class ServiceScalingConfig(object): + """Resource scaling config + Currently we ignore max_replica + Args: + max_replica: maximum replica + min_replica: minimum replica + """ + max_replica: int = field(default=1, validator=validators.ge(1)) + min_replica: int = field( + default=1, validator=[validators.ge(1), min_smaller_than_max]) + + +@define +class ServiceResourceConfig(object): + """Eas Resource request. + + Args: + accelerator: the accelerator(cpu|gpu) + instance_type: the instance type. + scaling: The instance scaling config. + """ + instance_type: str + scaling: ServiceScalingConfig + accelerator: str = field( + default=Accelerator.CPU, + validator=validators.in_([Accelerator.CPU, Accelerator.GPU])) + + +@define +class ServiceProviderParameters(ABC): + pass + + +@define +class EASDeployParameters(ServiceProviderParameters): + """Parameters for EAS Deployment. + + Args: + resource_group: the resource group to deploy, current default. + region: The eas instance region(eg: cn-hangzhou). + access_key_id: The eas account access key id. + access_key_secret: The eas account access key secret. + vendor: must be 'eas' + """ + region: str + access_key_id: str + access_key_secret: str + resource_group: Optional[str] = None + vendor: str = field( + default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) + + +@define +class EASListParameters(ServiceProviderParameters): + """EAS instance list parameters. + + Args: + resource_group: the resource group to deploy, current default. + region: The eas instance region(eg: cn-hangzhou). + access_key_id: The eas account access key id. + access_key_secret: The eas account access key secret. + vendor: must be 'eas' + """ + access_key_id: str + access_key_secret: str + region: str = None + resource_group: str = None + vendor: str = field( + default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) + + +@define +class DeployServiceParameters(object): + """Deploy service parameters + + Args: + instance_name: the name of the service. + model_id: the modelscope model_id + revision: the modelscope model revision + resource: the resource requirement. + provider: the cloud service provider. + """ + instance_name: str + model_id: str + revision: str + resource: ServiceResourceConfig + provider: ServiceProviderParameters + + +class AttrsToQueryString(ABC): + """Convert the attrs class to json string. + + Args: + """ + + def to_query_str(self): + self_dict = asdict( + self.provider, filter=lambda attr, value: value is not None) + json_str = json.dumps(self_dict) + print(json_str) + safe_str = urllib.parse.quote_plus(json_str) + print(safe_str) + query_param = 'provider=%s' % safe_str + return query_param + + +@define +class ListServiceParameters(AttrsToQueryString): + provider: ServiceProviderParameters + skip: int = 0 + limit: int = 100 + + +@define +class GetServiceParameters(AttrsToQueryString): + provider: ServiceProviderParameters + + +@define +class DeleteServiceParameters(AttrsToQueryString): + provider: ServiceProviderParameters + + +class ServiceDeployer(object): + + def __init__(self, endpoint=None): + self.endpoint = endpoint if endpoint is not None else get_endpoint() + self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} + self.cookies = ModelScopeConfig.get_cookies() + if self.cookies is None: + raise NotLoginException( + 'Token does not exist, please login with HubApi first.') + + # deploy_model + def create(self, model_id: str, revision: str, instance_name: str, + resource: ServiceResourceConfig, + provider: ServiceProviderParameters): + """Deploy model to cloud, current we only support PAI EAS, this is an async API , + and the deployment could take a while to finish remotely. Please check deploy instance + status separately via checking the status. + + Args: + model_id (str): The deployed model id + revision (str): The model revision + instance_name (str): The deployed model instance name. + resource (ServiceResourceConfig): The service resource information. + provider (ServiceProviderParameters): The service provider parameter + + Raises: + NotLoginException: To use this api, you need login first. + NotSupportError: Not supported platform. + RequestError: The server return error. + + Returns: + ServiceInstanceInfo: The information of the deployed service instance. + """ + if provider.vendor != Vendor.EAS: + raise NotSupportError( + 'Not support vendor: %s ,only support EAS current.' % + (provider.vendor)) + create_params = DeployServiceParameters( + instance_name=instance_name, + model_id=model_id, + revision=revision, + resource=resource, + provider=provider) + path = f'{self.endpoint}/api/v1/deployer/endpoint' + body = asdict(create_params) + r = requests.post( + path, json=body, cookies=self.cookies, headers=self.headers) + handle_http_response(r, logger, self.cookies, 'create_service') + if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: + if is_ok(r.json()): + data = r.json()[API_RESPONSE_FIELD_DATA] + return data + else: + raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + return None + + def get(self, instance_name: str, provider: ServiceProviderParameters): + """Query the specified instance information. + + Args: + instance_name (str): The deployed instance name. + provider (ServiceProviderParameters): The cloud provider information, for eas + need region(eg: ch-hangzhou), access_key_id and access_key_secret. + + Raises: + NotLoginException: To use this api, you need login first. + RequestError: The request is failed from server. + + Returns: + Dict: The information of the requested service instance. + """ + params = GetServiceParameters(provider=provider) + path = '%s/api/v1/deployer/endpoint/%s?%s' % ( + self.endpoint, instance_name, params.to_query_str()) + r = requests.get(path, cookies=self.cookies, headers=self.headers) + handle_http_response(r, logger, self.cookies, 'get_service') + if r.status_code == HTTPStatus.OK: + if is_ok(r.json()): + data = r.json()[API_RESPONSE_FIELD_DATA] + return data + else: + raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + return None + + def delete(self, instance_name: str, provider: ServiceProviderParameters): + """Delete deployed model, this api send delete command and return, it will take + some to delete, please check through the cloud console. + + Args: + instance_name (str): The instance name you want to delete. + provider (ServiceProviderParameters): The cloud provider information, for eas + need region(eg: ch-hangzhou), access_key_id and access_key_secret. + + Raises: + NotLoginException: To call this api, you need login first. + RequestError: The request is failed. + + Returns: + Dict: The deleted instance information. + """ + params = DeleteServiceParameters(provider=provider) + path = '%s/api/v1/deployer/endpoint/%s?%s' % ( + self.endpoint, instance_name, params.to_query_str()) + r = requests.delete(path, cookies=self.cookies, headers=self.headers) + handle_http_response(r, logger, self.cookies, 'delete_service') + if r.status_code == HTTPStatus.OK: + if is_ok(r.json()): + data = r.json()[API_RESPONSE_FIELD_DATA] + return data + else: + raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + return None + + def list(self, + provider: ServiceProviderParameters, + skip: int = 0, + limit: int = 100): + """List deployed model instances. + + Args: + provider (ServiceProviderParameters): The cloud service provider parameter, + for eas, need access_key_id and access_key_secret. + skip: start of the list, current not support. + limit: maximum number of instances return, current not support + Raises: + NotLoginException: To use this api, you need login first. + RequestError: The request is failed from server. + + Returns: + List: List of instance information + """ + + params = ListServiceParameters( + provider=provider, skip=skip, limit=limit) + path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, + params.to_query_str()) + r = requests.get(path, cookies=self.cookies, headers=self.headers) + handle_http_response(r, logger, self.cookies, 'list_service_instances') + if r.status_code == HTTPStatus.OK: + if is_ok(r.json()): + data = r.json()[API_RESPONSE_FIELD_DATA] + return data + else: + raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) + else: + raise_for_http_status(r) + return None diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py new file mode 100644 index 00000000..4c4e5dbd --- /dev/null +++ b/modelscope/hub/errors.py @@ -0,0 +1,136 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from http import HTTPStatus + +from requests.exceptions import HTTPError + +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class NotSupportError(Exception): + pass + + +class NoValidRevisionError(Exception): + pass + + +class NotExistError(Exception): + pass + + +class RequestError(Exception): + pass + + +class GitError(Exception): + pass + + +class InvalidParameter(Exception): + pass + + +class NotLoginException(Exception): + pass + + +class FileIntegrityError(Exception): + pass + + +class FileDownloadError(Exception): + pass + + +def is_ok(rsp): + """ Check the request is ok + + Args: + rsp (_type_): The request response body + Failed: {'Code': 10010101004, 'Message': 'get model info failed, err: unauthorized permission', + 'RequestId': '', 'Success': False} + Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True} + """ + return rsp['Code'] == HTTPStatus.OK and rsp['Success'] + + +def handle_http_post_error(response, url, request_body): + try: + response.raise_for_status() + except HTTPError as error: + logger.error('Request %s with body: %s exception' % + (url, request_body)) + logger.error('Response details: %s' % response.content) + raise error + + +def handle_http_response(response, logger, cookies, model_id): + try: + response.raise_for_status() + except HTTPError as error: + if cookies is None: # code in [403] and + logger.error( + f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ + private. Please login first.') + logger.error('Response details: %s' % response.content) + raise error + + +def raise_on_error(rsp): + """If response error, raise exception + + Args: + rsp (_type_): The server response + """ + if rsp['Code'] == HTTPStatus.OK: + return True + else: + raise RequestError(rsp['Message']) + + +# TODO use raise_on_error instead if modelhub and datahub response have uniform structures, +def datahub_raise_on_error(url, rsp): + """If response error, raise exception + + Args: + rsp (_type_): The server response + """ + if rsp.get('Code') == HTTPStatus.OK: + return True + else: + raise RequestError( + f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" + ) + + +def raise_for_http_status(rsp): + """ + Attempt to decode utf-8 first since some servers + localize reason strings, for invalid utf-8, fall back + to decoding with iso-8859-1. + """ + http_error_msg = '' + if isinstance(rsp.reason, bytes): + try: + reason = rsp.reason.decode('utf-8') + except UnicodeDecodeError: + reason = rsp.reason.decode('iso-8859-1') + else: + reason = rsp.reason + + if 400 <= rsp.status_code < 500: + http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code, + reason, rsp.url) + + elif 500 <= rsp.status_code < 600: + http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code, + reason, rsp.url) + + if http_error_msg: + req = rsp.request + if req.method == 'POST': + http_error_msg = u'%s, body: %s' % (http_error_msg, req.body) + raise HTTPError(http_error_msg, response=rsp) diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py new file mode 100644 index 00000000..042ea6a6 --- /dev/null +++ b/modelscope/hub/file_download.py @@ -0,0 +1,241 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import copy +import os +import tempfile +from functools import partial +from http.cookiejar import CookieJar +from pathlib import Path +from typing import Dict, Optional, Union + +import requests +from tqdm import tqdm + +from modelscope import __version__ +from modelscope.hub.api import HubApi, ModelScopeConfig +from modelscope.utils.constant import DEFAULT_MODEL_REVISION +from modelscope.utils.logger import get_logger +from .constants import FILE_HASH +from .errors import FileDownloadError, NotExistError +from .utils.caching import ModelFileSystemCache +from .utils.utils import (file_integrity_validation, get_cache_dir, + get_endpoint, model_id_to_group_owner_name) + +logger = get_logger() + + +def model_file_download( + model_id: str, + file_path: str, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cache_dir: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None, +) -> Optional[str]: # pragma: no cover + """ + Download from a given URL and cache it if it's not already present in the + local cache. + + Given a URL, this function looks for the corresponding file in the local + cache. If it's not there, download it. Then return the path to the cached + file. + + Args: + model_id (`str`): + The model to whom the file to be downloaded belongs. + file_path(`str`): + Path of the file to be downloaded, relative to the root of model repo + revision(`str`, *optional*): + revision of the model file to be downloaded. + Can be any of a branch, tag or commit hash + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + user_agent (`dict`, `str`, *optional*): + The user-agent info in the form of a dictionary or a string. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + if `False`, download the file anyway even it exists + + Returns: + Local path (string) of file or if networking is off, last version of + file cached on disk. + + + + Raises the following errors: + + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `use_auth_token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) + if ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + + + """ + if cache_dir is None: + cache_dir = get_cache_dir() + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + temporary_cache_dir = os.path.join(cache_dir, 'temp') + os.makedirs(temporary_cache_dir, exist_ok=True) + + group_or_owner, name = model_id_to_group_owner_name(model_id) + + cache = ModelFileSystemCache(cache_dir, group_or_owner, name) + + # if local_files_only is `True` and the file already exists in cached_path + # return the cached path + if local_files_only: + cached_file_path = cache.get_file_by_path(file_path) + if cached_file_path is not None: + logger.warning( + "File exists in local cache, but we're not sure it's up to date" + ) + return cached_file_path + else: + raise ValueError( + 'Cannot find the requested files in the cached path and outgoing' + ' traffic has been disabled. To enable model look-ups and downloads' + " online, set 'local_files_only' to False.") + + _api = HubApi() + headers = { + 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ) + } + if cookies is None: + cookies = ModelScopeConfig.get_cookies() + + revision = _api.get_valid_revision( + model_id, revision=revision, cookies=cookies) + file_to_download_info = None + # we need to confirm the version is up-to-date + # we need to get the file list to check if the latest version is cached, if so return, otherwise download + model_files = _api.get_model_files( + model_id=model_id, + revision=revision, + recursive=True, + use_cookies=False if cookies is None else cookies) + + for model_file in model_files: + if model_file['Type'] == 'tree': + continue + + if model_file['Path'] == file_path: + if cache.exists(model_file): + logger.info( + f'File {model_file["Name"]} already in cache, skip downloading!' + ) + return cache.get_file_by_info(model_file) + else: + file_to_download_info = model_file + break + + if file_to_download_info is None: + raise NotExistError('The file path: %s not exist in: %s' % + (file_path, model_id)) + + # we need to download again + url_to_download = get_file_download_url(model_id, file_path, revision) + file_to_download_info = { + 'Path': file_path, + 'Revision': file_to_download_info['Revision'], + FILE_HASH: file_to_download_info[FILE_HASH] + } + + temp_file_name = next(tempfile._get_candidate_names()) + http_get_file( + url_to_download, + temporary_cache_dir, + temp_file_name, + headers=headers, + cookies=None if cookies is None else cookies.get_dict()) + temp_file_path = os.path.join(temporary_cache_dir, temp_file_name) + # for download with commit we can't get Sha256 + if file_to_download_info[FILE_HASH] is not None: + file_integrity_validation(temp_file_path, + file_to_download_info[FILE_HASH]) + return cache.put_file(file_to_download_info, + os.path.join(temporary_cache_dir, temp_file_name)) + + +def get_file_download_url(model_id: str, file_path: str, revision: str): + """ + Format file download url according to `model_id`, `revision` and `file_path`. + e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`, + the resulted download url is: https://modelscope.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md + """ + download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}' + return download_url_template.format( + endpoint=get_endpoint(), + model_id=model_id, + revision=revision, + file_path=file_path, + ) + + +def http_get_file( + url: str, + local_dir: str, + file_name: str, + cookies: CookieJar, + headers: Optional[Dict[str, str]] = None, +): + """ + Download remote file. Do not gobble up errors. + This method is only used by snapshot_download, since the behavior is quite different with single file download + TODO: consolidate with http_get_file() to avoild duplicate code + + Args: + url(`str`): + actual download url of the file + local_dir(`str`): + local directory where the downloaded file stores + file_name(`str`): + name of the file stored in `local_dir` + cookies(`CookieJar`): + cookies used to authentication the user, which is used for downloading private repos + headers(`Optional[Dict[str, str]] = None`): + http headers to carry necessary info when requesting the remote file + + """ + total = -1 + temp_file_manager = partial( + tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) + + with temp_file_manager() as temp_file: + logger.info('downloading %s to %s', url, temp_file.name) + headers = copy.deepcopy(headers) + + r = requests.get(url, stream=True, headers=headers, cookies=cookies) + r.raise_for_status() + + content_length = r.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + + progress = tqdm( + unit='B', + unit_scale=True, + unit_divisor=1024, + total=total, + initial=0, + desc='Downloading', + ) + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + logger.info('storing %s in cache at %s', url, local_dir) + downloaded_length = os.path.getsize(temp_file.name) + if total != downloaded_length: + os.remove(temp_file.name) + msg = 'File %s download incomplete, content_length: %s but the \ + file downloaded length: %s, please download again' % ( + file_name, total, downloaded_length) + logger.error(msg) + raise FileDownloadError(msg) + os.replace(temp_file.name, os.path.join(local_dir, file_name)) diff --git a/modelscope/hub/git.py b/modelscope/hub/git.py new file mode 100644 index 00000000..7943023b --- /dev/null +++ b/modelscope/hub/git.py @@ -0,0 +1,249 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import subprocess +from typing import List + +from modelscope.utils.logger import get_logger +from ..utils.constant import MASTER_MODEL_BRANCH +from .errors import GitError + +logger = get_logger() + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, + cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class GitCommandWrapper(metaclass=Singleton): + """Some git operation wrapper + """ + default_git_path = 'git' # The default git command line + + def __init__(self, path: str = None): + self.git_path = path or self.default_git_path + + def _run_git_command(self, *args) -> subprocess.CompletedProcess: + """Run git command, if command return 0, return subprocess.response + otherwise raise GitError, message is stdout and stderr. + + Raises: + GitError: Exception with stdout and stderr. + + Returns: + subprocess.CompletedProcess: the command response + """ + logger.debug(' '.join(args)) + git_env = os.environ.copy() + git_env['GIT_TERMINAL_PROMPT'] = '0' + response = subprocess.run( + [self.git_path, *args], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=git_env, + ) # compatible for python3.6 + try: + response.check_returncode() + return response + except subprocess.CalledProcessError as error: + if response.returncode == 1: + logger.info('Nothing to commit.') + return response + else: + logger.error( + 'There are error run git command, you may need to login first.' + ) + raise GitError('stdout: %s, stderr: %s' % + (response.stdout.decode('utf8'), + error.stderr.decode('utf8'))) + + def config_auth_token(self, repo_dir, auth_token): + url = self.get_repo_remote_url(repo_dir) + if '//oauth2' not in url: + auth_url = self._add_token(auth_token, url) + cmd_args = '-C %s remote set-url origin %s' % (repo_dir, auth_url) + cmd_args = cmd_args.split(' ') + rsp = self._run_git_command(*cmd_args) + logger.debug(rsp.stdout.decode('utf8')) + + def _add_token(self, token: str, url: str): + if token: + if '//oauth2' not in url: + url = url.replace('//', '//oauth2:%s@' % token) + return url + + def remove_token_from_url(self, url: str): + if url and '//oauth2' in url: + start_index = url.find('oauth2') + end_index = url.find('@') + url = url[:start_index] + url[end_index + 1:] + return url + + def is_lfs_installed(self): + cmd = ['lfs', 'env'] + try: + self._run_git_command(*cmd) + return True + except GitError: + return False + + def git_lfs_install(self, repo_dir): + cmd = ['git', '-C', repo_dir, 'lfs', 'install'] + try: + self._run_git_command(*cmd) + return True + except GitError: + return False + + def clone(self, + repo_base_dir: str, + token: str, + url: str, + repo_name: str, + branch: str = None): + """ git clone command wrapper. + For public project, token can None, private repo, there must token. + + Args: + repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name + token (str): The git token, must be provided for private project. + url (str): The remote url + repo_name (str): The local repository path name. + branch (str, optional): _description_. Defaults to None. + """ + url = self._add_token(token, url) + if branch: + clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url, + repo_name, branch) + else: + clone_args = '-C %s clone %s' % (repo_base_dir, url) + logger.debug(clone_args) + clone_args = clone_args.split(' ') + response = self._run_git_command(*clone_args) + logger.debug(response.stdout.decode('utf8')) + return response + + def add_user_info(self, repo_base_dir, repo_name): + from modelscope.hub.api import ModelScopeConfig + user_name, user_email = ModelScopeConfig.get_user_info() + if user_name and user_email: + # config user.name and user.email if exist + config_user_name_args = '-C %s/%s config user.name %s' % ( + repo_base_dir, repo_name, user_name) + response = self._run_git_command(*config_user_name_args.split(' ')) + logger.debug(response.stdout.decode('utf8')) + config_user_email_args = '-C %s/%s config user.email %s' % ( + repo_base_dir, repo_name, user_email) + response = self._run_git_command( + *config_user_email_args.split(' ')) + logger.debug(response.stdout.decode('utf8')) + + def add(self, + repo_dir: str, + files: List[str] = list(), + all_files: bool = False): + if all_files: + add_args = '-C %s add -A' % repo_dir + elif len(files) > 0: + files_str = ' '.join(files) + add_args = '-C %s add %s' % (repo_dir, files_str) + add_args = add_args.split(' ') + rsp = self._run_git_command(*add_args) + logger.debug(rsp.stdout.decode('utf8')) + return rsp + + def commit(self, repo_dir: str, message: str): + """Run git commit command + + Args: + message (str): commit message. + """ + commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message] + rsp = self._run_git_command(*commit_args) + logger.info(rsp.stdout.decode('utf8')) + return rsp + + def checkout(self, repo_dir: str, revision: str): + cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision] + return self._run_git_command(*cmds) + + def new_branch(self, repo_dir: str, revision: str): + cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision] + return self._run_git_command(*cmds) + + def get_remote_branches(self, repo_dir: str): + cmds = ['-C', '%s' % repo_dir, 'branch', '-r'] + rsp = self._run_git_command(*cmds) + info = [ + line.strip() + for line in rsp.stdout.decode('utf8').strip().split(os.linesep) + ] + if len(info) == 1: + return ['/'.join(info[0].split('/')[1:])] + else: + return ['/'.join(line.split('/')[1:]) for line in info[1:]] + + def pull(self, repo_dir: str): + cmds = ['-C', repo_dir, 'pull'] + return self._run_git_command(*cmds) + + def push(self, + repo_dir: str, + token: str, + url: str, + local_branch: str, + remote_branch: str, + force: bool = False): + url = self._add_token(token, url) + + push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch, + remote_branch) + if force: + push_args += ' -f' + push_args = push_args.split(' ') + rsp = self._run_git_command(*push_args) + logger.debug(rsp.stdout.decode('utf8')) + return rsp + + def get_repo_remote_url(self, repo_dir: str): + cmd_args = '-C %s config --get remote.origin.url' % repo_dir + cmd_args = cmd_args.split(' ') + rsp = self._run_git_command(*cmd_args) + url = rsp.stdout.decode('utf8') + return url.strip() + + def list_lfs_files(self, repo_dir: str): + cmd_args = '-C %s lfs ls-files' % repo_dir + cmd_args = cmd_args.split(' ') + rsp = self._run_git_command(*cmd_args) + out = rsp.stdout.decode('utf8').strip() + files = [] + for line in out.split(os.linesep): + files.append(line.split(' ')[-1]) + + return files + + def tag(self, + repo_dir: str, + tag_name: str, + message: str, + ref: str = MASTER_MODEL_BRANCH): + cmd_args = [ + '-C', repo_dir, 'tag', tag_name, '-m', + '"%s"' % message, ref + ] + rsp = self._run_git_command(*cmd_args) + logger.debug(rsp.stdout.decode('utf8')) + return rsp + + def push_tag(self, repo_dir: str, tag_name): + cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name] + rsp = self._run_git_command(*cmd_args) + logger.debug(rsp.stdout.decode('utf8')) + return rsp diff --git a/modelscope/hub/repository.py b/modelscope/hub/repository.py new file mode 100644 index 00000000..6b116f79 --- /dev/null +++ b/modelscope/hub/repository.py @@ -0,0 +1,277 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Optional + +from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, + DEFAULT_REPOSITORY_REVISION, + MASTER_MODEL_BRANCH) +from modelscope.utils.logger import get_logger +from .git import GitCommandWrapper +from .utils.utils import get_endpoint + +logger = get_logger() + + +class Repository: + """A local representation of the model git repository. + """ + + def __init__(self, + model_dir: str, + clone_from: str, + revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, + auth_token: Optional[str] = None, + git_path: Optional[str] = None): + """ + Instantiate a Repository object by cloning the remote ModelScopeHub repo + Args: + model_dir(`str`): + The model root directory. + clone_from: + model id in ModelScope-hub from which git clone + revision(`Optional[str]`): + revision of the model you want to clone from. Can be any of a branch, tag or commit hash + auth_token(`Optional[str]`): + token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter + as the token is already saved when you login the first time, if None, we will use saved token. + git_path:(`Optional[str]`): + The git command line path, if None, we use 'git' + """ + self.model_dir = model_dir + self.model_base_dir = os.path.dirname(model_dir) + self.model_repo_name = os.path.basename(model_dir) + + if not revision: + err_msg = 'a non-default value of revision cannot be empty.' + raise InvalidParameter(err_msg) + + from modelscope.hub.api import ModelScopeConfig + if auth_token: + self.auth_token = auth_token + else: + self.auth_token = ModelScopeConfig.get_token() + + git_wrapper = GitCommandWrapper() + if not git_wrapper.is_lfs_installed(): + logger.error('git lfs is not installed, please install.') + + self.git_wrapper = GitCommandWrapper(git_path) + os.makedirs(self.model_dir, exist_ok=True) + url = self._get_model_id_url(clone_from) + if os.listdir(self.model_dir): # directory not empty. + remote_url = self._get_remote_url() + remote_url = self.git_wrapper.remove_token_from_url(remote_url) + if remote_url and remote_url == url: # need not clone again + return + self.git_wrapper.clone(self.model_base_dir, self.auth_token, url, + self.model_repo_name, revision) + + if git_wrapper.is_lfs_installed(): + git_wrapper.git_lfs_install(self.model_dir) # init repo lfs + + # add user info if login + self.git_wrapper.add_user_info(self.model_base_dir, + self.model_repo_name) + if self.auth_token: # config remote with auth token + self.git_wrapper.config_auth_token(self.model_dir, self.auth_token) + + def _get_model_id_url(self, model_id): + url = f'{get_endpoint()}/{model_id}.git' + return url + + def _get_remote_url(self): + try: + remote = self.git_wrapper.get_repo_remote_url(self.model_dir) + except GitError: + remote = None + return remote + + def push(self, + commit_message: str, + local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, + remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION, + force: bool = False): + """Push local files to remote, this method will do. + git pull + git add + git commit + git push + Args: + commit_message (str): commit message + branch (Optional[str], optional): which branch to push. + force (Optional[bool]): whether to use forced-push. + """ + if commit_message is None or not isinstance(commit_message, str): + msg = 'commit_message must be provided!' + raise InvalidParameter(msg) + if not isinstance(force, bool): + raise InvalidParameter('force must be bool') + + if not self.auth_token: + raise NotLoginException('Must login to push, please login first.') + + self.git_wrapper.config_auth_token(self.model_dir, self.auth_token) + self.git_wrapper.add_user_info(self.model_base_dir, + self.model_repo_name) + + url = self.git_wrapper.get_repo_remote_url(self.model_dir) + self.git_wrapper.pull(self.model_dir) + + self.git_wrapper.add(self.model_dir, all_files=True) + self.git_wrapper.commit(self.model_dir, commit_message) + self.git_wrapper.push( + repo_dir=self.model_dir, + token=self.auth_token, + url=url, + local_branch=local_branch, + remote_branch=remote_branch) + + def tag(self, tag_name: str, message: str, ref: str = MASTER_MODEL_BRANCH): + """Create a new tag. + Args: + tag_name (str): The name of the tag + message (str): The tag message. + ref (str): The tag reference, can be commit id or branch. + """ + if tag_name is None or tag_name == '': + msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.' + raise InvalidParameter(msg) + if message is None or message == '': + msg = 'We use annotated tag, therefore message cannot None or empty.' + self.git_wrapper.tag( + repo_dir=self.model_dir, + tag_name=tag_name, + message=message, + ref=ref) + + def tag_and_push(self, + tag_name: str, + message: str, + ref: str = MASTER_MODEL_BRANCH): + """Create tag and push to remote + + Args: + tag_name (str): The name of the tag + message (str): The tag message. + ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH. + """ + self.tag(tag_name, message, ref) + + self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name) + + +class DatasetRepository: + """A local representation of the dataset (metadata) git repository. + """ + + def __init__(self, + repo_work_dir: str, + dataset_id: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION, + auth_token: Optional[str] = None, + git_path: Optional[str] = None): + """ + Instantiate a Dataset Repository object by cloning the remote ModelScope dataset repo + Args: + repo_work_dir(`str`): + The dataset repo root directory. + dataset_id: + dataset id in ModelScope from which git clone + revision(`Optional[str]`): + revision of the dataset you want to clone from. Can be any of a branch, tag or commit hash + auth_token(`Optional[str]`): + token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter + as the token is already saved when you login the first time, if None, we will use saved token. + git_path:(`Optional[str]`): + The git command line path, if None, we use 'git' + """ + self.dataset_id = dataset_id + if not repo_work_dir or not isinstance(repo_work_dir, str): + err_msg = 'dataset_work_dir must be provided!' + raise InvalidParameter(err_msg) + self.repo_work_dir = repo_work_dir.rstrip('/') + if not self.repo_work_dir: + err_msg = 'dataset_work_dir can not be root dir!' + raise InvalidParameter(err_msg) + self.repo_base_dir = os.path.dirname(self.repo_work_dir) + self.repo_name = os.path.basename(self.repo_work_dir) + + if not revision: + err_msg = 'a non-default value of revision cannot be empty.' + raise InvalidParameter(err_msg) + self.revision = revision + from modelscope.hub.api import ModelScopeConfig + if auth_token: + self.auth_token = auth_token + else: + self.auth_token = ModelScopeConfig.get_token() + + self.git_wrapper = GitCommandWrapper(git_path) + os.makedirs(self.repo_work_dir, exist_ok=True) + self.repo_url = self._get_repo_url(dataset_id=dataset_id) + + def clone(self) -> str: + # check local repo dir, directory not empty. + if os.listdir(self.repo_work_dir): + remote_url = self._get_remote_url() + remote_url = self.git_wrapper.remove_token_from_url(remote_url) + # no need clone again + if remote_url and remote_url == self.repo_url: + return '' + + logger.info('Cloning repo from {} '.format(self.repo_url)) + self.git_wrapper.clone(self.repo_base_dir, self.auth_token, + self.repo_url, self.repo_name, self.revision) + return self.repo_work_dir + + def push(self, + commit_message: str, + branch: Optional[str] = DEFAULT_DATASET_REVISION, + force: bool = False): + """Push local files to remote, this method will do. + git pull + git add + git commit + git push + Args: + commit_message (str): commit message + branch (Optional[str], optional): which branch to push. + force (Optional[bool]): whether to use forced-push. + """ + if commit_message is None or not isinstance(commit_message, str): + msg = 'commit_message must be provided!' + raise InvalidParameter(msg) + + if not isinstance(force, bool): + raise InvalidParameter('force must be bool') + + if not self.auth_token: + raise NotLoginException('Must login to push, please login first.') + + self.git_wrapper.config_auth_token(self.repo_work_dir, self.auth_token) + self.git_wrapper.add_user_info(self.repo_base_dir, self.repo_name) + + remote_url = self._get_remote_url() + remote_url = self.git_wrapper.remove_token_from_url(remote_url) + + self.git_wrapper.pull(self.repo_work_dir) + self.git_wrapper.add(self.repo_work_dir, all_files=True) + self.git_wrapper.commit(self.repo_work_dir, commit_message) + self.git_wrapper.push( + repo_dir=self.repo_work_dir, + token=self.auth_token, + url=remote_url, + local_branch=branch, + remote_branch=branch) + + def _get_repo_url(self, dataset_id): + return f'{get_endpoint()}/datasets/{dataset_id}.git' + + def _get_remote_url(self): + try: + remote = self.git_wrapper.get_repo_remote_url(self.repo_work_dir) + except GitError: + remote = None + return remote diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py new file mode 100644 index 00000000..4b81de44 --- /dev/null +++ b/modelscope/hub/snapshot_download.py @@ -0,0 +1,141 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import tempfile +from http.cookiejar import CookieJar +from pathlib import Path +from typing import Dict, Optional, Union + +from modelscope.hub.api import HubApi, ModelScopeConfig +from modelscope.utils.constant import DEFAULT_MODEL_REVISION +from modelscope.utils.logger import get_logger +from .constants import FILE_HASH +from .file_download import get_file_download_url, http_get_file +from .utils.caching import ModelFileSystemCache +from .utils.utils import (file_integrity_validation, get_cache_dir, + model_id_to_group_owner_name) + +logger = get_logger() + + +def snapshot_download(model_id: str, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cache_dir: Union[str, Path, None] = None, + user_agent: Optional[Union[Dict, str]] = None, + local_files_only: Optional[bool] = False, + cookies: Optional[CookieJar] = None) -> str: + """Download all files of a repo. + Downloads a whole snapshot of a repo's files at the specified revision. This + is useful when you want all files from a repo, because you don't know which + ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. + + An alternative would be to just clone a repo but this would require that the + user always has git and git-lfs installed, and properly configured. + Args: + model_id (`str`): + A user or an organization name and a repo name separated by a `/`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. NOTE: currently only branch and tag name is supported + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + user_agent (`str`, `dict`, *optional*): + The user-agent info in the form of a dictionary or a string. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + Returns: + Local folder path (string) of repo snapshot + + + Raises the following errors: + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `use_auth_token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + + """ + + if cache_dir is None: + cache_dir = get_cache_dir() + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + temporary_cache_dir = os.path.join(cache_dir, 'temp') + os.makedirs(temporary_cache_dir, exist_ok=True) + + group_or_owner, name = model_id_to_group_owner_name(model_id) + + cache = ModelFileSystemCache(cache_dir, group_or_owner, name) + if local_files_only: + if len(cache.cached_files) == 0: + raise ValueError( + 'Cannot find the requested files in the cached path and outgoing' + ' traffic has been disabled. To enable model look-ups and downloads' + " online, set 'local_files_only' to False.") + logger.warn('We can not confirm the cached file is for revision: %s' + % revision) + return cache.get_root_location( + ) # we can not confirm the cached file is for snapshot 'revision' + else: + # make headers + headers = { + 'user-agent': + ModelScopeConfig.get_user_agent(user_agent=user_agent, ) + } + _api = HubApi() + if cookies is None: + cookies = ModelScopeConfig.get_cookies() + revision = _api.get_valid_revision( + model_id, revision=revision, cookies=cookies) + + snapshot_header = headers if 'CI_TEST' in os.environ else { + **headers, + **{ + 'Snapshot': 'True' + } + } + model_files = _api.get_model_files( + model_id=model_id, + revision=revision, + recursive=True, + use_cookies=False if cookies is None else cookies, + headers=snapshot_header, + ) + + with tempfile.TemporaryDirectory( + dir=temporary_cache_dir) as temp_cache_dir: + for model_file in model_files: + if model_file['Type'] == 'tree': + continue + # check model_file is exist in cache, if existed, skip download, otherwise download + if cache.exists(model_file): + file_name = os.path.basename(model_file['Name']) + logger.info( + f'File {file_name} already in cache, skip downloading!' + ) + continue + + # get download url + url = get_file_download_url( + model_id=model_id, + file_path=model_file['Path'], + revision=revision) + + # First download to /tmp + http_get_file( + url=url, + local_dir=temp_cache_dir, + file_name=model_file['Name'], + headers=headers, + cookies=cookies) + # check file integrity + temp_file = os.path.join(temp_cache_dir, model_file['Name']) + if FILE_HASH in model_file: + file_integrity_validation(temp_file, model_file[FILE_HASH]) + # put file to cache + cache.put_file(model_file, temp_file) + + return os.path.join(cache.get_root_location()) diff --git a/modelscope/hub/utils/__init__.py b/modelscope/hub/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/hub/utils/caching.py b/modelscope/hub/utils/caching.py new file mode 100644 index 00000000..1acd2e84 --- /dev/null +++ b/modelscope/hub/utils/caching.py @@ -0,0 +1,298 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import hashlib +import os +import pickle +import tempfile +from shutil import move, rmtree + +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class FileSystemCache(object): + KEY_FILE_NAME = '.msc' + """Local file cache. + """ + + def __init__( + self, + cache_root_location: str, + **kwargs, + ): + """ + Parameters + ---------- + cache_location: str + The root location to store files. + """ + os.makedirs(cache_root_location, exist_ok=True) + self.cache_root_location = cache_root_location + self.load_cache() + + def get_root_location(self): + return self.cache_root_location + + def load_cache(self): + """Read set of stored blocks from file + Args: + owner(`str`): individual or group username at modelscope, can be empty for official models + name(`str`): name of the model + Returns: + The model details information. + Raises: + NotExistError: If the model is not exist, will throw NotExistError + TODO: Error based error code. + + model_id = {owner}/{name} + + """ + self.cached_files = [] + cache_keys_file_path = os.path.join(self.cache_root_location, + FileSystemCache.KEY_FILE_NAME) + if os.path.exists(cache_keys_file_path): + with open(cache_keys_file_path, 'rb') as f: + self.cached_files = pickle.load(f) + + def save_cached_files(self): + """Save cache metadata.""" + # save new meta to tmp and move to KEY_FILE_NAME + cache_keys_file_path = os.path.join(self.cache_root_location, + FileSystemCache.KEY_FILE_NAME) + # TODO: Sync file write + fd, fn = tempfile.mkstemp() + with open(fd, 'wb') as f: + pickle.dump(self.cached_files, f) + move(fn, cache_keys_file_path) + + def get_file(self, key): + """Check the key is in the cache, if exist, return the file, otherwise return None. + Args: + key(`str`): The cache key. + Returns: + If file exist, return the cached file location, otherwise None. + Raises: + None + + model_id = {owner}/{name} + + """ + pass + + def put_file(self, key, location): + """Put file to the cache, + Args: + key(`str`): The cache key + location(`str`): Location of the file, we will move the file to cache. + Returns: + The cached file path of the file. + Raises: + None + + model_id = {owner}/{name} + + """ + pass + + def remove_key(self, key): + """Remove cache key in index, The file is removed manually + + Args: + key (dict): The cache key. + """ + if key in self.cached_files: + self.cached_files.remove(key) + self.save_cached_files() + + def exists(self, key): + for cache_file in self.cached_files: + if cache_file == key: + return True + + return False + + def clear_cache(self): + """Remove all files and metadat from the cache + + In the case of multiple cache locations, this clears only the last one, + which is assumed to be the read/write one. + """ + rmtree(self.cache_root_location) + self.load_cache() + + def hash_name(self, key): + return hashlib.sha256(key.encode()).hexdigest() + + +class ModelFileSystemCache(FileSystemCache): + """Local cache file layout + cache_root/owner/model_name/|individual cached files + |.mk: file, The cache index file + Save only one version for each file. + """ + + def __init__(self, cache_root, owner, name): + """Put file to the cache + Args: + cache_root(`str`): The modelscope local cache root(default: ~/.modelscope/cache/models/) + owner(`str`): The model owner. + name('str'): The name of the model + branch('str'): The branch of model + tag('str'): The tag of model + Returns: + Raises: + None + + model_id = {owner}/{name} + + """ + super().__init__(os.path.join(cache_root, owner, name)) + + def get_file_by_path(self, file_path): + """Retrieve the cache if there is file match the path. + Args: + file_path (str): The file path in the model. + Returns: + path: the full path of the file. + """ + for cached_file in self.cached_files: + if file_path == cached_file['Path']: + cached_file_path = os.path.join(self.cache_root_location, + cached_file['Path']) + if os.path.exists(cached_file_path): + return cached_file_path + else: + self.remove_key(cached_file) + + return None + + def get_file_by_path_and_commit_id(self, file_path, commit_id): + """Retrieve the cache if there is file match the path. + Args: + file_path (str): The file path in the model. + commit_id (str): The commit id of the file + Returns: + path: the full path of the file. + """ + for cached_file in self.cached_files: + if file_path == cached_file['Path'] and \ + (cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])): + cached_file_path = os.path.join(self.cache_root_location, + cached_file['Path']) + if os.path.exists(cached_file_path): + return cached_file_path + else: + self.remove_key(cached_file) + + return None + + def get_file_by_info(self, model_file_info): + """Check if exist cache file. + + Args: + model_file_info (ModelFileInfo): The file information of the file. + + Returns: + _type_: _description_ + """ + cache_key = self.__get_cache_key(model_file_info) + for cached_file in self.cached_files: + if cached_file == cache_key: + orig_path = os.path.join(self.cache_root_location, + cached_file['Path']) + if os.path.exists(orig_path): + return orig_path + else: + self.remove_key(cached_file) + break + + return None + + def __get_cache_key(self, model_file_info): + cache_key = { + 'Path': model_file_info['Path'], + 'Revision': model_file_info['Revision'], # commit id + } + return cache_key + + def exists(self, model_file_info): + """Check the file is cached or not. + + Args: + model_file_info (CachedFileInfo): The cached file info + + Returns: + bool: If exists return True otherwise False + """ + key = self.__get_cache_key(model_file_info) + is_exists = False + for cached_key in self.cached_files: + if cached_key['Path'] == key['Path'] and ( + cached_key['Revision'].startswith(key['Revision']) + or key['Revision'].startswith(cached_key['Revision'])): + is_exists = True + break + file_path = os.path.join(self.cache_root_location, + model_file_info['Path']) + if is_exists: + if os.path.exists(file_path): + return True + else: + self.remove_key( + model_file_info) # sameone may manual delete the file + return False + + def remove_if_exists(self, model_file_info): + """We in cache, remove it. + + Args: + model_file_info (ModelFileInfo): The model file information from server. + """ + for cached_file in self.cached_files: + if cached_file['Path'] == model_file_info['Path']: + self.remove_key(cached_file) + file_path = os.path.join(self.cache_root_location, + cached_file['Path']) + if os.path.exists(file_path): + os.remove(file_path) + break + + def put_file(self, model_file_info, model_file_location): + """Put model on model_file_location to cache, the model first download to /tmp, and move to cache. + + Args: + model_file_info (str): The file description returned by get_model_files + sample: + { + "CommitMessage": "add model\n", + "CommittedDate": 1654857567, + "CommitterName": "mulin.lyh", + "IsLFS": false, + "Mode": "100644", + "Name": "resnet18.pth", + "Path": "resnet18.pth", + "Revision": "09b68012b27de0048ba74003690a890af7aff192", + "Size": 46827520, + "Type": "blob" + } + model_file_location (str): The location of the temporary file. + Raises: + NotImplementedError: _description_ + + Returns: + str: The location of the cached file. + """ + self.remove_if_exists(model_file_info) # backup old revision + cache_key = self.__get_cache_key(model_file_info) + cache_full_path = os.path.join( + self.cache_root_location, + cache_key['Path']) # Branch and Tag do not have same name. + cache_file_dir = os.path.dirname(cache_full_path) + if not os.path.exists(cache_file_dir): + os.makedirs(cache_file_dir, exist_ok=True) + # We can't make operation transaction + move(model_file_location, cache_full_path) + self.cached_files.append(cache_key) + self.save_cached_files() + return cache_full_path diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py new file mode 100644 index 00000000..61d560fa --- /dev/null +++ b/modelscope/hub/utils/utils.py @@ -0,0 +1,102 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import hashlib +import os +from datetime import datetime +from typing import Optional + +import requests + +from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, + DEFAULT_MODELSCOPE_GROUP, + MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, + MODELSCOPE_URL_SCHEME) +from modelscope.hub.errors import FileIntegrityError +from modelscope.utils.file_utils import get_default_cache_dir +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def model_id_to_group_owner_name(model_id): + if MODEL_ID_SEPARATOR in model_id: + group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0] + name = model_id.split(MODEL_ID_SEPARATOR)[1] + else: + group_or_owner = DEFAULT_MODELSCOPE_GROUP + name = model_id + return group_or_owner, name + + +def get_cache_dir(model_id: Optional[str] = None): + """ + cache dir precedence: + function parameter > enviroment > ~/.cache/modelscope/hub + """ + default_cache_dir = get_default_cache_dir() + base_path = os.getenv('MODELSCOPE_CACHE', + os.path.join(default_cache_dir, 'hub')) + return base_path if model_id is None else os.path.join( + base_path, model_id + '/') + + +def get_release_datetime(): + if MODELSCOPE_SDK_DEBUG in os.environ: + rt = int(round(datetime.now().timestamp())) + else: + from modelscope import version + rt = int( + round( + datetime.strptime(version.__release_datetime__, + '%Y-%m-%d %H:%M:%S').timestamp())) + return rt + + +def get_endpoint(): + modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', + DEFAULT_MODELSCOPE_DOMAIN) + return MODELSCOPE_URL_SCHEME + modelscope_domain + + +def compute_hash(file_path): + BUFFER_SIZE = 1024 * 64 # 64k buffer size + sha256_hash = hashlib.sha256() + with open(file_path, 'rb') as f: + while True: + data = f.read(BUFFER_SIZE) + if not data: + break + sha256_hash.update(data) + return sha256_hash.hexdigest() + + +def file_integrity_validation(file_path, expected_sha256): + """Validate the file hash is expected, if not, delete the file + + Args: + file_path (str): The file to validate + expected_sha256 (str): The expected sha256 hash + + Raises: + FileIntegrityError: If file_path hash is not expected. + + """ + file_sha256 = compute_hash(file_path) + if not file_sha256 == expected_sha256: + os.remove(file_path) + msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path + logger.error(msg) + raise FileIntegrityError(msg) + + +def create_library_statistics(method: str, name: str, cn_name: Optional[str]): + try: + from modelscope.hub.api import ModelScopeConfig + path = f'{get_endpoint()}/api/v1/statistics/library' + headers = {'user-agent': ModelScopeConfig.get_user_agent()} + params = {'Method': method, 'Name': name, 'CnName': cn_name} + r = requests.post(path, params=params, headers=headers) + r.raise_for_status() + except Exception: + pass + return diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py new file mode 100644 index 00000000..c7c3e729 --- /dev/null +++ b/modelscope/metainfo.py @@ -0,0 +1,505 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + + +class Models(object): + """ Names for different models. + + Holds the standard model name to use for identifying different model. + This should be used to register models. + + Model name should only contain model info but not task info. + """ + # tinynas models + tinynas_detection = 'tinynas-detection' + tinynas_damoyolo = 'tinynas-damoyolo' + + # vision models + detection = 'detection' + realtime_object_detection = 'realtime-object-detection' + realtime_video_object_detection = 'realtime-video-object-detection' + scrfd = 'scrfd' + classification_model = 'ClassificationModel' + nafnet = 'nafnet' + csrnet = 'csrnet' + cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' + gpen = 'gpen' + product_retrieval_embedding = 'product-retrieval-embedding' + body_2d_keypoints = 'body-2d-keypoints' + body_3d_keypoints = 'body-3d-keypoints' + crowd_counting = 'HRNetCrowdCounting' + face_2d_keypoints = 'face-2d-keypoints' + panoptic_segmentation = 'swinL-panoptic-segmentation' + image_reid_person = 'passvitb' + image_inpainting = 'FFTInpainting' + video_summarization = 'pgl-video-summarization' + swinL_semantic_segmentation = 'swinL-semantic-segmentation' + vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' + text_driven_segmentation = 'text-driven-segmentation' + resnet50_bert = 'resnet50-bert' + referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' + fer = 'fer' + retinaface = 'retinaface' + shop_segmentation = 'shop-segmentation' + mogface = 'mogface' + mtcnn = 'mtcnn' + ulfd = 'ulfd' + video_inpainting = 'video-inpainting' + human_wholebody_keypoint = 'human-wholebody-keypoint' + hand_static = 'hand-static' + face_human_hand_detection = 'face-human-hand-detection' + face_emotion = 'face-emotion' + product_segmentation = 'product-segmentation' + image_body_reshaping = 'image-body-reshaping' + + # EasyCV models + yolox = 'YOLOX' + segformer = 'Segformer' + hand_2d_keypoints = 'HRNet-Hand2D-Keypoints' + image_object_detection_auto = 'image-object-detection-auto' + + # nlp models + bert = 'bert' + palm = 'palm-v2' + structbert = 'structbert' + deberta_v2 = 'deberta_v2' + veco = 'veco' + translation = 'csanmt-translation' + space_dst = 'space-dst' + space_intent = 'space-intent' + space_modeling = 'space-modeling' + space_T_en = 'space-T-en' + space_T_cn = 'space-T-cn' + tcrf = 'transformer-crf' + tcrf_wseg = 'transformer-crf-for-word-segmentation' + transformer_softmax = 'transformer-softmax' + lcrf = 'lstm-crf' + lcrf_wseg = 'lstm-crf-for-word-segmentation' + gcnncrf = 'gcnn-crf' + bart = 'bart' + gpt3 = 'gpt3' + gpt_neo = 'gpt-neo' + plug = 'plug' + bert_for_ds = 'bert-for-document-segmentation' + ponet = 'ponet' + T5 = 'T5' + mglm = 'mglm' + bloom = 'bloom' + + # audio models + sambert_hifigan = 'sambert-hifigan' + speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' + kws_kwsbp = 'kws-kwsbp' + generic_asr = 'generic-asr' + + # multi-modal models + ofa = 'ofa' + clip = 'clip-multi-modal-embedding' + gemm = 'gemm-generative-multi-modal' + mplug = 'mplug' + diffusion = 'diffusion-text-to-image-synthesis' + multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis' + team = 'team-multi-modal-similarity' + video_clip = 'video-clip-multi-modal-embedding' + + # science models + unifold = 'unifold' + unifold_symmetry = 'unifold-symmetry' + + +class TaskModels(object): + # nlp task + text_classification = 'text-classification' + token_classification = 'token-classification' + information_extraction = 'information-extraction' + fill_mask = 'fill-mask' + feature_extraction = 'feature-extraction' + text_generation = 'text-generation' + + +class Heads(object): + # nlp heads + + # text cls + text_classification = 'text-classification' + # fill mask + fill_mask = 'fill-mask' + bert_mlm = 'bert-mlm' + roberta_mlm = 'roberta-mlm' + # token cls + token_classification = 'token-classification' + # extraction + information_extraction = 'information-extraction' + # text gen + text_generation = 'text-generation' + + +class Pipelines(object): + """ Names for different pipelines. + + Holds the standard pipline name to use for identifying different pipeline. + This should be used to register pipelines. + + For pipeline which support different models and implements the common function, we + should use task name for this pipeline. + For pipeline which suuport only one model, we should use ${Model}-${Task} as its name. + """ + # vision tasks + portrait_matting = 'unet-image-matting' + image_denoise = 'nafnet-image-denoise' + person_image_cartoon = 'unet-person-image-cartoon' + ocr_detection = 'resnet18-ocr-detection' + action_recognition = 'TAdaConv_action-recognition' + animal_recognition = 'resnet101-animal-recognition' + general_recognition = 'resnet101-general-recognition' + cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + hicossl_video_embedding = 'hicossl-s3dg-video_embedding' + body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' + body_3d_keypoints = 'canonical_body-3d-keypoints_video' + hand_2d_keypoints = 'hrnetv2w18_hand-2d-keypoints_image' + human_detection = 'resnet18-human-detection' + object_detection = 'vit-object-detection' + easycv_detection = 'easycv-detection' + easycv_segmentation = 'easycv-segmentation' + face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment' + salient_detection = 'u2net-salient-detection' + image_classification = 'image-classification' + face_detection = 'resnet-face-detection-scrfd10gkps' + card_detection = 'resnet-card-detection-scrfd34gkps' + ulfd_face_detection = 'manual-face-detection-ulfd' + facial_expression_recognition = 'vgg19-facial-expression-recognition-fer' + retina_face_detection = 'resnet50-face-detection-retinaface' + mog_face_detection = 'resnet101-face-detection-cvpr22papermogface' + mtcnn_face_detection = 'manual-face-detection-mtcnn' + live_category = 'live-category' + general_image_classification = 'vit-base_image-classification_ImageNet-labels' + daily_image_classification = 'vit-base_image-classification_Dailylife-labels' + image_color_enhance = 'csrnet-image-color-enhance' + virtual_try_on = 'virtual-try-on' + image_colorization = 'unet-image-colorization' + image_style_transfer = 'AAMS-style-transfer' + image_super_resolution = 'rrdb-image-super-resolution' + face_image_generation = 'gan-face-image-generation' + product_retrieval_embedding = 'resnet50-product-retrieval-embedding' + realtime_object_detection = 'cspnet_realtime-object-detection_yolox' + realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo' + face_recognition = 'ir101-face-recognition-cfglint' + image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' + image2image_translation = 'image-to-image-translation' + live_category = 'live-category' + video_category = 'video-category' + ocr_recognition = 'convnextTiny-ocr-recognition' + image_portrait_enhancement = 'gpen-image-portrait-enhancement' + image_to_image_generation = 'image-to-image-generation' + image_object_detection_auto = 'yolox_image-object-detection-auto' + skin_retouching = 'unet-skin-retouching' + tinynas_classification = 'tinynas-classification' + tinynas_detection = 'tinynas-detection' + crowd_counting = 'hrnet-crowd-counting' + action_detection = 'ResNetC3D-action-detection' + video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' + image_panoptic_segmentation = 'image-panoptic-segmentation' + video_summarization = 'googlenet_pgl_video_summarization' + image_semantic_segmentation = 'image-semantic-segmentation' + image_reid_person = 'passvitb-image-reid-person' + image_inpainting = 'fft-inpainting' + text_driven_segmentation = 'text-driven-segmentation' + movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' + shop_segmentation = 'shop-segmentation' + video_inpainting = 'video-inpainting' + human_wholebody_keypoint = 'hrnetw48_human-wholebody-keypoint_image' + pst_action_recognition = 'patchshift-action-recognition' + hand_static = 'hand-static' + face_human_hand_detection = 'face-human-hand-detection' + face_emotion = 'face-emotion' + product_segmentation = 'product-segmentation' + image_body_reshaping = 'flow-based-body-reshaping' + referring_video_object_segmentation = 'referring-video-object-segmentation' + + # nlp tasks + automatic_post_editing = 'automatic-post-editing' + translation_quality_estimation = 'translation-quality-estimation' + domain_classification = 'domain-classification' + sentence_similarity = 'sentence-similarity' + word_segmentation = 'word-segmentation' + multilingual_word_segmentation = 'multilingual-word-segmentation' + word_segmentation_thai = 'word-segmentation-thai' + part_of_speech = 'part-of-speech' + named_entity_recognition = 'named-entity-recognition' + named_entity_recognition_thai = 'named-entity-recognition-thai' + named_entity_recognition_viet = 'named-entity-recognition-viet' + text_generation = 'text-generation' + text2text_generation = 'text2text-generation' + sentiment_analysis = 'sentiment-analysis' + sentiment_classification = 'sentiment-classification' + text_classification = 'text-classification' + fill_mask = 'fill-mask' + fill_mask_ponet = 'fill-mask-ponet' + csanmt_translation = 'csanmt-translation' + nli = 'nli' + dialog_intent_prediction = 'dialog-intent-prediction' + dialog_modeling = 'dialog-modeling' + dialog_state_tracking = 'dialog-state-tracking' + zero_shot_classification = 'zero-shot-classification' + text_error_correction = 'text-error-correction' + plug_generation = 'plug-generation' + gpt3_generation = 'gpt3-generation' + faq_question_answering = 'faq-question-answering' + conversational_text_to_sql = 'conversational-text-to-sql' + table_question_answering_pipeline = 'table-question-answering-pipeline' + sentence_embedding = 'sentence-embedding' + text_ranking = 'text-ranking' + relation_extraction = 'relation-extraction' + document_segmentation = 'document-segmentation' + feature_extraction = 'feature-extraction' + mglm_text_summarization = 'mglm-text-summarization' + translation_en_to_de = 'translation_en_to_de' # keep it underscore + translation_en_to_ro = 'translation_en_to_ro' # keep it underscore + translation_en_to_fr = 'translation_en_to_fr' # keep it underscore + token_classification = 'token-classification' + + # audio tasks + sambert_hifigan_tts = 'sambert-hifigan-tts' + speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' + speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' + kws_kwsbp = 'kws-kwsbp' + asr_inference = 'asr-inference' + + # multi-modal tasks + image_captioning = 'image-captioning' + multi_modal_embedding = 'multi-modal-embedding' + generative_multi_modal_embedding = 'generative-multi-modal-embedding' + visual_question_answering = 'visual-question-answering' + visual_grounding = 'visual-grounding' + visual_entailment = 'visual-entailment' + multi_modal_similarity = 'multi-modal-similarity' + text_to_image_synthesis = 'text-to-image-synthesis' + video_multi_modal_embedding = 'video-multi-modal-embedding' + image_text_retrieval = 'image-text-retrieval' + ofa_ocr_recognition = 'ofa-ocr-recognition' + + # science tasks + protein_structure = 'unifold-protein-structure' + + +class Trainers(object): + """ Names for different trainer. + + Holds the standard trainer name to use for identifying different trainer. + This should be used to register trainers. + + For a general Trainer, you can use EpochBasedTrainer. + For a model specific Trainer, you can use ${ModelName}-${Task}-trainer. + """ + + default = 'trainer' + easycv = 'easycv' + + # multi-modal trainers + clip_multi_modal_embedding = 'clip-multi-modal-embedding' + ofa = 'ofa' + mplug = 'mplug' + + # cv trainers + image_instance_segmentation = 'image-instance-segmentation' + image_portrait_enhancement = 'image-portrait-enhancement' + video_summarization = 'video-summarization' + movie_scene_segmentation = 'movie-scene-segmentation' + face_detection_scrfd = 'face-detection-scrfd' + card_detection_scrfd = 'card-detection-scrfd' + image_inpainting = 'image-inpainting' + referring_video_object_segmentation = 'referring-video-object-segmentation' + image_classification_team = 'image-classification-team' + + # nlp trainers + bert_sentiment_analysis = 'bert-sentiment-analysis' + dialog_modeling_trainer = 'dialog-modeling-trainer' + dialog_intent_trainer = 'dialog-intent-trainer' + nlp_base_trainer = 'nlp-base-trainer' + nlp_veco_trainer = 'nlp-veco-trainer' + nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' + text_generation_trainer = 'text-generation-trainer' + + # audio trainers + speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' + + +class Preprocessors(object): + """ Names for different preprocessor. + + Holds the standard preprocessor name to use for identifying different preprocessor. + This should be used to register preprocessors. + + For a general preprocessor, just use the function name as preprocessor name such as + resize-image, random-crop + For a model-specific preprocessor, use ${modelname}-${fuction} + """ + + # cv preprocessor + load_image = 'load-image' + image_denoie_preprocessor = 'image-denoise-preprocessor' + image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' + image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' + image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' + video_summarization_preprocessor = 'video-summarization-preprocessor' + movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor' + + # nlp preprocessor + sen_sim_tokenizer = 'sen-sim-tokenizer' + cross_encoder_tokenizer = 'cross-encoder-tokenizer' + bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' + text_gen_tokenizer = 'text-gen-tokenizer' + text2text_gen_preprocessor = 'text2text-gen-preprocessor' + text_gen_jieba_tokenizer = 'text-gen-jieba-tokenizer' + text2text_translate_preprocessor = 'text2text-translate-preprocessor' + token_cls_tokenizer = 'token-cls-tokenizer' + ner_tokenizer = 'ner-tokenizer' + thai_ner_tokenizer = 'thai-ner-tokenizer' + viet_ner_tokenizer = 'viet-ner-tokenizer' + nli_tokenizer = 'nli-tokenizer' + sen_cls_tokenizer = 'sen-cls-tokenizer' + dialog_intent_preprocessor = 'dialog-intent-preprocessor' + dialog_modeling_preprocessor = 'dialog-modeling-preprocessor' + dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor' + sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' + zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' + text_error_correction = 'text-error-correction' + sentence_embedding = 'sentence-embedding' + text_ranking = 'text-ranking' + sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' + word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' + thai_wseg_tokenizer = 'thai-wseg-tokenizer' + fill_mask = 'fill-mask' + fill_mask_ponet = 'fill-mask-ponet' + faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' + conversational_text_to_sql = 'conversational-text-to-sql' + table_question_answering_preprocessor = 'table-question-answering-preprocessor' + re_tokenizer = 're-tokenizer' + document_segmentation = 'document-segmentation' + feature_extraction = 'feature-extraction' + mglm_summarization = 'mglm-summarization' + sentence_piece = 'sentence-piece' + + # audio preprocessor + linear_aec_fbank = 'linear-aec-fbank' + text_to_tacotron_symbols = 'text-to-tacotron-symbols' + wav_to_lists = 'wav-to-lists' + wav_to_scp = 'wav-to-scp' + + # multi-modal preprocessor + ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' + clip_preprocessor = 'clip-preprocessor' + mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' + + # science preprocessor + unifold_preprocessor = 'unifold-preprocessor' + + +class Metrics(object): + """ Names for different metrics. + """ + + # accuracy + accuracy = 'accuracy' + multi_average_precision = 'mAP' + audio_noise_metric = 'audio-noise-metric' + + # text gen + BLEU = 'bleu' + + # metrics for image denoise task + image_denoise_metric = 'image-denoise-metric' + + # metric for image instance segmentation task + image_ins_seg_coco_metric = 'image-ins-seg-coco-metric' + # metrics for sequence classification task + seq_cls_metric = 'seq-cls-metric' + # metrics for token-classification task + token_cls_metric = 'token-cls-metric' + # metrics for text-generation task + text_gen_metric = 'text-gen-metric' + # metrics for image-color-enhance task + image_color_enhance_metric = 'image-color-enhance-metric' + # metrics for image-portrait-enhancement task + image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' + video_summarization_metric = 'video-summarization-metric' + # metric for movie-scene-segmentation task + movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' + # metric for inpainting task + image_inpainting_metric = 'image-inpainting-metric' + # metric for ocr + NED = 'ned' + # metric for cross-modal retrieval + inbatch_recall = 'inbatch_recall' + # metric for referring-video-object-segmentation task + referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric' + + +class Optimizers(object): + """ Names for different OPTIMIZER. + + Holds the standard optimizer name to use for identifying different optimizer. + This should be used to register optimizer. + """ + + default = 'optimizer' + + SGD = 'SGD' + + +class Hooks(object): + """ Names for different hooks. + + All kinds of hooks are defined here + """ + # lr + LrSchedulerHook = 'LrSchedulerHook' + PlateauLrSchedulerHook = 'PlateauLrSchedulerHook' + NoneLrSchedulerHook = 'NoneLrSchedulerHook' + + # optimizer + OptimizerHook = 'OptimizerHook' + TorchAMPOptimizerHook = 'TorchAMPOptimizerHook' + ApexAMPOptimizerHook = 'ApexAMPOptimizerHook' + NoneOptimizerHook = 'NoneOptimizerHook' + + # checkpoint + CheckpointHook = 'CheckpointHook' + BestCkptSaverHook = 'BestCkptSaverHook' + + # logger + TextLoggerHook = 'TextLoggerHook' + TensorboardHook = 'TensorboardHook' + + IterTimerHook = 'IterTimerHook' + EvaluationHook = 'EvaluationHook' + + # Compression + SparsityHook = 'SparsityHook' + + # CLIP logit_scale clamp + ClipClampLogitScaleHook = 'ClipClampLogitScaleHook' + + +class LR_Schedulers(object): + """learning rate scheduler is defined here + + """ + LinearWarmup = 'LinearWarmup' + ConstantWarmup = 'ConstantWarmup' + ExponentialWarmup = 'ExponentialWarmup' + + +class Datasets(object): + """ Names for different datasets. + """ + ClsDataset = 'ClsDataset' + Face2dKeypointsDataset = 'FaceKeypointDataset' + HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' + HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' + SegDataset = 'SegDataset' + DetDataset = 'DetDataset' + DetImagesMixDataset = 'DetImagesMixDataset' + PairedDataset = 'PairedDataset' diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py new file mode 100644 index 00000000..f106f054 --- /dev/null +++ b/modelscope/metrics/__init__.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .audio_noise_metric import AudioNoiseMetric + from .base import Metric + from .builder import METRICS, build_metric, task_default_metrics + from .image_color_enhance_metric import ImageColorEnhanceMetric + from .image_denoise_metric import ImageDenoiseMetric + from .image_instance_segmentation_metric import \ + ImageInstanceSegmentationCOCOMetric + from .image_portrait_enhancement_metric import ImagePortraitEnhancementMetric + from .sequence_classification_metric import SequenceClassificationMetric + from .text_generation_metric import TextGenerationMetric + from .token_classification_metric import TokenClassificationMetric + from .video_summarization_metric import VideoSummarizationMetric + from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric + from .accuracy_metric import AccuracyMetric + from .bleu_metric import BleuMetric + from .image_inpainting_metric import ImageInpaintingMetric + from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric + +else: + _import_structure = { + 'audio_noise_metric': ['AudioNoiseMetric'], + 'base': ['Metric'], + 'builder': ['METRICS', 'build_metric', 'task_default_metrics'], + 'image_color_enhance_metric': ['ImageColorEnhanceMetric'], + 'image_denoise_metric': ['ImageDenoiseMetric'], + 'image_instance_segmentation_metric': + ['ImageInstanceSegmentationCOCOMetric'], + 'image_portrait_enhancement_metric': + ['ImagePortraitEnhancementMetric'], + 'sequence_classification_metric': ['SequenceClassificationMetric'], + 'text_generation_metric': ['TextGenerationMetric'], + 'token_classification_metric': ['TokenClassificationMetric'], + 'video_summarization_metric': ['VideoSummarizationMetric'], + 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], + 'image_inpainting_metric': ['ImageInpaintingMetric'], + 'accuracy_metric': ['AccuracyMetric'], + 'bleu_metric': ['BleuMetric'], + 'referring_video_object_segmentation_metric': + ['ReferringVideoObjectSegmentationMetric'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/metrics/accuracy_metric.py b/modelscope/metrics/accuracy_metric.py new file mode 100644 index 00000000..fe040177 --- /dev/null +++ b/modelscope/metrics/accuracy_metric.py @@ -0,0 +1,53 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.chinese_utils import remove_space_between_chinese_chars +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) +class AccuracyMetric(Metric): + """The metric computation class for classification classes. + + This metric class calculates accuracy for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preds = [] + self.labels = [] + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = None + for key in [ + OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, + OutputKeys.LABEL, OutputKeys.LABELS, OutputKeys.SCORES + ]: + if key in outputs and outputs[key] is not None: + eval_results = outputs[key] + break + assert type(ground_truths) == type(eval_results) + for truth in ground_truths: + self.labels.append(truth) + for result in eval_results: + if isinstance(truth, str): + self.preds.append(remove_space_between_chinese_chars(result)) + else: + self.preds.append(result) + + def evaluate(self): + assert len(self.preds) == len(self.labels) + return { + MetricKeys.ACCURACY: (np.asarray([ + pred == ref for pred, ref in zip(self.preds, self.labels) + ])).mean().item() + } diff --git a/modelscope/metrics/audio_noise_metric.py b/modelscope/metrics/audio_noise_metric.py new file mode 100644 index 00000000..8555e95b --- /dev/null +++ b/modelscope/metrics/audio_noise_metric.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +from modelscope.metainfo import Metrics +from modelscope.metrics.base import Metric +from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.utils.registry import default_group + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.audio_noise_metric) +class AudioNoiseMetric(Metric): + """ + The metric computation class for acoustic noise suppression task. + """ + + def __init__(self): + self.loss = [] + self.amp_loss = [] + self.phase_loss = [] + self.sisnr = [] + + def add(self, outputs: Dict, inputs: Dict): + self.loss.append(outputs['loss'].data.cpu()) + self.amp_loss.append(outputs['amp_loss'].data.cpu()) + self.phase_loss.append(outputs['phase_loss'].data.cpu()) + self.sisnr.append(outputs['sisnr'].data.cpu()) + + def evaluate(self): + avg_loss = sum(self.loss) / len(self.loss) + avg_sisnr = sum(self.sisnr) / len(self.sisnr) + avg_amp = sum(self.amp_loss) / len(self.amp_loss) + avg_phase = sum(self.phase_loss) / len(self.phase_loss) + total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr + return { + 'total_loss': total_loss.item(), + # model use opposite number of sisnr as a calculation shortcut. + # revert it in evaluation result + 'avg_sisnr': -avg_sisnr.item(), + MetricKeys.AVERAGE_LOSS: avg_loss.item() + } diff --git a/modelscope/metrics/base.py b/modelscope/metrics/base.py new file mode 100644 index 00000000..955946b5 --- /dev/null +++ b/modelscope/metrics/base.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from abc import ABC, abstractmethod +from typing import Dict + + +class Metric(ABC): + """The metric base class for computing metrics. + + The subclasses can either compute a single metric like 'accuracy', or compute the + complex metrics for a specific task with or without other Metric subclasses. + """ + + def __init__(self, *args, **kwargs): + pass + + @abstractmethod + def add(self, outputs: Dict, inputs: Dict): + """ Append logits and labels within an eval loop. + + Will be called after every batch finished to gather the model predictions and the labels. + + Args: + outputs: The model prediction outputs. + inputs: The mini batch inputs from the dataloader. + + Returns: None + + """ + pass + + @abstractmethod + def evaluate(self): + """Evaluate the metrics after the eval finished. + + Will be called after the whole validation finished. + + Returns: The actual metric dict with standard names. + + """ + pass diff --git a/modelscope/metrics/bleu_metric.py b/modelscope/metrics/bleu_metric.py new file mode 100644 index 00000000..7c134b6a --- /dev/null +++ b/modelscope/metrics/bleu_metric.py @@ -0,0 +1,42 @@ +from itertools import zip_longest +from typing import Dict + +import sacrebleu + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + +EVAL_BLEU_ORDER = 4 + + +@METRICS.register_module(group_key=default_group, module_name=Metrics.BLEU) +class BleuMetric(Metric): + """The metric computation bleu for text generation classes. + + This metric class calculates accuracy for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False) + self.hyp_name = kwargs.get('hyp_name', 'hyp') + self.ref_name = kwargs.get('ref_name', 'ref') + self.refs = list() + self.hyps = list() + + def add(self, outputs: Dict, inputs: Dict): + self.refs.extend(inputs[self.ref_name]) + self.hyps.extend(outputs[self.hyp_name]) + + def evaluate(self): + if self.eval_tokenized_bleu: + bleu = sacrebleu.corpus_bleu( + self.hyps, list(zip_longest(*self.refs)), tokenize='none') + else: + bleu = sacrebleu.corpus_bleu(self.hyps, + list(zip_longest(*self.refs))) + return { + MetricKeys.BLEU_4: bleu.score, + } diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py new file mode 100644 index 00000000..03d4c324 --- /dev/null +++ b/modelscope/metrics/builder.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict, Mapping, Union + +from modelscope.metainfo import Metrics +from modelscope.utils.config import ConfigDict +from modelscope.utils.constant import Tasks +from modelscope.utils.registry import Registry, build_from_cfg, default_group + +METRICS = Registry('metrics') + + +class MetricKeys(object): + ACCURACY = 'accuracy' + F1 = 'f1' + PRECISION = 'precision' + RECALL = 'recall' + PSNR = 'psnr' + SSIM = 'ssim' + AVERAGE_LOSS = 'avg_loss' + FScore = 'fscore' + FID = 'fid' + BLEU_1 = 'bleu-1' + BLEU_4 = 'bleu-4' + ROUGE_1 = 'rouge-1' + ROUGE_L = 'rouge-l' + NED = 'ned' # ocr metric + mAP = 'mAP' + BatchAcc = 'inbatch_t2i_recall_at_1' + + +task_default_metrics = { + Tasks.image_segmentation: [Metrics.image_ins_seg_coco_metric], + Tasks.sentence_similarity: [Metrics.seq_cls_metric], + Tasks.nli: [Metrics.seq_cls_metric], + Tasks.sentiment_classification: [Metrics.seq_cls_metric], + Tasks.token_classification: [Metrics.token_cls_metric], + Tasks.text_generation: [Metrics.text_gen_metric], + Tasks.text_classification: [Metrics.seq_cls_metric], + Tasks.image_denoising: [Metrics.image_denoise_metric], + Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], + Tasks.image_portrait_enhancement: + [Metrics.image_portrait_enhancement_metric], + Tasks.video_summarization: [Metrics.video_summarization_metric], + Tasks.image_captioning: [Metrics.accuracy], + Tasks.visual_question_answering: [Metrics.accuracy], + Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], + Tasks.image_inpainting: [Metrics.image_inpainting_metric], + Tasks.referring_video_object_segmentation: + [Metrics.referring_video_object_segmentation_metric], +} + + +def build_metric(metric_cfg: Union[str, Dict], + field: str = default_group, + default_args: dict = None): + """ Build metric given metric_name and field. + + Args: + metric_name (str | dict): The metric name or metric config dict. + field (str, optional): The field of this metric, default value: 'default' for all fields. + default_args (dict, optional): Default initialization arguments. + """ + if isinstance(metric_cfg, Mapping): + assert 'type' in metric_cfg + else: + metric_cfg = ConfigDict({'type': metric_cfg}) + return build_from_cfg( + metric_cfg, METRICS, group_key=field, default_args=default_args) diff --git a/modelscope/metrics/ciderD/__init__.py b/modelscope/metrics/ciderD/__init__.py new file mode 100755 index 00000000..3f7d85bb --- /dev/null +++ b/modelscope/metrics/ciderD/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' diff --git a/modelscope/metrics/ciderD/ciderD.py b/modelscope/metrics/ciderD/ciderD.py new file mode 100755 index 00000000..05c7eb23 --- /dev/null +++ b/modelscope/metrics/ciderD/ciderD.py @@ -0,0 +1,57 @@ +# Filename: ciderD.py +# +# Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric +# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) +# +# Creation Date: Sun Feb 8 14:16:54 2015 +# +# Authors: Ramakrishna Vedantam and Tsung-Yi Lin +from __future__ import absolute_import, division, print_function + +from .ciderD_scorer import CiderScorer + + +class CiderD: + """ + Main Class to compute the CIDEr metric + + """ + + def __init__(self, n=4, sigma=6.0, df='corpus'): + # set cider to sum over 1 to 4-grams + self._n = n + # set the standard deviation parameter for gaussian penalty + self._sigma = sigma + # set which where to compute document frequencies from + self._df = df + self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df) + + def compute_score(self, gts, res): + """ + Main function to compute CIDEr score + :param hypo_for_image (dict) : dictionary with key and value + ref_for_image (dict) : dictionary with key and value + :return: cider (float) : computed CIDEr score for the corpus + """ # noqa + + # clear all the previous hypos and refs + tmp_cider_scorer = self.cider_scorer.copy_empty() + tmp_cider_scorer.clear() + for res_id in res: + + hypo = res_id['caption'] + ref = gts[res_id['image_id']] + + # Sanity check. + assert (type(hypo) is list) + assert (len(hypo) == 1) + assert (type(ref) is list) + assert (len(ref) > 0) + tmp_cider_scorer += (hypo[0], ref) + + (score, scores) = tmp_cider_scorer.compute_score() + + return score, scores + + def method(self): + return 'CIDEr-D' diff --git a/modelscope/metrics/ciderD/ciderD_scorer.py b/modelscope/metrics/ciderD/ciderD_scorer.py new file mode 100755 index 00000000..4157ec11 --- /dev/null +++ b/modelscope/metrics/ciderD/ciderD_scorer.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python +# Tsung-Yi Lin +# Ramakrishna Vedantam +from __future__ import absolute_import, division, print_function +import copy +import math +import os +import pdb +from collections import defaultdict + +import numpy as np +import six +from six.moves import cPickle + + +def precook(s, n=4, out=False): + """ + Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well. + :param s: string : sentence to be converted into ngrams + :param n: int : number of ngrams for which representation is calculated + :return: term frequency vector for occuring ngrams + """ + words = s.split() + counts = defaultdict(int) + for k in range(1, n + 1): + for i in range(len(words) - k + 1): + ngram = tuple(words[i:i + k]) + counts[ngram] += 1 + return counts + + +def cook_refs(refs, n=4): # lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them. + :param refs: list of string : reference sentences for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (list of dict) + ''' + return [precook(ref, n) for ref in refs] + + +def cook_test(test, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it. + :param test: list of string : hypothesis sentence for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (dict) + ''' + return precook(test, n, True) + + +class CiderScorer(object): + """CIDEr scorer. + """ + + def copy(self): + ''' copy the refs.''' + new = CiderScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + return new + + def copy_empty(self): + new = CiderScorer(df_mode='corpus', n=self.n, sigma=self.sigma) + new.df_mode = self.df_mode + new.ref_len = self.ref_len + new.document_frequency = self.document_frequency + return new + + def __init__(self, df_mode='corpus', test=None, refs=None, n=4, sigma=6.0): + ''' singular instance ''' + self.n = n + self.sigma = sigma + self.crefs = [] + self.ctest = [] + self.df_mode = df_mode + self.ref_len = None + if self.df_mode != 'corpus': + pkl_file = cPickle.load( + open(df_mode, 'rb'), + **(dict(encoding='latin1') if six.PY3 else {})) + self.ref_len = np.log(float(pkl_file['ref_len'])) + self.document_frequency = pkl_file['document_frequency'] + else: + self.document_frequency = None + self.cook_append(test, refs) + + def clear(self): + self.crefs = [] + self.ctest = [] + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + self.ctest.append(cook_test(test)) # N.B.: -1 + else: + self.ctest.append( + None) # lens of crefs and ctest have to match + + def size(self): + assert len(self.crefs) == len( + self.ctest), 'refs/test mismatch! %d<>%d' % (len( + self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + # avoid creating new CiderScorer instances + self.cook_append(other[0], other[1]) + else: + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + + return self + + def compute_doc_freq(self): + """ + Compute term frequency for reference data. + This will be used to compute idf (inverse document frequency later) + The term frequency is stored in the object + :return: None + """ + for refs in self.crefs: + # refs, k ref captions of one image + for ngram in set([ + ngram for ref in refs for (ngram, count) in ref.items() + ]): # noqa + self.document_frequency[ngram] += 1 + + def compute_cider(self): + + def counts2vec(cnts): + """ + Function maps counts of ngram to vector of tfidf weights. + The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. + The n-th entry of array denotes length of n-grams. + :param cnts: + :return: vec (array of dict), norm (array of float), length (int) + """ + vec = [defaultdict(float) for _ in range(self.n)] + length = 0 + norm = [0.0 for _ in range(self.n)] + for (ngram, term_freq) in cnts.items(): + # give word count 1 if it doesn't appear in reference corpus + df = np.log(max(1.0, self.document_frequency[ngram])) + # ngram index + n = len(ngram) - 1 + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[n][ngram] = float(term_freq) * (self.ref_len - df) + # compute norm for the vector. the norm will be used for computing similarity + norm[n] += pow(vec[n][ngram], 2) + + if n == 1: + length += term_freq + norm = [np.sqrt(n) for n in norm] + return vec, norm, length + + def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): + ''' + Compute the cosine similarity of two vectors. + :param vec_hyp: array of dictionary for vector corresponding to hypothesis + :param vec_ref: array of dictionary for vector corresponding to reference + :param norm_hyp: array of float for vector corresponding to hypothesis + :param norm_ref: array of float for vector corresponding to reference + :param length_hyp: int containing length of hypothesis + :param length_ref: int containing length of reference + :return: array of score for each n-grams cosine similarity + ''' + delta = float(length_hyp - length_ref) + # measure consine similarity + val = np.array([0.0 for _ in range(self.n)]) + for n in range(self.n): + # ngram + for (ngram, count) in vec_hyp[n].items(): + # vrama91 : added clipping + val[n] += min(vec_hyp[n][ngram], + vec_ref[n][ngram]) * vec_ref[n][ngram] + + if (norm_hyp[n] != 0) and (norm_ref[n] != 0): + val[n] /= (norm_hyp[n] * norm_ref[n]) + + assert (not math.isnan(val[n])) + # vrama91: added a length based gaussian penalty + val[n] *= np.e**(-(delta**2) / (2 * self.sigma**2)) + return val + + # compute log reference length + if self.df_mode == 'corpus': + self.ref_len = np.log(float(len(self.crefs))) + # elif self.df_mode == "coco-val-df": + # if coco option selected, use length of coco-val set + # self.ref_len = np.log(float(40504)) + + scores = [] + for test, refs in zip(self.ctest, self.crefs): + # compute vector for test captions + vec, norm, length = counts2vec(test) + # compute vector for ref captions + score = np.array([0.0 for _ in range(self.n)]) + for ref in refs: + vec_ref, norm_ref, length_ref = counts2vec(ref) + score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) + # change by vrama91 - mean of ngram scores, instead of sum + score_avg = np.mean(score) + # divide by number of references + score_avg /= len(refs) + # multiply score by 10 + score_avg *= 10.0 + # append score of an image to the score list + scores.append(score_avg) + return scores + + def compute_score(self, option=None, verbose=0): + # compute idf + if self.df_mode == 'corpus': + self.document_frequency = defaultdict(float) + self.compute_doc_freq() + # assert to check document frequency + assert (len(self.ctest) >= max(self.document_frequency.values())) + # import json for now and write the corresponding files + # compute cider score + score = self.compute_cider() + # debug + # print score + return np.mean(np.array(score)), np.array(score) diff --git a/modelscope/metrics/image_color_enhance_metric.py b/modelscope/metrics/image_color_enhance_metric.py new file mode 100644 index 00000000..b3744975 --- /dev/null +++ b/modelscope/metrics/image_color_enhance_metric.py @@ -0,0 +1,258 @@ +# The code is modified based on BasicSR metrics: +# https://github.com/XPixelGroup/BasicSR/tree/master/basicsr/metrics + +from typing import Dict + +import cv2 +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" + ) + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, + 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) + tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) + ssim_map = tmp1 / tmp2 + + return ssim_map.mean() + + +def calculate_psnr(img, + img2, + crop_border, + input_order='HWC', + test_y_channel=False, + **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, ( + f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' + ) + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +def calculate_ssim(img, + img2, + crop_border, + input_order='HWC', + test_y_channel=False, + **kwargs): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, ( + f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' + ) + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.image_color_enhance_metric) +class ImageColorEnhanceMetric(Metric): + """The metric computation class for image color enhance classes. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + + def add(self, outputs: Dict, inputs: Dict): + ground_truths = outputs['target'] + eval_results = outputs['pred'] + self.preds.extend(eval_results) + self.targets.extend(ground_truths) + + def evaluate(self): + psnrs = [ + calculate_psnr(pred, target, 2, test_y_channel=False) + for pred, target in zip(self.preds, self.targets) + ] + ssims = [ + calculate_ssim(pred, target, 2, test_y_channel=False) + for pred, target in zip(self.preds, self.targets) + ] + return { + MetricKeys.PSNR: sum(psnrs) / len(psnrs), + MetricKeys.SSIM: sum(ssims) / len(ssims) + } diff --git a/modelscope/metrics/image_denoise_metric.py b/modelscope/metrics/image_denoise_metric.py new file mode 100644 index 00000000..1692f299 --- /dev/null +++ b/modelscope/metrics/image_denoise_metric.py @@ -0,0 +1,272 @@ +# ------------------------------------------------------------------------ +# Copyright (c) Alibaba, Inc. and its affiliates. +# ------------------------------------------------------------------------ +# modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py +# ------------------------------------------------------------------------ +from typing import Dict + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.image_denoise_metric) +class ImageDenoiseMetric(Metric): + """The metric computation class for image denoise classes. + """ + pred_name = 'pred' + label_name = 'target' + + def __init__(self): + super(ImageDenoiseMetric, self).__init__() + self.preds = [] + self.labels = [] + + def add(self, outputs: Dict, inputs: Dict): + ground_truths = outputs[ImageDenoiseMetric.label_name] + eval_results = outputs[ImageDenoiseMetric.pred_name] + self.preds.append(eval_results) + self.labels.append(ground_truths) + + def evaluate(self): + psnr_list, ssim_list = [], [] + for (pred, label) in zip(self.preds, self.labels): + psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0)) + ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0)) + return { + MetricKeys.PSNR: np.mean(psnr_list), + MetricKeys.SSIM: np.mean(ssim_list) + } + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" + ) + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def calculate_psnr(img1, img2, crop_border, input_order='HWC'): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + Args: + img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) + + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + def _psnr(img1, img2): + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + max_value = 1. if img1.max() <= 1 else 255. + return 20. * np.log10(max_value / np.sqrt(mse)) + + return _psnr(img1, img2) + + +def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True): + """Calculate SSIM (structural similarity). + Ref: + Image quality assessment: From error visibility to structural similarity + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + For three-channel images, SSIM is calculated for each channel and then + averaged. + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) + + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + def _cal_ssim(img1, img2): + ssims = [] + + max_value = 1 if img1.max() <= 1 else 255 + with torch.no_grad(): + final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim( + img1, img2, max_value) + ssims.append(final_ssim) + + return np.array(ssims).mean() + + return _cal_ssim(img1, img2) + + +def _ssim(img, img2, max_value): + """Calculate SSIM (structural similarity) for one channel images. + It is called by func:`calculate_ssim`. + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + Returns: + float: SSIM result. + """ + + c1 = (0.01 * max_value)**2 + c2 = (0.03 * max_value)**2 + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, + 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) + tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) + ssim_map = tmp1 / tmp2 + return ssim_map.mean() + + +def _3d_gaussian_calculator(img, conv3d): + out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + return out + + +def _generate_3d_gaussian_kernel(): + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + kernel_3 = cv2.getGaussianKernel(11, 1.5) + kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) + conv3d = torch.nn.Conv3d( + 1, + 1, (11, 11, 11), + stride=1, + padding=(5, 5, 5), + bias=False, + padding_mode='replicate') + conv3d.weight.requires_grad = False + conv3d.weight[0, 0, :, :, :] = kernel + return conv3d + + +def _ssim_3d(img1, img2, max_value): + assert len(img1.shape) == 3 and len(img2.shape) == 3 + """Calculate SSIM (structural similarity) for one channel images. + It is called by func:`calculate_ssim`. + Args: + img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + Returns: + float: ssim result. + """ + C1 = (0.01 * max_value)**2 + C2 = (0.03 * max_value)**2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + kernel = _generate_3d_gaussian_kernel().cuda() + + img1 = torch.tensor(img1).float().cuda() + img2 = torch.tensor(img2).float().cuda() + + mu1 = _3d_gaussian_calculator(img1, kernel) + mu2 = _3d_gaussian_calculator(img2, kernel) + + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq + sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq + sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2 + + tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2) + tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ssim_map = tmp1 / tmp2 + return float(ssim_map.mean()) diff --git a/modelscope/metrics/image_inpainting_metric.py b/modelscope/metrics/image_inpainting_metric.py new file mode 100644 index 00000000..954d4ca2 --- /dev/null +++ b/modelscope/metrics/image_inpainting_metric.py @@ -0,0 +1,210 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F +from scipy import linalg + +from modelscope.metainfo import Metrics +from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3 +from modelscope.utils.registry import default_group +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) +from .base import Metric +from .builder import METRICS, MetricKeys + + +def fid_calculate_activation_statistics(act): + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): + mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) + mu2, sigma2 = fid_calculate_activation_statistics(activations_target) + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) + - 2 * tr_covmean) + + +class FIDScore(torch.nn.Module): + + def __init__(self, dims=2048, eps=1e-6): + super().__init__() + if getattr(FIDScore, '_MODEL', None) is None: + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + FIDScore._MODEL = InceptionV3([block_idx]).eval() + self.model = FIDScore._MODEL + self.eps = eps + self.reset() + + def forward(self, pred_batch, target_batch, mask=None): + activations_pred = self._get_activations(pred_batch) + activations_target = self._get_activations(target_batch) + + self.activations_pred.append(activations_pred.detach().cpu()) + self.activations_target.append(activations_target.detach().cpu()) + + def get_value(self): + activations_pred, activations_target = (self.activations_pred, + self.activations_target) + activations_pred = torch.cat(activations_pred).cpu().numpy() + activations_target = torch.cat(activations_target).cpu().numpy() + + total_distance = calculate_frechet_distance( + activations_pred, activations_target, eps=self.eps) + + self.reset() + return total_distance + + def reset(self): + self.activations_pred = [] + self.activations_target = [] + + def _get_activations(self, batch): + activations = self.model(batch)[0] + if activations.shape[2] != 1 or activations.shape[3] != 1: + assert False, \ + 'We should not have got here, because Inception always scales inputs to 299x299' + activations = activations.squeeze(-1).squeeze(-1) + return activations + + +class SSIM(torch.nn.Module): + """SSIM. Modified from: + https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py + """ + + def __init__(self, window_size=11, size_average=True): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.register_buffer('window', + self._create_window(window_size, self.channel)) + + def forward(self, img1, img2): + assert len(img1.shape) == 4 + + channel = img1.size()[1] + + if channel == self.channel and self.window.data.type( + ) == img1.data.type(): + window = self.window + else: + window = self._create_window(self.window_size, channel) + + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, + self.size_average) + + def _gaussian(self, window_size, sigma): + gauss = torch.Tensor([ + np.exp(-(x - (window_size // 2))**2 / float(2 * sigma**2)) + for x in range(window_size) + ]) + return gauss / gauss.sum() + + def _create_window(self, window_size, channel): + _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm( + _1D_window.t()).float().unsqueeze(0).unsqueeze(0) + return _2D_window.expand(channel, 1, window_size, + window_size).contiguous() + + def _ssim(self, + img1, + img2, + window, + window_size, + channel, + size_average=True): + mu1 = F.conv2d( + img1, window, padding=(window_size // 2), groups=channel) + mu2 = F.conv2d( + img2, window, padding=(window_size // 2), groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=(window_size // 2), + groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=(window_size // 2), + groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=(window_size // 2), + groups=channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + + return ssim_map.mean(1).mean(1).mean(1) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + return + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.image_inpainting_metric) +class ImageInpaintingMetric(Metric): + """The metric computation class for image inpainting classes. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + self.SSIM = SSIM(window_size=11, size_average=False).eval() + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.FID = FIDScore().to(device) + + def add(self, outputs: Dict, inputs: Dict): + pred = outputs['inpainted'] + target = inputs['image'] + self.preds.append(torch_nested_detach(pred)) + self.targets.append(torch_nested_detach(target)) + + def evaluate(self): + ssim_list = [] + for (pred, target) in zip(self.preds, self.targets): + ssim_list.append(self.SSIM(pred, target)) + self.FID(pred, target) + ssim_list = torch_nested_numpify(ssim_list) + fid = self.FID.get_value() + return {MetricKeys.SSIM: np.mean(ssim_list), MetricKeys.FID: fid} diff --git a/modelscope/metrics/image_instance_segmentation_metric.py b/modelscope/metrics/image_instance_segmentation_metric.py new file mode 100644 index 00000000..86a19d13 --- /dev/null +++ b/modelscope/metrics/image_instance_segmentation_metric.py @@ -0,0 +1,314 @@ +# Part of the implementation is borrowed and modified from MMDetection, publicly available at +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py +import os.path as osp +import tempfile +from collections import OrderedDict +from typing import Any, Dict + +import numpy as np +import pycocotools.mask as mask_util +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from modelscope.fileio import dump, load +from modelscope.metainfo import Metrics +from modelscope.metrics import METRICS, Metric +from modelscope.utils.registry import default_group + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.image_ins_seg_coco_metric) +class ImageInstanceSegmentationCOCOMetric(Metric): + """The metric computation class for COCO-style image instance segmentation. + """ + + def __init__(self): + self.ann_file = None + self.classes = None + self.metrics = ['bbox', 'segm'] + self.proposal_nums = (100, 300, 1000) + self.iou_thrs = np.linspace( + .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + self.results = [] + + def add(self, outputs: Dict[str, Any], inputs: Dict[str, Any]): + result = outputs['eval_result'] + # encode mask results + if isinstance(result[0], tuple): + result = [(bbox_results, encode_mask_results(mask_results)) + for bbox_results, mask_results in result] + self.results.extend(result) + if self.ann_file is None: + self.ann_file = outputs['img_metas'][0]['ann_file'] + self.classes = outputs['img_metas'][0]['classes'] + + def evaluate(self): + cocoGt = COCO(self.ann_file) + self.cat_ids = cocoGt.getCatIds(catNms=self.classes) + self.img_ids = cocoGt.getImgIds() + + result_files, tmp_dir = self.format_results(self.results, self.img_ids) + + eval_results = OrderedDict() + for metric in self.metrics: + iou_type = metric + if metric not in result_files: + raise KeyError(f'{metric} is not in results') + try: + predictions = load(result_files[metric]) + if iou_type == 'segm': + # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa + # When evaluating mask AP, if the results contain bbox, + # cocoapi will use the box area instead of the mask area + # for calculating the instance area. Though the overall AP + # is not affected, this leads to different + # small/medium/large mask AP results. + for x in predictions: + x.pop('bbox') + cocoDt = cocoGt.loadRes(predictions) + except IndexError: + print('The testing results of the whole dataset is empty.') + break + + cocoEval = COCOeval(cocoGt, cocoDt, iou_type) + cocoEval.params.catIds = self.cat_ids + cocoEval.params.imgIds = self.img_ids + cocoEval.params.maxDets = list(self.proposal_nums) + cocoEval.params.iouThrs = self.iou_thrs + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + metric_items = [ + 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' + ] + + for metric_item in metric_items: + key = f'{metric}_{metric_item}' + val = float( + f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}') + eval_results[key] = val + ap = cocoEval.stats[:6] + eval_results[f'{metric}_mAP_copypaste'] = ( + f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + if tmp_dir is not None: + tmp_dir.cleanup() + return eval_results + + def format_results(self, results, img_ids, jsonfile_prefix=None, **kwargs): + """Format the results to json (standard format for COCO evaluation). + + Args: + results (list[tuple | numpy.ndarray]): Testing results of the + dataset. + data_infos(list[tuple | numpy.ndarray]): data information + jsonfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Default: None. + + Returns: + tuple: (result_files, tmp_dir), result_files is a dict containing \ + the json filepaths, tmp_dir is the temporal directory created \ + for saving json files when jsonfile_prefix is not specified. + """ + assert isinstance(results, list), 'results must be a list' + assert len(results) == len(img_ids), ( + 'The length of results is not equal to the dataset len: {} != {}'. + format(len(results), len(img_ids))) + + if jsonfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + jsonfile_prefix = osp.join(tmp_dir.name, 'results') + else: + tmp_dir = None + result_files = self.results2json(results, jsonfile_prefix) + return result_files, tmp_dir + + def xyxy2xywh(self, bbox): + """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO + evaluation. + + Args: + bbox (numpy.ndarray): The bounding boxes, shape (4, ), in + ``xyxy`` order. + + Returns: + list[float]: The converted bounding boxes, in ``xywh`` order. + """ + + _bbox = bbox.tolist() + return [ + _bbox[0], + _bbox[1], + _bbox[2] - _bbox[0], + _bbox[3] - _bbox[1], + ] + + def _proposal2json(self, results): + """Convert proposal results to COCO json style.""" + json_results = [] + for idx in range(len(self.img_ids)): + img_id = self.img_ids[idx] + bboxes = results[idx] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = 1 + json_results.append(data) + return json_results + + def _det2json(self, results): + """Convert detection results to COCO json style.""" + json_results = [] + for idx in range(len(self.img_ids)): + img_id = self.img_ids[idx] + result = results[idx] + for label in range(len(result)): + # Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO) + # (assuming the num class of input dataset is no more than 80). + # Recommended manually set `num_classes=${your test dataset class num}` in the + # configuration.json in practice. + if label >= len(self.classes): + break + bboxes = result[label] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = self.cat_ids[label] + json_results.append(data) + return json_results + + def _segm2json(self, results): + """Convert instance segmentation results to COCO json style.""" + bbox_json_results = [] + segm_json_results = [] + for idx in range(len(self.img_ids)): + img_id = self.img_ids[idx] + det, seg = results[idx] + for label in range(len(det)): + # Here we skip invalid predicted labels, as we use the fixed num_classes of 80 (COCO) + # (assuming the num class of input dataset is no more than 80). + # Recommended manually set `num_classes=${your test dataset class num}` in the + # configuration.json in practice. + if label >= len(self.classes): + break + # bbox results + bboxes = det[label] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(bboxes[i][4]) + data['category_id'] = self.cat_ids[label] + bbox_json_results.append(data) + + # segm results + # some detectors use different scores for bbox and mask + if isinstance(seg, tuple): + segms = seg[0][label] + mask_score = seg[1][label] + else: + segms = seg[label] + mask_score = [bbox[4] for bbox in bboxes] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(mask_score[i]) + data['category_id'] = self.cat_ids[label] + if isinstance(segms[i]['counts'], bytes): + segms[i]['counts'] = segms[i]['counts'].decode() + data['segmentation'] = segms[i] + segm_json_results.append(data) + return bbox_json_results, segm_json_results + + def results2json(self, results, outfile_prefix): + """Dump the detection results to a COCO style json file. + + There are 3 types of results: proposals, bbox predictions, mask + predictions, and they have different data types. This method will + automatically recognize the type, and dump them to json files. + + Args: + results (list[list | tuple | ndarray]): Testing results of the + dataset. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.bbox.json", "somepath/xxx.segm.json", + "somepath/xxx.proposal.json". + + Returns: + dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \ + values are corresponding filenames. + """ + result_files = dict() + if isinstance(results[0], list): + json_results = self._det2json(results) + result_files['bbox'] = f'{outfile_prefix}.bbox.json' + result_files['proposal'] = f'{outfile_prefix}.bbox.json' + dump(json_results, result_files['bbox']) + elif isinstance(results[0], tuple): + json_results = self._segm2json(results) + result_files['bbox'] = f'{outfile_prefix}.bbox.json' + result_files['proposal'] = f'{outfile_prefix}.bbox.json' + result_files['segm'] = f'{outfile_prefix}.segm.json' + dump(json_results[0], result_files['bbox']) + dump(json_results[1], result_files['segm']) + elif isinstance(results[0], np.ndarray): + json_results = self._proposal2json(results) + result_files['proposal'] = f'{outfile_prefix}.proposal.json' + dump(json_results, result_files['proposal']) + else: + raise TypeError('invalid type of results') + return result_files + + +def encode_mask_results(mask_results): + """Encode bitmap mask to RLE code. + + Args: + mask_results (list | tuple[list]): bitmap mask results. + In mask scoring rcnn, mask_results is a tuple of (segm_results, + segm_cls_score). + + Returns: + list | tuple: RLE encoded mask. + """ + if isinstance(mask_results, tuple): # mask scoring + cls_segms, cls_mask_scores = mask_results + else: + cls_segms = mask_results + num_classes = len(cls_segms) + encoded_mask_results = [[] for _ in range(num_classes)] + for i in range(len(cls_segms)): + for cls_segm in cls_segms[i]: + encoded_mask_results[i].append( + mask_util.encode( + np.array( + cls_segm[:, :, np.newaxis], order='F', + dtype='uint8'))[0]) # encoded with RLE + if isinstance(mask_results, tuple): + return encoded_mask_results, cls_mask_scores + else: + return encoded_mask_results diff --git a/modelscope/metrics/image_portrait_enhancement_metric.py b/modelscope/metrics/image_portrait_enhancement_metric.py new file mode 100644 index 00000000..7d94aade --- /dev/null +++ b/modelscope/metrics/image_portrait_enhancement_metric.py @@ -0,0 +1,51 @@ +# Part of the implementation is borrowed and modified from BasicSR, publicly available at +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py +from typing import Dict + +import cv2 +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +def calculate_psnr(img, img2): + assert img.shape == img2.shape, ( + f'Image shapes are different: {img.shape}, {img2.shape}.') + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +@METRICS.register_module( + group_key=default_group, + module_name=Metrics.image_portrait_enhancement_metric) +class ImagePortraitEnhancementMetric(Metric): + """The metric for image-portrait-enhancement task. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + + def add(self, outputs: Dict, inputs: Dict): + ground_truths = outputs['target'] + eval_results = outputs['pred'] + + self.preds.extend(eval_results) + self.targets.extend(ground_truths) + + def evaluate(self): + psnrs = [ + calculate_psnr(pred, target) + for pred, target in zip(self.preds, self.targets) + ] + + return {MetricKeys.PSNR: sum(psnrs) / len(psnrs)} diff --git a/modelscope/metrics/inbatch_recall_metric.py b/modelscope/metrics/inbatch_recall_metric.py new file mode 100644 index 00000000..d098a883 --- /dev/null +++ b/modelscope/metrics/inbatch_recall_metric.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np +import torch + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.inbatch_recall) +class InbatchRecallMetric(Metric): + """The metric computation class for in-batch retrieval classes. + + This metric class calculates in-batch image recall@1 for each input batch. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inbatch_t2i_hitcnts = [] + self.batch_sizes = [] + + def add(self, outputs: Dict, inputs: Dict): + image_features = outputs[OutputKeys.IMG_EMBEDDING] + text_features = outputs[OutputKeys.TEXT_EMBEDDING] + + assert type(image_features) == torch.Tensor and type( + text_features) == torch.Tensor + + with torch.no_grad(): + logits_per_image = image_features @ text_features.t() + logits_per_text = logits_per_image.t() + batch_size = logits_per_image.shape[0] + + ground_truth = torch.arange(batch_size).long() + ground_truth = ground_truth.to(image_features.device) + + inbatch_t2i_hitcnt = (logits_per_text.argmax(-1) == ground_truth + ).sum().float().item() + + self.inbatch_t2i_hitcnts.append(inbatch_t2i_hitcnt) + self.batch_sizes.append(batch_size) + + def evaluate(self): + assert len(self.inbatch_t2i_hitcnts) == len( + self.batch_sizes) and len(self.batch_sizes) > 0 + return { + MetricKeys.BatchAcc: + sum(self.inbatch_t2i_hitcnts) / sum(self.batch_sizes) + } diff --git a/modelscope/metrics/map_metric.py b/modelscope/metrics/map_metric.py new file mode 100644 index 00000000..aac76f22 --- /dev/null +++ b/modelscope/metrics/map_metric.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.multi_average_precision) +class AveragePrecisionMetric(Metric): + """The metric computation class for multi avarage precision classes. + + This metric class calculates multi avarage precision for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preds = [] + self.labels = [] + self.thresh = kwargs.get('threshold', 0.5) + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = outputs[label_name] + for key in [ + OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, + OutputKeys.LABELS, OutputKeys.SCORES + ]: + if key in outputs and outputs[key] is not None: + eval_results = outputs[key] + break + assert type(ground_truths) == type(eval_results) + for truth in ground_truths: + self.labels.append(truth) + for result in eval_results: + if isinstance(truth, str): + self.preds.append(result.strip().replace(' ', '')) + else: + self.preds.append(result) + + def evaluate(self): + assert len(self.preds) == len(self.labels) + scores = self._calculate_ap_score(self.preds, self.labels, self.thresh) + return {MetricKeys.mAP: scores.mean().item()} + + def _calculate_ap_score(self, preds, labels, thresh=0.5): + hyps = np.array(preds) + refs = np.array(labels) + a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]) + b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:]) + interacts = np.concatenate([a, b], axis=1) + area_predictions = (hyps[:, 2] - hyps[:, 0]) * ( + hyps[:, 3] - hyps[:, 1]) + area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1]) + interacts_w = interacts[:, 2] - interacts[:, 0] + interacts_h = interacts[:, 3] - interacts[:, 1] + area_interacts = interacts_w * interacts_h + ious = area_interacts / ( + area_predictions + area_targets - area_interacts + 1e-6) + return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0) diff --git a/modelscope/metrics/movie_scene_segmentation_metric.py b/modelscope/metrics/movie_scene_segmentation_metric.py new file mode 100644 index 00000000..65725b6f --- /dev/null +++ b/modelscope/metrics/movie_scene_segmentation_metric.py @@ -0,0 +1,54 @@ +# The implementation here is modified based on BaSSL, +# originally Apache 2.0 License and publicly available at https://github.com/kakaobrain/bassl +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, + module_name=Metrics.movie_scene_segmentation_metric) +class MovieSceneSegmentationMetric(Metric): + """The metric computation class for movie scene segmentation classes. + """ + + def __init__(self): + self.preds = [] + self.labels = [] + self.eps = 1e-5 + + def add(self, outputs: Dict, inputs: Dict): + preds = outputs['pred'] + labels = inputs['label'] + self.preds.extend(preds) + self.labels.extend(labels) + + def evaluate(self): + gts = np.array(torch_nested_numpify(torch_nested_detach(self.labels))) + prob = np.array(torch_nested_numpify(torch_nested_detach(self.preds))) + + gt_one = gts == 1 + gt_zero = gts == 0 + pred_one = prob == 1 + pred_zero = prob == 0 + + tp = (gt_one * pred_one).sum() + fp = (gt_zero * pred_one).sum() + fn = (gt_one * pred_zero).sum() + + precision = 100.0 * tp / (tp + fp + self.eps) + recall = 100.0 * tp / (tp + fn + self.eps) + f1 = 2 * precision * recall / (precision + recall) + + return { + MetricKeys.F1: f1, + MetricKeys.RECALL: recall, + MetricKeys.PRECISION: precision + } diff --git a/modelscope/metrics/ned_metric.py b/modelscope/metrics/ned_metric.py new file mode 100644 index 00000000..e87bb2c4 --- /dev/null +++ b/modelscope/metrics/ned_metric.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module(group_key=default_group, module_name=Metrics.NED) +class NedMetric(Metric): + """The ned metric computation class for classification classes. + + This metric class calculates the levenshtein distance between sentences for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preds = [] + self.labels = [] + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = outputs[label_name] + for key in [ + OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, + OutputKeys.LABELS, OutputKeys.SCORES + ]: + if key in outputs and outputs[key] is not None: + eval_results = outputs[key] + break + assert type(ground_truths) == type(eval_results) + if isinstance(ground_truths, list): + self.preds.extend(eval_results) + self.labels.extend(ground_truths) + elif isinstance(ground_truths, np.ndarray): + self.preds.extend(eval_results.tolist()) + self.labels.extend(ground_truths.tolist()) + else: + raise Exception('only support list or np.ndarray') + + def evaluate(self): + assert len(self.preds) == len(self.labels) + return { + MetricKeys.NED: (np.asarray([ + 1.0 - NedMetric._distance(pred, ref) + for pred, ref in zip(self.preds, self.labels) + ])).mean().item() + } + + @staticmethod + def _distance(pred, ref): + if pred is None or ref is None: + raise TypeError('Argument (pred or ref) is NoneType.') + if pred == ref: + return 0.0 + if len(pred) == 0: + return len(ref) + if len(ref) == 0: + return len(pred) + m_len = max(len(pred), len(ref)) + if m_len == 0: + return 0.0 + + def levenshtein(s0, s1): + v0 = [0] * (len(s1) + 1) + v1 = [0] * (len(s1) + 1) + + for i in range(len(v0)): + v0[i] = i + + for i in range(len(s0)): + v1[0] = i + 1 + for j in range(len(s1)): + cost = 1 + if s0[i] == s1[j]: + cost = 0 + v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost) + v0, v1 = v1, v0 + return v0[len(s1)] + + return levenshtein(pred, ref) / m_len diff --git a/modelscope/metrics/referring_video_object_segmentation_metric.py b/modelscope/metrics/referring_video_object_segmentation_metric.py new file mode 100644 index 00000000..5a0af30b --- /dev/null +++ b/modelscope/metrics/referring_video_object_segmentation_metric.py @@ -0,0 +1,108 @@ +# Part of the implementation is borrowed and modified from MTTR, +# publicly available at https://github.com/mttr2021/MTTR +from typing import Dict + +import numpy as np +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from pycocotools.mask import decode +from tqdm import tqdm + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, + module_name=Metrics.referring_video_object_segmentation_metric) +class ReferringVideoObjectSegmentationMetric(Metric): + """The metric computation class for movie scene segmentation classes. + """ + + def __init__(self, + ann_file=None, + calculate_precision_and_iou_metrics=True): + self.ann_file = ann_file + self.calculate_precision_and_iou_metrics = calculate_precision_and_iou_metrics + self.preds = [] + + def add(self, outputs: Dict, inputs: Dict): + preds_batch = outputs['pred'] + self.preds.extend(preds_batch) + + def evaluate(self): + coco_gt = COCO(self.ann_file) + coco_pred = coco_gt.loadRes(self.preds) + coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') + coco_eval.params.useCats = 0 + + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + ap_labels = [ + 'mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', + 'AP 0.5:0.95 M', 'AP 0.5:0.95 L' + ] + ap_metrics = coco_eval.stats[:6] + eval_metrics = {la: m for la, m in zip(ap_labels, ap_metrics)} + if self.calculate_precision_and_iou_metrics: + precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics( + coco_gt, coco_pred) + eval_metrics.update({ + f'P@{k}': m + for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k) + }) + eval_metrics.update({ + 'overall_iou': overall_iou, + 'mean_iou': mean_iou + }) + + return eval_metrics + + +def compute_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6): + outputs = outputs.int() + intersection = (outputs & labels).float().sum( + (1, 2)) # Will be zero if Truth=0 or Prediction=0 + union = (outputs | labels).float().sum( + (1, 2)) # Will be zero if both are 0 + iou = (intersection + EPS) / (union + EPS + ) # EPS is used to avoid division by zero + return iou, intersection, union + + +def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): + print('evaluating precision@k & iou metrics...') + counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} + total_intersection_area = 0 + total_union_area = 0 + ious_list = [] + for instance in tqdm(coco_gt.imgs.keys() + ): # each image_id contains exactly one instance + gt_annot = coco_gt.imgToAnns[instance][0] + gt_mask = decode(gt_annot['segmentation']) + pred_annots = coco_pred.imgToAnns[instance] + pred_annot = sorted( + pred_annots, + key=lambda a: a['score'])[-1] # choose pred with highest score + pred_mask = decode(pred_annot['segmentation']) + iou, intersection, union = compute_iou( + torch.tensor(pred_mask).unsqueeze(0), + torch.tensor(gt_mask).unsqueeze(0)) + iou, intersection, union = iou.item(), intersection.item(), union.item( + ) + for iou_threshold in counters_by_iou.keys(): + if iou > iou_threshold: + counters_by_iou[iou_threshold] += 1 + total_intersection_area += intersection + total_union_area += union + ious_list.append(iou) + num_samples = len(ious_list) + precision_at_k = np.array(list(counters_by_iou.values())) / num_samples + overall_iou = total_intersection_area / total_union_area + mean_iou = np.mean(ious_list) + return precision_at_k, overall_iou, mean_iou diff --git a/modelscope/metrics/sequence_classification_metric.py b/modelscope/metrics/sequence_classification_metric.py new file mode 100644 index 00000000..1fe1c329 --- /dev/null +++ b/modelscope/metrics/sequence_classification_metric.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np +from sklearn.metrics import accuracy_score, f1_score + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.seq_cls_metric) +class SequenceClassificationMetric(Metric): + """The metric computation class for sequence classification tasks. + + This metric class calculates accuracy of the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preds = [] + self.labels = [] + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = outputs[OutputKeys.LOGITS] + self.preds.append( + torch_nested_numpify(torch_nested_detach(eval_results))) + self.labels.append( + torch_nested_numpify(torch_nested_detach(ground_truths))) + + def evaluate(self): + preds = np.concatenate(self.preds, axis=0) + labels = np.concatenate(self.labels, axis=0) + preds = np.argmax(preds, axis=1) + return { + MetricKeys.ACCURACY: + accuracy_score(labels, preds), + MetricKeys.F1: + f1_score( + labels, + preds, + average='micro' if any([label > 1 + for label in labels]) else None), + } diff --git a/modelscope/metrics/text_generation_metric.py b/modelscope/metrics/text_generation_metric.py new file mode 100644 index 00000000..08df5235 --- /dev/null +++ b/modelscope/metrics/text_generation_metric.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict, Iterable, List + +from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu +from rouge import Rouge + +from modelscope.metainfo import Metrics +from modelscope.metrics.base import Metric +from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.utils.chinese_utils import rebuild_chinese_str +from modelscope.utils.registry import default_group + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.text_gen_metric) +class TextGenerationMetric(Metric): + """The metric computation class for text generation classes. + + This metric class calculates F1 of the rouge scores for the whole evaluation dataset. + """ + + def __init__(self): + self.preds: List[str] = [] + self.tgts: List[str] = [] + self.rouge = Rouge() + + def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): + ground_truths = inputs['tgts'] + eval_results = outputs['preds'] + for truth in ground_truths: + self.tgts.append(rebuild_chinese_str(truth)) + for result in eval_results: + self.preds.append(rebuild_chinese_str(result)) + + def _check(self, pred: str, tgt: str) -> bool: + + def remove_useless(string: str) -> str: + return string.replace(' ', '').replace('.', '') + + return remove_useless(pred) and remove_useless(tgt) + + def evaluate(self): + assert self.preds, 'preds in TextGenerationMetric must not be empty!' + tmp = [(pred, tgt) for pred, tgt in zip(self.preds, self.tgts) + if self._check(pred, tgt)] + preds, tgts = zip(*tmp) + + def mean(iter: Iterable) -> float: + return sum(iter) / len(self.preds) + + rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) + rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) + rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) + + pred_list = [each.strip().split(' ') for each in self.preds] + tgt_list = [[each.strip().split(' ')] for each in self.tgts] + bleu_1 = corpus_bleu( + tgt_list, + pred_list, + weights=(1, 0, 0, 0), + smoothing_function=SmoothingFunction().method3) + bleu_4 = corpus_bleu( + tgt_list, + pred_list, + smoothing_function=SmoothingFunction().method3) + return { + MetricKeys.ROUGE_1: rouge_1, + MetricKeys.ROUGE_L: rouge_l, + MetricKeys.BLEU_1: bleu_1, + MetricKeys.BLEU_4: bleu_4 + } diff --git a/modelscope/metrics/token_classification_metric.py b/modelscope/metrics/token_classification_metric.py new file mode 100644 index 00000000..f8595fc1 --- /dev/null +++ b/modelscope/metrics/token_classification_metric.py @@ -0,0 +1,134 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import importlib +from typing import Dict, List, Optional, Union + +import numpy as np + +from modelscope.outputs import OutputKeys +from ..metainfo import Metrics +from ..utils.registry import default_group +from ..utils.tensor_utils import torch_nested_detach, torch_nested_numpify +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.token_cls_metric) +class TokenClassificationMetric(Metric): + """The metric computation class for token-classification task. + + This metric class uses seqeval to calculate the scores. + + Args: + return_entity_level_metrics (bool, *optional*): + Whether to return every label's detail metrics, default False. + """ + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = outputs[OutputKeys.LOGITS] + self.preds.append( + torch_nested_numpify(torch_nested_detach(eval_results))) + self.labels.append( + torch_nested_numpify(torch_nested_detach(ground_truths))) + + def __init__(self, + return_entity_level_metrics=False, + label2id=None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.return_entity_level_metrics = return_entity_level_metrics + self.preds = [] + self.labels = [] + self.label2id = label2id + + def evaluate(self): + label2id = self.label2id + if label2id is None: + assert hasattr(self, 'trainer') + label2id = self.trainer.label2id + + self.id2label = {id: label for label, id in label2id.items()} + self.preds = np.concatenate(self.preds, axis=0) + self.labels = np.concatenate(self.labels, axis=0) + predictions = np.argmax(self.preds, axis=-1) + + true_predictions = [[ + self.id2label[p] for (p, lb) in zip(prediction, label) + if lb != -100 + ] for prediction, label in zip(predictions, self.labels)] + true_labels = [[ + self.id2label[lb] for (p, lb) in zip(prediction, label) + if lb != -100 + ] for prediction, label in zip(predictions, self.labels)] + + results = self._compute( + predictions=true_predictions, references=true_labels) + if self.return_entity_level_metrics: + final_results = {} + for key, value in results.items(): + if isinstance(value, dict): + for n, v in value.items(): + final_results[f'{key}_{n}'] = v + else: + final_results[key] = value + return final_results + else: + return { + MetricKeys.PRECISION: results[MetricKeys.PRECISION], + MetricKeys.RECALL: results[MetricKeys.RECALL], + MetricKeys.F1: results[MetricKeys.F1], + MetricKeys.ACCURACY: results[MetricKeys.ACCURACY], + } + + @staticmethod + def _compute( + predictions, + references, + suffix: bool = False, + scheme: Optional[str] = None, + mode: Optional[str] = None, + sample_weight: Optional[List[int]] = None, + zero_division: Union[str, int] = 'warn', + ): + from seqeval.metrics import accuracy_score, classification_report + if scheme is not None: + try: + scheme_module = importlib.import_module('seqeval.scheme') + scheme = getattr(scheme_module, scheme) + except AttributeError: + raise ValueError( + f'Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {scheme}' + ) + report = classification_report( + y_true=references, + y_pred=predictions, + suffix=suffix, + output_dict=True, + scheme=scheme, + mode=mode, + sample_weight=sample_weight, + zero_division=zero_division, + ) + report.pop('macro avg') + report.pop('weighted avg') + overall_score = report.pop('micro avg') + + scores = { + type_name: { + MetricKeys.PRECISION: score['precision'], + MetricKeys.RECALL: score['recall'], + MetricKeys.F1: score['f1-score'], + 'number': score['support'], + } + for type_name, score in report.items() + } + scores[MetricKeys.PRECISION] = overall_score['precision'] + scores[MetricKeys.RECALL] = overall_score['recall'] + scores[MetricKeys.F1] = overall_score['f1-score'] + scores[MetricKeys.ACCURACY] = accuracy_score( + y_true=references, y_pred=predictions) + return scores diff --git a/modelscope/metrics/video_summarization_metric.py b/modelscope/metrics/video_summarization_metric.py new file mode 100644 index 00000000..40580382 --- /dev/null +++ b/modelscope/metrics/video_summarization_metric.py @@ -0,0 +1,81 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM + +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.models.cv.video_summarization.summarizer import \ + generate_summary +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +def evaluate_summary(predicted_summary, user_summary, eval_method): + """ Compare the predicted summary with the user defined one(s). + + :param ndarray predicted_summary: The generated summary from our model. + :param ndarray user_summary: The user defined ground truth summaries (or summary). + :param str eval_method: The proposed evaluation method; either 'max' (SumMe) or 'avg' (TVSum). + :return: The reduced fscore based on the eval_method + """ + max_len = max(len(predicted_summary), user_summary.shape[1]) + S = np.zeros(max_len, dtype=int) + G = np.zeros(max_len, dtype=int) + S[:len(predicted_summary)] = predicted_summary + + f_scores = [] + for user in range(user_summary.shape[0]): + G[:user_summary.shape[1]] = user_summary[user] + overlapped = S & G + + # Compute precision, recall, f-score + precision = sum(overlapped) / sum(S) + recall = sum(overlapped) / sum(G) + if precision + recall == 0: + f_scores.append(0) + else: + f_score = 2 * precision * recall * 100 / (precision + recall) + f_scores.append(f_score) + + if eval_method == 'max': + return max(f_scores) + else: + return sum(f_scores) / len(f_scores) + + +def calculate_f_score(outputs: Dict, inputs: Dict): + scores = outputs['scores'] + scores = scores.squeeze(0).cpu().numpy().tolist() + user_summary = inputs['user_summary'].cpu().numpy()[0] + sb = inputs['change_points'].cpu().numpy()[0] + n_frames = inputs['n_frames'].cpu().numpy()[0] + positions = inputs['positions'].cpu().numpy()[0] + summary = generate_summary([sb], [scores], [n_frames], [positions])[0] + f_score = evaluate_summary(summary, user_summary, 'avg') + return f_score + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.video_summarization_metric) +class VideoSummarizationMetric(Metric): + """The metric for video summarization task. + """ + + def __init__(self): + self.inputs = [] + self.outputs = [] + + def add(self, outputs: Dict, inputs: Dict): + self.outputs.append(outputs) + self.inputs.append(inputs) + + def evaluate(self): + f_scores = [ + calculate_f_score(output, input) + for output, input in zip(self.outputs, self.inputs) + ] + + return {MetricKeys.FScore: sum(f_scores) / len(f_scores)} diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py new file mode 100644 index 00000000..e7cb2adc --- /dev/null +++ b/modelscope/models/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.error import (AUDIO_IMPORT_ERROR, + TENSORFLOW_IMPORT_WARNING) +from modelscope.utils.import_utils import is_torch_available +from . import audio, cv, multi_modal, nlp +from .base import Head, Model +from .builder import BACKBONES, HEADS, MODELS, build_model + +if is_torch_available(): + from .base import TorchModel, TorchHead diff --git a/modelscope/models/audio/__init__.py b/modelscope/models/audio/__init__.py new file mode 100644 index 00000000..07798cf4 --- /dev/null +++ b/modelscope/models/audio/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from . import ans, asr, kws, tts diff --git a/modelscope/models/audio/aec/__init__.py b/modelscope/models/audio/aec/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/aec/layers/__init__.py b/modelscope/models/audio/aec/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/aec/layers/activations.py b/modelscope/models/audio/aec/layers/activations.py new file mode 100644 index 00000000..f78ad4b5 --- /dev/null +++ b/modelscope/models/audio/aec/layers/activations.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch.nn as nn + +from .layer_base import LayerBase + + +class RectifiedLinear(LayerBase): + + def __init__(self, input_dim, output_dim): + super(RectifiedLinear, self).__init__() + self.dim = input_dim + self.relu = nn.ReLU() + + def forward(self, input): + return self.relu(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + return re_str + + def load_kaldi_nnet(self, instr): + return instr + + +class LogSoftmax(LayerBase): + + def __init__(self, input_dim, output_dim): + super(LogSoftmax, self).__init__() + self.dim = input_dim + self.ls = nn.LogSoftmax() + + def forward(self, input): + return self.ls(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + return re_str + + def load_kaldi_nnet(self, instr): + return instr + + +class Sigmoid(LayerBase): + + def __init__(self, input_dim, output_dim): + super(Sigmoid, self).__init__() + self.dim = input_dim + self.sig = nn.Sigmoid() + + def forward(self, input): + return self.sig(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + return re_str + + def load_kaldi_nnet(self, instr): + return instr diff --git a/modelscope/models/audio/aec/layers/affine_transform.py b/modelscope/models/audio/aec/layers/affine_transform.py new file mode 100644 index 00000000..2de8a03f --- /dev/null +++ b/modelscope/models/audio/aec/layers/affine_transform.py @@ -0,0 +1,80 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch as th +import torch.nn as nn + +from .layer_base import (LayerBase, expect_kaldi_matrix, expect_token_number, + to_kaldi_matrix) + + +class AffineTransform(LayerBase): + + def __init__(self, input_dim, output_dim): + super(AffineTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.linear = nn.Linear(input_dim, output_dim) + + def forward(self, input): + return self.linear(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1 1 0\n' + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += to_kaldi_matrix(x) + return re_str + + def to_raw_nnet(self, fid): + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + x.tofile(fid) + + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + x.tofile(fid) + + def load_kaldi_nnet(self, instr): + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('AffineTransform format error for ') + instr, lr = output + + output = expect_token_number(instr, '') + if output is None: + raise Exception( + 'AffineTransform format error for ') + instr, lr = output + + output = expect_token_number(instr, '') + if output is None: + raise Exception('AffineTransform format error for ') + instr, lr = output + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('AffineTransform format error for parsing matrix') + instr, mat = output + + print(mat.shape) + self.linear.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('AffineTransform format error for parsing matrix') + instr, mat = output + mat = np.squeeze(mat) + self.linear.bias = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + return instr diff --git a/modelscope/models/audio/aec/layers/deep_fsmn.py b/modelscope/models/audio/aec/layers/deep_fsmn.py new file mode 100644 index 00000000..1582b908 --- /dev/null +++ b/modelscope/models/audio/aec/layers/deep_fsmn.py @@ -0,0 +1,180 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .layer_base import (LayerBase, expect_kaldi_matrix, expect_token_number, + to_kaldi_matrix) + + +class DeepFsmn(LayerBase): + + def __init__(self, + input_dim, + output_dim, + lorder=None, + rorder=None, + hidden_size=None, + layer_norm=False, + dropout=0): + super(DeepFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if lorder is None: + return + + self.lorder = lorder + self.rorder = rorder + self.hidden_size = hidden_size + self.layer_norm = layer_norm + + self.linear = nn.Linear(input_dim, hidden_size) + self.norm = nn.LayerNorm(hidden_size) + self.drop1 = nn.Dropout(p=dropout) + self.drop2 = nn.Dropout(p=dropout) + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.conv1 = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], [1, 1], + groups=output_dim, + bias=False) + self.conv2 = nn.Conv2d( + output_dim, + output_dim, [rorder, 1], [1, 1], + groups=output_dim, + bias=False) + + def forward(self, input): + + f1 = F.relu(self.linear(input)) + + f1 = self.drop1(f1) + if self.layer_norm: + f1 = self.norm(f1) + + p1 = self.project(f1) + + x = th.unsqueeze(p1, 1) + + x_per = x.permute(0, 3, 2, 1) + + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + yr = F.pad(x_per, [0, 0, 0, self.rorder]) + yr = yr[:, :, 1:, :] + + out = x_per + self.conv1(y) + self.conv2(yr) + out = self.drop2(out) + + out1 = out.permute(0, 3, 2, 1) + + return input + out1.squeeze() + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n'\ + % (self.output_dim, self.input_dim) + re_str += ' %d %d %d %d 0\n'\ + % (1, self.hidden_size, self.lorder, 1) + lfiters = self.state_dict()['conv1.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + re_str += to_kaldi_matrix(x) + proj_weights = self.state_dict()['project.weight'] + x = proj_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += to_kaldi_matrix(x) + return re_str + + def load_kaldi_nnet(self, instr): + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lr = output + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, hiddensize = output + self.hidden_size = int(hiddensize) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lorder = output + self.lorder = int(lorder) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lstride = output + self.lstride = lstride + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + mat1 = np.fliplr(mat.T).copy() + self.conv1 = nn.Conv2d( + self.output_dim, + self.output_dim, [self.lorder, 1], [1, 1], + groups=self.output_dim, + bias=False) + mat_th = th.from_numpy(mat1).type(th.FloatTensor) + mat_th = mat_th.unsqueeze(1) + mat_th = mat_th.unsqueeze(3) + self.conv1.weight = th.nn.Parameter(mat_th) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + + self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False) + self.linear = nn.Linear(self.input_dim, self.hidden_size) + + self.project.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + self.linear.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + self.linear.bias = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + return instr diff --git a/modelscope/models/audio/aec/layers/layer_base.py b/modelscope/models/audio/aec/layers/layer_base.py new file mode 100644 index 00000000..7c39e5be --- /dev/null +++ b/modelscope/models/audio/aec/layers/layer_base.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import abc +import re + +import numpy as np +import torch.nn as nn + + +def expect_token_number(instr, token): + first_token = re.match(r'^\s*' + token, instr) + if first_token is None: + return None + instr = instr[first_token.end():] + lr = re.match(r'^\s*(-?\d+\.?\d*e?-?\d*?)', instr) + if lr is None: + return None + return instr[lr.end():], lr.groups()[0] + + +def expect_kaldi_matrix(instr): + pos2 = instr.find('[', 0) + pos3 = instr.find(']', pos2) + mat = [] + for stt in instr[pos2 + 1:pos3].split('\n'): + tmp_mat = np.fromstring(stt, dtype=np.float32, sep=' ') + if tmp_mat.size > 0: + mat.append(tmp_mat) + return instr[pos3 + 1:], np.array(mat) + + +def to_kaldi_matrix(np_mat): + """ + function that transform as str numpy mat to standard kaldi str matrix + :param np_mat: numpy mat + :return: str + """ + np.set_printoptions(threshold=np.inf, linewidth=np.nan, suppress=True) + out_str = str(np_mat) + out_str = out_str.replace('[', '') + out_str = out_str.replace(']', '') + return '[ %s ]\n' % out_str + + +class LayerBase(nn.Module, metaclass=abc.ABCMeta): + + def __init__(self): + super(LayerBase, self).__init__() + + @abc.abstractmethod + def to_kaldi_nnet(self): + pass diff --git a/modelscope/models/audio/aec/layers/uni_deep_fsmn.py b/modelscope/models/audio/aec/layers/uni_deep_fsmn.py new file mode 100644 index 00000000..a276db05 --- /dev/null +++ b/modelscope/models/audio/aec/layers/uni_deep_fsmn.py @@ -0,0 +1,484 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .layer_base import (LayerBase, expect_kaldi_matrix, expect_token_number, + to_kaldi_matrix) + + +class SepConv(nn.Module): + + def __init__(self, + in_channels, + filters, + out_channels, + kernel_size=(5, 2), + dilation=(1, 1)): + """ :param kernel_size (time, frequency) + + """ + super(SepConv, self).__init__() + # depthwise + pointwise + self.dconv = nn.Conv2d( + in_channels, + in_channels * filters, + kernel_size, + dilation=dilation, + groups=in_channels) + self.pconv = nn.Conv2d( + in_channels * filters, out_channels, kernel_size=1) + self.padding = dilation[0] * (kernel_size[0] - 1) + + def forward(self, input): + ''' input: [B, C, T, F] + ''' + x = F.pad(input, [0, 0, self.padding, 0]) + x = self.dconv(x) + x = self.pconv(x) + return x + + +class Conv2d(nn.Module): + + def __init__(self, + input_dim, + output_dim, + lorder=20, + rorder=0, + groups=1, + bias=False, + skip_connect=True): + super(Conv2d, self).__init__() + self.lorder = lorder + self.conv = nn.Conv2d( + input_dim, output_dim, [lorder, 1], groups=groups, bias=bias) + self.rorder = rorder + if self.rorder: + self.conv2 = nn.Conv2d( + input_dim, output_dim, [rorder, 1], groups=groups, bias=bias) + self.skip_connect = skip_connect + + def forward(self, input): + # [B, 1, T, F] + x = th.unsqueeze(input, 1) + # [B, F, T, 1] + x_per = x.permute(0, 3, 2, 1) + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + out = self.conv(y) + if self.rorder: + yr = F.pad(x_per, [0, 0, 0, self.rorder]) + yr = yr[:, :, 1:, :] + out += self.conv2(yr) + out = out.permute(0, 3, 2, 1).squeeze(1) + if self.skip_connect: + out = out + input + return out + + +class SelfAttLayer(nn.Module): + + def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None): + super(SelfAttLayer, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if lorder is None: + return + + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.att = nn.Linear(input_dim, lorder, bias=False) + + def forward(self, input): + + f1 = F.relu(self.linear(input)) + + p1 = self.project(f1) + + x = th.unsqueeze(p1, 1) + + x_per = x.permute(0, 3, 2, 1) + + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + + # z [B, F, T, lorder] + z = x_per + for i in range(1, self.lorder): + z = th.cat([z, y[:, :, self.lorder - 1 - i:-i, :]], axis=-1) + + # [B, T, lorder] + att = F.softmax(self.att(input), dim=-1) + att = th.unsqueeze(att, 1) + z = th.sum(z * att, axis=-1) + + out1 = z.permute(0, 2, 1) + + return input + out1 + + +class TFFsmn(nn.Module): + + def __init__(self, + input_dim, + output_dim, + lorder=None, + hidden_size=None, + dilation=1, + layer_norm=False, + dropout=0, + skip_connect=True): + super(TFFsmn, self).__init__() + + self.skip_connect = skip_connect + + self.linear = nn.Linear(input_dim, hidden_size) + self.norm = nn.Identity() + if layer_norm: + self.norm = nn.LayerNorm(input_dim) + self.act = nn.ReLU() + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.conv1 = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], + dilation=[dilation, 1], + groups=output_dim, + bias=False) + self.padding_left = dilation * (lorder - 1) + dorder = 5 + self.conv2 = nn.Conv2d(1, 1, [dorder, 1], bias=False) + self.padding_freq = dorder - 1 + + def forward(self, input): + return self.compute1(input) + + def compute1(self, input): + ''' linear-dconv-relu(norm)-linear-dconv + ''' + x = self.linear(input) + # [B, 1, F, T] + x = th.unsqueeze(x, 1).permute(0, 1, 3, 2) + z = F.pad(x, [0, 0, self.padding_freq, 0]) + z = self.conv2(z) + x + x = z.permute(0, 3, 2, 1).squeeze(-1) + x = self.act(x) + x = self.norm(x) + x = self.project(x) + x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) + # [B, F, T+lorder-1, 1] + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.conv1(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + + return input + out + + +class CNNFsmn(nn.Module): + ''' use cnn to reduce parameters + ''' + + def __init__(self, + input_dim, + output_dim, + lorder=None, + hidden_size=None, + dilation=1, + layer_norm=False, + dropout=0, + skip_connect=True): + super(CNNFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.skip_connect = skip_connect + + if lorder is None: + return + + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + self.act = nn.ReLU() + kernel_size = (3, 8) + stride = (1, 4) + self.conv = nn.Sequential( + nn.ConstantPad2d((stride[1], 0, kernel_size[0] - 1, 0), 0), + nn.Conv2d(1, stride[1], kernel_size=kernel_size, stride=stride)) + + self.dconv = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], + dilation=[dilation, 1], + groups=output_dim, + bias=False) + self.padding_left = dilation * (lorder - 1) + + def forward(self, input): + return self.compute2(input) + + def compute1(self, input): + ''' linear-relu(norm)-conv2d-relu?-dconv + ''' + # [B, T, F] + x = self.linear(input) + x = self.act(x) + x = th.unsqueeze(x, 1) + x = self.conv(x) + # [B, C, T, F] -> [B, 1, T, F] + b, c, t, f = x.shape + x = x.view([b, 1, t, -1]) + x = x.permute(0, 3, 2, 1) + # [B, F, T+lorder-1, 1] + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.dconv(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + return input + out + + def compute2(self, input): + ''' conv2d-relu-linear-relu?-dconv + ''' + x = th.unsqueeze(input, 1) + x = self.conv(x) + x = self.act(x) + # [B, C, T, F] -> [B, T, F] + b, c, t, f = x.shape + x = x.view([b, t, -1]) + x = self.linear(x) + x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.dconv(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + return input + out + + +class UniDeepFsmn(LayerBase): + + def __init__(self, + input_dim, + output_dim, + lorder=None, + hidden_size=None, + dilation=1, + layer_norm=False, + dropout=0, + skip_connect=True): + super(UniDeepFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.skip_connect = skip_connect + + if lorder is None: + return + + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + self.norm = nn.Identity() + if layer_norm: + self.norm = nn.LayerNorm(input_dim) + self.act = nn.ReLU() + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.conv1 = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], + dilation=[dilation, 1], + groups=output_dim, + bias=False) + self.padding_left = dilation * (lorder - 1) + + def forward(self, input): + return self.compute1(input) + + def compute1(self, input): + ''' linear-relu(norm)-linear-dconv + ''' + # [B, T, F] + x = self.linear(input) + x = self.act(x) + x = self.norm(x) + x = self.project(x) + x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) + # [B, F, T+lorder-1, 1] + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.conv1(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + + return input + out + + def compute2(self, input): + ''' linear-dconv-linear-relu(norm) + ''' + x = self.project(input) + x = th.unsqueeze(x, 1).permute(0, 3, 2, 1) + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.conv1(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + x = self.linear(out) + x = self.act(x) + x = self.norm(x) + + return input + x + + def compute3(self, input): + ''' dconv-linear-relu(norm)-linear + ''' + x = th.unsqueeze(input, 1).permute(0, 3, 2, 1) + y = F.pad(x, [0, 0, self.padding_left, 0]) + out = self.conv1(y) + if self.skip_connect: + out = out + x + out = out.permute(0, 3, 2, 1).squeeze() + x = self.linear(out) + x = self.act(x) + x = self.norm(x) + x = self.project(x) + + return input + x + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' \ + % (self.output_dim, self.input_dim) + re_str += ' %d %d %d %d 0\n' \ + % (1, self.hidden_size, self.lorder, 1) + lfiters = self.state_dict()['conv1.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + re_str += to_kaldi_matrix(x) + proj_weights = self.state_dict()['project.weight'] + x = proj_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += to_kaldi_matrix(x) + return re_str + + def to_raw_nnet(self, fid): + lfiters = self.state_dict()['conv1.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + x.tofile(fid) + + proj_weights = self.state_dict()['project.weight'] + x = proj_weights.squeeze().numpy() + x.tofile(fid) + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + x.tofile(fid) + + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + x.tofile(fid) + + def load_kaldi_nnet(self, instr): + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lr = output + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, hiddensize = output + self.hidden_size = int(hiddensize) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lorder = output + self.lorder = int(lorder) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + instr, lstride = output + self.lstride = lstride + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error for ') + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + mat1 = np.fliplr(mat.T).copy() + + self.conv1 = nn.Conv2d( + self.output_dim, + self.output_dim, [self.lorder, 1], [1, 1], + groups=self.output_dim, + bias=False) + + mat_th = th.from_numpy(mat1).type(th.FloatTensor) + mat_th = mat_th.unsqueeze(1) + mat_th = mat_th.unsqueeze(3) + self.conv1.weight = th.nn.Parameter(mat_th) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + + self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False) + self.linear = nn.Linear(self.input_dim, self.hidden_size) + + self.project.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + self.linear.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error for parsing matrix') + instr, mat = output + mat = np.squeeze(mat) + self.linear.bias = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + return instr diff --git a/modelscope/models/audio/aec/network/__init__.py b/modelscope/models/audio/aec/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/aec/network/loss.py b/modelscope/models/audio/aec/network/loss.py new file mode 100644 index 00000000..1f20072a --- /dev/null +++ b/modelscope/models/audio/aec/network/loss.py @@ -0,0 +1,396 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn.functional as F + +from .modulation_loss import (GaborSTRFConv, MelScale, + ModulationDomainLossModule) + +EPS = 1e-8 + + +def compute_mask(mixed_spec, clean_spec, mask_type='psmiam', clip=1): + ''' + stft: (batch, ..., 2) or complex(batch, ...) + y = x + n + ''' + if torch.is_complex(mixed_spec): + yr, yi = mixed_spec.real, mixed_spec.imag + else: + yr, yi = mixed_spec[..., 0], mixed_spec[..., 1] + if torch.is_complex(clean_spec): + xr, xi = clean_spec.real, clean_spec.imag + else: + xr, xi = clean_spec[..., 0], clean_spec[..., 1] + + if mask_type == 'iam': + ymag = torch.sqrt(yr**2 + yi**2) + xmag = torch.sqrt(xr**2 + xi**2) + iam = xmag / (ymag + EPS) + return torch.clamp(iam, 0, 1) + + elif mask_type == 'psm': + ypow = yr**2 + yi**2 + psm = (xr * yr + xi * yi) / (ypow + EPS) + return torch.clamp(psm, 0, 1) + + elif mask_type == 'psmiam': + ypow = yr**2 + yi**2 + psm = (xr * yr + xi * yi) / (ypow + EPS) + ymag = torch.sqrt(yr**2 + yi**2) + xmag = torch.sqrt(xr**2 + xi**2) + iam = xmag / (ymag + EPS) + psmiam = psm * iam + return torch.clamp(psmiam, 0, 1) + + elif mask_type == 'crm': + ypow = yr**2 + yi**2 + mr = (xr * yr + xi * yi) / (ypow + EPS) + mi = (xi * yr - xr * yi) / (ypow + EPS) + mr = torch.clamp(mr, -clip, clip) + mi = torch.clamp(mi, -clip, clip) + return mr, mi + + +def energy_vad(spec, + thdhigh=320 * 600 * 600 * 2, + thdlow=320 * 300 * 300 * 2, + int16=True): + ''' + energy based vad should be accurate enough + spec: (batch, bins, frames, 2) + returns (batch, frames) + ''' + energy = torch.sum(spec[..., 0]**2 + spec[..., 1]**2, dim=1) + vad = energy > thdhigh + idx = torch.logical_and(vad == 0, energy > thdlow) + vad[idx] = 0.5 + return vad + + +def modulation_loss_init(n_fft): + gabor_strf_parameters = torch.load( + './network/gabor_strf_parameters.pt')['state_dict'] + gabor_modulation_kernels = GaborSTRFConv(supn=30, supk=30, nkern=60) + gabor_modulation_kernels.load_state_dict(gabor_strf_parameters) + + modulation_loss_module = ModulationDomainLossModule( + gabor_modulation_kernels.eval()) + for param in modulation_loss_module.parameters(): + param.requires_grad = False + + stft2mel = MelScale( + n_mels=80, sample_rate=16000, n_stft=n_fft // 2 + 1).cuda() + + return modulation_loss_module, stft2mel + + +def mask_loss_function( + loss_func='psm_loss', + loss_type='mse', # ['mse', 'mae', 'comb'] + mask_type='psmiam', + use_mod_loss=False, + use_wav2vec_loss=False, + n_fft=640, + hop_length=320, + EPS=1e-8, + weight=None): + if weight is not None: + print(f'Use loss weight: {weight}') + winlen = n_fft + window = torch.hamming_window(winlen, periodic=False) + + def stft(x, return_complex=False): + # returns [batch, bins, frames, 2] + return torch.stft( + x, + n_fft, + hop_length, + winlen, + window=window.to(x.device), + center=False, + return_complex=return_complex) + + def istft(x, slen): + return torch.istft( + x, + n_fft, + hop_length, + winlen, + window=window.to(x.device), + center=False, + length=slen) + + def mask_loss(targets, masks, nframes): + ''' [Batch, Time, Frequency] + ''' + with torch.no_grad(): + mask_for_loss = torch.ones_like(targets) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + masks = masks * mask_for_loss + targets = targets * mask_for_loss + + if weight is None: + alpha = 1 + else: # for aec ST + alpha = weight - targets + + if loss_type == 'mse': + loss = 0.5 * torch.sum(alpha * torch.pow(targets - masks, 2)) + elif loss_type == 'mae': + loss = torch.sum(alpha * torch.abs(targets - masks)) + else: # mse(mask), mae(mask) approx 1:2 + loss = 0.5 * torch.sum(alpha * torch.pow(targets - masks, 2) + + 0.1 * alpha * torch.abs(targets - masks)) + loss /= torch.sum(nframes) + return loss + + def spectrum_loss(targets, spec, nframes): + ''' [Batch, Time, Frequency, 2] + ''' + with torch.no_grad(): + mask_for_loss = torch.ones_like(targets[..., 0]) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + xr = spec[..., 0] * mask_for_loss + xi = spec[..., 1] * mask_for_loss + yr = targets[..., 0] * mask_for_loss + yi = targets[..., 1] * mask_for_loss + xmag = torch.sqrt(spec[..., 0]**2 + spec[..., 1]**2) * mask_for_loss + ymag = torch.sqrt(targets[..., 0]**2 + + targets[..., 1]**2) * mask_for_loss + + loss1 = torch.sum(torch.pow(xr - yr, 2) + torch.pow(xi - yi, 2)) + loss2 = torch.sum(torch.pow(xmag - ymag, 2)) + + loss = (loss1 + loss2) / torch.sum(nframes) + return loss + + def sa_loss_dlen(mixed, clean, masks, nframes): + yspec = stft(mixed).permute([0, 2, 1, 3]) / 32768 + xspec = stft(clean).permute([0, 2, 1, 3]) / 32768 + with torch.no_grad(): + mask_for_loss = torch.ones_like(xspec[..., 0]) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + emag = ((yspec[..., 0]**2 + yspec[..., 1]**2)**0.15) * (masks**0.3) + xmag = (xspec[..., 0]**2 + xspec[..., 1]**2)**0.15 + emag = emag * mask_for_loss + xmag = xmag * mask_for_loss + + loss = torch.sum(torch.pow(emag - xmag, 2)) / torch.sum(nframes) + return loss + + def psm_vad_loss_dlen(mixed, clean, masks, nframes, subtask=None): + mixed_spec = stft(mixed) + clean_spec = stft(clean) + targets = compute_mask(mixed_spec, clean_spec, mask_type) + # [B, T, F] + targets = targets.permute(0, 2, 1) + + loss = mask_loss(targets, masks, nframes) + + if subtask is not None: + vadtargets = energy_vad(clean_spec) + with torch.no_grad(): + mask_for_loss = torch.ones_like(targets[:, :, 0]) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:] = 0 + subtask = subtask[:, :, 0] * mask_for_loss + vadtargets = vadtargets * mask_for_loss + + loss_vad = F.binary_cross_entropy(subtask, vadtargets) + return loss + loss_vad + return loss + + def modulation_loss(mixed, clean, masks, nframes, subtask=None): + mixed_spec = stft(mixed, True) + clean_spec = stft(clean, True) + enhanced_mag = torch.abs(mixed_spec) + clean_mag = torch.abs(clean_spec) + with torch.no_grad(): + mask_for_loss = torch.ones_like(clean_mag) + for idx, num in enumerate(nframes): + mask_for_loss[idx, :, num:] = 0 + clean_mag = clean_mag * mask_for_loss + enhanced_mag = enhanced_mag * mask_for_loss * masks.permute([0, 2, 1]) + + # Covert to log-mel representation + # (B,T,#mel_channels) + clean_log_mel = torch.log( + torch.transpose(stft2mel(clean_mag**2), 2, 1) + 1e-8) + enhanced_log_mel = torch.log( + torch.transpose(stft2mel(enhanced_mag**2), 2, 1) + 1e-8) + + alpha = compute_mask(mixed_spec, clean_spec, mask_type) + alpha = alpha.permute(0, 2, 1) + loss = 0.05 * modulation_loss_module(enhanced_log_mel, clean_log_mel, + alpha) + loss2 = psm_vad_loss_dlen(mixed, clean, masks, nframes, subtask) + # print(loss.item(), loss2.item()) #approx 1:4 + loss = loss + loss2 + return loss + + def wav2vec_loss(mixed, clean, masks, nframes, subtask=None): + mixed /= 32768 + clean /= 32768 + mixed_spec = stft(mixed) + with torch.no_grad(): + mask_for_loss = torch.ones_like(masks) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + masks_est = masks * mask_for_loss + + estimate = mixed_spec * masks_est.permute([0, 2, 1]).unsqueeze(3) + est_clean = istft(estimate, clean.shape[1]) + loss = wav2vec_loss_module(est_clean, clean) + return loss + + def sisdr_loss_dlen(mixed, + clean, + masks, + nframes, + subtask=None, + zero_mean=True): + mixed_spec = stft(mixed) + with torch.no_grad(): + mask_for_loss = torch.ones_like(masks) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + masks_est = masks * mask_for_loss + + estimate = mixed_spec * masks_est.permute([0, 2, 1]).unsqueeze(3) + est_clean = istft(estimate, clean.shape[1]) + flen = min(clean.shape[1], est_clean.shape[1]) + clean = clean[:, :flen] + est_clean = est_clean[:, :flen] + + # follow asteroid/losses/sdr.py + if zero_mean: + clean = clean - torch.mean(clean, dim=1, keepdim=True) + est_clean = est_clean - torch.mean(est_clean, dim=1, keepdim=True) + + dot = torch.sum(est_clean * clean, dim=1, keepdim=True) + s_clean_energy = torch.sum(clean**2, dim=1, keepdim=True) + EPS + scaled_clean = dot * clean / s_clean_energy + e_noise = est_clean - scaled_clean + + # [batch] + sisdr = torch.sum( + scaled_clean**2, dim=1) / ( + torch.sum(e_noise**2, dim=1) + EPS) + sisdr = -10 * torch.log10(sisdr + EPS) + loss = sisdr.mean() + return loss + + def sisdr_freq_loss_dlen(mixed, clean, masks, nframes, subtask=None): + mixed_spec = stft(mixed) + clean_spec = stft(clean) + with torch.no_grad(): + mask_for_loss = torch.ones_like(masks) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + masks_est = masks * mask_for_loss + + estimate = mixed_spec * masks_est.permute([0, 2, 1]).unsqueeze(3) + + dot_real = estimate[..., 0] * clean_spec[..., 0] + \ + estimate[..., 1] * clean_spec[..., 1] + dot_imag = estimate[..., 0] * clean_spec[..., 1] - \ + estimate[..., 1] * clean_spec[..., 0] + dot = torch.cat([dot_real.unsqueeze(3), dot_imag.unsqueeze(3)], dim=-1) + s_clean_energy = clean_spec[..., 0] ** 2 + \ + clean_spec[..., 1] ** 2 + EPS + scaled_clean = dot * clean_spec / s_clean_energy.unsqueeze(3) + e_noise = estimate - scaled_clean + + # [batch] + scaled_clean_energy = torch.sum( + scaled_clean[..., 0]**2 + scaled_clean[..., 1]**2, dim=1) + e_noise_energy = torch.sum( + e_noise[..., 0]**2 + e_noise[..., 1]**2, dim=1) + sisdr = torch.sum( + scaled_clean_energy, dim=1) / ( + torch.sum(e_noise_energy, dim=1) + EPS) + sisdr = -10 * torch.log10(sisdr + EPS) + loss = sisdr.mean() + return loss + + def crm_loss_dlen(mixed, clean, masks, nframes, subtask=None): + mixed_spec = stft(mixed).permute([0, 2, 1, 3]) + clean_spec = stft(clean).permute([0, 2, 1, 3]) + mixed_spec = mixed_spec / 32768 + clean_spec = clean_spec / 32768 + tgt_mr, tgt_mi = compute_mask(mixed_spec, clean_spec, mask_type='crm') + + D = int(masks.shape[2] / 2) + with torch.no_grad(): + mask_for_loss = torch.ones_like(clean_spec[..., 0]) + for idx, num in enumerate(nframes): + mask_for_loss[idx, num:, :] = 0 + mr = masks[..., :D] * mask_for_loss + mi = masks[..., D:] * mask_for_loss + tgt_mr = tgt_mr * mask_for_loss + tgt_mi = tgt_mi * mask_for_loss + + if weight is None: + alpha = 1 + else: + alpha = weight - tgt_mr + # signal approximation + yr = mixed_spec[..., 0] + yi = mixed_spec[..., 1] + loss1 = torch.sum(alpha * torch.pow((mr * yr - mi * yi) - clean_spec[..., 0], 2)) \ + + torch.sum(alpha * torch.pow((mr * yi + mi * yr) - clean_spec[..., 1], 2)) + # mask approximation + loss2 = torch.sum(alpha * torch.pow(mr - tgt_mr, 2)) \ + + torch.sum(alpha * torch.pow(mi - tgt_mi, 2)) + loss = 0.5 * (loss1 + loss2) / torch.sum(nframes) + return loss + + def crm_miso_loss_dlen(mixed, clean, masks, nframes): + return crm_loss_dlen(mixed[..., 0], clean[..., 0], masks, nframes) + + def mimo_loss_dlen(mixed, clean, masks, nframes): + chs = mixed.shape[-1] + D = masks.shape[2] // chs + loss = psm_vad_loss_dlen(mixed[..., 0], clean[..., 0], masks[..., :D], + nframes) + for ch in range(1, chs): + loss1 = psm_vad_loss_dlen(mixed[..., ch], clean[..., ch], + masks[..., ch * D:ch * D + D], nframes) + loss = loss + loss1 + return loss / chs + + def spec_loss_dlen(mixed, clean, spec, nframes): + clean_spec = stft(clean).permute([0, 2, 1, 3]) + clean_spec = clean_spec / 32768 + + D = spec.shape[2] // 2 + spec_est = torch.cat([spec[..., :D, None], spec[..., D:, None]], + dim=-1) + loss = spectrum_loss(clean_spec, spec_est, nframes) + return loss + + if loss_func == 'psm_vad_loss_dlen': + return psm_vad_loss_dlen + elif loss_func == 'sisdr_loss_dlen': + return sisdr_loss_dlen + elif loss_func == 'sisdr_freq_loss_dlen': + return sisdr_freq_loss_dlen + elif loss_func == 'crm_loss_dlen': + return crm_loss_dlen + elif loss_func == 'modulation_loss': + return modulation_loss + elif loss_func == 'wav2vec_loss': + return wav2vec_loss + elif loss_func == 'mimo_loss_dlen': + return mimo_loss_dlen + elif loss_func == 'spec_loss_dlen': + return spec_loss_dlen + elif loss_func == 'sa_loss_dlen': + return sa_loss_dlen + else: + print('error loss func') + return None diff --git a/modelscope/models/audio/aec/network/modulation_loss.py b/modelscope/models/audio/aec/network/modulation_loss.py new file mode 100644 index 00000000..3017b5c6 --- /dev/null +++ b/modelscope/models/audio/aec/network/modulation_loss.py @@ -0,0 +1,250 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchaudio.transforms import MelScale + + +class ModulationDomainLossModule(torch.nn.Module): + """Modulation-domain loss function developed in [1] for supervised speech enhancement + + In our paper, we used the gabor-based STRF kernels as the modulation kernels and used the log-mel spectrogram + as the input spectrogram representation. + Specific parameter details are in the paper and in the example below + + Parameters + ---------- + modulation_kernels: nn.Module + Differentiable module that transforms a spectrogram representation to the modulation domain + + modulation_domain = modulation_kernels(input_tf_representation) + Input Spectrogram representation (B, T, F) ---> |(M) modulation_kernels|--->Modulation Domain(B, M, T', F') + + norm: boolean + Normalizes the modulation domain representation to be 0 mean across time + + [1] T. Vuong, Y. Xia, and R. M. Stern, “A modulation-domain lossfor neural-network-based real-time + speech enhancement” + Accepted ICASSP 2021, https://arxiv.org/abs/2102.07330 + + + """ + + def __init__(self, modulation_kernels, norm=True): + super(ModulationDomainLossModule, self).__init__() + + self.modulation_kernels = modulation_kernels + self.mse = nn.MSELoss(reduce=False) + self.norm = norm + + def forward(self, enhanced_spect, clean_spect, weight=None): + """Calculate modulation-domain loss + Args: + enhanced_spect (Tensor): spectrogram representation of enhanced signal (B, #frames, #freq_channels). + clean_spect (Tensor): spectrogram representation of clean ground-truth signal (B, #frames, #freq_channels). + Returns: + Tensor: Modulation-domain loss value. + """ + + clean_mod = self.modulation_kernels(clean_spect) + enhanced_mod = self.modulation_kernels(enhanced_spect) + + if self.norm: + mean_clean_mod = torch.mean(clean_mod, dim=2) + mean_enhanced_mod = torch.mean(enhanced_mod, dim=2) + + clean_mod = clean_mod - mean_clean_mod.unsqueeze(2) + enhanced_mod = enhanced_mod - mean_enhanced_mod.unsqueeze(2) + + if weight is None: + alpha = 1 + else: # TF-mask weight + alpha = 1 + torch.sum(weight, dim=-1, keepdim=True).unsqueeze(1) + mod_mse_loss = self.mse(enhanced_mod, clean_mod) * alpha + mod_mse_loss = torch.mean( + torch.sum(mod_mse_loss, dim=(1, 2, 3)) + / torch.sum(clean_mod**2, dim=(1, 2, 3))) + + return mod_mse_loss + + +class ModulationDomainNCCLossModule(torch.nn.Module): + """Modulation-domain loss function developed in [1] for supervised speech enhancement + + # Speech Intelligibility Prediction Using Spectro-Temporal Modulation Analysis - based off of this + + In our paper, we used the gabor-based STRF kernels as the modulation kernels and used the log-mel spectrogram + as the input spectrogram representation. + Specific parameter details are in the paper and in the example below + + Parameters + ---------- + modulation_kernels: nn.Module + Differentiable module that transforms a spectrogram representation to the modulation domain + + modulation_domain = modulation_kernels(input_tf_representation) + Input Spectrogram representation(B, T, F) --- (M) modulation_kernels---> Modulation Domain(B, M, T', F') + + [1] + + """ + + def __init__(self, modulation_kernels): + super(ModulationDomainNCCLossModule, self).__init__() + + self.modulation_kernels = modulation_kernels + self.mse = nn.MSELoss(reduce=False) + + def forward(self, enhanced_spect, clean_spect): + """Calculate modulation-domain loss + Args: + enhanced_spect (Tensor): spectrogram representation of enhanced signal (B, #frames, #freq_channels). + clean_spect (Tensor): spectrogram representation of clean ground-truth signal (B, #frames, #freq_channels). + Returns: + Tensor: Modulation-domain loss value. + """ + + clean_mod = self.modulation_kernels(clean_spect) + enhanced_mod = self.modulation_kernels(enhanced_spect) + mean_clean_mod = torch.mean(clean_mod, dim=2) + mean_enhanced_mod = torch.mean(enhanced_mod, dim=2) + + normalized_clean = clean_mod - mean_clean_mod.unsqueeze(2) + normalized_enhanced = enhanced_mod - mean_enhanced_mod.unsqueeze(2) + + inner_product = torch.sum( + normalized_clean * normalized_enhanced, dim=2) + normalized_denom = (torch.sum( + normalized_clean * normalized_clean, dim=2))**.5 * (torch.sum( + normalized_enhanced * normalized_enhanced, dim=2))**.5 + + ncc = inner_product / normalized_denom + mod_mse_loss = torch.mean((ncc - 1.0)**2) + + return mod_mse_loss + + +class GaborSTRFConv(nn.Module): + """Gabor-STRF-based cross-correlation kernel.""" + + def __init__(self, + supn, + supk, + nkern, + rates=None, + scales=None, + norm_strf=True, + real_only=False): + """Instantiate a Gabor-based STRF convolution layer. + Parameters + ---------- + supn: int + Time support in number of frames. Also the window length. + supk: int + Frequency support in number of channels. Also the window length. + nkern: int + Number of kernels, each with a learnable rate and scale. + rates: list of float, None + Initial values for temporal modulation. + scales: list of float, None + Initial values for spectral modulation. + norm_strf: Boolean + Normalize STRF kernels to be unit length + real_only: Boolean + If True, nkern REAL gabor-STRF kernels + If False, nkern//2 REAL and nkern//2 IMAGINARY gabor-STRF kernels + """ + super(GaborSTRFConv, self).__init__() + self.numN = supn + self.numK = supk + self.numKern = nkern + self.real_only = real_only + self.norm_strf = norm_strf + + if not real_only: + nkern = nkern // 2 + + if supk % 2 == 0: # force odd number + supk += 1 + self.supk = torch.arange(supk, dtype=torch.float32) + if supn % 2 == 0: # force odd number + supn += 1 + self.supn = torch.arange(supn, dtype=self.supk.dtype) + self.padding = (supn // 2, supk // 2) + # Set up learnable parameters + # for param in (rates, scales): + # assert (not param) or len(param) == nkern + if not rates: + + rates = torch.rand(nkern) * math.pi / 2.0 + + if not scales: + + scales = (torch.rand(nkern) * 2.0 - 1.0) * math.pi / 2.0 + + self.rates_ = nn.Parameter(torch.Tensor(rates)) + self.scales_ = nn.Parameter(torch.Tensor(scales)) + + def strfs(self): + """Make STRFs using the current parameters.""" + + if self.supn.device != self.rates_.device: # for first run + self.supn = self.supn.to(self.rates_.device) + self.supk = self.supk.to(self.rates_.device) + n0, k0 = self.padding + + nwind = .5 - .5 * \ + torch.cos(2 * math.pi * (self.supn + 1) / (len(self.supn) + 1)) + kwind = .5 - .5 * \ + torch.cos(2 * math.pi * (self.supk + 1) / (len(self.supk) + 1)) + + new_wind = torch.matmul((nwind).unsqueeze(-1), (kwind).unsqueeze(0)) + + n_n_0 = self.supn - n0 + k_k_0 = self.supk - k0 + n_mult = torch.matmul( + n_n_0.unsqueeze(1), + torch.ones((1, len(self.supk))).type(torch.FloatTensor).to( + self.rates_.device)) + k_mult = torch.matmul( + torch.ones((len(self.supn), + 1)).type(torch.FloatTensor).to(self.rates_.device), + k_k_0.unsqueeze(0)) + + inside = self.rates_.unsqueeze(1).unsqueeze( + 1) * n_mult + self.scales_.unsqueeze(1).unsqueeze(1) * k_mult + real_strf = torch.cos(inside) * new_wind.unsqueeze(0) + + if self.real_only: + final_strf = real_strf + + else: + imag_strf = torch.sin(inside) * new_wind.unsqueeze(0) + final_strf = torch.cat([real_strf, imag_strf], dim=0) + + if self.norm_strf: + final_strf = final_strf / (torch.sum( + final_strf**2, dim=(1, 2)).unsqueeze(1).unsqueeze(2))**.5 + + return final_strf + + def forward(self, sigspec): + """Forward pass a batch of (real) spectra [Batch x Time x Frequency].""" + if len(sigspec.shape) == 2: # expand batch dimension if single eg + sigspec = sigspec.unsqueeze(0) + strfs = self.strfs().unsqueeze(1).type_as(sigspec) + out = F.conv2d(sigspec.unsqueeze(1), strfs, padding=self.padding) + return out + + def __repr__(self): + """Gabor filter""" + report = """ + +++++ Gabor Filter Kernels [{}], supn[{}], supk[{}] real only [{}] norm strf [{}] +++++ + + """.format(self.numKern, self.numN, self.numK, self.real_only, + self.norm_strf) + + return report diff --git a/modelscope/models/audio/aec/network/se_net.py b/modelscope/models/audio/aec/network/se_net.py new file mode 100644 index 00000000..40639605 --- /dev/null +++ b/modelscope/models/audio/aec/network/se_net.py @@ -0,0 +1,487 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.audio.aec.layers.activations import (RectifiedLinear, + Sigmoid) +from modelscope.models.audio.aec.layers.affine_transform import AffineTransform +from modelscope.models.audio.aec.layers.deep_fsmn import DeepFsmn +from modelscope.models.audio.aec.layers.uni_deep_fsmn import (Conv2d, + UniDeepFsmn) + + +class MaskNet(nn.Module): + + def __init__(self, + indim, + outdim, + layers=9, + hidden_dim=128, + hidden_dim2=None, + lorder=20, + rorder=0, + dilation=1, + layer_norm=False, + dropout=0, + crm=False, + vad=False, + linearout=False): + super(MaskNet, self).__init__() + + self.linear1 = AffineTransform(indim, hidden_dim) + self.relu = RectifiedLinear(hidden_dim, hidden_dim) + if hidden_dim2 is None: + hidden_dim2 = hidden_dim + + if rorder == 0: + repeats = [ + UniDeepFsmn( + hidden_dim, + hidden_dim, + lorder, + hidden_dim2, + dilation=dilation, + layer_norm=layer_norm, + dropout=dropout) for i in range(layers) + ] + else: + repeats = [ + DeepFsmn( + hidden_dim, + hidden_dim, + lorder, + rorder, + hidden_dim2, + layer_norm=layer_norm, + dropout=dropout) for i in range(layers) + ] + self.deepfsmn = nn.Sequential(*repeats) + + self.linear2 = AffineTransform(hidden_dim, outdim) + + self.crm = crm + if self.crm: + self.sig = nn.Tanh() + else: + self.sig = Sigmoid(outdim, outdim) + + self.vad = vad + if self.vad: + self.linear3 = AffineTransform(hidden_dim, 1) + + self.layers = layers + self.linearout = linearout + if self.linearout and self.vad: + print('Warning: not supported nnet') + + def forward(self, feat, ctl=None): + x1 = self.linear1(feat) + x2 = self.relu(x1) + if ctl is not None: + ctl = min(ctl, self.layers - 1) + for i in range(ctl): + x2 = self.deepfsmn[i](x2) + mask = self.sig(self.linear2(x2)) + if self.vad: + vad = torch.sigmoid(self.linear3(x2)) + return mask, vad + else: + return mask + x3 = self.deepfsmn(x2) + if self.linearout: + return self.linear2(x3) + mask = self.sig(self.linear2(x3)) + if self.vad: + vad = torch.sigmoid(self.linear3(x3)) + return mask, vad + else: + return mask + + def to_kaldi_nnet(self): + re_str = '' + re_str += '\n' + re_str += self.linear1.to_kaldi_nnet() + re_str += self.relu.to_kaldi_nnet() + for dfsmn in self.deepfsmn: + re_str += dfsmn.to_kaldi_nnet() + re_str += self.linear2.to_kaldi_nnet() + re_str += self.sig.to_kaldi_nnet() + re_str += '\n' + + return re_str + + def to_raw_nnet(self, fid): + self.linear1.to_raw_nnet(fid) + for dfsmn in self.deepfsmn: + dfsmn.to_raw_nnet(fid) + self.linear2.to_raw_nnet(fid) + + +class StageNet(nn.Module): + + def __init__(self, + indim, + outdim, + layers=9, + layers2=6, + hidden_dim=128, + lorder=20, + rorder=0, + layer_norm=False, + dropout=0, + crm=False, + vad=False, + linearout=False): + super(StageNet, self).__init__() + + self.stage1 = nn.ModuleList() + self.stage2 = nn.ModuleList() + layer = nn.Sequential(nn.Linear(indim, hidden_dim), nn.ReLU()) + self.stage1.append(layer) + for i in range(layers): + layer = UniDeepFsmn( + hidden_dim, + hidden_dim, + lorder, + hidden_dim, + layer_norm=layer_norm, + dropout=dropout) + self.stage1.append(layer) + layer = nn.Sequential(nn.Linear(hidden_dim, 321), nn.Sigmoid()) + self.stage1.append(layer) + # stage2 + layer = nn.Sequential(nn.Linear(321 + indim, hidden_dim), nn.ReLU()) + self.stage2.append(layer) + for i in range(layers2): + layer = UniDeepFsmn( + hidden_dim, + hidden_dim, + lorder, + hidden_dim, + layer_norm=layer_norm, + dropout=dropout) + self.stage2.append(layer) + layer = nn.Sequential( + nn.Linear(hidden_dim, outdim), + nn.Sigmoid() if not crm else nn.Tanh()) + self.stage2.append(layer) + self.crm = crm + self.vad = vad + self.linearout = linearout + self.window = torch.hamming_window(640, periodic=False).cuda() + self.freezed = False + + def freeze(self): + if not self.freezed: + for param in self.stage1.parameters(): + param.requires_grad = False + self.freezed = True + print('freezed stage1') + + def forward(self, feat, mixture, ctl=None): + if ctl == 'off': + x = feat + for i in range(len(self.stage1)): + x = self.stage1[i](x) + return x + else: + self.freeze() + x = feat + for i in range(len(self.stage1)): + x = self.stage1[i](x) + + spec = torch.stft( + mixture / 32768, + 640, + 320, + 640, + self.window, + center=False, + return_complex=True) + spec = torch.view_as_real(spec).permute([0, 2, 1, 3]) + specmag = torch.sqrt(spec[..., 0]**2 + spec[..., 1]**2) + est = x * specmag + y = torch.cat([est, feat], dim=-1) + for i in range(len(self.stage2)): + y = self.stage2[i](y) + return y + + +class Unet(nn.Module): + + def __init__(self, + indim, + outdim, + layers=9, + dims=[256] * 4, + lorder=20, + rorder=0, + dilation=1, + layer_norm=False, + dropout=0, + crm=False, + vad=False, + linearout=False): + super(Unet, self).__init__() + + self.linear1 = AffineTransform(indim, dims[0]) + self.relu = RectifiedLinear(dims[0], dims[0]) + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + for i in range(len(dims) - 1): + layer = nn.Sequential( + nn.Linear(dims[i], dims[i + 1]), nn.ReLU(), + nn.Linear(dims[i + 1], dims[i + 1], bias=False), + Conv2d( + dims[i + 1], + dims[i + 1], + lorder, + groups=dims[i + 1], + skip_connect=True)) + self.encoder.append(layer) + for i in range(len(dims) - 1, 0, -1): + layer = nn.Sequential( + nn.Linear(dims[i] * 2, dims[i - 1]), nn.ReLU(), + nn.Linear(dims[i - 1], dims[i - 1], bias=False), + Conv2d( + dims[i - 1], + dims[i - 1], + lorder, + groups=dims[i - 1], + skip_connect=True)) + self.decoder.append(layer) + self.tf = nn.ModuleList() + for i in range(layers - 2 * (len(dims) - 1)): + layer = nn.Sequential( + nn.Linear(dims[-1], dims[-1]), nn.ReLU(), + nn.Linear(dims[-1], dims[-1], bias=False), + Conv2d( + dims[-1], + dims[-1], + lorder, + groups=dims[-1], + skip_connect=True)) + self.tf.append(layer) + + self.linear2 = AffineTransform(dims[0], outdim) + self.crm = crm + self.act = nn.Tanh() if self.crm else nn.Sigmoid() + self.vad = False + self.layers = layers + self.linearout = linearout + + def forward(self, x, ctl=None): + x = self.linear1(x) + x = self.relu(x) + + encoder_out = [] + for i in range(len(self.encoder)): + x = self.encoder[i](x) + encoder_out.append(x) + for i in range(len(self.tf)): + x = self.tf[i](x) + for i in range(len(self.decoder)): + x = torch.cat([x, encoder_out[-1 - i]], dim=-1) + x = self.decoder[i](x) + + x = self.linear2(x) + if self.linearout: + return x + return self.act(x) + + +class BranchNet(nn.Module): + + def __init__(self, + indim, + outdim, + layers=9, + hidden_dim=256, + lorder=20, + rorder=0, + dilation=1, + layer_norm=False, + dropout=0, + crm=False, + vad=False, + linearout=False): + super(BranchNet, self).__init__() + + self.linear1 = AffineTransform(indim, hidden_dim) + self.relu = RectifiedLinear(hidden_dim, hidden_dim) + + self.convs = nn.ModuleList() + self.deepfsmn = nn.ModuleList() + self.FREQ = nn.ModuleList() + self.TIME = nn.ModuleList() + self.br1 = nn.ModuleList() + self.br2 = nn.ModuleList() + for i in range(layers): + ''' + layer = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim, bias=False), + Conv2d(hidden_dim, hidden_dim, lorder, + groups=hidden_dim, skip_connect=True) + ) + self.deepfsmn.append(layer) + ''' + layer = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) + self.FREQ.append(layer) + ''' + layer = nn.GRU(hidden_dim, hidden_dim, + batch_first=True, + bidirectional=False) + self.TIME.append(layer) + + layer = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim//2, bias=False), + Conv2d(hidden_dim//2, hidden_dim//2, lorder, + groups=hidden_dim//2, skip_connect=True) + ) + self.br1.append(layer) + layer = nn.GRU(hidden_dim, hidden_dim//2, + batch_first=True, + bidirectional=False) + self.br2.append(layer) + ''' + + self.linear2 = AffineTransform(hidden_dim, outdim) + self.crm = crm + self.act = nn.Tanh() if self.crm else nn.Sigmoid() + self.vad = False + self.layers = layers + self.linearout = linearout + + def forward(self, x, ctl=None): + return self.forward_branch(x) + + def forward_sepconv(self, x): + x = torch.unsqueeze(x, 1) + for i in range(len(self.convs)): + x = self.convs[i](x) + x = F.relu(x) + B, C, H, W = x.shape + x = x.permute(0, 2, 1, 3) + x = torch.reshape(x, [B, H, C * W]) + x = self.linear1(x) + x = self.relu(x) + for i in range(self.layers): + x = self.deepfsmn[i](x) + x + x = self.linear2(x) + return self.act(x) + + def forward_branch(self, x): + x = self.linear1(x) + x = self.relu(x) + for i in range(self.layers): + z = self.FREQ[i](x) + x = z + x + x = self.linear2(x) + if self.linearout: + return x + return self.act(x) + + +class TACNet(nn.Module): + ''' transform average concatenate for ad hoc dr + ''' + + def __init__(self, + indim, + outdim, + layers=9, + hidden_dim=128, + lorder=20, + rorder=0, + crm=False, + vad=False, + linearout=False): + super(TACNet, self).__init__() + + self.linear1 = AffineTransform(indim, hidden_dim) + self.relu = RectifiedLinear(hidden_dim, hidden_dim) + + if rorder == 0: + repeats = [ + UniDeepFsmn(hidden_dim, hidden_dim, lorder, hidden_dim) + for i in range(layers) + ] + else: + repeats = [ + DeepFsmn(hidden_dim, hidden_dim, lorder, rorder, hidden_dim) + for i in range(layers) + ] + self.deepfsmn = nn.Sequential(*repeats) + + self.ch_transform = nn.ModuleList([]) + self.ch_average = nn.ModuleList([]) + self.ch_concat = nn.ModuleList([]) + for i in range(layers): + self.ch_transform.append( + nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.PReLU())) + self.ch_average.append( + nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.PReLU())) + self.ch_concat.append( + nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), nn.PReLU())) + + self.linear2 = AffineTransform(hidden_dim, outdim) + + self.crm = crm + if self.crm: + self.sig = nn.Tanh() + else: + self.sig = Sigmoid(outdim, outdim) + + self.vad = vad + if self.vad: + self.linear3 = AffineTransform(hidden_dim, 1) + + self.layers = layers + self.linearout = linearout + if self.linearout and self.vad: + print('Warning: not supported nnet') + + def forward(self, feat, ctl=None): + B, T, F = feat.shape + # assume 4ch + ch = 4 + zlist = [] + for c in range(ch): + z = self.linear1(feat[..., c * (F // 4):(c + 1) * (F // 4)]) + z = self.relu(z) + zlist.append(z) + for i in range(self.layers): + # forward + for c in range(ch): + zlist[c] = self.deepfsmn[i](zlist[c]) + + # transform + olist = [] + for c in range(ch): + z = self.ch_transform[i](zlist[c]) + olist.append(z) + # average + avg = 0 + for c in range(ch): + avg = avg + olist[c] + avg = avg / ch + avg = self.ch_average[i](avg) + # concate + for c in range(ch): + tac = torch.cat([olist[c], avg], dim=-1) + tac = self.ch_concat[i](tac) + zlist[c] = zlist[c] + tac + + for c in range(ch): + zlist[c] = self.sig(self.linear2(zlist[c])) + mask = torch.cat(zlist, dim=-1) + return mask + + def to_kaldi_nnet(self): + pass diff --git a/modelscope/models/audio/ans/__init__.py b/modelscope/models/audio/ans/__init__.py new file mode 100644 index 00000000..afcdf314 --- /dev/null +++ b/modelscope/models/audio/ans/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .frcrn import FRCRNDecorator + +else: + _import_structure = { + 'frcrn': ['FRCRNDecorator'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/audio/ans/complex_nn.py b/modelscope/models/audio/ans/complex_nn.py new file mode 100644 index 00000000..beaa3187 --- /dev/null +++ b/modelscope/models/audio/ans/complex_nn.py @@ -0,0 +1,255 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# +# The implementation of class ComplexConv2d, ComplexConvTranspose2d and +# ComplexBatchNorm2d here is modified based on Jongho Choi(sweetcocoa@snu.ac.kr +# / Seoul National Univ., ESTsoft ) and publicly available at +# https://github.com/sweetcocoa/DeepComplexUNetPyTorch + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UniDeepFsmn(nn.Module): + + def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None): + super(UniDeepFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if lorder is None: + return + + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.conv1 = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], [1, 1], + groups=output_dim, + bias=False) + + def forward(self, input): + r""" + + Args: + input: torch with shape: batch (b) x sequence(T) x feature (h) + + Returns: + batch (b) x channel (c) x sequence(T) x feature (h) + """ + f1 = F.relu(self.linear(input)) + + p1 = self.project(f1) + + x = torch.unsqueeze(p1, 1) + # x: batch (b) x channel (c) x sequence(T) x feature (h) + x_per = x.permute(0, 3, 2, 1) + # x_per: batch (b) x feature (h) x sequence(T) x channel (c) + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + + out = x_per + self.conv1(y) + + out1 = out.permute(0, 3, 2, 1) + # out1: batch (b) x channel (c) x sequence(T) x feature (h) + return input + out1.squeeze() + + +class ComplexUniDeepFsmn(nn.Module): + + def __init__(self, nIn, nHidden=128, nOut=128): + super(ComplexUniDeepFsmn, self).__init__() + + self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + self.fsmn_re_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden) + self.fsmn_im_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden) + + def forward(self, x): + r""" + + Args: + x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2] + + Returns: + [batch, feature, sequence, 2], eg: [6, 99, 1024, 2] + """ + # + b, c, h, T, d = x.size() + x = torch.reshape(x, (b, c * h, T, d)) + # x: [b,h,T,2], [6, 256, 106, 2] + x = torch.transpose(x, 1, 2) + # x: [b,T,h,2], [6, 106, 256, 2] + + real_L1 = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1]) + imaginary_L1 = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0]) + # GRU output: [99, 6, 128] + real = self.fsmn_re_L2(real_L1) - self.fsmn_im_L2(imaginary_L1) + imaginary = self.fsmn_re_L2(imaginary_L1) + self.fsmn_im_L2(real_L1) + # output: [b,T,h,2], [99, 6, 1024, 2] + output = torch.stack((real, imaginary), dim=-1) + + # output: [b,h,T,2], [6, 99, 1024, 2] + output = torch.transpose(output, 1, 2) + output = torch.reshape(output, (b, c, h, T, d)) + + return output + + +class ComplexUniDeepFsmn_L1(nn.Module): + + def __init__(self, nIn, nHidden=128, nOut=128): + super(ComplexUniDeepFsmn_L1, self).__init__() + self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + + def forward(self, x): + r""" + + Args: + x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2] + """ + b, c, h, T, d = x.size() + # x : [b,T,h,c,2] + x = torch.transpose(x, 1, 3) + x = torch.reshape(x, (b * T, h, c, d)) + + real = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1]) + imaginary = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0]) + # output: [b*T,h,c,2], [6*106, h, 256, 2] + output = torch.stack((real, imaginary), dim=-1) + + output = torch.reshape(output, (b, T, h, c, d)) + output = torch.transpose(output, 1, 3) + return output + + +class ComplexConv2d(nn.Module): + # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + **kwargs): + super().__init__() + + # Model components + self.conv_re = nn.Conv2d( + in_channel, + out_channel, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs) + self.conv_im = nn.Conv2d( + in_channel, + out_channel, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs) + + def forward(self, x): + r""" + + Args: + x: torch with shape: [batch,channel,axis1,axis2,2] + """ + real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1]) + imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0]) + output = torch.stack((real, imaginary), dim=-1) + return output + + +class ComplexConvTranspose2d(nn.Module): + + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + output_padding=0, + dilation=1, + groups=1, + bias=True, + **kwargs): + super().__init__() + + # Model components + self.tconv_re = nn.ConvTranspose2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + **kwargs) + self.tconv_im = nn.ConvTranspose2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + **kwargs) + + def forward(self, x): # shpae of x : [batch,channel,axis1,axis2,2] + real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1]) + imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0]) + output = torch.stack((real, imaginary), dim=-1) + return output + + +class ComplexBatchNorm2d(nn.Module): + + def __init__(self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + **kwargs): + super().__init__() + self.bn_re = nn.BatchNorm2d( + num_features=num_features, + momentum=momentum, + affine=affine, + eps=eps, + track_running_stats=track_running_stats, + **kwargs) + self.bn_im = nn.BatchNorm2d( + num_features=num_features, + momentum=momentum, + affine=affine, + eps=eps, + track_running_stats=track_running_stats, + **kwargs) + + def forward(self, x): + real = self.bn_re(x[..., 0]) + imag = self.bn_im(x[..., 1]) + output = torch.stack((real, imag), dim=-1) + return output diff --git a/modelscope/models/audio/ans/conv_stft.py b/modelscope/models/audio/ans/conv_stft.py new file mode 100644 index 00000000..4b393a4c --- /dev/null +++ b/modelscope/models/audio/ans/conv_stft.py @@ -0,0 +1,113 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.signal import get_window + + +def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): + if win_type == 'None' or win_type is None: + window = np.ones(win_len) + else: + window = get_window(win_type, win_len, fftbins=True)**0.5 + + N = fft_len + fourier_basis = np.fft.rfft(np.eye(N))[:win_len] + real_kernel = np.real(fourier_basis) + imag_kernel = np.imag(fourier_basis) + kernel = np.concatenate([real_kernel, imag_kernel], 1).T + + if invers: + kernel = np.linalg.pinv(kernel).T + + kernel = kernel * window + kernel = kernel[:, None, :] + return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy( + window[None, :, None].astype(np.float32)) + + +class ConvSTFT(nn.Module): + + def __init__(self, + win_len, + win_inc, + fft_len=None, + win_type='hamming', + feature_type='real', + fix=True): + super(ConvSTFT, self).__init__() + + if fft_len is None: + self.fft_len = np.int(2**np.ceil(np.log2(win_len))) + else: + self.fft_len = fft_len + + kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) + self.weight = nn.Parameter(kernel, requires_grad=(not fix)) + self.feature_type = feature_type + self.stride = win_inc + self.win_len = win_len + self.dim = self.fft_len + + def forward(self, inputs): + if inputs.dim() == 2: + inputs = torch.unsqueeze(inputs, 1) + + outputs = F.conv1d(inputs, self.weight, stride=self.stride) + + if self.feature_type == 'complex': + return outputs + else: + dim = self.dim // 2 + 1 + real = outputs[:, :dim, :] + imag = outputs[:, dim:, :] + mags = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag, real) + return mags, phase + + +class ConviSTFT(nn.Module): + + def __init__(self, + win_len, + win_inc, + fft_len=None, + win_type='hamming', + feature_type='real', + fix=True): + super(ConviSTFT, self).__init__() + if fft_len is None: + self.fft_len = np.int(2**np.ceil(np.log2(win_len))) + else: + self.fft_len = fft_len + kernel, window = init_kernels( + win_len, win_inc, self.fft_len, win_type, invers=True) + self.weight = nn.Parameter(kernel, requires_grad=(not fix)) + self.feature_type = feature_type + self.win_type = win_type + self.win_len = win_len + self.win_inc = win_inc + self.stride = win_inc + self.dim = self.fft_len + self.register_buffer('window', window) + self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) + + def forward(self, inputs, phase=None): + """ + Args: + inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) + phase: [B, N//2+1, T] (if not none) + """ + + if phase is not None: + real = inputs * torch.cos(phase) + imag = inputs * torch.sin(phase) + inputs = torch.cat([real, imag], 1) + outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) + + # this is from torch-stft: https://github.com/pseeth/torch-stft + t = self.window.repeat(1, 1, inputs.size(-1))**2 + coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) + outputs = outputs / (coff + 1e-8) + return outputs diff --git a/modelscope/models/audio/ans/frcrn.py b/modelscope/models/audio/ans/frcrn.py new file mode 100644 index 00000000..b74fc273 --- /dev/null +++ b/modelscope/models/audio/ans/frcrn.py @@ -0,0 +1,280 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .conv_stft import ConviSTFT, ConvSTFT +from .unet import UNet + + +@MODELS.register_module( + Tasks.acoustic_noise_suppression, + module_name=Models.speech_frcrn_ans_cirm_16k) +class FRCRNDecorator(TorchModel): + r""" A decorator of FRCRN for integrating into modelscope framework """ + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the frcrn model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + self.model = FRCRN(*args, **kwargs) + model_bin_file = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + if os.path.exists(model_bin_file): + checkpoint = torch.load( + model_bin_file, map_location=torch.device('cpu')) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + # the new trained model by user is based on FRCRNDecorator + self.load_state_dict(checkpoint['state_dict']) + else: + # The released model on Modelscope is based on FRCRN + self.model.load_state_dict(checkpoint, strict=False) + + def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + result_list = self.model.forward(inputs['noisy']) + output = { + 'spec_l1': result_list[0], + 'wav_l1': result_list[1], + 'mask_l1': result_list[2], + 'spec_l2': result_list[3], + 'wav_l2': result_list[4], + 'mask_l2': result_list[5] + } + if 'clean' in inputs: + mix_result = self.model.loss( + inputs['noisy'], inputs['clean'], result_list, mode='Mix') + output.update(mix_result) + sisnr_result = self.model.loss( + inputs['noisy'], inputs['clean'], result_list, mode='SiSNR') + output.update(sisnr_result) + # logger hooker will use items under 'log_vars' + output['log_vars'] = {k: mix_result[k].item() for k in mix_result} + output['log_vars'].update( + {k: sisnr_result[k].item() + for k in sisnr_result}) + return output + + +class FRCRN(nn.Module): + r""" Frequency Recurrent CRN """ + + def __init__(self, + complex, + model_complexity, + model_depth, + log_amp, + padding_mode, + win_len=400, + win_inc=100, + fft_len=512, + win_type='hanning', + **kwargs): + r""" + Args: + complex: Whether to use complex networks. + model_complexity: define the model complexity with the number of layers + model_depth: Only two options are available : 10, 20 + log_amp: Whether to use log amplitude to estimate signals + padding_mode: Encoder's convolution filter. 'zeros', 'reflect' + win_len: length of window used for defining one frame of sample points + win_inc: length of window shifting (equivalent to hop_size) + fft_len: number of Short Time Fourier Transform (STFT) points + win_type: windowing type used in STFT, eg. 'hanning', 'hamming' + """ + super().__init__() + self.feat_dim = fft_len // 2 + 1 + + self.win_len = win_len + self.win_inc = win_inc + self.fft_len = fft_len + self.win_type = win_type + + fix = True + self.stft = ConvSTFT( + self.win_len, + self.win_inc, + self.fft_len, + self.win_type, + feature_type='complex', + fix=fix) + self.istft = ConviSTFT( + self.win_len, + self.win_inc, + self.fft_len, + self.win_type, + feature_type='complex', + fix=fix) + self.unet = UNet( + 1, + complex=complex, + model_complexity=model_complexity, + model_depth=model_depth, + padding_mode=padding_mode) + self.unet2 = UNet( + 1, + complex=complex, + model_complexity=model_complexity, + model_depth=model_depth, + padding_mode=padding_mode) + + def forward(self, inputs): + out_list = [] + # [B, D*2, T] + cmp_spec = self.stft(inputs) + # [B, 1, D*2, T] + cmp_spec = torch.unsqueeze(cmp_spec, 1) + + # to [B, 2, D, T] real_part/imag_part + cmp_spec = torch.cat([ + cmp_spec[:, :, :self.feat_dim, :], + cmp_spec[:, :, self.feat_dim:, :], + ], 1) + + # [B, 2, D, T] + cmp_spec = torch.unsqueeze(cmp_spec, 4) + # [B, 1, D, T, 2] + cmp_spec = torch.transpose(cmp_spec, 1, 4) + unet1_out = self.unet(cmp_spec) + cmp_mask1 = torch.tanh(unet1_out) + unet2_out = self.unet2(unet1_out) + cmp_mask2 = torch.tanh(unet2_out) + est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1) + out_list.append(est_spec) + out_list.append(est_wav) + out_list.append(est_mask) + cmp_mask2 = cmp_mask2 + cmp_mask1 + est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2) + out_list.append(est_spec) + out_list.append(est_wav) + out_list.append(est_mask) + return out_list + + def apply_mask(self, cmp_spec, cmp_mask): + est_spec = torch.cat([ + cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 0] + - cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 1], + cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 1] + + cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 0] + ], 1) + est_spec = torch.cat([est_spec[:, 0, :, :], est_spec[:, 1, :, :]], 1) + cmp_mask = torch.squeeze(cmp_mask, 1) + cmp_mask = torch.cat([cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], 1) + + est_wav = self.istft(est_spec) + est_wav = torch.squeeze(est_wav, 1) + return est_spec, est_wav, cmp_mask + + def get_params(self, weight_decay=0.0): + # add L2 penalty + weights, biases = [], [] + for name, param in self.named_parameters(): + if 'bias' in name: + biases += [param] + else: + weights += [param] + params = [{ + 'params': weights, + 'weight_decay': weight_decay, + }, { + 'params': biases, + 'weight_decay': 0.0, + }] + return params + + def loss(self, noisy, labels, out_list, mode='Mix'): + if mode == 'SiSNR': + count = 0 + while count < len(out_list): + est_spec = out_list[count] + count = count + 1 + est_wav = out_list[count] + count = count + 1 + est_mask = out_list[count] + count = count + 1 + if count != 3: + loss = self.loss_1layer(noisy, est_spec, est_wav, labels, + est_mask, mode) + return dict(sisnr=loss) + + elif mode == 'Mix': + count = 0 + while count < len(out_list): + est_spec = out_list[count] + count = count + 1 + est_wav = out_list[count] + count = count + 1 + est_mask = out_list[count] + count = count + 1 + if count != 3: + amp_loss, phase_loss, SiSNR_loss = self.loss_1layer( + noisy, est_spec, est_wav, labels, est_mask, mode) + loss = amp_loss + phase_loss + SiSNR_loss + return dict(loss=loss, amp_loss=amp_loss, phase_loss=phase_loss) + + def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'): + r""" Compute the loss by mode + mode == 'Mix' + est: [B, F*2, T] + labels: [B, F*2,T] + mode == 'SiSNR' + est: [B, T] + labels: [B, T] + """ + if mode == 'SiSNR': + if labels.dim() == 3: + labels = torch.squeeze(labels, 1) + if est_wav.dim() == 3: + est_wav = torch.squeeze(est_wav, 1) + return -si_snr(est_wav, labels) + elif mode == 'Mix': + + if labels.dim() == 3: + labels = torch.squeeze(labels, 1) + if est_wav.dim() == 3: + est_wav = torch.squeeze(est_wav, 1) + SiSNR_loss = -si_snr(est_wav, labels) + + b, d, t = est.size() + S = self.stft(labels) + Sr = S[:, :self.feat_dim, :] + Si = S[:, self.feat_dim:, :] + Y = self.stft(noisy) + Yr = Y[:, :self.feat_dim, :] + Yi = Y[:, self.feat_dim:, :] + Y_pow = Yr**2 + Yi**2 + gth_mask = torch.cat([(Sr * Yr + Si * Yi) / (Y_pow + 1e-8), + (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)], 1) + gth_mask[gth_mask > 2] = 1 + gth_mask[gth_mask < -2] = -1 + amp_loss = F.mse_loss(gth_mask[:, :self.feat_dim, :], + cmp_mask[:, :self.feat_dim, :]) * d + phase_loss = F.mse_loss(gth_mask[:, self.feat_dim:, :], + cmp_mask[:, self.feat_dim:, :]) * d + return amp_loss, phase_loss, SiSNR_loss + + +def l2_norm(s1, s2): + norm = torch.sum(s1 * s2, -1, keepdim=True) + return norm + + +def si_snr(s1, s2, eps=1e-8): + s1_s2_norm = l2_norm(s1, s2) + s2_s2_norm = l2_norm(s2, s2) + s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 + e_nosie = s1 - s_target + target_norm = l2_norm(s_target, s_target) + noise_norm = l2_norm(e_nosie, e_nosie) + snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps) + return torch.mean(snr) diff --git a/modelscope/models/audio/ans/se_module_complex.py b/modelscope/models/audio/ans/se_module_complex.py new file mode 100644 index 00000000..b58eb6ba --- /dev/null +++ b/modelscope/models/audio/ans/se_module_complex.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +from torch import nn + + +class SELayer(nn.Module): + + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc_r = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + self.fc_i = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + + def forward(self, x): + b, c, _, _, _ = x.size() + x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c) + x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c) + y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view( + b, c, 1, 1, 1) + y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view( + b, c, 1, 1, 1) + y = torch.cat([y_r, y_i], 4) + return x * y diff --git a/modelscope/models/audio/ans/unet.py b/modelscope/models/audio/ans/unet.py new file mode 100644 index 00000000..7b4df1e9 --- /dev/null +++ b/modelscope/models/audio/ans/unet.py @@ -0,0 +1,276 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# +# The implementation here is modified based on +# Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ) +# and publicly available at +# https://github.com/sweetcocoa/DeepComplexUNetPyTorch + +import torch +import torch.nn as nn + +from . import complex_nn +from .se_module_complex import SELayer + + +class Encoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding=None, + complex=False, + padding_mode='zeros'): + super().__init__() + if padding is None: + padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding + + if complex: + conv = complex_nn.ComplexConv2d + bn = complex_nn.ComplexBatchNorm2d + else: + conv = nn.Conv2d + bn = nn.BatchNorm2d + + self.conv = conv( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode=padding_mode) + self.bn = bn(out_channels) + self.relu = nn.LeakyReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Decoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding=(0, 0), + complex=False): + super().__init__() + if complex: + tconv = complex_nn.ComplexConvTranspose2d + bn = complex_nn.ComplexBatchNorm2d + else: + tconv = nn.ConvTranspose2d + bn = nn.BatchNorm2d + + self.transconv = tconv( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + self.bn = bn(out_channels) + self.relu = nn.LeakyReLU(inplace=True) + + def forward(self, x): + x = self.transconv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class UNet(nn.Module): + + def __init__(self, + input_channels=1, + complex=False, + model_complexity=45, + model_depth=20, + padding_mode='zeros'): + super().__init__() + + if complex: + model_complexity = int(model_complexity // 1.414) + + self.set_size( + model_complexity=model_complexity, + input_channels=input_channels, + model_depth=model_depth) + self.encoders = [] + self.model_length = model_depth // 2 + self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128) + self.se_layers_enc = [] + self.fsmn_enc = [] + for i in range(self.model_length): + fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128) + self.add_module('fsmn_enc{}'.format(i), fsmn_enc) + self.fsmn_enc.append(fsmn_enc) + module = Encoder( + self.enc_channels[i], + self.enc_channels[i + 1], + kernel_size=self.enc_kernel_sizes[i], + stride=self.enc_strides[i], + padding=self.enc_paddings[i], + complex=complex, + padding_mode=padding_mode) + self.add_module('encoder{}'.format(i), module) + self.encoders.append(module) + se_layer_enc = SELayer(self.enc_channels[i + 1], 8) + self.add_module('se_layer_enc{}'.format(i), se_layer_enc) + self.se_layers_enc.append(se_layer_enc) + self.decoders = [] + self.fsmn_dec = [] + self.se_layers_dec = [] + for i in range(self.model_length): + fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128) + self.add_module('fsmn_dec{}'.format(i), fsmn_dec) + self.fsmn_dec.append(fsmn_dec) + module = Decoder( + self.dec_channels[i] * 2, + self.dec_channels[i + 1], + kernel_size=self.dec_kernel_sizes[i], + stride=self.dec_strides[i], + padding=self.dec_paddings[i], + complex=complex) + self.add_module('decoder{}'.format(i), module) + self.decoders.append(module) + if i < self.model_length - 1: + se_layer_dec = SELayer(self.dec_channels[i + 1], 8) + self.add_module('se_layer_dec{}'.format(i), se_layer_dec) + self.se_layers_dec.append(se_layer_dec) + if complex: + conv = complex_nn.ComplexConv2d + else: + conv = nn.Conv2d + + linear = conv(self.dec_channels[-1], 1, 1) + + self.add_module('linear', linear) + self.complex = complex + self.padding_mode = padding_mode + + self.decoders = nn.ModuleList(self.decoders) + self.encoders = nn.ModuleList(self.encoders) + self.se_layers_enc = nn.ModuleList(self.se_layers_enc) + self.se_layers_dec = nn.ModuleList(self.se_layers_dec) + self.fsmn_enc = nn.ModuleList(self.fsmn_enc) + self.fsmn_dec = nn.ModuleList(self.fsmn_dec) + + def forward(self, inputs): + x = inputs + # go down + xs = [] + xs_se = [] + xs_se.append(x) + for i, encoder in enumerate(self.encoders): + xs.append(x) + if i > 0: + x = self.fsmn_enc[i](x) + x = encoder(x) + xs_se.append(self.se_layers_enc[i](x)) + # xs : x0=input x1 ... x9 + x = self.fsmn(x) + + p = x + for i, decoder in enumerate(self.decoders): + p = decoder(p) + if i < self.model_length - 1: + p = self.fsmn_dec[i](p) + if i == self.model_length - 1: + break + if i < self.model_length - 2: + p = self.se_layers_dec[i](p) + p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1) + + # cmp_spec: [12, 1, 513, 64, 2] + cmp_spec = self.linear(p) + return cmp_spec + + def set_size(self, model_complexity, model_depth=20, input_channels=1): + + if model_depth == 14: + self.enc_channels = [ + input_channels, 128, 128, 128, 128, 128, 128, 128 + ] + self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2), + (5, 2), (2, 2)] + self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), + (2, 1)] + self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), + (0, 1), (0, 1)] + self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1] + self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2), + (5, 2), (5, 2)] + self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), + (2, 1)] + self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), + (0, 1), (0, 1)] + + elif model_depth == 10: + self.enc_channels = [ + input_channels, + 16, + 32, + 64, + 128, + 256, + ] + self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)] + self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + self.dec_channels = [128, 128, 64, 32, 16, 1] + self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)] + self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + + elif model_depth == 20: + self.enc_channels = [ + input_channels, model_complexity, model_complexity, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, 128 + ] + + self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3), + (5, 3), (5, 3), (5, 3), (5, 3), (5, 3)] + + self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1), + (2, 2), (2, 1), (2, 2), (2, 1)] + + self.enc_paddings = [ + (3, 0), + (0, 3), + None, # (0, 2), + None, + None, # (3,1), + None, # (3,1), + None, # (1,2), + None, + None, + None + ] + + self.dec_channels = [ + 0, model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2 + ] + + self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3), + (4, 2), (6, 3), (7, 4), (1, 7), (7, 1)] + + self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2), + (2, 1), (2, 2), (1, 1), (1, 1)] + + self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1), + (1, 0), (2, 1), (2, 1), (0, 3), (3, 0)] + else: + raise ValueError('Unknown model depth : {}'.format(model_depth)) diff --git a/modelscope/models/audio/asr/__init__.py b/modelscope/models/audio/asr/__init__.py new file mode 100644 index 00000000..d6ca2c4e --- /dev/null +++ b/modelscope/models/audio/asr/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .generic_automatic_speech_recognition import GenericAutomaticSpeechRecognition + +else: + _import_structure = { + 'generic_automatic_speech_recognition': + ['GenericAutomaticSpeechRecognition'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py new file mode 100644 index 00000000..aebc6751 --- /dev/null +++ b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Frameworks, Tasks + +__all__ = ['GenericAutomaticSpeechRecognition'] + + +@MODELS.register_module( + Tasks.auto_speech_recognition, module_name=Models.generic_asr) +class GenericAutomaticSpeechRecognition(Model): + + def __init__(self, model_dir: str, am_model_name: str, + model_config: Dict[str, Any], *args, **kwargs): + """initialize the info of model. + + Args: + model_dir (str): the model path. + am_model_name (str): the am model name from configuration.json + model_config (Dict[str, Any]): the detail config about model from configuration.json + """ + super().__init__(model_dir, am_model_name, model_config, *args, + **kwargs) + self.model_cfg = { + # the recognition model dir path + 'model_workspace': model_dir, + # the am model name + 'am_model': am_model_name, + # the am model file path + 'am_model_path': os.path.join(model_dir, am_model_name), + # the recognition model config dict + 'model_config': model_config + } + + def forward(self) -> Dict[str, Any]: + """preload model and return the info of the model + """ + if self.model_cfg['model_config']['type'] == Frameworks.tf: + from easyasr import asr_inference_paraformer_tf + if hasattr(asr_inference_paraformer_tf, 'preload'): + model_workspace = self.model_cfg['model_workspace'] + model_path = os.path.join(model_workspace, + self.model_cfg['am_model']) + vocab_path = os.path.join( + model_workspace, + self.model_cfg['model_config']['vocab_file']) + sampled_ids = 'seq2seq/sampled_ids' + sampled_lengths = 'seq2seq/sampled_lengths' + if 'sampled_ids' in self.model_cfg['model_config']: + sampled_ids = self.model_cfg['model_config']['sampled_ids'] + if 'sampled_lengths' in self.model_cfg['model_config']: + sampled_lengths = self.model_cfg['model_config'][ + 'sampled_lengths'] + asr_inference_paraformer_tf.preload( + ngpu=1, + asr_model_file=model_path, + vocab_file=vocab_path, + sampled_ids=sampled_ids, + sampled_lengths=sampled_lengths) + + return self.model_cfg diff --git a/modelscope/models/audio/kws/__init__.py b/modelscope/models/audio/kws/__init__.py new file mode 100644 index 00000000..dd183fe5 --- /dev/null +++ b/modelscope/models/audio/kws/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .generic_key_word_spotting import GenericKeyWordSpotting + from .farfield.model import FSMNSeleNetV2Decorator + +else: + _import_structure = { + 'generic_key_word_spotting': ['GenericKeyWordSpotting'], + 'farfield.model': ['FSMNSeleNetV2Decorator'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/audio/kws/farfield/__init__.py b/modelscope/models/audio/kws/farfield/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/kws/farfield/fsmn.py b/modelscope/models/audio/kws/farfield/fsmn.py new file mode 100644 index 00000000..e06d7911 --- /dev/null +++ b/modelscope/models/audio/kws/farfield/fsmn.py @@ -0,0 +1,497 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .model_def import (HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32, + printNeonMatrix, printNeonVector) + +DEBUG = False + + +def to_kaldi_matrix(np_mat): + """ function that transform as str numpy mat to standard kaldi str matrix + + Args: + np_mat: numpy mat + + Returns: str + """ + np.set_printoptions(threshold=np.inf, linewidth=np.nan) + out_str = str(np_mat) + out_str = out_str.replace('[', '') + out_str = out_str.replace(']', '') + return '[ %s ]\n' % out_str + + +def print_tensor(torch_tensor): + """ print torch tensor for debug + + Args: + torch_tensor: a tensor + """ + re_str = '' + x = torch_tensor.detach().squeeze().numpy() + re_str += to_kaldi_matrix(x) + re_str += '\n' + print(re_str) + + +class LinearTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(LinearTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.linear = nn.Linear(input_dim, output_dim, bias=False) + + self.debug = False + self.dataout = None + + def forward(self, input): + output = self.linear(input) + + if self.debug: + self.dataout = output + + return output + + def print_model(self): + printNeonMatrix(self.linear.weight) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1\n' + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + re_str += '\n' + + return re_str + + +class AffineTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(AffineTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + self.linear = nn.Linear(input_dim, output_dim) + + self.debug = False + self.dataout = None + + def forward(self, input): + output = self.linear(input) + + if self.debug: + self.dataout = output + + return output + + def print_model(self): + printNeonMatrix(self.linear.weight) + printNeonVector(self.linear.bias) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1 1 0\n' + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += to_kaldi_matrix(x) + re_str += '\n' + + return re_str + + +class Fsmn(nn.Module): + """ + FSMN implementation. + """ + + def __init__(self, + input_dim, + output_dim, + lorder=None, + rorder=None, + lstride=None, + rstride=None): + super(Fsmn, self).__init__() + + self.dim = input_dim + + if lorder is None: + return + + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + + self.conv_left = nn.Conv2d( + self.dim, + self.dim, (lorder, 1), + dilation=(lstride, 1), + groups=self.dim, + bias=False) + + if rorder > 0: + self.conv_right = nn.Conv2d( + self.dim, + self.dim, (rorder, 1), + dilation=(rstride, 1), + groups=self.dim, + bias=False) + else: + self.conv_right = None + + self.debug = False + self.dataout = None + + def forward(self, input): + x = torch.unsqueeze(input, 1) + x_per = x.permute(0, 3, 2, 1) + + y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) + + if self.conv_right is not None: + y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) + y_right = y_right[:, :, self.rstride:, :] + out = x_per + self.conv_left(y_left) + self.conv_right(y_right) + else: + out = x_per + self.conv_left(y_left) + + out1 = out.permute(0, 3, 2, 1) + output = out1.squeeze(1) + + if self.debug: + self.dataout = output + + return output + + def print_model(self): + tmpw = self.conv_left.weight + tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) + for j in range(tmpw.shape[0]): + tmpwm[:, j] = tmpw[j, 0, :, 0] + + printNeonMatrix(tmpwm) + + if self.conv_right is not None: + tmpw = self.conv_right.weight + tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) + for j in range(tmpw.shape[0]): + tmpwm[:, j] = tmpw[j, 0, :, 0] + + printNeonMatrix(tmpwm) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + re_str += ' %d %d %d %d %d 0\n' % ( + 1, self.lorder, self.rorder, self.lstride, self.rstride) + + lfiters = self.state_dict()['conv_left.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + re_str += to_kaldi_matrix(x) + + if self.conv_right is not None: + rfiters = self.state_dict()['conv_right.weight'] + x = (rfiters.squeeze().numpy().T) + re_str += to_kaldi_matrix(x) + re_str += '\n' + + return re_str + + +class RectifiedLinear(nn.Module): + + def __init__(self, input_dim, output_dim): + super(RectifiedLinear, self).__init__() + self.dim = input_dim + self.relu = nn.ReLU() + + def forward(self, input): + return self.relu(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + re_str += '\n' + return re_str + + +class FSMNNet(nn.Module): + """ + FSMN net for keyword spotting + """ + + def __init__(self, + input_dim=200, + linear_dim=128, + proj_dim=128, + lorder=10, + rorder=1, + num_syn=5, + fsmn_layers=4): + """ + Args: + input_dim: input dimension + linear_dim: fsmn input dimension + proj_dim: fsmn projection dimension + lorder: fsmn left order + rorder: fsmn right order + num_syn: output dimension + fsmn_layers: no. of sequential fsmn layers + """ + super(FSMNNet, self).__init__() + + self.input_dim = input_dim + self.linear_dim = linear_dim + self.proj_dim = proj_dim + self.lorder = lorder + self.rorder = rorder + self.num_syn = num_syn + self.fsmn_layers = fsmn_layers + + self.linear1 = AffineTransform(input_dim, linear_dim) + self.relu = RectifiedLinear(linear_dim, linear_dim) + + self.fsmn = self._build_repeats(linear_dim, proj_dim, lorder, rorder, + fsmn_layers) + + self.linear2 = AffineTransform(linear_dim, num_syn) + + @staticmethod + def _build_repeats(linear_dim=136, + proj_dim=68, + lorder=3, + rorder=2, + fsmn_layers=5): + repeats = [ + nn.Sequential( + LinearTransform(linear_dim, proj_dim), + Fsmn(proj_dim, proj_dim, lorder, rorder, 1, 1), + AffineTransform(proj_dim, linear_dim), + RectifiedLinear(linear_dim, linear_dim)) + for i in range(fsmn_layers) + ] + + return nn.Sequential(*repeats) + + def forward(self, input): + x1 = self.linear1(input) + x2 = self.relu(x1) + x3 = self.fsmn(x2) + x4 = self.linear2(x3) + return x4 + + def print_model(self): + self.linear1.print_model() + + for layer in self.fsmn: + layer[0].print_model() + layer[1].print_model() + layer[2].print_model() + + self.linear2.print_model() + + def print_header(self): + # + # write total header + # + header = [0.0] * HEADER_BLOCK_SIZE * 4 + # numins + header[0] = 0.0 + # numouts + header[1] = 0.0 + # dimins + header[2] = self.input_dim + # dimouts + header[3] = self.num_syn + # numlayers + header[4] = 3 + + # + # write each layer's header + # + hidx = 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_DENSE.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = self.input_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = self.linear_dim + header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 + header[HEADER_BLOCK_SIZE * hidx + 5] = float( + ActivationType.ACTIVATION_RELU.value) + hidx += 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_SEQUENTIAL_FSMN.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = self.proj_dim + header[HEADER_BLOCK_SIZE * hidx + 4] = self.lorder + header[HEADER_BLOCK_SIZE * hidx + 5] = self.rorder + header[HEADER_BLOCK_SIZE * hidx + 6] = self.fsmn_layers + header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0 + hidx += 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_DENSE.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = self.num_syn + header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 + header[HEADER_BLOCK_SIZE * hidx + 5] = float( + ActivationType.ACTIVATION_SOFTMAX.value) + + for h in header: + print(f32ToI32(h)) + + def to_kaldi_nnet(self): + re_str = '' + re_str += '\n' + re_str += self.linear1.to_kaldi_nnet() + re_str += self.relu.to_kaldi_nnet() + + for fsmn in self.fsmn: + re_str += fsmn[0].to_kaldi_nnet() + re_str += fsmn[1].to_kaldi_nnet() + re_str += fsmn[2].to_kaldi_nnet() + re_str += fsmn[3].to_kaldi_nnet() + + re_str += self.linear2.to_kaldi_nnet() + re_str += ' %d %d\n' % (self.num_syn, self.num_syn) + re_str += '\n' + re_str += '\n' + + return re_str + + +class DFSMN(nn.Module): + """ + One deep fsmn layer + """ + + def __init__(self, + dimproj=64, + dimlinear=128, + lorder=20, + rorder=1, + lstride=1, + rstride=1): + """ + Args: + dimproj: projection dimension, input and output dimension of memory blocks + dimlinear: dimension of mapping layer + lorder: left order + rorder: right order + lstride: left stride + rstride: right stride + """ + super(DFSMN, self).__init__() + + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + + self.expand = AffineTransform(dimproj, dimlinear) + self.shrink = LinearTransform(dimlinear, dimproj) + + self.conv_left = nn.Conv2d( + dimproj, + dimproj, (lorder, 1), + dilation=(lstride, 1), + groups=dimproj, + bias=False) + + if rorder > 0: + self.conv_right = nn.Conv2d( + dimproj, + dimproj, (rorder, 1), + dilation=(rstride, 1), + groups=dimproj, + bias=False) + else: + self.conv_right = None + + def forward(self, input): + f1 = F.relu(self.expand(input)) + p1 = self.shrink(f1) + + x = torch.unsqueeze(p1, 1) + x_per = x.permute(0, 3, 2, 1) + + y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) + + if self.conv_right is not None: + y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) + y_right = y_right[:, :, self.rstride:, :] + out = x_per + self.conv_left(y_left) + self.conv_right(y_right) + else: + out = x_per + self.conv_left(y_left) + + out1 = out.permute(0, 3, 2, 1) + output = input + out1.squeeze(1) + + return output + + def print_model(self): + self.expand.print_model() + self.shrink.print_model() + + tmpw = self.conv_left.weight + tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) + for j in range(tmpw.shape[0]): + tmpwm[:, j] = tmpw[j, 0, :, 0] + + printNeonMatrix(tmpwm) + + if self.conv_right is not None: + tmpw = self.conv_right.weight + tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) + for j in range(tmpw.shape[0]): + tmpwm[:, j] = tmpw[j, 0, :, 0] + + printNeonMatrix(tmpwm) + + +def build_dfsmn_repeats(linear_dim=128, + proj_dim=64, + lorder=20, + rorder=1, + fsmn_layers=6): + """ + build stacked dfsmn layers + Args: + linear_dim: + proj_dim: + lorder: + rorder: + fsmn_layers: + + Returns: + + """ + repeats = [ + nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) + for i in range(fsmn_layers) + ] + + return nn.Sequential(*repeats) diff --git a/modelscope/models/audio/kws/farfield/fsmn_sele_v2.py b/modelscope/models/audio/kws/farfield/fsmn_sele_v2.py new file mode 100644 index 00000000..8af16cc9 --- /dev/null +++ b/modelscope/models/audio/kws/farfield/fsmn_sele_v2.py @@ -0,0 +1,238 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .fsmn import AffineTransform, Fsmn, LinearTransform, RectifiedLinear +from .model_def import HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32 + + +class FSMNUnit(nn.Module): + """ A multi-channel fsmn unit + + """ + + def __init__(self, dimlinear=128, dimproj=64, lorder=20, rorder=1): + """ + Args: + dimlinear: input / output dimension + dimproj: fsmn input / output dimension + lorder: left ofder + rorder: right order + """ + super(FSMNUnit, self).__init__() + + self.shrink = LinearTransform(dimlinear, dimproj) + self.fsmn = Fsmn(dimproj, dimproj, lorder, rorder, 1, 1) + self.expand = AffineTransform(dimproj, dimlinear) + + self.debug = False + self.dataout = None + + ''' + batch, time, channel, feature + ''' + + def forward(self, input): + if torch.cuda.is_available(): + out = torch.zeros(input.shape).cuda() + else: + out = torch.zeros(input.shape) + + for n in range(input.shape[2]): + out1 = self.shrink(input[:, :, n, :]) + out2 = self.fsmn(out1) + out[:, :, n, :] = F.relu(self.expand(out2)) + + if self.debug: + self.dataout = out + + return out + + def print_model(self): + self.shrink.print_model() + self.fsmn.print_model() + self.expand.print_model() + + def to_kaldi_nnet(self): + re_str = self.shrink.to_kaldi_nnet() + re_str += self.fsmn.to_kaldi_nnet() + re_str += self.expand.to_kaldi_nnet() + + relu = RectifiedLinear(self.expand.linear.out_features, + self.expand.linear.out_features) + re_str += relu.to_kaldi_nnet() + + return re_str + + +class FSMNSeleNetV2(nn.Module): + """ FSMN model with channel selection. + """ + + def __init__(self, + input_dim=120, + linear_dim=128, + proj_dim=64, + lorder=20, + rorder=1, + num_syn=5, + fsmn_layers=5, + sele_layer=0): + """ + Args: + input_dim: input dimension + linear_dim: fsmn input dimension + proj_dim: fsmn projection dimension + lorder: fsmn left order + rorder: fsmn right order + num_syn: output dimension + fsmn_layers: no. of fsmn units + sele_layer: channel selection layer index + """ + super(FSMNSeleNetV2, self).__init__() + + self.sele_layer = sele_layer + + self.featmap = AffineTransform(input_dim, linear_dim) + + self.mem = [] + for i in range(fsmn_layers): + unit = FSMNUnit(linear_dim, proj_dim, lorder, rorder) + self.mem.append(unit) + self.add_module('mem_{:d}'.format(i), unit) + + self.decision = AffineTransform(linear_dim, num_syn) + + def forward(self, input): + # multi-channel feature mapping + if torch.cuda.is_available(): + x = torch.zeros(input.shape[0], input.shape[1], input.shape[2], + self.featmap.linear.out_features).cuda() + else: + x = torch.zeros(input.shape[0], input.shape[1], input.shape[2], + self.featmap.linear.out_features) + + for n in range(input.shape[2]): + x[:, :, n, :] = F.relu(self.featmap(input[:, :, n, :])) + + for i, unit in enumerate(self.mem): + y = unit(x) + + # perform channel selection + if i == self.sele_layer: + pool = nn.MaxPool2d((y.shape[2], 1), stride=(y.shape[2], 1)) + y = pool(y) + + x = y + + # remove channel dimension + y = torch.squeeze(y, -2) + z = self.decision(y) + + return z + + def print_model(self): + self.featmap.print_model() + + for unit in self.mem: + unit.print_model() + + self.decision.print_model() + + def print_header(self): + ''' + get FSMN params + ''' + input_dim = self.featmap.linear.in_features + linear_dim = self.featmap.linear.out_features + proj_dim = self.mem[0].shrink.linear.out_features + lorder = self.mem[0].fsmn.conv_left.kernel_size[0] + rorder = 0 + if self.mem[0].fsmn.conv_right is not None: + rorder = self.mem[0].fsmn.conv_right.kernel_size[0] + + num_syn = self.decision.linear.out_features + fsmn_layers = len(self.mem) + + # no. of output channels, 0.0 means the same as numins + # numouts = 0.0 + numouts = 1.0 + + # + # write total header + # + header = [0.0] * HEADER_BLOCK_SIZE * 4 + # numins + header[0] = 0.0 + # numouts + header[1] = numouts + # dimins + header[2] = input_dim + # dimouts + header[3] = num_syn + # numlayers + header[4] = 3 + + # + # write each layer's header + # + hidx = 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_DENSE.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = input_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = linear_dim + header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 + header[HEADER_BLOCK_SIZE * hidx + 5] = float( + ActivationType.ACTIVATION_RELU.value) + hidx += 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_SEQUENTIAL_FSMN.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 + header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = proj_dim + header[HEADER_BLOCK_SIZE * hidx + 4] = lorder + header[HEADER_BLOCK_SIZE * hidx + 5] = rorder + header[HEADER_BLOCK_SIZE * hidx + 6] = fsmn_layers + if numouts == 1.0: + header[HEADER_BLOCK_SIZE * hidx + 7] = float(self.sele_layer) + else: + header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0 + hidx += 1 + + header[HEADER_BLOCK_SIZE * hidx + 0] = float( + LayerType.LAYER_DENSE.value) + header[HEADER_BLOCK_SIZE * hidx + 1] = numouts + header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim + header[HEADER_BLOCK_SIZE * hidx + 3] = num_syn + header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 + header[HEADER_BLOCK_SIZE * hidx + 5] = float( + ActivationType.ACTIVATION_SOFTMAX.value) + + for h in header: + print(f32ToI32(h)) + + def to_kaldi_nnet(self): + re_str = '\n' + + re_str = self.featmap.to_kaldi_nnet() + + relu = RectifiedLinear(self.featmap.linear.out_features, + self.featmap.linear.out_features) + re_str += relu.to_kaldi_nnet() + + for unit in self.mem: + re_str += unit.to_kaldi_nnet() + + re_str += self.decision.to_kaldi_nnet() + + re_str += ' %d %d\n' % (self.decision.linear.out_features, + self.decision.linear.out_features) + re_str += '\n' + re_str += '\n' + + return re_str diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py new file mode 100644 index 00000000..af1c0a27 --- /dev/null +++ b/modelscope/models/audio/kws/farfield/model.py @@ -0,0 +1,73 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import tempfile +from typing import Dict, Optional + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.utils.audio.audio_utils import update_conf +from modelscope.utils.constant import Tasks +from .fsmn_sele_v2 import FSMNSeleNetV2 + + +@MODELS.register_module( + Tasks.keyword_spotting, module_name=Models.speech_dfsmn_kws_char_farfield) +class FSMNSeleNetV2Decorator(TorchModel): + r""" A decorator of FSMNSeleNetV2 for integrating into modelscope framework """ + + MODEL_TXT = 'model.txt' + SC_CONFIG = 'sound_connect.conf' + + def __init__(self, + model_dir: str, + training: Optional[bool] = False, + *args, + **kwargs): + """initialize the dfsmn model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + if training: + self.model = FSMNSeleNetV2(*args, **kwargs) + else: + sc_config_file = os.path.join(model_dir, self.SC_CONFIG) + model_txt_file = os.path.join(model_dir, self.MODEL_TXT) + self.tmp_dir = tempfile.TemporaryDirectory() + new_config_file = os.path.join(self.tmp_dir.name, self.SC_CONFIG) + + self._sc = None + if os.path.exists(model_txt_file): + conf_dict = dict(mode=56542, kws_model=model_txt_file) + update_conf(sc_config_file, new_config_file, conf_dict) + import py_sound_connect + self._sc = py_sound_connect.SoundConnect(new_config_file) + self.size_in = self._sc.bytesPerBlockIn() + self.size_out = self._sc.bytesPerBlockOut() + else: + raise Exception( + f'Invalid model directory! Failed to load model file: {model_txt_file}.' + ) + + def __del__(self): + self.tmp_dir.cleanup() + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + return self.model.forward(input) + + def forward_decode(self, data: bytes): + result = {'pcm': self._sc.process(data, self.size_out)} + state = self._sc.kwsState() + if state == 2: + result['kws'] = { + 'keyword': + self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()), + 'offset': self._sc.kwsKeywordOffset(), + 'length': self._sc.kwsKeywordLength(), + 'confidence': self._sc.kwsConfidence() + } + return result diff --git a/modelscope/models/audio/kws/farfield/model_def.py b/modelscope/models/audio/kws/farfield/model_def.py new file mode 100644 index 00000000..be9cca2c --- /dev/null +++ b/modelscope/models/audio/kws/farfield/model_def.py @@ -0,0 +1,123 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import struct +from enum import Enum + +HEADER_BLOCK_SIZE = 10 + + +class LayerType(Enum): + LAYER_DENSE = 1 + LAYER_GRU = 2 + LAYER_ATTENTION = 3 + LAYER_FSMN = 4 + LAYER_SEQUENTIAL_FSMN = 5 + LAYER_FSMN_SELE = 6 + LAYER_GRU_ATTENTION = 7 + LAYER_DFSMN = 8 + + +class ActivationType(Enum): + ACTIVATION_NONE = 0 + ACTIVATION_RELU = 1 + ACTIVATION_TANH = 2 + ACTIVATION_SIGMOID = 3 + ACTIVATION_SOFTMAX = 4 + ACTIVATION_LOGSOFTMAX = 5 + + +def f32ToI32(f): + """ + print layer + """ + bs = struct.pack('f', f) + + ba = bytearray() + ba.append(bs[0]) + ba.append(bs[1]) + ba.append(bs[2]) + ba.append(bs[3]) + + return struct.unpack('i', ba)[0] + + +def printNeonMatrix(w): + """ + print matrix with neon padding + """ + numrows, numcols = w.shape + numnecols = math.ceil(numcols / 4) + + for i in range(numrows): + for j in range(numcols): + print(f32ToI32(w[i, j])) + + for j in range(numnecols * 4 - numcols): + print(0) + + +def printNeonVector(b): + """ + print vector with neon padding + """ + size = b.shape[0] + nesize = math.ceil(size / 4) + + for i in range(size): + print(f32ToI32(b[i])) + + for i in range(nesize * 4 - size): + print(0) + + +def printDense(layer): + """ + save dense layer + """ + statedict = layer.state_dict() + printNeonMatrix(statedict['weight']) + printNeonVector(statedict['bias']) + + +def printGRU(layer): + """ + save gru layer + """ + statedict = layer.state_dict() + weight = [statedict['weight_ih_l0'], statedict['weight_hh_l0']] + bias = [statedict['bias_ih_l0'], statedict['bias_hh_l0']] + numins, numouts = weight[0].shape + numins = numins // 3 + + # output input weights + w_rx = weight[0][:numins, :] + w_zx = weight[0][numins:numins * 2, :] + w_x = weight[0][numins * 2:, :] + printNeonMatrix(w_zx) + printNeonMatrix(w_rx) + printNeonMatrix(w_x) + + # output recurrent weights + w_rh = weight[1][:numins, :] + w_zh = weight[1][numins:numins * 2, :] + w_h = weight[1][numins * 2:, :] + printNeonMatrix(w_zh) + printNeonMatrix(w_rh) + printNeonMatrix(w_h) + + # output input bias + b_rx = bias[0][:numins] + b_zx = bias[0][numins:numins * 2] + b_x = bias[0][numins * 2:] + printNeonVector(b_zx) + printNeonVector(b_rx) + printNeonVector(b_x) + + # output recurrent bias + b_rh = bias[1][:numins] + b_zh = bias[1][numins:numins * 2] + b_h = bias[1][numins * 2:] + printNeonVector(b_zh) + printNeonVector(b_rh) + printNeonVector(b_h) diff --git a/modelscope/models/audio/kws/generic_key_word_spotting.py b/modelscope/models/audio/kws/generic_key_word_spotting.py new file mode 100644 index 00000000..2f70327d --- /dev/null +++ b/modelscope/models/audio/kws/generic_key_word_spotting.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + +__all__ = ['GenericKeyWordSpotting'] + + +@MODELS.register_module(Tasks.keyword_spotting, module_name=Models.kws_kwsbp) +class GenericKeyWordSpotting(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the info of model. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + self.model_cfg = { + 'model_workspace': model_dir, + 'config_path': os.path.join(model_dir, 'config.yaml') + } + + def forward(self) -> Dict[str, Any]: + """return the info of the model + """ + return self.model_cfg diff --git a/modelscope/models/audio/tts/__init__.py b/modelscope/models/audio/tts/__init__.py new file mode 100644 index 00000000..8f1af95c --- /dev/null +++ b/modelscope/models/audio/tts/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .sambert_hifi import SambertHifigan + +else: + _import_structure = { + 'sambert_hifi': ['SambertHifigan'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/audio/tts/models/__init__.py b/modelscope/models/audio/tts/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/tts/models/datasets/__init__.py b/modelscope/models/audio/tts/models/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/tts/models/datasets/kantts_data4fs.py b/modelscope/models/audio/tts/models/datasets/kantts_data4fs.py new file mode 100644 index 00000000..cc47d0c4 --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/kantts_data4fs.py @@ -0,0 +1,238 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +import json +import numpy as np +import torch +from torch.utils.data import Dataset +from tqdm import tqdm + +from modelscope.utils.logger import get_logger +from .units import KanTtsLinguisticUnit + +logger = get_logger() + + +class KanTtsText2MelDataset(Dataset): + + def __init__(self, metadata_filename, config_filename, cache=False): + super(KanTtsText2MelDataset, self).__init__() + + self.cache = cache + + with open(config_filename) as f: + self._config = json.loads(f.read()) + + # Load metadata: + self._datadir = os.path.dirname(metadata_filename) + with open(metadata_filename, encoding='utf-8') as f: + self._metadata = [line.strip().split('|') for line in f] + self._length_lst = [int(x[2]) for x in self._metadata] + hours = sum( + self._length_lst) * self._config['audio']['frame_shift_ms'] / ( + 3600 * 1000) + + logger.info('Loaded metadata for %d examples (%.2f hours)' % + (len(self._metadata), hours)) + logger.info('Minimum length: %d, Maximum length: %d' % + (min(self._length_lst), max(self._length_lst))) + + self.ling_unit = KanTtsLinguisticUnit(config_filename) + self.pad_executor = KanTtsText2MelPad() + + self.r = self._config['am']['outputs_per_step'] + self.num_mels = self._config['am']['num_mels'] + + if 'adv' in self._config: + self.feat_window = self._config['adv']['random_window'] + else: + self.feat_window = None + logger.info(self.feat_window) + + self.data_cache = [ + self.cache_load(i) for i in tqdm(range(self.__len__())) + ] if self.cache else [] + + def get_frames_lst(self): + return self._length_lst + + def __getitem__(self, index): + if self.cache: + sample = self.data_cache[index] + return sample + + return self.cache_load(index) + + def cache_load(self, index): + sample = {} + + meta = self._metadata[index] + + sample['utt_id'] = meta[0] + + sample['mel_target'] = np.load(os.path.join( + self._datadir, meta[1]))[:, :self.num_mels] + sample['output_length'] = len(sample['mel_target']) + + lfeat_symbol = meta[3] + sample['ling'] = self.ling_unit.encode_symbol_sequence(lfeat_symbol) + + sample['duration'] = np.load(os.path.join(self._datadir, meta[4])) + + sample['pitch_contour'] = np.load(os.path.join(self._datadir, meta[5])) + + sample['energy_contour'] = np.load( + os.path.join(self._datadir, meta[6])) + + return sample + + def __len__(self): + return len(self._metadata) + + def collate_fn(self, batch): + data_dict = {} + + max_input_length = max((len(x['ling'][0]) for x in batch)) + + # pure linguistic info: sy|tone|syllable_flag|word_segment + + # sy + lfeat_type = self.ling_unit._lfeat_type_list[0] + inputs_sy = self.pad_executor._prepare_scalar_inputs( + [x['ling'][0] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + # tone + lfeat_type = self.ling_unit._lfeat_type_list[1] + inputs_tone = self.pad_executor._prepare_scalar_inputs( + [x['ling'][1] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + # syllable_flag + lfeat_type = self.ling_unit._lfeat_type_list[2] + inputs_syllable_flag = self.pad_executor._prepare_scalar_inputs( + [x['ling'][2] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + # word_segment + lfeat_type = self.ling_unit._lfeat_type_list[3] + inputs_ws = self.pad_executor._prepare_scalar_inputs( + [x['ling'][3] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + # emotion category + lfeat_type = self.ling_unit._lfeat_type_list[4] + data_dict['input_emotions'] = self.pad_executor._prepare_scalar_inputs( + [x['ling'][4] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + # speaker category + lfeat_type = self.ling_unit._lfeat_type_list[5] + data_dict['input_speakers'] = self.pad_executor._prepare_scalar_inputs( + [x['ling'][5] for x in batch], max_input_length, + self.ling_unit._sub_unit_pad[lfeat_type]).long() + + data_dict['input_lings'] = torch.stack( + [inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2) + + data_dict['valid_input_lengths'] = torch.as_tensor( + [len(x['ling'][0]) - 1 for x in batch], dtype=torch.long + ) # There is one '~' in the last of symbol sequence. We put length-1 for calculation. + + data_dict['valid_output_lengths'] = torch.as_tensor( + [x['output_length'] for x in batch], dtype=torch.long) + max_output_length = torch.max(data_dict['valid_output_lengths']).item() + max_output_round_length = self.pad_executor._round_up( + max_output_length, self.r) + + if self.feat_window is not None: + active_feat_len = np.minimum(max_output_round_length, + self.feat_window) + if active_feat_len < self.feat_window: + max_output_round_length = self.pad_executor._round_up( + self.feat_window, self.r) + active_feat_len = self.feat_window + + max_offsets = [x['output_length'] - active_feat_len for x in batch] + feat_offsets = [ + np.random.randint(0, np.maximum(1, offset)) + for offset in max_offsets + ] + feat_offsets = torch.from_numpy( + np.asarray(feat_offsets, dtype=np.int32)).long() + data_dict['feat_offsets'] = feat_offsets + + data_dict['mel_targets'] = self.pad_executor._prepare_targets( + [x['mel_target'] for x in batch], max_output_round_length, 0.0) + data_dict['durations'] = self.pad_executor._prepare_durations( + [x['duration'] for x in batch], max_input_length, + max_output_round_length) + + data_dict['pitch_contours'] = self.pad_executor._prepare_scalar_inputs( + [x['pitch_contour'] for x in batch], max_input_length, + 0.0).float() + data_dict[ + 'energy_contours'] = self.pad_executor._prepare_scalar_inputs( + [x['energy_contour'] for x in batch], max_input_length, + 0.0).float() + + data_dict['utt_ids'] = [x['utt_id'] for x in batch] + + return data_dict + + +class KanTtsText2MelPad(object): + + def __init__(self): + super(KanTtsText2MelPad, self).__init__() + pass + + def _pad1D(self, x, length, pad): + return np.pad( + x, (0, length - x.shape[0]), mode='constant', constant_values=pad) + + def _pad2D(self, x, length, pad): + return np.pad( + x, [(0, length - x.shape[0]), (0, 0)], + mode='constant', + constant_values=pad) + + def _pad_durations(self, duration, max_in_len, max_out_len): + framenum = np.sum(duration) + symbolnum = duration.shape[0] + if framenum < max_out_len: + padframenum = max_out_len - framenum + duration = np.insert( + duration, symbolnum, values=padframenum, axis=0) + duration = np.insert( + duration, + symbolnum + 1, + values=[0] * (max_in_len - symbolnum - 1), + axis=0) + else: + if symbolnum < max_in_len: + duration = np.insert( + duration, + symbolnum, + values=[0] * (max_in_len - symbolnum), + axis=0) + return duration + + def _round_up(self, x, multiple): + remainder = x % multiple + return x if remainder == 0 else x + multiple - remainder + + def _prepare_scalar_inputs(self, inputs, max_len, pad): + return torch.from_numpy( + np.stack([self._pad1D(x, max_len, pad) for x in inputs])) + + def _prepare_targets(self, targets, max_len, pad): + return torch.from_numpy( + np.stack([self._pad2D(t, max_len, pad) for t in targets])).float() + + def _prepare_durations(self, durations, max_in_len, max_out_len): + return torch.from_numpy( + np.stack([ + self._pad_durations(t, max_in_len, max_out_len) + for t in durations + ])).long() diff --git a/modelscope/models/audio/tts/models/datasets/samplers.py b/modelscope/models/audio/tts/models/datasets/samplers.py new file mode 100644 index 00000000..0657fa8a --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/samplers.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import random + +import torch +from torch import distributed as dist +from torch.utils.data import Sampler + + +class LenSortGroupPoolSampler(Sampler): + + def __init__(self, data_source, length_lst, group_size): + super(LenSortGroupPoolSampler, self).__init__(data_source) + + self.data_source = data_source + self.length_lst = length_lst + self.group_size = group_size + + self.num = len(self.length_lst) + self.buckets = self.num // group_size + + def __iter__(self): + + def getkey(item): + return item[1] + + random_lst = torch.randperm(self.num).tolist() + random_len_lst = [(i, self.length_lst[i]) for i in random_lst] + + # Bucket examples based on similar output sequence length for efficiency: + groups = [ + random_len_lst[i:i + self.group_size] + for i in range(0, self.num, self.group_size) + ] + if (self.num % self.group_size): + groups.append(random_len_lst[self.buckets * self.group_size:-1]) + + indices = [] + + for group in groups: + group.sort(key=getkey, reverse=True) + for item in group: + indices.append(item[0]) + + return iter(indices) + + def __len__(self): + return len(self.data_source) + + +class DistributedLenSortGroupPoolSampler(Sampler): + + def __init__(self, + dataset, + length_lst, + group_size, + num_replicas=None, + rank=None, + shuffle=True): + super(DistributedLenSortGroupPoolSampler, self).__init__(dataset) + + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError( + 'modelscope error: Requires distributed package to be available' + ) + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError( + 'modelscope error: Requires distributed package to be available' + ) + rank = dist.get_rank() + self.dataset = dataset + self.length_lst = length_lst + self.group_size = group_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int( + math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.buckets = self.num_samples // group_size + self.shuffle = shuffle + + def __iter__(self): + + def getkey(item): + return item[1] + + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + random_len_lst = [(i, self.length_lst[i]) for i in indices] + + # Bucket examples based on similar output sequence length for efficiency: + groups = [ + random_len_lst[i:i + self.group_size] + for i in range(0, self.num_samples, self.group_size) + ] + if (self.num_samples % self.group_size): + groups.append(random_len_lst[self.buckets * self.group_size:-1]) + + new_indices = [] + + for group in groups: + group.sort(key=getkey, reverse=True) + for item in group: + new_indices.append(item[0]) + + return iter(new_indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/modelscope/models/audio/tts/models/datasets/units/__init__.py b/modelscope/models/audio/tts/models/datasets/units/__init__.py new file mode 100644 index 00000000..4d03df04 --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/units/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .ling_unit import * # noqa F403 diff --git a/modelscope/models/audio/tts/models/datasets/units/cleaners.py b/modelscope/models/audio/tts/models/datasets/units/cleaners.py new file mode 100644 index 00000000..07d4fbdb --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/units/cleaners.py @@ -0,0 +1,88 @@ +# from https://github.com/keithito/tacotron +# Cleaners are transformations that run over the input text at both training and eval time. +# +# Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +# hyperparameter. Some cleaners are English-specific. You'll typically want to use: +# 1. "english_cleaners" for English text +# 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using +# the Unidecode library (https://pypi.python.org/pypi/Unidecode) +# 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update +# the symbols in symbols.py to match your data). + +import re + +from unidecode import unidecode + +from .numbers import normalize_numbers + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) + for x in [('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), ]] # yapf:disable + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/modelscope/models/audio/tts/models/datasets/units/ling_unit.py b/modelscope/models/audio/tts/models/datasets/units/ling_unit.py new file mode 100644 index 00000000..3c211cc7 --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/units/ling_unit.py @@ -0,0 +1,395 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import abc +import codecs +import os +import re +import shutil + +import json +import numpy as np + +from . import cleaners as cleaners + +# Regular expression matching text enclosed in curly braces: +_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception( + 'modelscope error: configuration cleaner unknown: %s' % name) + text = cleaner(text) + return text + + +class LinguisticBaseUnit(abc.ABC): + + def set_config_params(self, config_params): + self.config_params = config_params + + def save(self, config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) + + +class KanTtsLinguisticUnit(LinguisticBaseUnit): + + def __init__(self, config, path, has_mask=True): + super(KanTtsLinguisticUnit, self).__init__() + + # special symbol + self._pad = '_' + self._eos = '~' + self._mask = '@[MASK]' + self._has_mask = has_mask + self._unit_config = config + self._path = path + + self._cleaner_names = [ + x.strip() for x in self._unit_config['cleaners'].split(',') + ] + self._lfeat_type_list = self._unit_config['lfeat_type_list'].strip( + ).split(',') + + self.build() + + def get_unit_size(self): + ling_unit_size = {} + ling_unit_size['sy'] = len(self.sy) + ling_unit_size['tone'] = len(self.tone) + ling_unit_size['syllable_flag'] = len(self.syllable_flag) + ling_unit_size['word_segment'] = len(self.word_segment) + + if 'emo_category' in self._lfeat_type_list: + ling_unit_size['emotion'] = len(self.emo_category) + if 'speaker_category' in self._lfeat_type_list: + ling_unit_size['speaker'] = len(self.speaker) + + return ling_unit_size + + def build(self): + + self._sub_unit_dim = {} + self._sub_unit_pad = {} + # sy sub-unit + _characters = '' + + _ch_symbols = [] + + sy_path = os.path.join(self._path, self._unit_config['sy']) + f = codecs.open(sy_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_symbols.append(line) + + _arpabet = ['@' + s for s in _ch_symbols] + + # Export all symbols: + self.sy = list(_characters) + _arpabet + [self._pad, self._eos] + if self._has_mask: + self.sy.append(self._mask) + self._sy_to_id = {s: i for i, s in enumerate(self.sy)} + self._id_to_sy = {i: s for i, s in enumerate(self.sy)} + self._sub_unit_dim['sy'] = len(self.sy) + self._sub_unit_pad['sy'] = self._sy_to_id['_'] + + # tone sub-unit + _characters = '' + + _ch_tones = [] + + tone_path = os.path.join(self._path, self._unit_config['tone']) + f = codecs.open(tone_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_tones.append(line) + + # Export all tones: + self.tone = list(_characters) + _ch_tones + [self._pad, self._eos] + if self._has_mask: + self.tone.append(self._mask) + self._tone_to_id = {s: i for i, s in enumerate(self.tone)} + self._id_to_tone = {i: s for i, s in enumerate(self.tone)} + self._sub_unit_dim['tone'] = len(self.tone) + self._sub_unit_pad['tone'] = self._tone_to_id['_'] + + # syllable flag sub-unit + _characters = '' + + _ch_syllable_flags = [] + + sy_flag_path = os.path.join(self._path, + self._unit_config['syllable_flag']) + f = codecs.open(sy_flag_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_syllable_flags.append(line) + + # Export all syllable_flags: + self.syllable_flag = list(_characters) + _ch_syllable_flags + [ + self._pad, self._eos + ] + if self._has_mask: + self.syllable_flag.append(self._mask) + self._syllable_flag_to_id = { + s: i + for i, s in enumerate(self.syllable_flag) + } + self._id_to_syllable_flag = { + i: s + for i, s in enumerate(self.syllable_flag) + } + self._sub_unit_dim['syllable_flag'] = len(self.syllable_flag) + self._sub_unit_pad['syllable_flag'] = self._syllable_flag_to_id['_'] + + # word segment sub-unit + _characters = '' + + _ch_word_segments = [] + + ws_path = os.path.join(self._path, self._unit_config['word_segment']) + f = codecs.open(ws_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_word_segments.append(line) + + # Export all syllable_flags: + self.word_segment = list(_characters) + _ch_word_segments + [ + self._pad, self._eos + ] + if self._has_mask: + self.word_segment.append(self._mask) + self._word_segment_to_id = { + s: i + for i, s in enumerate(self.word_segment) + } + self._id_to_word_segment = { + i: s + for i, s in enumerate(self.word_segment) + } + self._sub_unit_dim['word_segment'] = len(self.word_segment) + self._sub_unit_pad['word_segment'] = self._word_segment_to_id['_'] + + if 'emo_category' in self._lfeat_type_list: + # emotion category sub-unit + _characters = '' + + _ch_emo_types = [] + + emo_path = os.path.join(self._path, + self._unit_config['emo_category']) + f = codecs.open(emo_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_emo_types.append(line) + + self.emo_category = list(_characters) + _ch_emo_types + [ + self._pad, self._eos + ] + if self._has_mask: + self.emo_category.append(self._mask) + self._emo_category_to_id = { + s: i + for i, s in enumerate(self.emo_category) + } + self._id_to_emo_category = { + i: s + for i, s in enumerate(self.emo_category) + } + self._sub_unit_dim['emo_category'] = len(self.emo_category) + self._sub_unit_pad['emo_category'] = self._emo_category_to_id['_'] + + if 'speaker_category' in self._lfeat_type_list: + # speaker category sub-unit + _characters = '' + + _ch_speakers = [] + + speaker_path = os.path.join(self._path, + self._unit_config['speaker_category']) + f = codecs.open(speaker_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_speakers.append(line) + + # Export all syllable_flags: + self.speaker = list(_characters) + _ch_speakers + [ + self._pad, self._eos + ] + if self._has_mask: + self.speaker.append(self._mask) + self._speaker_to_id = {s: i for i, s in enumerate(self.speaker)} + self._id_to_speaker = {i: s for i, s in enumerate(self.speaker)} + self._sub_unit_dim['speaker_category'] = len(self._speaker_to_id) + self._sub_unit_pad['speaker_category'] = self._speaker_to_id['_'] + + def encode_symbol_sequence(self, lfeat_symbol): + lfeat_symbol = lfeat_symbol.strip().split(' ') + + lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list)) + for this_lfeat_symbol in lfeat_symbol: + this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split( + '$') + index = 0 + while index < len(lfeat_symbol_separate): + lfeat_symbol_separate[index] = lfeat_symbol_separate[ + index] + this_lfeat_symbol[index] + ' ' + index = index + 1 + + input_and_label_data = [] + index = 0 + while index < len(self._lfeat_type_list): + sequence = self.encode_sub_unit( + lfeat_symbol_separate[index].strip(), + self._lfeat_type_list[index]) + sequence_array = np.asarray(sequence, dtype=np.int32) + input_and_label_data.append(sequence_array) + index = index + 1 + + return input_and_label_data + + def decode_symbol_sequence(self, sequence): + result = [] + for i, lfeat_type in enumerate(self._lfeat_type_list): + s = '' + sequence_item = sequence[i].tolist() + if lfeat_type == 'sy': + s = self.decode_sy(sequence_item) + elif lfeat_type == 'tone': + s = self.decode_tone(sequence_item) + elif lfeat_type == 'syllable_flag': + s = self.decode_syllable_flag(sequence_item) + elif lfeat_type == 'word_segment': + s = self.decode_word_segment(sequence_item) + elif lfeat_type == 'emo_category': + s = self.decode_emo_category(sequence_item) + elif lfeat_type == 'speaker_category': + s = self.decode_speaker_category(sequence_item) + else: + raise Exception( + 'modelscope error: configuration lfeat type(%s) unknown.' + % lfeat_type) + result.append('%s:%s' % (lfeat_type, s)) + + return result + + def encode_sub_unit(self, this_lfeat_symbol, lfeat_type): + sequence = [] + if lfeat_type == 'sy': + this_lfeat_symbol = this_lfeat_symbol.strip().split(' ') + this_lfeat_symbol_format = '' + index = 0 + while index < len(this_lfeat_symbol): + this_lfeat_symbol_format = this_lfeat_symbol_format + '{' + this_lfeat_symbol[ + index] + '}' + ' ' + index = index + 1 + sequence = self.encode_text(this_lfeat_symbol_format, + self._cleaner_names) + elif lfeat_type == 'tone': + sequence = self.encode_tone(this_lfeat_symbol) + elif lfeat_type == 'syllable_flag': + sequence = self.encode_syllable_flag(this_lfeat_symbol) + elif lfeat_type == 'word_segment': + sequence = self.encode_word_segment(this_lfeat_symbol) + elif lfeat_type == 'emo_category': + sequence = self.encode_emo_category(this_lfeat_symbol) + elif lfeat_type == 'speaker_category': + sequence = self.encode_speaker_category(this_lfeat_symbol) + else: + raise Exception( + 'modelscope error: configuration lfeat type(%s) unknown.' + % lfeat_type) + + return sequence + + def encode_text(self, text, cleaner_names): + sequence = [] + + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + sequence += self.encode_sy(_clean_text(text, cleaner_names)) + break + sequence += self.encode_sy(_clean_text(m.group(1), cleaner_names)) + sequence += self.encode_arpanet(m.group(2)) + text = m.group(3) + + # Append EOS token + sequence.append(self._sy_to_id['~']) + return sequence + + def encode_sy(self, sy): + return [self._sy_to_id[s] for s in sy if self.should_keep_sy(s)] + + def decode_sy(self, id): + s = self._id_to_sy[id] + if len(s) > 1 and s[0] == '@': + s = s[1:] + return s + + def should_keep_sy(self, s): + return s in self._sy_to_id and s != '_' and s != '~' + + def encode_arpanet(self, text): + return self.encode_sy(['@' + s for s in text.split()]) + + def encode_tone(self, tone): + tones = tone.strip().split(' ') + sequence = [] + for this_tone in tones: + sequence.append(self._tone_to_id[this_tone]) + sequence.append(self._tone_to_id['~']) + return sequence + + def decode_tone(self, id): + return self._id_to_tone[id] + + def encode_syllable_flag(self, syllable_flag): + syllable_flags = syllable_flag.strip().split(' ') + sequence = [] + for this_syllable_flag in syllable_flags: + sequence.append(self._syllable_flag_to_id[this_syllable_flag]) + sequence.append(self._syllable_flag_to_id['~']) + return sequence + + def decode_syllable_flag(self, id): + return self._id_to_syllable_flag[id] + + def encode_word_segment(self, word_segment): + word_segments = word_segment.strip().split(' ') + sequence = [] + for this_word_segment in word_segments: + sequence.append(self._word_segment_to_id[this_word_segment]) + sequence.append(self._word_segment_to_id['~']) + return sequence + + def decode_word_segment(self, id): + return self._id_to_word_segment[id] + + def encode_emo_category(self, emo_type): + emo_categories = emo_type.strip().split(' ') + sequence = [] + for this_category in emo_categories: + sequence.append(self._emo_category_to_id[this_category]) + sequence.append(self._emo_category_to_id['~']) + return sequence + + def decode_emo_category(self, id): + return self._id_to_emo_category[id] + + def encode_speaker_category(self, speaker): + speakers = speaker.strip().split(' ') + sequence = [] + for this_speaker in speakers: + sequence.append(self._speaker_to_id[this_speaker]) + sequence.append(self._speaker_to_id['~']) + return sequence + + def decode_speaker_category(self, id): + return self._id_to_speaker[id] diff --git a/modelscope/models/audio/tts/models/datasets/units/numbers.py b/modelscope/models/audio/tts/models/datasets/units/numbers.py new file mode 100644 index 00000000..d8835059 --- /dev/null +++ b/modelscope/models/audio/tts/models/datasets/units/numbers.py @@ -0,0 +1,73 @@ +# The implementation is adopted from tacotron, +# made publicly available under the MIT License at https://github.com/keithito/tacotron + +import re + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words( + num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/modelscope/models/audio/tts/models/models/__init__.py b/modelscope/models/audio/tts/models/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/tts/models/models/hifigan/__init__.py b/modelscope/models/audio/tts/models/models/hifigan/__init__.py new file mode 100644 index 00000000..ae9d10ea --- /dev/null +++ b/modelscope/models/audio/tts/models/models/hifigan/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .hifigan import * # noqa F403 diff --git a/modelscope/models/audio/tts/models/models/hifigan/hifigan.py b/modelscope/models/audio/tts/models/models/hifigan/hifigan.py new file mode 100755 index 00000000..0f950539 --- /dev/null +++ b/modelscope/models/audio/tts/models/models/hifigan/hifigan.py @@ -0,0 +1,238 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from https://github.com/jik876/hifi-gan + +from distutils.version import LooseVersion + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from modelscope.models.audio.tts.models.utils import get_padding, init_weights +from modelscope.utils.logger import get_logger + +logger = get_logger() +is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7') + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + + Returns: + Tensor: Magnitude spectrogram (B). + + """ + if is_pytorch_17plus: + x_stft = torch.stft( + x, fft_size, hop_size, win_length, window, return_complex=False) + else: + x_stft = torch.stft(x, fft_size, hop_size, win_length, window) + real = x_stft[..., 0] + imag = x_stft[..., 1] + + # NOTE(kan-bayashi): clamp is needed to avoid nan or inf + return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) + + +LRELU_SLOPE = 0.1 + + +def get_padding_casual(kernel_size, dilation=1): + return int(kernel_size * dilation - dilation) + + +class Conv1dCasual(torch.nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros'): + super(Conv1dCasual, self).__init__() + self.pad = padding + self.conv1d = weight_norm( + Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode)) + self.conv1d.apply(init_weights) + + def forward(self, x): # bdt + # described starting from the last dimension and moving forward. + x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), 'constant') + x = self.conv1d(x) + return x + + def remove_weight_norm(self): + remove_weight_norm(self.conv1d) + + +class ConvTranspose1dCausal(torch.nn.Module): + """CausalConvTranspose1d module with customized initialization.""" + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding=0): + """Initialize CausalConvTranspose1d module.""" + super(ConvTranspose1dCausal, self).__init__() + self.deconv = weight_norm( + ConvTranspose1d(in_channels, out_channels, kernel_size, stride)) + self.stride = stride + self.deconv.apply(init_weights) + self.pad = kernel_size - stride + + def forward(self, x): + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, in_channels, T_in). + Returns: + Tensor: Output tensor (B, out_channels, T_out). + """ + # x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant") + return self.deconv(x)[:, :, :-self.pad] + + def remove_weight_norm(self): + remove_weight_norm(self.deconv) + + +class ResBlock1(torch.nn.Module): + + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + Conv1dCasual( + channels, + channels, + kernel_size, + 1, + dilation=dilation[i], + padding=get_padding_casual(kernel_size, dilation[i])) + for i in range(len(dilation)) + ]) + + self.convs2 = nn.ModuleList([ + Conv1dCasual( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding_casual(kernel_size, 1)) + for i in range(len(dilation)) + ]) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for layer in self.convs1: + layer.remove_weight_norm() + for layer in self.convs2: + layer.remove_weight_norm() + + +class Generator(torch.nn.Module): + + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + logger.info('num_kernels={}, num_upsamples={}'.format( + self.num_kernels, self.num_upsamples)) + self.conv_pre = Conv1dCasual( + 80, h.upsample_initial_channel, 7, 1, padding=7 - 1) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + self.repeat_ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip(h.upsample_rates, h.upsample_kernel_sizes)): + upsample = nn.Sequential( + nn.Upsample(mode='nearest', scale_factor=u), + nn.LeakyReLU(LRELU_SLOPE), + Conv1dCasual( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + kernel_size=7, + stride=1, + padding=7 - 1)) + self.repeat_ups.append(upsample) + self.ups.append( + ConvTranspose1dCausal( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + k, + u, + padding=(k - u) // 2)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = Conv1dCasual(ch, 1, 7, 1, padding=7 - 1) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = torch.sin(x) + x + # transconv + x1 = F.leaky_relu(x, LRELU_SLOPE) + x1 = self.ups[i](x1) + # repeat + x2 = self.repeat_ups[i](x) + x = x1 + x2 + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + logger.info('Removing weight norm...') + for layer in self.ups: + layer.remove_weight_norm() + for layer in self.repeat_ups: + layer[-1].remove_weight_norm() + for layer in self.resblocks: + layer.remove_weight_norm() + self.conv_pre.remove_weight_norm() + self.conv_post.remove_weight_norm() diff --git a/modelscope/models/audio/tts/models/models/sambert/__init__.py b/modelscope/models/audio/tts/models/models/sambert/__init__.py new file mode 100644 index 00000000..f0bf5290 --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .kantts_sambert import * # noqa F403 diff --git a/modelscope/models/audio/tts/models/models/sambert/adaptors.py b/modelscope/models/audio/tts/models/models/sambert/adaptors.py new file mode 100644 index 00000000..c171a1db --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/adaptors.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import Prenet +from .fsmn import FsmnEncoderV2 + + +class LengthRegulator(nn.Module): + + def __init__(self, r=1): + super(LengthRegulator, self).__init__() + + self.r = r + + def forward(self, inputs, durations, masks=None): + reps = (durations + 0.5).long() + output_lens = reps.sum(dim=1) + max_len = output_lens.max() + reps_cumsum = torch.cumsum( + F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] + range_ = torch.arange(max_len).to(inputs.device)[None, :, None] + mult = ((reps_cumsum[:, :, :-1] <= range_) + & (reps_cumsum[:, :, 1:] > range_)) # yapf:disable + mult = mult.float() + out = torch.matmul(mult, inputs) + + if masks is not None: + out = out.masked_fill(masks.unsqueeze(-1), 0.0) + + seq_len = out.size(1) + padding = self.r - int(seq_len) % self.r + if (padding < self.r): + out = F.pad( + out.transpose(1, 2), (0, padding, 0, 0, 0, 0), value=0.0) + out = out.transpose(1, 2) + + return out, output_lens + + +class VarRnnARPredictor(nn.Module): + + def __init__(self, cond_units, prenet_units, rnn_units): + super(VarRnnARPredictor, self).__init__() + + self.prenet = Prenet(1, prenet_units) + self.lstm = nn.LSTM( + prenet_units[-1] + cond_units, + rnn_units, + num_layers=2, + batch_first=True, + bidirectional=False) + self.fc = nn.Linear(rnn_units, 1) + + def forward(self, inputs, cond, h=None, masks=None): + x = torch.cat([self.prenet(inputs), cond], dim=-1) + # The input can also be a packed variable length sequence, + # here we just omit it for simplicity due to the mask and uni-directional lstm. + x, h_new = self.lstm(x, h) + + x = self.fc(x).squeeze(-1) + x = F.relu(x) + + if masks is not None: + x = x.masked_fill(masks, 0.0) + + return x, h_new + + def infer(self, cond, masks=None): + batch_size, length = cond.size(0), cond.size(1) + + output = [] + x = torch.zeros((batch_size, 1)).to(cond.device) + h = None + + for i in range(length): + x, h = self.forward(x.unsqueeze(1), cond[:, i:i + 1, :], h=h) + output.append(x) + + output = torch.cat(output, dim=-1) + + if masks is not None: + output = output.masked_fill(masks, 0.0) + + return output + + +class VarFsmnRnnNARPredictor(nn.Module): + + def __init__(self, in_dim, filter_size, fsmn_num_layers, num_memory_units, + ffn_inner_dim, dropout, shift, lstm_units): + super(VarFsmnRnnNARPredictor, self).__init__() + + self.fsmn = FsmnEncoderV2(filter_size, fsmn_num_layers, in_dim, + num_memory_units, ffn_inner_dim, dropout, + shift) + self.blstm = nn.LSTM( + num_memory_units, + lstm_units, + num_layers=1, + batch_first=True, + bidirectional=True) + self.fc = nn.Linear(2 * lstm_units, 1) + + def forward(self, inputs, masks=None): + input_lengths = None + if masks is not None: + input_lengths = torch.sum((~masks).float(), dim=1).long() + + x = self.fsmn(inputs, masks) + + if input_lengths is not None: + x = nn.utils.rnn.pack_padded_sequence( + x, + input_lengths.tolist(), + batch_first=True, + enforce_sorted=False) + x, _ = self.blstm(x) + x, _ = nn.utils.rnn.pad_packed_sequence( + x, batch_first=True, total_length=inputs.size(1)) + else: + x, _ = self.blstm(x) + + x = self.fc(x).squeeze(-1) + + if masks is not None: + x = x.masked_fill(masks, 0.0) + + return x diff --git a/modelscope/models/audio/tts/models/models/sambert/base.py b/modelscope/models/audio/tts/models/models/sambert/base.py new file mode 100644 index 00000000..873aecbf --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/base.py @@ -0,0 +1,369 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ScaledDotProductAttention(nn.Module): + """ Scaled Dot-Product Attention """ + + def __init__(self, temperature, dropatt=0.0): + super().__init__() + self.temperature = temperature + self.softmax = nn.Softmax(dim=2) + self.dropatt = nn.Dropout(dropatt) + + def forward(self, q, k, v, mask=None): + + attn = torch.bmm(q, k.transpose(1, 2)) + attn = attn / self.temperature + + if mask is not None: + attn = attn.masked_fill(mask, -np.inf) + + attn = self.softmax(attn) + attn = self.dropatt(attn) + output = torch.bmm(attn, v) + + return output, attn + + +class Prenet(nn.Module): + + def __init__(self, in_units, prenet_units, out_units=0): + super(Prenet, self).__init__() + + self.fcs = nn.ModuleList() + for in_dim, out_dim in zip([in_units] + prenet_units[:-1], + prenet_units): + self.fcs.append(nn.Linear(in_dim, out_dim)) + self.fcs.append(nn.ReLU()) + self.fcs.append(nn.Dropout(0.5)) + + if (out_units): + self.fcs.append(nn.Linear(prenet_units[-1], out_units)) + + def forward(self, input): + output = input + for layer in self.fcs: + output = layer(output) + return output + + +class MultiHeadSelfAttention(nn.Module): + """ Multi-Head SelfAttention module """ + + def __init__(self, n_head, d_in, d_model, d_head, dropout, dropatt=0.0): + super().__init__() + + self.n_head = n_head + self.d_head = d_head + self.d_in = d_in + self.d_model = d_model + + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.w_qkv = nn.Linear(d_in, 3 * n_head * d_head) + + self.attention = ScaledDotProductAttention( + temperature=np.power(d_head, 0.5), dropatt=dropatt) + + self.fc = nn.Linear(n_head * d_head, d_model) + + self.dropout = nn.Dropout(dropout) + + def forward(self, input, mask=None): + d_head, n_head = self.d_head, self.n_head + + sz_b, len_in, _ = input.size() + + residual = input + + x = self.layer_norm(input) + qkv = self.w_qkv(x) + q, k, v = qkv.chunk(3, -1) + + q = q.view(sz_b, len_in, n_head, d_head) + k = k.view(sz_b, len_in, n_head, d_head) + v = v.view(sz_b, len_in, n_head, d_head) + + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_in, + d_head) # (n*b) x l x d + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_in, + d_head) # (n*b) x l x d + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_in, + d_head) # (n*b) x l x d + + if mask is not None: + mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + output, attn = self.attention(q, k, v, mask=mask) + + output = output.view(n_head, sz_b, len_in, d_head) + output = (output.permute(1, 2, 0, + 3).contiguous().view(sz_b, len_in, + -1)) # b x l x (n*d) + + output = self.dropout(self.fc(output)) + if (output.size(-1) == residual.size(-1)): + output = output + residual + + return output, attn + + +class PositionwiseConvFeedForward(nn.Module): + """ A two-feed-forward-layer module """ + + def __init__(self, + d_in, + d_hid, + kernel_size=(3, 1), + dropout_inner=0.1, + dropout=0.1): + super().__init__() + # Use Conv1D + # position-wise + self.w_1 = nn.Conv1d( + d_in, + d_hid, + kernel_size=kernel_size[0], + padding=(kernel_size[0] - 1) // 2, + ) + # position-wise + self.w_2 = nn.Conv1d( + d_hid, + d_in, + kernel_size=kernel_size[1], + padding=(kernel_size[1] - 1) // 2, + ) + + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.dropout_inner = nn.Dropout(dropout_inner) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask=None): + residual = x + x = self.layer_norm(x) + + output = x.transpose(1, 2) + output = F.relu(self.w_1(output)) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(1), 0) + output = self.dropout_inner(output) + output = self.w_2(output) + output = output.transpose(1, 2) + output = self.dropout(output) + + output = output + residual + + return output + + +class FFTBlock(nn.Module): + """FFT Block""" + + def __init__(self, + d_in, + d_model, + n_head, + d_head, + d_inner, + kernel_size, + dropout, + dropout_attn=0.0, + dropout_relu=0.0): + super(FFTBlock, self).__init__() + self.slf_attn = MultiHeadSelfAttention( + n_head, + d_in, + d_model, + d_head, + dropout=dropout, + dropatt=dropout_attn) + self.pos_ffn = PositionwiseConvFeedForward( + d_model, + d_inner, + kernel_size, + dropout_inner=dropout_relu, + dropout=dropout) + + def forward(self, input, mask=None, slf_attn_mask=None): + output, slf_attn = self.slf_attn(input, mask=slf_attn_mask) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + output = self.pos_ffn(output, mask=mask) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + return output, slf_attn + + +class MultiHeadPNCAAttention(nn.Module): + """ Multi-Head Attention PNCA module """ + + def __init__(self, n_head, d_model, d_mem, d_head, dropout, dropatt=0.0): + super().__init__() + + self.n_head = n_head + self.d_head = d_head + self.d_model = d_model + self.d_mem = d_mem + + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + self.w_x_qkv = nn.Linear(d_model, 3 * n_head * d_head) + self.fc_x = nn.Linear(n_head * d_head, d_model) + + self.w_h_kv = nn.Linear(d_mem, 2 * n_head * d_head) + self.fc_h = nn.Linear(n_head * d_head, d_model) + + self.attention = ScaledDotProductAttention( + temperature=np.power(d_head, 0.5), dropatt=dropatt) + + self.dropout = nn.Dropout(dropout) + + def update_x_state(self, x): + d_head, n_head = self.d_head, self.n_head + + sz_b, len_x, _ = x.size() + + x_qkv = self.w_x_qkv(x) + x_q, x_k, x_v = x_qkv.chunk(3, -1) + + x_q = x_q.view(sz_b, len_x, n_head, d_head) + x_k = x_k.view(sz_b, len_x, n_head, d_head) + x_v = x_v.view(sz_b, len_x, n_head, d_head) + + x_q = x_q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head) + x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head) + x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head) + + if (self.x_state_size): + self.x_k = torch.cat([self.x_k, x_k], dim=1) + self.x_v = torch.cat([self.x_v, x_v], dim=1) + else: + self.x_k = x_k + self.x_v = x_v + + self.x_state_size += len_x + + return x_q, x_k, x_v + + def update_h_state(self, h): + if (self.h_state_size == h.size(1)): + return None, None + + d_head, n_head = self.d_head, self.n_head + + # H + sz_b, len_h, _ = h.size() + + h_kv = self.w_h_kv(h) + h_k, h_v = h_kv.chunk(2, -1) + + h_k = h_k.view(sz_b, len_h, n_head, d_head) + h_v = h_v.view(sz_b, len_h, n_head, d_head) + + self.h_k = h_k.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head) + self.h_v = h_v.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head) + + self.h_state_size += len_h + + return h_k, h_v + + def reset_state(self): + self.h_k = None + self.h_v = None + self.h_state_size = 0 + self.x_k = None + self.x_v = None + self.x_state_size = 0 + + def forward(self, x, h, mask_x=None, mask_h=None): + residual = x + self.update_h_state(h) + x_q, x_k, x_v = self.update_x_state(self.layer_norm(x)) + + d_head, n_head = self.d_head, self.n_head + + sz_b, len_in, _ = x.size() + + # X + if mask_x is not None: + mask_x = mask_x.repeat(n_head, 1, 1) # (n*b) x .. x .. + output_x, attn_x = self.attention(x_q, self.x_k, self.x_v, mask=mask_x) + + output_x = output_x.view(n_head, sz_b, len_in, d_head) + output_x = (output_x.permute(1, 2, 0, + 3).contiguous().view(sz_b, len_in, + -1)) # b x l x (n*d) + output_x = self.fc_x(output_x) + + # H + if mask_h is not None: + mask_h = mask_h.repeat(n_head, 1, 1) + output_h, attn_h = self.attention(x_q, self.h_k, self.h_v, mask=mask_h) + + output_h = output_h.view(n_head, sz_b, len_in, d_head) + output_h = (output_h.permute(1, 2, 0, + 3).contiguous().view(sz_b, len_in, + -1)) # b x l x (n*d) + output_h = self.fc_h(output_h) + + output = output_x + output_h + + output = self.dropout(output) + + output = output + residual + + return output, attn_x, attn_h + + +class PNCABlock(nn.Module): + """PNCA Block""" + + def __init__(self, + d_model, + d_mem, + n_head, + d_head, + d_inner, + kernel_size, + dropout, + dropout_attn=0.0, + dropout_relu=0.0): + super(PNCABlock, self).__init__() + self.pnca_attn = MultiHeadPNCAAttention( + n_head, + d_model, + d_mem, + d_head, + dropout=dropout, + dropatt=dropout_attn) + self.pos_ffn = PositionwiseConvFeedForward( + d_model, + d_inner, + kernel_size, + dropout_inner=dropout_relu, + dropout=dropout) + + def forward(self, + input, + memory, + mask=None, + pnca_x_attn_mask=None, + pnca_h_attn_mask=None): + output, pnca_attn_x, pnca_attn_h = self.pnca_attn( + input, memory, pnca_x_attn_mask, pnca_h_attn_mask) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + output = self.pos_ffn(output, mask=mask) + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + return output, pnca_attn_x, pnca_attn_h + + def reset_state(self): + self.pnca_attn.reset_state() diff --git a/modelscope/models/audio/tts/models/models/sambert/fsmn.py b/modelscope/models/audio/tts/models/models/sambert/fsmn.py new file mode 100644 index 00000000..c070ef35 --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/fsmn.py @@ -0,0 +1,126 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +FSMN Pytorch Version +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FeedForwardNet(nn.Module): + """ A two-feed-forward-layer module """ + + def __init__(self, d_in, d_hid, d_out, kernel_size=[1, 1], dropout=0.1): + super().__init__() + + # Use Conv1D + # position-wise + self.w_1 = nn.Conv1d( + d_in, + d_hid, + kernel_size=kernel_size[0], + padding=(kernel_size[0] - 1) // 2, + ) + # position-wise + self.w_2 = nn.Conv1d( + d_hid, + d_out, + kernel_size=kernel_size[1], + padding=(kernel_size[1] - 1) // 2, + bias=False) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + output = x.transpose(1, 2) + output = F.relu(self.w_1(output)) + output = self.dropout(output) + output = self.w_2(output) + output = output.transpose(1, 2) + + return output + + +class MemoryBlockV2(nn.Module): + + def __init__(self, d, filter_size, shift, dropout=0.0): + super(MemoryBlockV2, self).__init__() + + left_padding = int(round((filter_size - 1) / 2)) + right_padding = int((filter_size - 1) / 2) + if shift > 0: + left_padding += shift + right_padding -= shift + + self.lp, self.rp = left_padding, right_padding + + self.conv_dw = nn.Conv1d(d, d, filter_size, 1, 0, groups=d, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, input, mask=None): + if mask is not None: + input = input.masked_fill(mask.unsqueeze(-1), 0) + + x = F.pad( + input, (0, 0, self.lp, self.rp, 0, 0), mode='constant', value=0.0) + output = self.conv_dw(x.contiguous().transpose( + 1, 2)).contiguous().transpose(1, 2) + output += input + output = self.dropout(output) + + if mask is not None: + output = output.masked_fill(mask.unsqueeze(-1), 0) + + return output + + +class FsmnEncoderV2(nn.Module): + + def __init__(self, + filter_size, + fsmn_num_layers, + input_dim, + num_memory_units, + ffn_inner_dim, + dropout=0.0, + shift=0): + super(FsmnEncoderV2, self).__init__() + + self.filter_size = filter_size + self.fsmn_num_layers = fsmn_num_layers + self.num_memory_units = num_memory_units + self.ffn_inner_dim = ffn_inner_dim + self.dropout = dropout + self.shift = shift + if not isinstance(shift, list): + self.shift = [shift for _ in range(self.fsmn_num_layers)] + + self.ffn_lst = nn.ModuleList() + self.ffn_lst.append( + FeedForwardNet( + input_dim, ffn_inner_dim, num_memory_units, dropout=dropout)) + for i in range(1, fsmn_num_layers): + self.ffn_lst.append( + FeedForwardNet( + num_memory_units, + ffn_inner_dim, + num_memory_units, + dropout=dropout)) + + self.memory_block_lst = nn.ModuleList() + for i in range(fsmn_num_layers): + self.memory_block_lst.append( + MemoryBlockV2(num_memory_units, filter_size, self.shift[i], + dropout)) + + def forward(self, input, mask=None): + x = F.dropout(input, self.dropout, self.training) + for (ffn, memory_block) in zip(self.ffn_lst, self.memory_block_lst): + context = ffn(x) + memory = memory_block(context, mask) + memory = F.dropout(memory, self.dropout, self.training) + if (memory.size(-1) == x.size(-1)): + memory += x + x = memory + + return x diff --git a/modelscope/models/audio/tts/models/models/sambert/kantts_sambert.py b/modelscope/models/audio/tts/models/models/sambert/kantts_sambert.py new file mode 100644 index 00000000..3837a2e8 --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/kantts_sambert.py @@ -0,0 +1,718 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.audio.tts.models.utils import get_mask_from_lengths +from .adaptors import (LengthRegulator, VarFsmnRnnNARPredictor, + VarRnnARPredictor) +from .base import FFTBlock, PNCABlock, Prenet +from .fsmn import FsmnEncoderV2 +from .positions import DurSinusoidalPositionEncoder, SinusoidalPositionEncoder + + +class SelfAttentionEncoder(nn.Module): + + def __init__(self, n_layer, d_in, d_model, n_head, d_head, d_inner, + dropout, dropout_att, dropout_relu, position_encoder): + super(SelfAttentionEncoder, self).__init__() + + self.d_in = d_in + self.d_model = d_model + self.dropout = dropout + d_in_lst = [d_in] + [d_model] * (n_layer - 1) + self.fft = nn.ModuleList([ + FFTBlock(d, d_model, n_head, d_head, d_inner, (3, 1), dropout, + dropout_att, dropout_relu) for d in d_in_lst + ]) + self.ln = nn.LayerNorm(d_model, eps=1e-6) + self.position_enc = position_encoder + + def forward(self, input, mask=None, return_attns=False): + input *= self.d_model**0.5 + if (isinstance(self.position_enc, SinusoidalPositionEncoder)): + input = self.position_enc(input) + else: + raise NotImplementedError('modelscope error: position_enc invalid') + + input = F.dropout(input, p=self.dropout, training=self.training) + + enc_slf_attn_list = [] + max_len = input.size(1) + if mask is not None: + slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) + else: + slf_attn_mask = None + + enc_output = input + for id, layer in enumerate(self.fft): + enc_output, enc_slf_attn = layer( + enc_output, mask=mask, slf_attn_mask=slf_attn_mask) + if return_attns: + enc_slf_attn_list += [enc_slf_attn] + + enc_output = self.ln(enc_output) + + return enc_output, enc_slf_attn_list + + +class HybridAttentionDecoder(nn.Module): + + def __init__(self, d_in, prenet_units, n_layer, d_model, d_mem, n_head, + d_head, d_inner, dropout, dropout_att, dropout_relu, d_out): + super(HybridAttentionDecoder, self).__init__() + + self.d_model = d_model + self.dropout = dropout + self.prenet = Prenet(d_in, prenet_units, d_model) + self.dec_in_proj = nn.Linear(d_model + d_mem, d_model) + self.pnca = nn.ModuleList([ + PNCABlock(d_model, d_mem, n_head, d_head, d_inner, (1, 1), dropout, + dropout_att, dropout_relu) for _ in range(n_layer) + ]) + self.ln = nn.LayerNorm(d_model, eps=1e-6) + self.dec_out_proj = nn.Linear(d_model, d_out) + + def reset_state(self): + for layer in self.pnca: + layer.reset_state() + + def get_pnca_attn_mask(self, + device, + max_len, + x_band_width, + h_band_width, + mask=None): + if mask is not None: + pnca_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) + else: + pnca_attn_mask = None + + range_ = torch.arange(max_len).to(device) + x_start = torch.clamp_min(range_ - x_band_width, 0)[None, None, :] + x_end = (range_ + 1)[None, None, :] + h_start = range_[None, None, :] + h_end = torch.clamp_max(range_ + h_band_width + 1, + max_len + 1)[None, None, :] + + pnca_x_attn_mask = ~((x_start <= range_[None, :, None]) + & (x_end > range_[None, :, None])).transpose(1, 2) # yapf:disable + pnca_h_attn_mask = ~((h_start <= range_[None, :, None]) + & (h_end > range_[None, :, None])).transpose(1, 2) # yapf:disable + + if pnca_attn_mask is not None: + pnca_x_attn_mask = (pnca_x_attn_mask | pnca_attn_mask) + pnca_h_attn_mask = (pnca_h_attn_mask | pnca_attn_mask) + pnca_x_attn_mask = pnca_x_attn_mask.masked_fill( + pnca_attn_mask.transpose(1, 2), False) + pnca_h_attn_mask = pnca_h_attn_mask.masked_fill( + pnca_attn_mask.transpose(1, 2), False) + + return pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask + + # must call reset_state before + def forward(self, + input, + memory, + x_band_width, + h_band_width, + mask=None, + return_attns=False): + input = self.prenet(input) + input = torch.cat([memory, input], dim=-1) + input = self.dec_in_proj(input) + + if mask is not None: + input = input.masked_fill(mask.unsqueeze(-1), 0) + + input *= self.d_model**0.5 + input = F.dropout(input, p=self.dropout, training=self.training) + + max_len = input.size(1) + pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask( + input.device, max_len, x_band_width, h_band_width, mask) + + dec_pnca_attn_x_list = [] + dec_pnca_attn_h_list = [] + dec_output = input + for id, layer in enumerate(self.pnca): + dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer( + dec_output, + memory, + mask=mask, + pnca_x_attn_mask=pnca_x_attn_mask, + pnca_h_attn_mask=pnca_h_attn_mask) + if return_attns: + dec_pnca_attn_x_list += [dec_pnca_attn_x] + dec_pnca_attn_h_list += [dec_pnca_attn_h] + + dec_output = self.ln(dec_output) + dec_output = self.dec_out_proj(dec_output) + + return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list + + # must call reset_state before when step == 0 + def infer(self, + step, + input, + memory, + x_band_width, + h_band_width, + mask=None, + return_attns=False): + max_len = memory.size(1) + + input = self.prenet(input) + input = torch.cat([memory[:, step:step + 1, :], input], dim=-1) + input = self.dec_in_proj(input) + + input *= self.d_model**0.5 + input = F.dropout(input, p=self.dropout, training=self.training) + + pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask( + input.device, max_len, x_band_width, h_band_width, mask) + + dec_pnca_attn_x_list = [] + dec_pnca_attn_h_list = [] + dec_output = input + for id, layer in enumerate(self.pnca): + if mask is not None: + mask_step = mask[:, step:step + 1] + else: + mask_step = None + dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer( + dec_output, + memory, + mask=mask_step, + pnca_x_attn_mask=pnca_x_attn_mask[:, + step:step + 1, :(step + 1)], + pnca_h_attn_mask=pnca_h_attn_mask[:, step:step + 1, :]) + if return_attns: + dec_pnca_attn_x_list += [dec_pnca_attn_x] + dec_pnca_attn_h_list += [dec_pnca_attn_h] + + dec_output = self.ln(dec_output) + dec_output = self.dec_out_proj(dec_output) + + return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list + + +class TextFftEncoder(nn.Module): + + def __init__(self, config, ling_unit_size): + super(TextFftEncoder, self).__init__() + + # linguistic unit lookup table + nb_ling_sy = ling_unit_size['sy'] + nb_ling_tone = ling_unit_size['tone'] + nb_ling_syllable_flag = ling_unit_size['syllable_flag'] + nb_ling_ws = ling_unit_size['word_segment'] + + max_len = config['am']['max_len'] + + d_emb = config['am']['embedding_dim'] + nb_layers = config['am']['encoder_num_layers'] + nb_heads = config['am']['encoder_num_heads'] + d_model = config['am']['encoder_num_units'] + d_head = d_model // nb_heads + d_inner = config['am']['encoder_ffn_inner_dim'] + dropout = config['am']['encoder_dropout'] + dropout_attn = config['am']['encoder_attention_dropout'] + dropout_relu = config['am']['encoder_relu_dropout'] + d_proj = config['am']['encoder_projection_units'] + + self.d_model = d_model + + self.sy_emb = nn.Embedding(nb_ling_sy, d_emb) + self.tone_emb = nn.Embedding(nb_ling_tone, d_emb) + self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb) + self.ws_emb = nn.Embedding(nb_ling_ws, d_emb) + + position_enc = SinusoidalPositionEncoder(max_len, d_emb) + + self.ling_enc = SelfAttentionEncoder(nb_layers, d_emb, d_model, + nb_heads, d_head, d_inner, + dropout, dropout_attn, + dropout_relu, position_enc) + + self.ling_proj = nn.Linear(d_model, d_proj, bias=False) + + def forward(self, inputs_ling, masks=None, return_attns=False): + # Parse inputs_ling_seq + inputs_sy = inputs_ling[:, :, 0] + inputs_tone = inputs_ling[:, :, 1] + inputs_syllable_flag = inputs_ling[:, :, 2] + inputs_ws = inputs_ling[:, :, 3] + + # Lookup table + sy_embedding = self.sy_emb(inputs_sy) + tone_embedding = self.tone_emb(inputs_tone) + syllable_flag_embedding = self.syllable_flag_emb(inputs_syllable_flag) + ws_embedding = self.ws_emb(inputs_ws) + + ling_embedding = sy_embedding + tone_embedding + syllable_flag_embedding + ws_embedding + + enc_output, enc_slf_attn_list = self.ling_enc(ling_embedding, masks, + return_attns) + + enc_output = self.ling_proj(enc_output) + + return enc_output, enc_slf_attn_list + + +class VarianceAdaptor(nn.Module): + + def __init__(self, config): + super(VarianceAdaptor, self).__init__() + + input_dim = config['am']['encoder_projection_units'] + config['am'][ + 'emotion_units'] + config['am']['speaker_units'] + filter_size = config['am']['predictor_filter_size'] + fsmn_num_layers = config['am']['predictor_fsmn_num_layers'] + num_memory_units = config['am']['predictor_num_memory_units'] + ffn_inner_dim = config['am']['predictor_ffn_inner_dim'] + dropout = config['am']['predictor_dropout'] + shift = config['am']['predictor_shift'] + lstm_units = config['am']['predictor_lstm_units'] + + dur_pred_prenet_units = config['am']['dur_pred_prenet_units'] + dur_pred_lstm_units = config['am']['dur_pred_lstm_units'] + + self.pitch_predictor = VarFsmnRnnNARPredictor(input_dim, filter_size, + fsmn_num_layers, + num_memory_units, + ffn_inner_dim, dropout, + shift, lstm_units) + self.energy_predictor = VarFsmnRnnNARPredictor(input_dim, filter_size, + fsmn_num_layers, + num_memory_units, + ffn_inner_dim, dropout, + shift, lstm_units) + self.duration_predictor = VarRnnARPredictor(input_dim, + dur_pred_prenet_units, + dur_pred_lstm_units) + + self.length_regulator = LengthRegulator( + config['am']['outputs_per_step']) + self.dur_position_encoder = DurSinusoidalPositionEncoder( + config['am']['encoder_projection_units'], + config['am']['outputs_per_step']) + + self.pitch_emb = nn.Conv1d( + 1, + config['am']['encoder_projection_units'], + kernel_size=9, + padding=4) + self.energy_emb = nn.Conv1d( + 1, + config['am']['encoder_projection_units'], + kernel_size=9, + padding=4) + + def forward(self, + inputs_text_embedding, + inputs_emo_embedding, + inputs_spk_embedding, + masks=None, + output_masks=None, + duration_targets=None, + pitch_targets=None, + energy_targets=None): + + batch_size = inputs_text_embedding.size(0) + + variance_predictor_inputs = torch.cat([ + inputs_text_embedding, inputs_spk_embedding, inputs_emo_embedding + ], dim=-1) # yapf:disable + + pitch_predictions = self.pitch_predictor(variance_predictor_inputs, + masks) + energy_predictions = self.energy_predictor(variance_predictor_inputs, + masks) + + if pitch_targets is not None: + pitch_embeddings = self.pitch_emb( + pitch_targets.unsqueeze(1)).transpose(1, 2) + else: + pitch_embeddings = self.pitch_emb( + pitch_predictions.unsqueeze(1)).transpose(1, 2) + + if energy_targets is not None: + energy_embeddings = self.energy_emb( + energy_targets.unsqueeze(1)).transpose(1, 2) + else: + energy_embeddings = self.energy_emb( + energy_predictions.unsqueeze(1)).transpose(1, 2) + + inputs_text_embedding_aug = inputs_text_embedding + pitch_embeddings + energy_embeddings + duration_predictor_cond = torch.cat([ + inputs_text_embedding_aug, inputs_spk_embedding, + inputs_emo_embedding + ], dim=-1) # yapf:disable + if duration_targets is not None: + duration_predictor_go_frame = torch.zeros(batch_size, 1).to( + inputs_text_embedding.device) + duration_predictor_input = torch.cat([ + duration_predictor_go_frame, duration_targets[:, :-1].float() + ], dim=-1) # yapf:disable + duration_predictor_input = torch.log(duration_predictor_input + 1) + log_duration_predictions, _ = self.duration_predictor( + duration_predictor_input.unsqueeze(-1), + duration_predictor_cond, + masks=masks) + duration_predictions = torch.exp(log_duration_predictions) - 1 + else: + log_duration_predictions = self.duration_predictor.infer( + duration_predictor_cond, masks=masks) + duration_predictions = torch.exp(log_duration_predictions) - 1 + + if duration_targets is not None: + LR_text_outputs, LR_length_rounded = self.length_regulator( + inputs_text_embedding_aug, + duration_targets, + masks=output_masks) + LR_position_embeddings = self.dur_position_encoder( + duration_targets, masks=output_masks) + LR_emo_outputs, _ = self.length_regulator( + inputs_emo_embedding, duration_targets, masks=output_masks) + LR_spk_outputs, _ = self.length_regulator( + inputs_spk_embedding, duration_targets, masks=output_masks) + + else: + LR_text_outputs, LR_length_rounded = self.length_regulator( + inputs_text_embedding_aug, + duration_predictions, + masks=output_masks) + LR_position_embeddings = self.dur_position_encoder( + duration_predictions, masks=output_masks) + LR_emo_outputs, _ = self.length_regulator( + inputs_emo_embedding, duration_predictions, masks=output_masks) + LR_spk_outputs, _ = self.length_regulator( + inputs_spk_embedding, duration_predictions, masks=output_masks) + + LR_text_outputs = LR_text_outputs + LR_position_embeddings + + return (LR_text_outputs, LR_emo_outputs, LR_spk_outputs, + LR_length_rounded, log_duration_predictions, pitch_predictions, + energy_predictions) + + +class MelPNCADecoder(nn.Module): + + def __init__(self, config): + super(MelPNCADecoder, self).__init__() + + prenet_units = config['am']['decoder_prenet_units'] + nb_layers = config['am']['decoder_num_layers'] + nb_heads = config['am']['decoder_num_heads'] + d_model = config['am']['decoder_num_units'] + d_head = d_model // nb_heads + d_inner = config['am']['decoder_ffn_inner_dim'] + dropout = config['am']['decoder_dropout'] + dropout_attn = config['am']['decoder_attention_dropout'] + dropout_relu = config['am']['decoder_relu_dropout'] + outputs_per_step = config['am']['outputs_per_step'] + + d_mem = config['am'][ + 'encoder_projection_units'] * outputs_per_step + config['am'][ + 'emotion_units'] + config['am']['speaker_units'] + d_mel = config['am']['num_mels'] + + self.d_mel = d_mel + self.r = outputs_per_step + self.nb_layers = nb_layers + + self.mel_dec = HybridAttentionDecoder(d_mel, prenet_units, nb_layers, + d_model, d_mem, nb_heads, d_head, + d_inner, dropout, dropout_attn, + dropout_relu, + d_mel * outputs_per_step) + + def forward(self, + memory, + x_band_width, + h_band_width, + target=None, + mask=None, + return_attns=False): + batch_size = memory.size(0) + go_frame = torch.zeros((batch_size, 1, self.d_mel)).to(memory.device) + + if target is not None: + self.mel_dec.reset_state() + input = target[:, self.r - 1::self.r, :] + input = torch.cat([go_frame, input], dim=1)[:, :-1, :] + dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list = self.mel_dec( + input, + memory, + x_band_width, + h_band_width, + mask=mask, + return_attns=return_attns) + + else: + dec_output = [] + dec_pnca_attn_x_list = [[] for _ in range(self.nb_layers)] + dec_pnca_attn_h_list = [[] for _ in range(self.nb_layers)] + self.mel_dec.reset_state() + input = go_frame + for step in range(memory.size(1)): + dec_output_step, dec_pnca_attn_x_step, dec_pnca_attn_h_step = self.mel_dec.infer( + step, + input, + memory, + x_band_width, + h_band_width, + mask=mask, + return_attns=return_attns) + input = dec_output_step[:, :, -self.d_mel:] + + dec_output.append(dec_output_step) + for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate( + zip(dec_pnca_attn_x_step, dec_pnca_attn_h_step)): + left = memory.size(1) - pnca_x_attn.size(-1) + if (left > 0): + padding = torch.zeros( + (pnca_x_attn.size(0), 1, left)).to(pnca_x_attn) + pnca_x_attn = torch.cat([pnca_x_attn, padding], dim=-1) + dec_pnca_attn_x_list[layer_id].append(pnca_x_attn) + dec_pnca_attn_h_list[layer_id].append(pnca_h_attn) + + dec_output = torch.cat(dec_output, dim=1) + for layer_id in range(self.nb_layers): + dec_pnca_attn_x_list[layer_id] = torch.cat( + dec_pnca_attn_x_list[layer_id], dim=1) + dec_pnca_attn_h_list[layer_id] = torch.cat( + dec_pnca_attn_h_list[layer_id], dim=1) + + return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list + + +class PostNet(nn.Module): + + def __init__(self, config): + super(PostNet, self).__init__() + + self.filter_size = config['am']['postnet_filter_size'] + self.fsmn_num_layers = config['am']['postnet_fsmn_num_layers'] + self.num_memory_units = config['am']['postnet_num_memory_units'] + self.ffn_inner_dim = config['am']['postnet_ffn_inner_dim'] + self.dropout = config['am']['postnet_dropout'] + self.shift = config['am']['postnet_shift'] + self.lstm_units = config['am']['postnet_lstm_units'] + self.num_mels = config['am']['num_mels'] + + self.fsmn = FsmnEncoderV2(self.filter_size, self.fsmn_num_layers, + self.num_mels, self.num_memory_units, + self.ffn_inner_dim, self.dropout, self.shift) + self.lstm = nn.LSTM( + self.num_memory_units, + self.lstm_units, + num_layers=1, + batch_first=True) + self.fc = nn.Linear(self.lstm_units, self.num_mels) + + def forward(self, x, mask=None): + postnet_fsmn_output = self.fsmn(x, mask) + # The input can also be a packed variable length sequence, + # here we just omit it for simpliciy due to the mask and uni-directional lstm. + postnet_lstm_output, _ = self.lstm(postnet_fsmn_output) + mel_residual_output = self.fc(postnet_lstm_output) + + return mel_residual_output + + +def mel_recon_loss_fn(output_lengths, + mel_targets, + dec_outputs, + postnet_outputs=None): + mae_loss = nn.L1Loss(reduction='none') + + output_masks = get_mask_from_lengths( + output_lengths, max_len=mel_targets.size(1)) + output_masks = ~output_masks + valid_outputs = output_masks.sum() + + mel_loss_ = torch.sum( + mae_loss(mel_targets, dec_outputs) * output_masks.unsqueeze(-1)) / ( + valid_outputs * mel_targets.size(-1)) + + if postnet_outputs is not None: + mel_loss = torch.sum( + mae_loss(mel_targets, postnet_outputs) + * output_masks.unsqueeze(-1)) / ( + valid_outputs * mel_targets.size(-1)) + else: + mel_loss = 0.0 + + return mel_loss_, mel_loss + + +def prosody_recon_loss_fn(input_lengths, duration_targets, pitch_targets, + energy_targets, log_duration_predictions, + pitch_predictions, energy_predictions): + mae_loss = nn.L1Loss(reduction='none') + + input_masks = get_mask_from_lengths( + input_lengths, max_len=duration_targets.size(1)) + input_masks = ~input_masks + valid_inputs = input_masks.sum() + + dur_loss = torch.sum( + mae_loss( + torch.log(duration_targets.float() + 1), log_duration_predictions) + * input_masks) / valid_inputs + pitch_loss = torch.sum( + mae_loss(pitch_targets, pitch_predictions) + * input_masks) / valid_inputs + energy_loss = torch.sum( + mae_loss(energy_targets, energy_predictions) + * input_masks) / valid_inputs + + return dur_loss, pitch_loss, energy_loss + + +class KanTtsSAMBERT(nn.Module): + + def __init__(self, config, ling_unit_size): + super(KanTtsSAMBERT, self).__init__() + + self.text_encoder = TextFftEncoder(config, ling_unit_size) + self.spk_tokenizer = nn.Embedding(ling_unit_size['speaker'], + config['am']['speaker_units']) + self.emo_tokenizer = nn.Embedding(ling_unit_size['emotion'], + config['am']['emotion_units']) + self.variance_adaptor = VarianceAdaptor(config) + self.mel_decoder = MelPNCADecoder(config) + self.mel_postnet = PostNet(config) + + def get_lfr_mask_from_lengths(self, lengths, max_len): + batch_size = lengths.size(0) + # padding according to the outputs_per_step + padded_lr_lengths = torch.zeros_like(lengths) + for i in range(batch_size): + len_item = int(lengths[i].item()) + padding = self.mel_decoder.r - len_item % self.mel_decoder.r + if (padding < self.mel_decoder.r): + padded_lr_lengths[i] = (len_item + + padding) // self.mel_decoder.r + else: + padded_lr_lengths[i] = len_item // self.mel_decoder.r + + return get_mask_from_lengths( + padded_lr_lengths, max_len=max_len // self.mel_decoder.r) + + def forward(self, + inputs_ling, + inputs_emotion, + inputs_speaker, + input_lengths, + output_lengths=None, + mel_targets=None, + duration_targets=None, + pitch_targets=None, + energy_targets=None): + + batch_size = inputs_ling.size(0) + + input_masks = get_mask_from_lengths( + input_lengths, max_len=inputs_ling.size(1)) + + text_hid, enc_sla_attn_lst = self.text_encoder( + inputs_ling, input_masks, return_attns=True) + + emo_hid = self.emo_tokenizer(inputs_emotion) + spk_hid = self.spk_tokenizer(inputs_speaker) + + if output_lengths is not None: + output_masks = get_mask_from_lengths( + output_lengths, max_len=mel_targets.size(1)) + else: + output_masks = None + + (LR_text_outputs, LR_emo_outputs, LR_spk_outputs, LR_length_rounded, + log_duration_predictions, pitch_predictions, + energy_predictions) = self.variance_adaptor( + text_hid, + emo_hid, + spk_hid, + masks=input_masks, + output_masks=output_masks, + duration_targets=duration_targets, + pitch_targets=pitch_targets, + energy_targets=energy_targets) + + if output_lengths is not None: + lfr_masks = self.get_lfr_mask_from_lengths( + output_lengths, max_len=LR_text_outputs.size(1)) + else: + output_masks = get_mask_from_lengths( + LR_length_rounded, max_len=LR_text_outputs.size(1)) + lfr_masks = None + + # LFR with the factor of outputs_per_step + LFR_text_inputs = LR_text_outputs.contiguous().view( + batch_size, -1, self.mel_decoder.r * text_hid.shape[-1]) + LFR_emo_inputs = LR_emo_outputs.contiguous().view( + batch_size, -1, + self.mel_decoder.r * emo_hid.shape[-1])[:, :, :emo_hid.shape[-1]] + LFR_spk_inputs = LR_spk_outputs.contiguous().view( + batch_size, -1, + self.mel_decoder.r * spk_hid.shape[-1])[:, :, :spk_hid.shape[-1]] + + memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs], + dim=-1) + + if duration_targets is not None: + x_band_width = int( + duration_targets.float().masked_fill(input_masks, 0).max() + / self.mel_decoder.r + 0.5) + h_band_width = x_band_width + else: + x_band_width = int((torch.exp(log_duration_predictions) - 1).max() + / self.mel_decoder.r + 0.5) + h_band_width = x_band_width + + dec_outputs, pnca_x_attn_lst, pnca_h_attn_lst = self.mel_decoder( + memory, + x_band_width, + h_band_width, + target=mel_targets, + mask=lfr_masks, + return_attns=True) + + # De-LFR with the factor of outputs_per_step + dec_outputs = dec_outputs.contiguous().view(batch_size, -1, + self.mel_decoder.d_mel) + + if output_masks is not None: + dec_outputs = dec_outputs.masked_fill( + output_masks.unsqueeze(-1), 0) + + postnet_outputs = self.mel_postnet(dec_outputs, + output_masks) + dec_outputs + if output_masks is not None: + postnet_outputs = postnet_outputs.masked_fill( + output_masks.unsqueeze(-1), 0) + + res = { + 'x_band_width': x_band_width, + 'h_band_width': h_band_width, + 'enc_slf_attn_lst': enc_sla_attn_lst, + 'pnca_x_attn_lst': pnca_x_attn_lst, + 'pnca_h_attn_lst': pnca_h_attn_lst, + 'dec_outputs': dec_outputs, + 'postnet_outputs': postnet_outputs, + 'LR_length_rounded': LR_length_rounded, + 'log_duration_predictions': log_duration_predictions, + 'pitch_predictions': pitch_predictions, + 'energy_predictions': energy_predictions + } + + res['LR_text_outputs'] = LR_text_outputs + res['LR_emo_outputs'] = LR_emo_outputs + res['LR_spk_outputs'] = LR_spk_outputs + + return res diff --git a/modelscope/models/audio/tts/models/models/sambert/positions.py b/modelscope/models/audio/tts/models/models/sambert/positions.py new file mode 100644 index 00000000..9d1e375d --- /dev/null +++ b/modelscope/models/audio/tts/models/models/sambert/positions.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SinusoidalPositionEncoder(nn.Module): + + def __init__(self, max_len, depth): + super(SinusoidalPositionEncoder, self).__init__() + + self.max_len = max_len + self.depth = depth + self.position_enc = nn.Parameter( + self.get_sinusoid_encoding_table(max_len, depth).unsqueeze(0), + requires_grad=False) + + def forward(self, input): + bz_in, len_in, _ = input.size() + if len_in > self.max_len: + self.max_len = len_in + self.position_enc.data = self.get_sinusoid_encoding_table( + self.max_len, self.depth).unsqueeze(0).to(input.device) + + output = input + self.position_enc[:, :len_in, :].expand(bz_in, -1, -1) + + return output + + @staticmethod + def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + """ Sinusoid position encoding table """ + + def cal_angle(position, hid_idx): + return position / np.power(10000, hid_idx / float(d_hid / 2 - 1)) + + def get_posi_angle_vec(position): + return [cal_angle(position, hid_j) for hid_j in range(d_hid // 2)] + + scaled_time_table = np.array( + [get_posi_angle_vec(pos_i + 1) for pos_i in range(n_position)]) + + sinusoid_table = np.zeros((n_position, d_hid)) + sinusoid_table[:, :d_hid // 2] = np.sin(scaled_time_table) + sinusoid_table[:, d_hid // 2:] = np.cos(scaled_time_table) + + if padding_idx is not None: + # zero vector for padding dimension + sinusoid_table[padding_idx] = 0.0 + + return torch.FloatTensor(sinusoid_table) + + +class DurSinusoidalPositionEncoder(nn.Module): + + def __init__(self, depth, outputs_per_step): + super(DurSinusoidalPositionEncoder, self).__init__() + + self.depth = depth + self.outputs_per_step = outputs_per_step + + inv_timescales = [ + np.power(10000, 2 * (hid_idx // 2) / depth) + for hid_idx in range(depth) + ] + self.inv_timescales = nn.Parameter( + torch.FloatTensor(inv_timescales), requires_grad=False) + + def forward(self, durations, masks=None): + reps = (durations + 0.5).long() + output_lens = reps.sum(dim=1) + max_len = output_lens.max() + reps_cumsum = torch.cumsum( + F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] + range_ = torch.arange(max_len).to(durations.device)[None, :, None] + mult = ((reps_cumsum[:, :, :-1] <= range_) + & (reps_cumsum[:, :, 1:] > range_)) # yapf:disable + mult = mult.float() + offsets = torch.matmul(mult, + reps_cumsum[:, + 0, :-1].unsqueeze(-1)).squeeze(-1) + dur_pos = range_[:, :, 0] - offsets + 1 + + if masks is not None: + assert masks.size(1) == dur_pos.size(1) + dur_pos = dur_pos.masked_fill(masks, 0.0) + + seq_len = dur_pos.size(1) + padding = self.outputs_per_step - int(seq_len) % self.outputs_per_step + if (padding < self.outputs_per_step): + dur_pos = F.pad(dur_pos, (0, padding, 0, 0), value=0.0) + + position_embedding = dur_pos[:, :, None] / self.inv_timescales[None, + None, :] + position_embedding[:, :, 0::2] = torch.sin(position_embedding[:, :, + 0::2]) + position_embedding[:, :, 1::2] = torch.cos(position_embedding[:, :, + 1::2]) + + return position_embedding diff --git a/modelscope/models/audio/tts/models/utils/__init__.py b/modelscope/models/audio/tts/models/utils/__init__.py new file mode 100644 index 00000000..e07f08ea --- /dev/null +++ b/modelscope/models/audio/tts/models/utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .utils import * # noqa F403 diff --git a/modelscope/models/audio/tts/models/utils/utils.py b/modelscope/models/audio/tts/models/utils/utils.py new file mode 100755 index 00000000..17ac8aee --- /dev/null +++ b/modelscope/models/audio/tts/models/utils/utils.py @@ -0,0 +1,136 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import glob +import os +import shutil + +import matplotlib +import matplotlib.pylab as plt +import torch + +matplotlib.use('Agg') + + +class AttrDict(dict): + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow( + spectrogram, aspect='auto', origin='lower', interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def plot_alignment(alignment, info=None): + fig, ax = plt.subplots() + im = ax.imshow( + alignment, aspect='auto', origin='lower', interpolation='none') + fig.colorbar(im, ax=ax) + xlabel = 'Input timestep' + if info is not None: + xlabel += '\t' + info + plt.xlabel(xlabel) + plt.ylabel('Output timestep') + fig.canvas.draw() + plt.close() + + return fig + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + checkpoint_dict = torch.load(filepath, map_location=device) + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + torch.save(obj, filepath) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????.pkl') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ValueWindow(): + + def __init__(self, window_size=100): + self._window_size = window_size + self._values = [] + + def append(self, x): + self._values = self._values[-(self._window_size - 1):] + [x] + + @property + def sum(self): + return sum(self._values) + + @property + def count(self): + return len(self._values) + + @property + def average(self): + return self.sum / max(1, self.count) + + def reset(self): + self._values = [] + + +def get_model_size(model): + param_num = sum([p.numel() for p in model.parameters() if p.requires_grad]) + param_size = param_num * 4 / 1024 / 1024 + return param_size + + +def get_grad_norm(model): + total_norm = 0 + params = [ + p for p in model.parameters() if p.grad is not None and p.requires_grad + ] + for p in params: + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.item()**2 + total_norm = total_norm**0.5 + return total_norm + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(mean, std) + + +def get_mask_from_lengths(lengths, max_len=None): + batch_size = lengths.shape[0] + if max_len is None: + max_len = torch.max(lengths).item() + + ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, + -1).to(lengths.device) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + + return mask diff --git a/modelscope/models/audio/tts/sambert_hifi.py b/modelscope/models/audio/tts/sambert_hifi.py new file mode 100644 index 00000000..a9b55795 --- /dev/null +++ b/modelscope/models/audio/tts/sambert_hifi.py @@ -0,0 +1,97 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import os +import zipfile + +import json +import numpy as np + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.audio.tts_exceptions import ( + TtsFrontendInitializeFailedException, + TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationException, + TtsVoiceNotExistsException) +from modelscope.utils.constant import Tasks +from .voice import Voice + +__all__ = ['SambertHifigan'] + + +@MODELS.register_module( + Tasks.text_to_speech, module_name=Models.sambert_hifigan) +class SambertHifigan(Model): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + if 'am' not in kwargs: + raise TtsModelConfigurationException( + 'modelscope error: configuration model field missing am!') + if 'vocoder' not in kwargs: + raise TtsModelConfigurationException( + 'modelscope error: configuration model field missing vocoder!') + if 'lang_type' not in kwargs: + raise TtsModelConfigurationException( + 'modelscope error: configuration model field missing lang_type!' + ) + am_cfg = kwargs['am'] + voc_cfg = kwargs['vocoder'] + # initialize frontend + import ttsfrd + frontend = ttsfrd.TtsFrontendEngine() + zip_file = os.path.join(model_dir, 'resource.zip') + self.__res_path = os.path.join(model_dir, 'resource') + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(model_dir) + if not frontend.initialize(self.__res_path): + raise TtsFrontendInitializeFailedException( + 'modelscope error: resource invalid: {}'.format( + self.__res_path)) + if not frontend.set_lang_type(kwargs['lang_type']): + raise TtsFrontendLanguageTypeInvalidException( + 'modelscope error: language type invalid: {}'.format( + kwargs['lang_type'])) + self.__frontend = frontend + zip_file = os.path.join(model_dir, 'voices.zip') + self.__voice_path = os.path.join(model_dir, 'voices') + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(model_dir) + voice_cfg_path = os.path.join(self.__voice_path, 'voices.json') + with open(voice_cfg_path, 'r') as f: + voice_cfg = json.load(f) + if 'voices' not in voice_cfg: + raise TtsModelConfigurationException( + 'modelscope error: voices invalid') + self.__voice = {} + for name in voice_cfg['voices']: + voice_path = os.path.join(self.__voice_path, name) + if not os.path.exists(voice_path): + continue + self.__voice[name] = Voice(name, voice_path, am_cfg, voc_cfg) + if voice_cfg['voices']: + self.__default_voice_name = voice_cfg['voices'][0] + else: + raise TtsVoiceNotExistsException( + 'modelscope error: voices is empty in voices.json') + + def __synthesis_one_sentences(self, voice_name, text): + if voice_name not in self.__voice: + raise TtsVoiceNotExistsException( + f'modelscope error: Voice {voice_name} not exists') + return self.__voice[voice_name].forward(text) + + def forward(self, text: str, voice_name: str = None): + voice = self.__default_voice_name + if voice_name is not None: + voice = voice_name + result = self.__frontend.gen_tacotron_symbols(text) + texts = [s for s in result.splitlines() if s != ''] + audio_total = np.empty((0), dtype='int16') + for line in texts: + line = line.strip().split('\t') + audio = self.__synthesis_one_sentences(voice, line[1]) + audio_total = np.append(audio_total, audio, axis=0) + return audio_total diff --git a/modelscope/models/audio/tts/voice.py b/modelscope/models/audio/tts/voice.py new file mode 100644 index 00000000..b7240088 --- /dev/null +++ b/modelscope/models/audio/tts/voice.py @@ -0,0 +1,135 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import pickle as pkl +from threading import Lock + +import json +import numpy as np +import torch + +from modelscope.utils.audio.tts_exceptions import \ + TtsModelConfigurationException +from modelscope.utils.constant import ModelFile, Tasks +from .models.datasets.units import KanTtsLinguisticUnit +from .models.models.hifigan import Generator +from .models.models.sambert import KanTtsSAMBERT +from .models.utils import (AttrDict, build_env, init_weights, load_checkpoint, + plot_spectrogram, save_checkpoint, scan_checkpoint) + +MAX_WAV_VALUE = 32768.0 + + +class Voice: + + def __init__(self, voice_name, voice_path, am_config, voc_config): + self.__voice_name = voice_name + self.__voice_path = voice_path + self.__am_config = AttrDict(**am_config) + self.__voc_config = AttrDict(**voc_config) + self.__model_loaded = False + self.__lock = Lock() + if 'am' not in self.__am_config: + raise TtsModelConfigurationException( + 'modelscope error: am configuration invalid') + if 'linguistic_unit' not in self.__am_config: + raise TtsModelConfigurationException( + 'modelscope error: am configuration invalid') + self.__am_lingustic_unit_config = self.__am_config['linguistic_unit'] + + def __load_am(self): + local_am_ckpt_path = os.path.join(self.__voice_path, 'am') + self.__am_ckpt_path = os.path.join(local_am_ckpt_path, + ModelFile.TORCH_MODEL_BIN_FILE) + has_mask = True + if 'has_mask' in self.__am_lingustic_unit_config: + has_mask = self.__am_lingustic_unit_config.has_mask + self.__ling_unit = KanTtsLinguisticUnit( + self.__am_lingustic_unit_config, self.__voice_path, has_mask) + self.__am_net = KanTtsSAMBERT(self.__am_config, + self.__ling_unit.get_unit_size()).to( + self.__device) + state_dict_g = {} + try: + state_dict_g = load_checkpoint(self.__am_ckpt_path, self.__device) + except RuntimeError: + with open(self.__am_ckpt_path, 'rb') as f: + pth_var_dict = pkl.load(f) + state_dict_g['fsnet'] = { + k: torch.FloatTensor(v) + for k, v in pth_var_dict['fsnet'].items() + } + self.__am_net.load_state_dict(state_dict_g['fsnet'], strict=False) + self.__am_net.eval() + + def __load_vocoder(self): + local_voc_ckpy_path = os.path.join(self.__voice_path, 'vocoder') + self.__voc_ckpt_path = os.path.join(local_voc_ckpy_path, + ModelFile.TORCH_MODEL_BIN_FILE) + self.__generator = Generator(self.__voc_config).to(self.__device) + state_dict_g = load_checkpoint(self.__voc_ckpt_path, self.__device) + self.__generator.load_state_dict(state_dict_g['generator']) + self.__generator.eval() + self.__generator.remove_weight_norm() + + def __am_forward(self, symbol_seq): + with self.__lock: + with torch.no_grad(): + inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( + symbol_seq) + inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to( + self.__device) + inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to( + self.__device) + inputs_syllable = torch.from_numpy( + inputs_feat_lst[2]).long().to(self.__device) + inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to( + self.__device) + inputs_ling = torch.stack( + [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], + dim=-1).unsqueeze(0) + inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to( + self.__device).unsqueeze(0) + inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to( + self.__device).unsqueeze(0) + inputs_len = torch.zeros(1).to(self.__device).long( + ) + inputs_emo.size(1) - 1 # minus 1 for "~" + res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1], + inputs_spk[:, :-1], inputs_len) + postnet_outputs = res['postnet_outputs'] + LR_length_rounded = res['LR_length_rounded'] + valid_length = int(LR_length_rounded[0].item()) + postnet_outputs = postnet_outputs[ + 0, :valid_length, :].cpu().numpy() + return postnet_outputs + + def __vocoder_forward(self, melspec): + dim0 = list(melspec.shape)[-1] + if dim0 != self.__voc_config.num_mels: + raise TtsVocoderMelspecShapeMismatchException( + 'modelscope error: input melspec mismatch require {} but {}'. + format(self.__voc_config.num_mels, dim0)) + with torch.no_grad(): + x = melspec.T + x = torch.FloatTensor(x).to(self.__device) + if len(x.shape) == 2: + x = x.unsqueeze(0) + y_g_hat = self.__generator(x) + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype('int16') + return audio + + def forward(self, symbol_seq): + with self.__lock: + if not self.__model_loaded: + torch.manual_seed(self.__am_config.seed) + if torch.cuda.is_available(): + torch.manual_seed(self.__am_config.seed) + self.__device = torch.device('cuda') + else: + self.__device = torch.device('cpu') + self.__load_am() + self.__load_vocoder() + self.__model_loaded = True + return self.__vocoder_forward(self.__am_forward(symbol_seq)) diff --git a/modelscope/models/base/__init__.py b/modelscope/models/base/__init__.py new file mode 100644 index 00000000..8c47ecaf --- /dev/null +++ b/modelscope/models/base/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .base_head import * # noqa F403 +from .base_model import * # noqa F403 +from .base_torch_head import * # noqa F403 +from .base_torch_model import * # noqa F403 diff --git a/modelscope/models/base/base_head.py b/modelscope/models/base/base_head.py new file mode 100644 index 00000000..11bda32f --- /dev/null +++ b/modelscope/models/base/base_head.py @@ -0,0 +1,41 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from abc import ABC, abstractmethod +from typing import Any, Dict, Union + +from modelscope.models.base.base_model import Model +from modelscope.utils.config import ConfigDict +from modelscope.utils.logger import get_logger + +logger = get_logger() + +Tensor = Union['torch.Tensor', 'tf.Tensor'] +Input = Union[Dict[str, Tensor], Model] + + +class Head(ABC): + """ + The head base class is for the tasks head method definition + + """ + + def __init__(self, **kwargs): + self.config = ConfigDict(kwargs) + + @abstractmethod + def forward(self, *args, **kwargs) -> Dict[str, Any]: + """ + This method will use the output from backbone model to do any + downstream tasks. Recieve The output from backbone model. + + Returns (Dict[str, Any]): The output from downstream task. + """ + pass + + @abstractmethod + def compute_loss(self, *args, **kwargs) -> Dict[str, Any]: + """ + compute loss for head during the finetuning. + + Returns (Dict[str, Any]): The loss dict + """ + pass diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py new file mode 100644 index 00000000..721478c3 --- /dev/null +++ b/modelscope/models/base/base_model.py @@ -0,0 +1,174 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Union + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.builder import MODELS, build_model +from modelscope.utils.checkpoint import save_checkpoint, save_pretrained +from modelscope.utils.config import Config +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile, Tasks +from modelscope.utils.device import verify_device +from modelscope.utils.logger import get_logger + +logger = get_logger() + +Tensor = Union['torch.Tensor', 'tf.Tensor'] + + +class Model(ABC): + + def __init__(self, model_dir, *args, **kwargs): + self.model_dir = model_dir + device_name = kwargs.get('device', 'gpu') + verify_device(device_name) + self._device_name = device_name + + def __call__(self, *args, **kwargs) -> Dict[str, Any]: + return self.postprocess(self.forward(*args, **kwargs)) + + @abstractmethod + def forward(self, *args, **kwargs) -> Dict[str, Any]: + """ + Run the forward pass for a model. + + Returns: + Dict[str, Any]: output from the model forward pass + """ + pass + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """ Model specific postprocess and convert model output to + standard model outputs. + + Args: + inputs: input data + + Return: + dict of results: a dict containing outputs of model, each + output should have the standard output name. + """ + return inputs + + @classmethod + def _instantiate(cls, **kwargs): + """ Define the instantiation method of a model,default method is by + calling the constructor. Note that in the case of no loading model + process in constructor of a task model, a load_model method is + added, and thus this method is overloaded + """ + return cls(**kwargs) + + @classmethod + def from_pretrained(cls, + model_name_or_path: str, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cfg_dict: Config = None, + device: str = None, + **kwargs): + """Instantiate a model from local directory or remote model repo. Note + that when loading from remote, the model revision can be specified. + + Args: + model_name_or_path(str): A model dir or a model id to be loaded + revision(str, `optional`): The revision used when the model_name_or_path is + a model id of the remote hub. default `master`. + cfg_dict(Config, `optional`): An optional model config. If provided, it will replace + the config read out of the `model_name_or_path` + device(str, `optional`): The device to load the model. + **kwargs: + task(str, `optional`): The `Tasks` enumeration value to replace the task value + read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not + equal to the model saved. + For example, load a `backbone` into a `text-classification` model. + Other kwargs will be directly fed into the `model` key, to replace the default configs. + Returns: + A model instance. + + Examples: + >>> from modelscope.models import Model + >>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification') + """ + prefetched = kwargs.get('model_prefetched') + if prefetched is not None: + kwargs.pop('model_prefetched') + + if osp.exists(model_name_or_path): + local_model_dir = model_name_or_path + else: + if prefetched is True: + raise RuntimeError( + 'Expecting model is pre-fetched locally, but is not found.' + ) + local_model_dir = snapshot_download(model_name_or_path, revision) + logger.info(f'initialize model from {local_model_dir}') + if cfg_dict is not None: + cfg = cfg_dict + else: + cfg = Config.from_file( + osp.join(local_model_dir, ModelFile.CONFIGURATION)) + task_name = cfg.task + if 'task' in kwargs: + task_name = kwargs.pop('task') + model_cfg = cfg.model + if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): + model_cfg.type = model_cfg.model_type + model_cfg.model_dir = local_model_dir + for k, v in kwargs.items(): + model_cfg[k] = v + if device is not None: + model_cfg.device = device + model = build_model( + model_cfg, task_name=task_name, default_args=kwargs) + else: + model = build_model( + model_cfg, task_name=task_name, default_args=kwargs) + + # dynamically add pipeline info to model for pipeline inference + if hasattr(cfg, 'pipeline'): + model.pipeline = cfg.pipeline + + if not hasattr(model, 'cfg'): + model.cfg = cfg + + model.name = model_name_or_path + return model + + def save_pretrained(self, + target_folder: Union[str, os.PathLike], + save_checkpoint_names: Union[str, List[str]] = None, + save_function: Callable = save_checkpoint, + config: Optional[dict] = None, + **kwargs): + """save the pretrained model, its configuration and other related files to a directory, + so that it can be re-loaded + + Args: + target_folder (Union[str, os.PathLike]): + Directory to which to save. Will be created if it doesn't exist. + + save_checkpoint_names (Union[str, List[str]]): + The checkpoint names to be saved in the target_folder + + save_function (Callable, optional): + The function to use to save the state dictionary. + + config (Optional[dict], optional): + The config for the configuration.json, might not be identical with model.config + + """ + if config is None and hasattr(self, 'cfg'): + config = self.cfg + assert config is not None, 'Cannot save the model because the model config is empty.' + if isinstance(config, Config): + config = config.to_dict() + if 'preprocessor' in config and config['preprocessor'] is not None: + if 'mode' in config['preprocessor']: + config['preprocessor']['mode'] = 'inference' + elif 'val' in config['preprocessor'] and 'mode' in config[ + 'preprocessor']['val']: + config['preprocessor']['val']['mode'] = 'inference' + + save_pretrained(self, target_folder, save_checkpoint_names, + save_function, config, **kwargs) diff --git a/modelscope/models/base/base_torch_head.py b/modelscope/models/base/base_torch_head.py new file mode 100644 index 00000000..faee4296 --- /dev/null +++ b/modelscope/models/base/base_torch_head.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch + +from modelscope.models.base.base_head import Head +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + + +class TorchHead(Head, torch.nn.Module): + """ Base head interface for pytorch + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + torch.nn.Module.__init__(self) + + def forward(self, *args, **kwargs) -> Dict[str, Any]: + raise NotImplementedError + + def compute_loss(self, *args, **kwargs) -> Dict[str, Any]: + raise NotImplementedError diff --git a/modelscope/models/base/base_torch_model.py b/modelscope/models/base/base_torch_model.py new file mode 100644 index 00000000..3c99a1f2 --- /dev/null +++ b/modelscope/models/base/base_torch_model.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import torch +from torch import nn + +from modelscope.utils.file_utils import func_receive_dict_inputs +from modelscope.utils.logger import get_logger +from .base_model import Model + +logger = get_logger(__name__) + + +class TorchModel(Model, torch.nn.Module): + """ Base model interface for pytorch + + """ + + def __init__(self, model_dir=None, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + torch.nn.Module.__init__(self) + + def __call__(self, *args, **kwargs) -> Dict[str, Any]: + # Adapting a model with only one dict arg, and the arg name must be input or inputs + if func_receive_dict_inputs(self.forward): + return self.postprocess(self.forward(args[0], **kwargs)) + else: + return self.postprocess(self.forward(*args, **kwargs)) + + def forward(self, *args, **kwargs) -> Dict[str, Any]: + raise NotImplementedError + + def post_init(self): + """ + A method executed at the end of each model initialization, to execute code that needs the model's + modules properly initialized (such as weight initialization). + """ + self.init_weights() + + def init_weights(self): + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) diff --git a/modelscope/models/builder.py b/modelscope/models/builder.py new file mode 100644 index 00000000..2804c6c7 --- /dev/null +++ b/modelscope/models/builder.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.config import ConfigDict +from modelscope.utils.constant import Tasks +from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule +from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg + +MODELS = Registry('models') +BACKBONES = MODELS +HEADS = Registry('heads') + +modules = LazyImportModule.AST_INDEX[INDEX_KEY] +for module_index in list(modules.keys()): + if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': + modules[(MODELS.name.upper(), module_index[1], + module_index[2])] = modules[module_index] + + +def build_model(cfg: ConfigDict, + task_name: str = None, + default_args: dict = None): + """ build model given model config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for model object. + task_name (str, optional): task name, refer to + :obj:`Tasks` for more details + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, MODELS, group_key=task_name, default_args=default_args) + + +def build_backbone(cfg: ConfigDict, default_args: dict = None): + """ build backbone given backbone config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for backbone object. + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, BACKBONES, group_key=Tasks.backbone, default_args=default_args) + + +def build_head(cfg: ConfigDict, + task_name: str = None, + default_args: dict = None): + """ build head given config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for head object. + task_name (str, optional): task name, refer to + :obj:`Tasks` for more details + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, HEADS, group_key=task_name, default_args=default_args) diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py new file mode 100644 index 00000000..64039863 --- /dev/null +++ b/modelscope/models/cv/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +# yapf: disable +from . import (action_recognition, animal_recognition, body_2d_keypoints, + body_3d_keypoints, cartoon, cmdssl_video_embedding, + crowd_counting, face_2d_keypoints, face_detection, + face_generation, human_wholebody_keypoint, image_classification, + image_color_enhance, image_colorization, image_denoise, + image_inpainting, image_instance_segmentation, + image_panoptic_segmentation, image_portrait_enhancement, + image_reid_person, image_semantic_segmentation, + image_to_image_generation, image_to_image_translation, + movie_scene_segmentation, object_detection, + product_retrieval_embedding, realtime_object_detection, + referring_video_object_segmentation, salient_detection, + shop_segmentation, super_resolution, + video_single_object_tracking, video_summarization, virual_tryon) + +# yapf: enable diff --git a/modelscope/models/cv/action_detection/__init__.py b/modelscope/models/cv/action_detection/__init__.py new file mode 100644 index 00000000..fedbe19c --- /dev/null +++ b/modelscope/models/cv/action_detection/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .action_detection_onnx import ActionDetONNX + +else: + _import_structure = {'action_detection_onnx': ['ActionDetONNX']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/action_detection/action_detection_onnx.py b/modelscope/models/cv/action_detection/action_detection_onnx.py new file mode 100644 index 00000000..223d77f7 --- /dev/null +++ b/modelscope/models/cv/action_detection/action_detection_onnx.py @@ -0,0 +1,182 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +import shutil +import subprocess +import uuid + +import cv2 +import numpy as np +import onnxruntime as rt + +from modelscope.models import Model +from modelscope.utils.constant import Devices +from modelscope.utils.device import verify_device + + +class ActionDetONNX(Model): + + def __init__(self, model_dir, config, *args, **kwargs): + super().__init__(self, model_dir, *args, **kwargs) + model_file = osp.join(config['model_file']) + device_type, device_id = verify_device(self._device_name) + options = rt.SessionOptions() + options.intra_op_num_threads = 1 + options.inter_op_num_threads = 1 + if device_type == Devices.gpu: + sess = rt.InferenceSession( + model_file, + providers=['CUDAExecutionProvider'], + sess_options=options, + provider_options=[{ + 'device_id': device_id + }]) + else: + sess = rt.InferenceSession( + model_file, + providers=['CPUExecutionProvider'], + sess_options=options) + self.input_name = sess.get_inputs()[0].name + self.sess = sess + self.num_stride = len(config['fpn_strides']) + self.score_thresh = np.asarray( + config['pre_nms_thresh'], dtype='float32').reshape((1, -1)) + self.size_divisibility = config['size_divisibility'] + self.nms_threshold = config['nms_thresh'] + self.tmp_dir = config['tmp_dir'] + self.temporal_stride = config['step'] + self.input_data_type = config['input_type'] + self.action_names = config['action_names'] + self.video_length_limit = config['video_length_limit'] + + def resize_box(self, det, height, width, scale_h, scale_w): + bboxs = det[0] + bboxs[:, [0, 2]] *= scale_w + bboxs[:, [1, 3]] *= scale_h + bboxs[:, [0, 2]] = bboxs[:, [0, 2]].clip(0, width - 1) + bboxs[:, [1, 3]] = bboxs[:, [1, 3]].clip(0, height - 1) + result = { + 'boxes': bboxs.round().astype('int32').tolist(), + 'scores': det[1].tolist(), + 'labels': [self.action_names[i] for i in det[2].tolist()] + } + return result + + def parse_frames(self, frame_names): + imgs = [cv2.imread(name)[:, :, ::-1] for name in frame_names] + imgs = np.stack(imgs).astype(self.input_data_type).transpose( + (3, 0, 1, 2)) # c,t,h,w + imgs = imgs[None] + return imgs + + def forward_img(self, imgs, h, w): + pred = self.sess.run(None, { + self.input_name: imgs, + 'height': np.asarray(h), + 'width': np.asarray(w) + }) + dets = self.post_nms( + pred, + score_threshold=self.score_thresh, + nms_threshold=self.nms_threshold) + return dets + + def forward_video(self, video_name, scale): + min_size, max_size = self._get_sizes(scale) + + tmp_dir = osp.join( + self.tmp_dir, + str(uuid.uuid1()) + '_' + osp.basename(video_name)[:-4]) + if osp.exists(tmp_dir): + shutil.rmtree(tmp_dir) + os.makedirs(tmp_dir) + frame_rate = 2 + cmd = f'ffmpeg -y -loglevel quiet -ss 0 -t {self.video_length_limit}' + \ + f' -i {video_name} -r {frame_rate} -f image2 {tmp_dir}/%06d.jpg' + + cmd = cmd.split(' ') + subprocess.call(cmd) + + frame_names = [ + osp.join(tmp_dir, name) for name in sorted(os.listdir(tmp_dir)) + if name.endswith('.jpg') + ] + frame_names = [ + frame_names[i:i + frame_rate * 2] + for i in range(0, + len(frame_names) - frame_rate * 2 + 1, frame_rate + * self.temporal_stride) + ] + timestamp = list( + range(1, + len(frame_names) * self.temporal_stride, + self.temporal_stride)) + batch_imgs = [self.parse_frames(names) for names in frame_names] + shutil.rmtree(tmp_dir) + + N, _, T, H, W = batch_imgs[0].shape + scale_min = min_size / min(H, W) + h, w = min(int(scale_min * H), + max_size), min(int(scale_min * W), max_size) + h = round(h / self.size_divisibility) * self.size_divisibility + w = round(w / self.size_divisibility) * self.size_divisibility + scale_h, scale_w = H / h, W / w + + results = [] + for imgs in batch_imgs: + det = self.forward_img(imgs, h, w) + det = self.resize_box(det[0], H, W, scale_h, scale_w) + results.append(det) + results = [{ + 'timestamp': t, + 'actions': res + } for t, res in zip(timestamp, results)] + return results + + def forward(self, video_name): + return self.forward_video(video_name, scale=1) + + def post_nms(self, pred, score_threshold, nms_threshold=0.3): + pred_bboxes, pred_scores = pred + N = len(pred_bboxes) + dets = [] + for i in range(N): + bboxes, scores = pred_bboxes[i], pred_scores[i] + candidate_inds = scores > score_threshold + scores = scores[candidate_inds] + candidate_nonzeros = candidate_inds.nonzero() + bboxes = bboxes[candidate_nonzeros[0]] + labels = candidate_nonzeros[1] + keep = self._nms(bboxes, scores, labels, nms_threshold) + bbox = bboxes[keep] + score = scores[keep] + label = labels[keep] + dets.append((bbox, score, label)) + return dets + + def _nms(self, boxes, scores, idxs, nms_threshold): + if len(boxes) == 0: + return [] + max_coordinate = boxes.max() + offsets = idxs * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None].astype('float32') + boxes_for_nms[:, 2] = boxes_for_nms[:, 2] - boxes_for_nms[:, 0] + boxes_for_nms[:, 3] = boxes_for_nms[:, 3] - boxes_for_nms[:, 1] + keep = cv2.dnn.NMSBoxes( + boxes_for_nms.tolist(), + scores.tolist(), + score_threshold=0, + nms_threshold=nms_threshold) + if len(keep.shape) == 2: + keep = np.squeeze(keep, 1) + return keep + + def _get_sizes(self, scale): + if scale == 1: + min_size, max_size = 512, 896 + elif scale == 2: + min_size, max_size = 768, 1280 + else: + min_size, max_size = 1024, 1792 + return min_size, max_size diff --git a/modelscope/models/cv/action_recognition/__init__.py b/modelscope/models/cv/action_recognition/__init__.py new file mode 100644 index 00000000..5e9dc310 --- /dev/null +++ b/modelscope/models/cv/action_recognition/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .models import BaseVideoModel + from .tada_convnext import TadaConvNeXt + from .temporal_patch_shift_transformer import PatchShiftTransformer + +else: + _import_structure = { + 'models': ['BaseVideoModel'], + 'tada_convnext': ['TadaConvNeXt'], + 'temporal_patch_shift_transformer': ['PatchShiftTransformer'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/action_recognition/models.py b/modelscope/models/cv/action_recognition/models.py new file mode 100644 index 00000000..f16805fb --- /dev/null +++ b/modelscope/models/cv/action_recognition/models.py @@ -0,0 +1,134 @@ +# The implementation is also open-sourced by the authors, +# and available at https://github.com/alibaba-mmai-research/TAdaConv +# Copyright 2021-2022 The Alibaba FVI Team Authors. All rights reserved. +import torch.nn as nn + +from .s3dg import Inception3D +from .tada_convnext import TadaConvNeXt + + +class BaseVideoModel(nn.Module): + """ + Standard video model. + The model is divided into the backbone and the head, where the backbone + extracts features and the head performs classification. + + The backbones can be defined in model/base/backbone.py or anywhere else + as long as the backbone is registered by the BACKBONE_REGISTRY. + The heads can be defined in model/module_zoo/heads/ or anywhere else + as long as the head is registered by the HEAD_REGISTRY. + + The registries automatically finds the registered modules and construct + the base video model. + """ + + def __init__(self, cfg): + """ + Args: + cfg (Config): global config object. + """ + super(BaseVideoModel, self).__init__() + # the backbone is created according to meta-architectures + # defined in models/base/backbone.py + if cfg.MODEL.NAME == 'ConvNeXt_tiny': + self.backbone = TadaConvNeXt(cfg) + elif cfg.MODEL.NAME == 'S3DG': + self.backbone = Inception3D(cfg) + else: + error_str = 'backbone {} is not supported, ConvNeXt_tiny or S3DG is supported'.format( + cfg.MODEL.NAME) + raise NotImplementedError(error_str) + + # the head is created according to the heads + # defined in models/module_zoo/heads + if cfg.VIDEO.HEAD.NAME == 'BaseHead': + self.head = BaseHead(cfg) + elif cfg.VIDEO.HEAD.NAME == 'AvgHead': + self.head = AvgHead(cfg) + else: + error_str = 'head {} is not supported, BaseHead or AvgHead is supported'.format( + cfg.VIDEO.HEAD.NAME) + raise NotImplementedError(error_str) + + def forward(self, x): + x = self.backbone(x) + x = self.head(x) + return x + + +class BaseHead(nn.Module): + """ + Constructs base head. + """ + + def __init__( + self, + cfg, + ): + """ + Args: + cfg (Config): global config object. + """ + super(BaseHead, self).__init__() + self.cfg = cfg + dim = cfg.VIDEO.BACKBONE.NUM_OUT_FEATURES + num_classes = cfg.VIDEO.HEAD.NUM_CLASSES + dropout_rate = cfg.VIDEO.HEAD.DROPOUT_RATE + activation_func = cfg.VIDEO.HEAD.ACTIVATION + self._construct_head(dim, num_classes, dropout_rate, activation_func) + + def _construct_head(self, dim, num_classes, dropout_rate, activation_func): + self.global_avg_pool = nn.AdaptiveAvgPool3d(1) + + if dropout_rate > 0.0: + self.dropout = nn.Dropout(dropout_rate) + + self.out = nn.Linear(dim, num_classes, bias=True) + + if activation_func == 'softmax': + self.activation = nn.Softmax(dim=-1) + elif activation_func == 'sigmoid': + self.activation = nn.Sigmoid() + else: + raise NotImplementedError('{} is not supported as an activation' + 'function.'.format(activation_func)) + + def forward(self, x): + if len(x.shape) == 5: + x = self.global_avg_pool(x) + # (N, C, T, H, W) -> (N, T, H, W, C). + x = x.permute((0, 2, 3, 4, 1)) + if hasattr(self, 'dropout'): + out = self.dropout(x) + else: + out = x + out = self.out(out) + out = self.activation(out) + out = out.view(out.shape[0], -1) + return out, x.view(x.shape[0], -1) + + +class AvgHead(nn.Module): + """ + Constructs base head. + """ + + def __init__( + self, + cfg, + ): + """ + Args: + cfg (Config): global config object. + """ + super(AvgHead, self).__init__() + self.cfg = cfg + self.global_avg_pool = nn.AdaptiveAvgPool3d(1) + + def forward(self, x): + if len(x.shape) == 5: + x = self.global_avg_pool(x) + # (N, C, T, H, W) -> (N, T, H, W, C). + x = x.permute((0, 2, 3, 4, 1)) + out = x.view(x.shape[0], -1) + return out, x.view(x.shape[0], -1) diff --git a/modelscope/models/cv/action_recognition/s3dg.py b/modelscope/models/cv/action_recognition/s3dg.py new file mode 100644 index 00000000..46e76892 --- /dev/null +++ b/modelscope/models/cv/action_recognition/s3dg.py @@ -0,0 +1,304 @@ +# The implementation is adopted from https://github.com/TengdaHan/CoCLR, +# made pubicly available under the Apache License, Version 2.0 at https://github.com/TengdaHan/CoCLR +# Copyright 2021-2022 The Alibaba FVI Team Authors. All rights reserved. +import torch +import torch.nn as nn + + +class InceptionBaseConv3D(nn.Module): + """ + Constructs basic inception 3D conv. + Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. + """ + + def __init__(self, + cfg, + in_planes, + out_planes, + kernel_size, + stride, + padding=0): + super(InceptionBaseConv3D, self).__init__() + self.conv = nn.Conv3d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False) + self.bn = nn.BatchNorm3d(out_planes) + self.relu = nn.ReLU(inplace=True) + + # init + self.conv.weight.data.normal_( + mean=0, std=0.01) # original s3d is truncated normal within 2 std + self.bn.weight.data.fill_(1) + self.bn.bias.data.zero_() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class InceptionBlock3D(nn.Module): + """ + Element constructing the S3D/S3DG. + See models/base/backbone.py L99-186. + + Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. + """ + + def __init__(self, cfg, in_planes, out_planes): + super(InceptionBlock3D, self).__init__() + + _gating = cfg.VIDEO.BACKBONE.BRANCH.GATING + + assert len(out_planes) == 6 + assert isinstance(out_planes, list) + + [ + num_out_0_0a, num_out_1_0a, num_out_1_0b, num_out_2_0a, + num_out_2_0b, num_out_3_0b + ] = out_planes + + self.branch0 = nn.Sequential( + InceptionBaseConv3D( + cfg, in_planes, num_out_0_0a, kernel_size=1, stride=1), ) + self.branch1 = nn.Sequential( + InceptionBaseConv3D( + cfg, in_planes, num_out_1_0a, kernel_size=1, stride=1), + STConv3d( + cfg, + num_out_1_0a, + num_out_1_0b, + kernel_size=3, + stride=1, + padding=1), + ) + self.branch2 = nn.Sequential( + InceptionBaseConv3D( + cfg, in_planes, num_out_2_0a, kernel_size=1, stride=1), + STConv3d( + cfg, + num_out_2_0a, + num_out_2_0b, + kernel_size=3, + stride=1, + padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), + InceptionBaseConv3D( + cfg, in_planes, num_out_3_0b, kernel_size=1, stride=1), + ) + + self.out_channels = sum( + [num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b]) + + self.gating = _gating + if _gating: + self.gating_b0 = SelfGating(num_out_0_0a) + self.gating_b1 = SelfGating(num_out_1_0b) + self.gating_b2 = SelfGating(num_out_2_0b) + self.gating_b3 = SelfGating(num_out_3_0b) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + if self.gating: + x0 = self.gating_b0(x0) + x1 = self.gating_b1(x1) + x2 = self.gating_b2(x2) + x3 = self.gating_b3(x3) + + out = torch.cat((x0, x1, x2, x3), 1) + + return out + + +class SelfGating(nn.Module): + + def __init__(self, input_dim): + super(SelfGating, self).__init__() + self.fc = nn.Linear(input_dim, input_dim) + + def forward(self, input_tensor): + """Feature gating as used in S3D-G""" + spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4]) + weights = self.fc(spatiotemporal_average) + weights = torch.sigmoid(weights) + return weights[:, :, None, None, None] * input_tensor + + +class STConv3d(nn.Module): + """ + Element constructing the S3D/S3DG. + See models/base/backbone.py L99-186. + + Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. + """ + + def __init__(self, + cfg, + in_planes, + out_planes, + kernel_size, + stride, + padding=0): + super(STConv3d, self).__init__() + if isinstance(stride, tuple): + t_stride = stride[0] + stride = stride[-1] + else: # int + t_stride = stride + + self.bn_mmt = cfg.BN.MOMENTUM + self.bn_eps = float(cfg.BN.EPS) + self._construct_branch(cfg, in_planes, out_planes, kernel_size, stride, + t_stride, padding) + + def _construct_branch(self, + cfg, + in_planes, + out_planes, + kernel_size, + stride, + t_stride, + padding=0): + self.conv1 = nn.Conv3d( + in_planes, + out_planes, + kernel_size=(1, kernel_size, kernel_size), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False) + self.conv2 = nn.Conv3d( + out_planes, + out_planes, + kernel_size=(kernel_size, 1, 1), + stride=(t_stride, 1, 1), + padding=(padding, 0, 0), + bias=False) + + self.bn1 = nn.BatchNorm3d( + out_planes, eps=self.bn_eps, momentum=self.bn_mmt) + self.bn2 = nn.BatchNorm3d( + out_planes, eps=self.bn_eps, momentum=self.bn_mmt) + self.relu = nn.ReLU(inplace=True) + + # init + self.conv1.weight.data.normal_( + mean=0, std=0.01) # original s3d is truncated normal within 2 std + self.conv2.weight.data.normal_( + mean=0, std=0.01) # original s3d is truncated normal within 2 std + self.bn1.weight.data.fill_(1) + self.bn1.bias.data.zero_() + self.bn2.weight.data.fill_(1) + self.bn2.bias.data.zero_() + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + return x + + +class Inception3D(nn.Module): + """ + Backbone architecture for I3D/S3DG. + Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. + """ + + def __init__(self, cfg): + """ + Args: + cfg (Config): global config object. + """ + super(Inception3D, self).__init__() + _input_channel = cfg.DATA.NUM_INPUT_CHANNELS + self._construct_backbone(cfg, _input_channel) + + def _construct_backbone(self, cfg, input_channel): + # ------------------- Block 1 ------------------- + self.Conv_1a = STConv3d( + cfg, input_channel, 64, kernel_size=7, stride=2, padding=3) + + self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112) + + # ------------------- Block 2 ------------------- + self.MaxPool_2a = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + self.Conv_2b = InceptionBaseConv3D( + cfg, 64, 64, kernel_size=1, stride=1) + self.Conv_2c = STConv3d( + cfg, 64, 192, kernel_size=3, stride=1, padding=1) + + self.block2 = nn.Sequential( + self.MaxPool_2a, # (64, 32, 56, 56) + self.Conv_2b, # (64, 32, 56, 56) + self.Conv_2c) # (192, 32, 56, 56) + + # ------------------- Block 3 ------------------- + self.MaxPool_3a = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + self.Mixed_3b = InceptionBlock3D( + cfg, in_planes=192, out_planes=[64, 96, 128, 16, 32, 32]) + self.Mixed_3c = InceptionBlock3D( + cfg, in_planes=256, out_planes=[128, 128, 192, 32, 96, 64]) + + self.block3 = nn.Sequential( + self.MaxPool_3a, # (192, 32, 28, 28) + self.Mixed_3b, # (256, 32, 28, 28) + self.Mixed_3c) # (480, 32, 28, 28) + + # ------------------- Block 4 ------------------- + self.MaxPool_4a = nn.MaxPool3d( + kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) + self.Mixed_4b = InceptionBlock3D( + cfg, in_planes=480, out_planes=[192, 96, 208, 16, 48, 64]) + self.Mixed_4c = InceptionBlock3D( + cfg, in_planes=512, out_planes=[160, 112, 224, 24, 64, 64]) + self.Mixed_4d = InceptionBlock3D( + cfg, in_planes=512, out_planes=[128, 128, 256, 24, 64, 64]) + self.Mixed_4e = InceptionBlock3D( + cfg, in_planes=512, out_planes=[112, 144, 288, 32, 64, 64]) + self.Mixed_4f = InceptionBlock3D( + cfg, in_planes=528, out_planes=[256, 160, 320, 32, 128, 128]) + + self.block4 = nn.Sequential( + self.MaxPool_4a, # (480, 16, 14, 14) + self.Mixed_4b, # (512, 16, 14, 14) + self.Mixed_4c, # (512, 16, 14, 14) + self.Mixed_4d, # (512, 16, 14, 14) + self.Mixed_4e, # (528, 16, 14, 14) + self.Mixed_4f) # (832, 16, 14, 14) + + # ------------------- Block 5 ------------------- + self.MaxPool_5a = nn.MaxPool3d( + kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)) + self.Mixed_5b = InceptionBlock3D( + cfg, in_planes=832, out_planes=[256, 160, 320, 32, 128, 128]) + self.Mixed_5c = InceptionBlock3D( + cfg, in_planes=832, out_planes=[384, 192, 384, 48, 128, 128]) + + self.block5 = nn.Sequential( + self.MaxPool_5a, # (832, 8, 7, 7) + self.Mixed_5b, # (832, 8, 7, 7) + self.Mixed_5c) # (1024, 8, 7, 7) + + def forward(self, x): + if isinstance(x, dict): + x = x['video'] + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + return x diff --git a/modelscope/models/cv/action_recognition/tada_convnext.py b/modelscope/models/cv/action_recognition/tada_convnext.py new file mode 100644 index 00000000..b1de7af8 --- /dev/null +++ b/modelscope/models/cv/action_recognition/tada_convnext.py @@ -0,0 +1,476 @@ +# The implementation is adopted from https://github.com/facebookresearch/ConvNeXt, +# made pubicly available under the MIT License at https://github.com/facebookresearch/ConvNeXt +# Copyright 2021-2022 The Alibaba FVI Team Authors. All rights reserved. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair, _triple + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """ + From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py. + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """ + From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py. + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class TadaConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, cfg + # in_chans=3, num_classes=1000, + # depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., + # layer_scale_init_value=1e-6, head_init_scale=1., + ): + super().__init__() + in_chans = cfg.VIDEO.BACKBONE.NUM_INPUT_CHANNELS + dims = cfg.VIDEO.BACKBONE.NUM_FILTERS + drop_path_rate = cfg.VIDEO.BACKBONE.DROP_PATH + depths = cfg.VIDEO.BACKBONE.DEPTH + layer_scale_init_value = cfg.VIDEO.BACKBONE.LARGE_SCALE_INIT_VALUE + stem_t_kernel_size = cfg.VIDEO.BACKBONE.STEM.T_KERNEL_SIZE if hasattr( + cfg.VIDEO.BACKBONE.STEM, 'T_KERNEL_SIZE') else 2 + t_stride = cfg.VIDEO.BACKBONE.STEM.T_STRIDE if hasattr( + cfg.VIDEO.BACKBONE.STEM, 'T_STRIDE') else 2 + + self.downsample_layers = nn.ModuleList( + ) # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv3d( + in_chans, + dims[0], + kernel_size=(stem_t_kernel_size, 4, 4), + stride=(t_stride, 4, 4), + padding=((stem_t_kernel_size - 1) // 2, 0, 0)), + LayerNorm(dims[0], eps=1e-6, data_format='channels_first')) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format='channels_first'), + nn.Conv3d( + dims[i], + dims[i + 1], + kernel_size=(1, 2, 2), + stride=(1, 2, 2)), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList( + ) # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + cur = 0 + for i in range(4): + stage = nn.Sequential(*[ + TAdaConvNeXtBlock( + cfg, + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) + for j in range(depths[i]) + ]) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return self.norm(x.mean( + [-3, -2, -1])) # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + if isinstance(x, dict): + x = x['video'] + x = self.forward_features(x) + return x + + def get_num_layers(self): + return 12, 0 + + +class ConvNeXtBlock(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv3d( + dim, dim, kernel_size=(1, 7, 7), padding=(0, 3, 3), + groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, + 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W) + + x = input + self.drop_path(x) + return x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, + normalized_shape, + eps=1e-6, + data_format='channels_last'): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ['channels_last', 'channels_first']: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == 'channels_last': + return F.layer_norm(x, self.normalized_shape, self.weight, + self.bias, self.eps) + elif self.data_format == 'channels_first': + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None, None] * x + self.bias[:, None, None, + None] + return x + + +class TAdaConvNeXtBlock(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_fi rst) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + layer_scale_init_value = float(layer_scale_init_value) + self.dwconv = TAdaConv2d( + dim, + dim, + kernel_size=(1, 7, 7), + padding=(0, 3, 3), + groups=dim, + cal_dim='cout') + route_func_type = cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_TYPE + if route_func_type == 'normal': + self.dwconv_rf = RouteFuncMLP( + c_in=dim, + ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R, + kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K, + with_bias_cal=self.dwconv.bias is not None) + elif route_func_type == 'normal_lngelu': + self.dwconv_rf = RouteFuncMLPLnGelu( + c_in=dim, + ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R, + kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K, + with_bias_cal=self.dwconv.bias is not None) + else: + raise ValueError( + 'Unknown route_func_type: {}'.format(route_func_type)) + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, + 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x, self.dwconv_rf(x)) + x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W) + + x = input + self.drop_path(x) + return x + + +class RouteFuncMLPLnGelu(nn.Module): + """ + The routing function for generating the calibration weights. + """ + + def __init__(self, + c_in, + ratio, + kernels, + with_bias_cal=False, + bn_eps=1e-5, + bn_mmt=0.1): + """ + Args: + c_in (int): number of input channels. + ratio (int): reduction ratio for the routing function. + kernels (list): temporal kernel size of the stacked 1D convolutions + """ + super(RouteFuncMLPLnGelu, self).__init__() + self.c_in = c_in + self.with_bias_cal = with_bias_cal + self.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1)) + self.globalpool = nn.AdaptiveAvgPool3d(1) + self.g = nn.Conv3d( + in_channels=c_in, + out_channels=c_in, + kernel_size=1, + padding=0, + ) + self.a = nn.Conv3d( + in_channels=c_in, + out_channels=int(c_in // ratio), + kernel_size=[kernels[0], 1, 1], + padding=[kernels[0] // 2, 0, 0], + ) + # self.bn = nn.BatchNorm3d(int(c_in//ratio), eps=bn_eps, momentum=bn_mmt) + self.ln = LayerNorm( + int(c_in // ratio), eps=1e-6, data_format='channels_first') + self.gelu = nn.GELU() + # self.relu = nn.ReLU(inplace=True) + self.b = nn.Conv3d( + in_channels=int(c_in // ratio), + out_channels=c_in, + kernel_size=[kernels[1], 1, 1], + padding=[kernels[1] // 2, 0, 0], + bias=False) + self.b.skip_init = True + self.b.weight.data.zero_() # to make sure the initial values + # for the output is 1. + if with_bias_cal: + self.b_bias = nn.Conv3d( + in_channels=int(c_in // ratio), + out_channels=c_in, + kernel_size=[kernels[1], 1, 1], + padding=[kernels[1] // 2, 0, 0], + bias=False) + self.b_bias.skip_init = True + self.b_bias.weight.data.zero_() # to make sure the initial values + # for the output is 1. + + def forward(self, x): + g = self.globalpool(x) + x = self.avgpool(x) + x = self.a(x + self.g(g)) + # x = self.bn(x) + # x = self.relu(x) + x = self.ln(x) + x = self.gelu(x) + if self.with_bias_cal: + return [self.b(x) + 1, self.b_bias(x) + 1] + else: + return self.b(x) + 1 + + +class TAdaConv2d(nn.Module): + """ + Performs temporally adaptive 2D convolution. + Currently, only application on 5D tensors is supported, which makes TAdaConv2d + essentially a 3D convolution with temporal kernel size of 1. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + cal_dim='cin'): + super(TAdaConv2d, self).__init__() + """ + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (list): kernel size of TAdaConv2d. + stride (list): stride for the convolution in TAdaConv2d. + padding (list): padding for the convolution in TAdaConv2d. + dilation (list): dilation of the convolution in TAdaConv2d. + groups (int): number of groups for TAdaConv2d. + bias (bool): whether to use bias in TAdaConv2d. + calibration_mode (str): calibrated dimension in TAdaConv2d. + Supported input "cin", "cout". + """ + + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + + assert kernel_size[0] == 1 + assert stride[0] == 1 + assert padding[0] == 0 + assert dilation[0] == 1 + assert cal_dim in ['cin', 'cout'] + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.cal_dim = cal_dim + + # base weights (W_b) + self.weight = nn.Parameter( + torch.Tensor(1, 1, out_channels, in_channels // groups, + kernel_size[1], kernel_size[2])) + if bias: + self.bias = nn.Parameter(torch.Tensor(1, 1, out_channels)) + else: + self.register_parameter('bias', None) + + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, x, alpha): + """ + Args: + x (tensor): feature to perform convolution on. + alpha (tensor): calibration weight for the base weights. + W_t = alpha_t * W_b + """ + if isinstance(alpha, list): + w_alpha, b_alpha = alpha[0], alpha[1] + else: + w_alpha = alpha + b_alpha = None + _, _, c_out, c_in, kh, kw = self.weight.size() + b, c_in, t, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).reshape(1, -1, h, w) + + if self.cal_dim == 'cin': + # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, 1, C, H(1), W(1) + # corresponding to calibrating the input channel + weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(2) + * self.weight).reshape(-1, c_in // self.groups, kh, kw) + elif self.cal_dim == 'cout': + # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C, 1, H(1), W(1) + # corresponding to calibrating the input channel + weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(3) + * self.weight).reshape(-1, c_in // self.groups, kh, kw) + + bias = None + if self.bias is not None: + if b_alpha is not None: + # b_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C + bias = (b_alpha.permute(0, 2, 1, 3, 4).squeeze() + * self.bias).reshape(-1) + else: + bias = self.bias.repeat(b, t, 1).reshape(-1) + output = F.conv2d( + x, + weight=weight, + bias=bias, + stride=self.stride[1:], + padding=self.padding[1:], + dilation=self.dilation[1:], + groups=self.groups * b * t) + + output = output.view(b, t, c_out, output.size(-2), + output.size(-1)).permute(0, 2, 1, 3, 4) + + return output + + def __repr__(self): + return f'TAdaConv2d({self.in_channels}, {self.out_channels}, kernel_size={self.kernel_size}, ' +\ + f"stride={self.stride}, padding={self.padding}, bias={self.bias is not None}, cal_dim=\"{self.cal_dim}\")" diff --git a/modelscope/models/cv/action_recognition/temporal_patch_shift_transformer.py b/modelscope/models/cv/action_recognition/temporal_patch_shift_transformer.py new file mode 100644 index 00000000..46596afd --- /dev/null +++ b/modelscope/models/cv/action_recognition/temporal_patch_shift_transformer.py @@ -0,0 +1,1198 @@ +# Part of the implementation is borrowed and modified from Video Swin Transformer, +# publicly available at https://github.com/SwinTransformer/Video-Swin-Transformer + +from abc import ABCMeta, abstractmethod +from functools import lru_cache, reduce +from operator import mul + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import torchvision.transforms as T +from einops import rearrange +from timm.models.layers import DropPath, Mlp, trunc_normal_ + +from modelscope.models import TorchModel + + +def normal_init(module, mean=0., std=1., bias=0.): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def window_partition(x, window_size): + """ window_partition function. + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], + window_size[1], W // window_size[2], window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, + 7).contiguous().view(-1, reduce(mul, window_size), C) + return windows + + +def window_reverse(windows, window_size, B, D, H, W): + """ window_reverse function. + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], + W // window_size[2], window_size[0], window_size[1], + window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class WindowAttention3D(nn.Module): + """ This is PyTorch impl of TPS + + Window based multi-head self attention (W-MSA) module with relative position bias. + The coordinates of patches and patches are shifted together using Pattern C. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + shift (bool, optional): If True, conduct shift operation + shift_type (str, optional): shift operation type, either using 'psm' or 'tsm' + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + shift=False, + shift_type='psm'): + + super().__init__() + self.dim = dim + window_size = (16, 7, 7) + self.window_size = window_size # Wd, Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.shift = shift + self.shift_type = shift_type + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + np.prod([2 * ws - 1 for ws in window_size]), + num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack( + torch.meshgrid(coords_d, coords_h, coords_w, + indexing='ij')) # 3, Wd, Wh, Ww + # Do the same rotation to coords + coords_old = coords.clone() + + # pattern patternC - 9 + coords[:, :, 0::3, 0::3] = torch.roll( + coords[:, :, 0::3, 0::3], shifts=-4, dims=1) + coords[:, :, 0::3, 1::3] = torch.roll( + coords[:, :, 0::3, 1::3], shifts=1, dims=1) + coords[:, :, 0::3, 2::3] = torch.roll( + coords[:, :, 0::3, 2::3], shifts=2, dims=1) + coords[:, :, 1::3, 2::3] = torch.roll( + coords[:, :, 1::3, 2::3], shifts=3, dims=1) + coords[:, :, 1::3, 0::3] = torch.roll( + coords[:, :, 1::3, 0::3], shifts=-1, dims=1) + coords[:, :, 2::3, 0::3] = torch.roll( + coords[:, :, 2::3, 0::3], shifts=-2, dims=1) + coords[:, :, 2::3, 1::3] = torch.roll( + coords[:, :, 2::3, 1::3], shifts=-3, dims=1) + coords[:, :, 2::3, 2::3] = torch.roll( + coords[:, :, 2::3, 2::3], shifts=4, dims=1) + + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + coords_old_flatten = torch.flatten(coords_old, 1) + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords_old = coords_old_flatten[:, :, + None] - coords_old_flatten[:, + None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords_old = relative_coords_old.permute( + 1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords_old[:, :, 0] += self.window_size[ + 0] - 1 # shift to start from 0 + relative_coords_old[:, :, 1] += self.window_size[1] - 1 + relative_coords_old[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] + - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) + + relative_coords_old[:, :, 0] *= (2 * self.window_size[1] + - 1) * (2 * self.window_size[2] - 1) + relative_coords_old[:, :, 1] *= (2 * self.window_size[2] - 1) + + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + + relative_position_index_old = relative_coords_old.sum(-1) + relative_position_index = relative_position_index.view( + window_size[0], window_size[1] * window_size[2], window_size[0], + window_size[1] * window_size[2]).permute(0, 2, 1, 3).reshape( + window_size[0] * window_size[0], + window_size[1] * window_size[2], + window_size[1] * window_size[2])[::window_size[0], :, :] + + relative_position_index_old = relative_position_index_old.view( + window_size[0], window_size[1] * window_size[2], window_size[0], + window_size[1] * window_size[2]).permute(0, 2, 1, 3).reshape( + window_size[0] * window_size[0], + window_size[1] * window_size[2], + window_size[1] * window_size[2])[::window_size[0], :, :] + + self.register_buffer('relative_position_index', + relative_position_index) + self.register_buffer('relative_position_index_old', + relative_position_index_old) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + if self.shift and self.shift_type == 'psm': + self.shift_op = PatchShift(False, 1) + self.shift_op_back = PatchShift(True, 1) + elif self.shift and self.shift_type == 'tsm': + self.shift_op = TemporalShift(8) + + def forward(self, x, mask=None, batch_size=8, frame_len=8): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + B_, N, C = x.shape + if self.shift: + x = x.view(B_, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + x = self.shift_op(x, batch_size, frame_len) + x = x.permute(0, 2, 1, 3).reshape(B_, N, C) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.shift and self.shift_type == 'psm': + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:].reshape(-1), :].reshape( + frame_len, N, N, -1) # 8frames ,Wd*Wh*Ww,Wd*Wh*Ww,nH + else: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index_old[:].reshape(-1), :].reshape( + frame_len, N, N, -1) # 8frames ,Wd*Wh*Ww,Wd*Wh*Ww,nH + + relative_position_bias = relative_position_bias.permute( + 0, 3, 1, 2).contiguous() # Frames, nH, Wd*Wh*Ww, Wd*Wh*Ww + + attn = attn.view( + batch_size, frame_len, -1, self.num_heads, N, N).permute( + 0, + 2, 1, 3, 4, 5) + relative_position_bias.unsqueeze(0).unsqueeze( + 1) # B_, nH, N, N + attn = attn.permute(0, 2, 1, 3, 4, 5).view(-1, self.num_heads, N, N) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + # Shift back for psm + if self.shift and self.shift_type == 'psm': + x = self.shift_op_back(attn @ v, batch_size, + frame_len).transpose(1, + 2).reshape(B_, N, C) + else: + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class PatchShift(nn.Module): + """ This is PyTorch impl of TPS + + The patches are shifted using Pattern C. + + It supports both of shifted and shift back. + + Args: + inv (bool): whether using inverse shifted (shift back) + ratio (float): ratio of channels to be shifted, patch shift using 1.0 + """ + + def __init__(self, inv=False, ratio=1): + super(PatchShift, self).__init__() + self.inv = inv + self.ratio = ratio + # if inv: + # print('=> Using inverse PatchShift, ratio {}, tps'.format(ratio)) + # else: + # print('=> Using bayershift, ratio {}, tps'.format(ratio)) + + def forward(self, x, batch_size, frame_len): + x = self.shift( + x, + inv=self.inv, + ratio=self.ratio, + batch_size=batch_size, + frame_len=frame_len) + return x + + @staticmethod + def shift(x, inv=False, ratio=0.5, batch_size=8, frame_len=8): + B, num_heads, N, c = x.size() + fold = int(num_heads * ratio) + feat = x + feat = feat.view(batch_size, frame_len, -1, num_heads, 7, 7, c) + out = feat.clone() + multiplier = 1 + stride = 1 + if inv: + multiplier = -1 + + # Pattern C + out[:, :, :, :fold, 0::3, 0::3, :] = torch.roll( + feat[:, :, :, :fold, 0::3, 0::3, :], + shifts=-4 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 0::3, 1::3, :] = torch.roll( + feat[:, :, :, :fold, 0::3, 1::3, :], + shifts=multiplier * stride, + dims=1) + out[:, :, :, :fold, 1::3, 0::3, :] = torch.roll( + feat[:, :, :, :fold, 1::3, 0::3, :], + shifts=-multiplier * stride, + dims=1) + out[:, :, :, :fold, 0::3, 2::3, :] = torch.roll( + feat[:, :, :, :fold, 0::3, 2::3, :], + shifts=2 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 2::3, 0::3, :] = torch.roll( + feat[:, :, :, :fold, 2::3, 0::3, :], + shifts=-2 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 1::3, 2::3, :] = torch.roll( + feat[:, :, :, :fold, 1::3, 2::3, :], + shifts=3 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 2::3, 1::3, :] = torch.roll( + feat[:, :, :, :fold, 2::3, 1::3, :], + shifts=-3 * multiplier * stride, + dims=1) + out[:, :, :, :fold, 2::3, 2::3, :] = torch.roll( + feat[:, :, :, :fold, 2::3, 2::3, :], + shifts=4 * multiplier * stride, + dims=1) + + out = out.view(B, num_heads, N, c) + return out + + +class TemporalShift(nn.Module): + """ This is PyTorch impl of TPS + + The temporal channel shift. + + The code is adopted from TSM: Temporal Shift Module for Efficient Video Understanding. ICCV19 + + https://github.com/mit-han-lab/temporal-shift-module/blob/master/ops/temporal_shift.py + + Args: + n_div (int): propotion of channel to be shifted. + """ + + def __init__(self, n_div=8): + super(TemporalShift, self).__init__() + self.fold_div = n_div + + def forward(self, x, batch_size, frame_len): + x = self.shift( + x, + fold_div=self.fold_div, + batch_size=batch_size, + frame_len=frame_len) + return x + + @staticmethod + def shift(x, fold_div=8, batch_size=8, frame_len=8): + B, num_heads, N, c = x.size() + fold = c // fold_div + feat = x + feat = feat.view(batch_size, frame_len, -1, num_heads, N, c) + out = feat.clone() + + out[:, 1:, :, :, :, :fold] = feat[:, :-1, :, :, :, :fold] # shift left + out[:, :-1, :, :, :, + fold:2 * fold] = feat[:, 1:, :, :, :, fold:2 * fold] # shift right + + out = out.view(B, num_heads, N, c) + + return out + + +class SwinTransformerBlock3D(nn.Module): + """ Swin Transformer Block from Video Swin Transformer. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=(2, 7, 7), + shift_size=(0, 0, 0), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_checkpoint=False, + shift=False, + shift_type='psm'): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + self.shift = shift + self.shift_type = shift_type + + assert 0 <= self.shift_size[0] < self.window_size[ + 0], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[1] < self.window_size[ + 1], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[2] < self.window_size[ + 2], 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention3D( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + shift=self.shift, + shift_type=self.shift_type) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + + x = self.norm1(x) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll( + x, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + # partition windows + x_windows = window_partition(shifted_x, + window_size) # B*nW, Wd*Wh*Ww, C + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask, batch_size=B, + frame_len=D) # B*nW, Wd*Wh*Ww, C + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C, ))) + shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, + Wp) # B D' H' W' C + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll( + shifted_x, + shifts=(shift_size[0], shift_size[1], shift_size[2]), + dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :].contiguous() + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer from Video Swin Transformer. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + """ + B, D, H, W, C = x.shape + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], + -shift_size[0]), slice( + -shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], + -shift_size[1]), slice( + -shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], + -shift_size[2]), slice( + -shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, + window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + return attn_mask + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage from Video Swin Transformer. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (1,7,7). + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(1, 7, 7), + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + shift_type='psm'): + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.shift_type = shift_type + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock3D( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + shift=True, + shift_type='tsm' if (i % 2 == 0 and self.shift_type == 'psm') + or self.shift_type == 'tsm' else 'psm', + ) for i in range(depth) + ]) + + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for SW-MSA + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, 'b d h w c -> b c d h w') + return x + + +class PatchEmbed3D(nn.Module): + """ Video to Patch Embedding from Video Swin Transformer. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad( + x, + (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # B C D Wh Ww + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + + return x + + +class SwinTransformer2D_TPS(nn.Module): + """ + Code is adopted from Video Swin Transformer. + + Args: + patch_size (int | tuple(int)): Patch size. Default: (4,4,4). + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer: Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + """ + + def __init__(self, + pretrained=None, + pretrained2d=True, + patch_size=(4, 4, 4), + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(2, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=False, + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrained = pretrained + self.pretrained2d = pretrained2d + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + self.window_size = window_size + self.patch_size = patch_size + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging + if i_layer < self.num_layers - 1 else None, + use_checkpoint=use_checkpoint, + shift_type='psm') + self.layers.append(layer) + + self.num_features = int(embed_dim * 2**(self.num_layers - 1)) + + # add a norm layer for each output + self.norm = norm_layer(self.num_features) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1: + self.pos_drop.eval() + for i in range(0, self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def inflate_weights(self): + """Inflate the swin2d parameters to swin3d. + + The differences between swin3d and swin2d mainly lie in an extra + axis. To utilize the pretrained parameters in 2d model, + the weight of swin2d models should be inflated to fit in the shapes of + the 3d counterpart. + + Args: + logger (logging.Logger): The logger used to print + debugging infomation. + """ + checkpoint = torch.load(self.pretrained, map_location='cpu') + state_dict = checkpoint['model'] + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_position_index' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k] + for k in attn_mask_keys: + del state_dict[k] + + state_dict['patch_embed.proj.weight'] = state_dict[ + 'patch_embed.proj.weight'].unsqueeze(2).repeat( + 1, 1, self.patch_size[0], 1, 1) / self.patch_size[0] + + # bicubic interpolate relative_position_bias_table if not match + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if 'relative_position_bias_table' in k + ] + for k in relative_position_bias_table_keys: + relative_position_bias_table_pretrained = state_dict[k] + relative_position_bias_table_current = self.state_dict()[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + # wd = self.window_size[0] + # to make it match + wd = 16 + if nH1 != nH2: + print(f'Error in loading {k}, passing') + else: + if L1 != L2: + S1 = int(L1**0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute( + 1, 0).view(1, nH1, S1, S1), + size=(2 * self.window_size[1] - 1, + 2 * self.window_size[2] - 1), + mode='bicubic') + relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view( + nH2, L2).permute(1, 0) + state_dict[k] = relative_position_bias_table_pretrained.repeat( + 2 * wd - 1, 1) + + msg = self.load_state_dict(state_dict, strict=False) + print(msg) + print(f"=> loaded successfully '{self.pretrained}'") + del checkpoint + torch.cuda.empty_cache() + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if pretrained: + self.pretrained = pretrained + if isinstance(self.pretrained, str): + self.apply(_init_weights) + print(f'load model from: {self.pretrained}') + + if self.pretrained2d: + # Inflate 2D model into 3D model. + # self.inflate_weights(logger) + self.inflate_weights() + else: + # Directly load 3D model. + torch.load_checkpoint(self, self.pretrained, strict=False) + elif self.pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x.contiguous()) + + x = rearrange(x, 'n c d h w -> n d h w c') + x = self.norm(x) + x = rearrange(x, 'n d h w c -> n c d h w') + + return x + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer2D_TPS, self).train(mode) + self._freeze_stages() + + +def top_k_accuracy(scores, labels, topk=(1, )): + """Calculate top k accuracy score from mmaction. + + Args: + scores (list[np.ndarray]): Prediction scores for each class. + labels (list[int]): Ground truth labels. + topk (tuple[int]): K value for top_k_accuracy. Default: (1, ). + + Returns: + list[float]: Top k accuracy score for each k. + """ + res = [] + labels = np.array(labels)[:, np.newaxis] + for k in topk: + max_k_preds = np.argsort(scores, axis=1)[:, -k:][:, ::-1] + match_array = np.logical_or.reduce(max_k_preds == labels, axis=1) + topk_acc_score = match_array.sum() / match_array.shape[0] + res.append(topk_acc_score) + + return res + + +class BaseHead(nn.Module, metaclass=ABCMeta): + """Base class for head from mmaction. + + All Head should subclass it. + All subclass should overwrite: + - Methods:``init_weights``, initializing weights in some modules. + - Methods:``forward``, supporting to forward both for training and testing. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + loss_cls (dict): Config for building loss. + Default: dict(type='CrossEntropyLoss', loss_weight=1.0). + multi_class (bool): Determines whether it is a multi-class + recognition task. Default: False. + label_smooth_eps (float): Epsilon used in label smooth. + Reference: arxiv.org/abs/1906.02629. Default: 0. + """ + + def __init__(self, + num_classes, + in_channels, + loss_cls=dict(type='CrossEntropyLoss', loss_weight=1.0), + multi_class=False, + label_smooth_eps=0.0): + super().__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.loss_cls = torch.nn.CrossEntropyLoss() + self.multi_class = multi_class + self.label_smooth_eps = label_smooth_eps + + @abstractmethod + def init_weights(self): + """Initiate the parameters either from existing checkpoint or from + scratch.""" + + @abstractmethod + def forward(self, x): + """Defines the computation performed at every call.""" + + def loss(self, cls_score, labels, **kwargs): + """Calculate the loss given output ``cls_score``, target ``labels``. + + Args: + cls_score (torch.Tensor): The output of the model. + labels (torch.Tensor): The target output of the model. + + Returns: + dict: A dict containing field 'loss_cls'(mandatory) + and 'top1_acc', 'top5_acc'(optional). + """ + losses = dict() + if labels.shape == torch.Size([]): + labels = labels.unsqueeze(0) + elif labels.dim() == 1 and labels.size()[0] == self.num_classes \ + and cls_score.size()[0] == 1: + # Fix a bug when training with soft labels and batch size is 1. + # When using soft labels, `labels` and `cls_socre` share the same + # shape. + labels = labels.unsqueeze(0) + + if not self.multi_class and cls_score.size() != labels.size(): + top_k_acc = top_k_accuracy(cls_score.detach().cpu().numpy(), + labels.detach().cpu().numpy(), (1, 5)) + losses['top1_acc'] = torch.tensor( + top_k_acc[0], device=cls_score.device) + losses['top5_acc'] = torch.tensor( + top_k_acc[1], device=cls_score.device) + + elif self.multi_class and self.label_smooth_eps != 0: + labels = ((1 - self.label_smooth_eps) * labels + + self.label_smooth_eps / self.num_classes) + + loss_cls = self.loss_cls(cls_score, labels, **kwargs) + # loss_cls may be dictionary or single tensor + if isinstance(loss_cls, dict): + losses.update(loss_cls) + else: + losses['loss_cls'] = loss_cls + + return losses + + +class I3DHead(BaseHead): + """Classification head for I3D from mmaction. + + Args: + num_classes (int): Number of classes to be classified. + in_channels (int): Number of channels in input feature. + loss_cls (dict): Config for building loss. + Default: dict(type='CrossEntropyLoss') + spatial_type (str): Pooling type in spatial dimension. Default: 'avg'. + dropout_ratio (float): Probability of dropout layer. Default: 0.5. + init_std (float): Std value for Initiation. Default: 0.01. + kwargs (dict, optional): Any keyword argument to be used to initialize + the head. + """ + + def __init__(self, + num_classes, + in_channels, + loss_cls=dict(type='CrossEntropyLoss'), + spatial_type='avg', + dropout_ratio=0.5, + init_std=0.01, + **kwargs): + super().__init__(num_classes, in_channels, loss_cls, **kwargs) + + self.spatial_type = spatial_type + self.dropout_ratio = dropout_ratio + self.init_std = init_std + if self.dropout_ratio != 0: + self.dropout = nn.Dropout(p=self.dropout_ratio) + else: + self.dropout = None + self.fc_cls = nn.Linear(self.in_channels, self.num_classes) + + if self.spatial_type == 'avg': + # use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. + self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) + else: + self.avg_pool = None + + def init_weights(self): + """Initiate the parameters from scratch.""" + normal_init(self.fc_cls, std=self.init_std) + + def forward(self, x): + """Defines the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + + Returns: + torch.Tensor: The classification scores for input samples. + """ + # [N, in_channels, 4, 7, 7] + if self.avg_pool is not None: + x = self.avg_pool(x) + # [N, in_channels, 1, 1, 1] + if self.dropout is not None: + x = self.dropout(x) + # [N, in_channels, 1, 1, 1] + x = x.view(x.shape[0], -1) + # [N, in_channels] + cls_score = self.fc_cls(x) + # [N, num_classes] + return cls_score + + +class PatchShiftTransformer(TorchModel): + """ This is PyTorch impl of PST: + Spatiotemporal Self-attention Modeling with Temporal Patch Shift for Action Recognition, ECCV22. + """ + + def __init__(self, + model_dir=None, + num_classes=400, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + embed_dim=96, + in_channels=768, + pretrained=None): + super().__init__(model_dir) + self.backbone = SwinTransformer2D_TPS( + pretrained=pretrained, + pretrained2d=True, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=(1, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + frozen_stages=-1, + use_checkpoint=False) + self.cls_head = I3DHead( + num_classes=num_classes, in_channels=in_channels) + + def forward(self, x): + feature = self.backbone(x) + output = self.cls_head(feature) + return output diff --git a/modelscope/models/cv/animal_recognition/__init__.py b/modelscope/models/cv/animal_recognition/__init__.py new file mode 100644 index 00000000..00a37a3f --- /dev/null +++ b/modelscope/models/cv/animal_recognition/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .resnet import ResNet, Bottleneck + from .splat import SplAtConv2d + +else: + _import_structure = { + 'resnet': ['ResNet', 'Bottleneck'], + 'splat': ['SplAtConv2d'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/animal_recognition/resnet.py b/modelscope/models/cv/animal_recognition/resnet.py new file mode 100644 index 00000000..d7c03c29 --- /dev/null +++ b/modelscope/models/cv/animal_recognition/resnet.py @@ -0,0 +1,430 @@ +# The implementation is adopted from Split-Attention Network, A New ResNet Variant, +# made pubicly available under the Apache License 2.0 License +# at https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/models/resnet.py +import math + +import torch +import torch.nn as nn + +from .splat import SplAtConv2d + +__all__ = ['ResNet', 'Bottleneck'] + + +class DropBlock2D(object): + + def __init__(self, *args, **kwargs): + raise NotImplementedError + + +class GlobalAvgPool2d(nn.Module): + + def __init__(self): + """Global average pooling over the input's spatial dimensions""" + super(GlobalAvgPool2d, self).__init__() + + def forward(self, inputs): + return nn.functional.adaptive_avg_pool2d(inputs, + 1).view(inputs.size(0), -1) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + radix=1, + cardinality=1, + bottleneck_width=64, + avd=False, + avd_first=False, + dilation=1, + is_first=False, + rectified_conv=False, + rectify_avg=False, + norm_layer=None, + dropblock_prob=0.0, + last_gamma=False): + super(Bottleneck, self).__init__() + group_width = int(planes * (bottleneck_width / 64.)) * cardinality + self.conv1 = nn.Conv2d( + inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.dropblock_prob = dropblock_prob + self.radix = radix + self.avd = avd and (stride > 1 or is_first) + self.avd_first = avd_first + + if self.avd: + self.avd_layer = nn.AvgPool2d(3, stride, padding=1) + stride = 1 + + if dropblock_prob > 0.0: + self.dropblock1 = DropBlock2D(dropblock_prob, 3) + if radix == 1: + self.dropblock2 = DropBlock2D(dropblock_prob, 3) + self.dropblock3 = DropBlock2D(dropblock_prob, 3) + + if radix >= 1: + self.conv2 = SplAtConv2d( + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False, + radix=radix, + rectify=rectified_conv, + rectify_avg=rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + elif rectified_conv: + self.conv2 = nn.Conv2d( + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False) + self.bn2 = norm_layer(group_width) + else: + self.conv2 = nn.Conv2d( + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False) + self.bn2 = norm_layer(group_width) + + self.conv3 = nn.Conv2d( + group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * 4) + + if last_gamma: + from torch.nn.init import zeros_ + zeros_(self.bn3.weight) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.dilation = dilation + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + if self.dropblock_prob > 0.0: + out = self.dropblock1(out) + out = self.relu(out) + + if self.avd and self.avd_first: + out = self.avd_layer(out) + + out = self.conv2(out) + if self.radix == 0: + out = self.bn2(out) + if self.dropblock_prob > 0.0: + out = self.dropblock2(out) + out = self.relu(out) + + if self.avd and not self.avd_first: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.dropblock_prob > 0.0: + out = self.dropblock3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, + block, + layers, + radix=1, + groups=1, + bottleneck_width=64, + num_classes=1000, + dilated=False, + dilation=1, + deep_stem=False, + stem_width=64, + avg_down=False, + rectified_conv=False, + rectify_avg=False, + avd=False, + avd_first=False, + final_drop=0.0, + dropblock_prob=0, + last_gamma=False, + norm_layer=nn.BatchNorm2d): + self.cardinality = groups + self.bottleneck_width = bottleneck_width + # ResNet-D params + self.inplanes = stem_width * 2 if deep_stem else 64 + self.avg_down = avg_down + self.last_gamma = last_gamma + # ResNeSt params + self.radix = radix + self.avd = avd + self.avd_first = avd_first + + super(ResNet, self).__init__() + self.rectified_conv = rectified_conv + self.rectify_avg = rectify_avg + if rectified_conv: + conv_layer = nn.Conv2d + else: + conv_layer = nn.Conv2d + conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} + if deep_stem: + self.conv1 = nn.Sequential( + conv_layer( + 3, + stem_width, + kernel_size=3, + stride=2, + padding=1, + bias=False, + **conv_kwargs), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer( + stem_width, + stem_width, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **conv_kwargs), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer( + stem_width, + stem_width * 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **conv_kwargs), + ) + else: + self.conv1 = conv_layer( + 3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False, + **conv_kwargs) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + block, 64, layers[0], norm_layer=norm_layer, is_first=False) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, norm_layer=norm_layer) + if dilated or dilation == 4: + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=1, + dilation=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=1, + dilation=4, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + elif dilation == 2: + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilation=1, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=1, + dilation=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + else: + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + self.avgpool = GlobalAvgPool2d() + self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, norm_layer): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilation=1, + norm_layer=None, + dropblock_prob=0.0, + is_first=True): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + down_layers = [] + if self.avg_down: + if dilation == 1: + down_layers.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + else: + down_layers.append( + nn.AvgPool2d( + kernel_size=1, + stride=1, + ceil_mode=True, + count_include_pad=False)) + down_layers.append( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=1, + bias=False)) + else: + down_layers.append( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False)) + down_layers.append(norm_layer(planes * block.expansion)) + downsample = nn.Sequential(*down_layers) + + layers = [] + if dilation == 1 or dilation == 2: + layers.append( + block( + self.inplanes, + planes, + stride, + downsample=downsample, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=1, + is_first=is_first, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + elif dilation == 4: + layers.append( + block( + self.inplanes, + planes, + stride, + downsample=downsample, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=2, + is_first=is_first, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + else: + raise RuntimeError('=> unknown dilation size: {}'.format(dilation)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=dilation, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + if self.drop: + x = self.drop(x) + x = self.fc(x) + + return x diff --git a/modelscope/models/cv/animal_recognition/splat.py b/modelscope/models/cv/animal_recognition/splat.py new file mode 100644 index 00000000..a10d0abe --- /dev/null +++ b/modelscope/models/cv/animal_recognition/splat.py @@ -0,0 +1,126 @@ +# The implementation is adopted from Split-Attention Network, A New ResNet Variant, +# made pubicly available under the Apache License 2.0 License +# at https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/models/splat.py +"""Split-Attention""" + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BatchNorm2d, Conv2d, Linear, Module, ReLU +from torch.nn.modules.utils import _pair + +__all__ = ['SplAtConv2d'] + + +class SplAtConv2d(Module): + """Split-Attention Conv2d + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + bias=True, + radix=2, + reduction_factor=4, + rectify=False, + rectify_avg=False, + norm_layer=None, + dropblock_prob=0.0, + **kwargs): + super(SplAtConv2d, self).__init__() + padding = _pair(padding) + self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) + self.rectify_avg = rectify_avg + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.cardinality = groups + self.channels = channels + self.dropblock_prob = dropblock_prob + if self.rectify: + self.conv = Conv2d( + in_channels, + channels * radix, + kernel_size, + stride, + padding, + dilation, + groups=groups * radix, + bias=bias, + **kwargs) + else: + self.conv = Conv2d( + in_channels, + channels * radix, + kernel_size, + stride, + padding, + dilation, + groups=groups * radix, + bias=bias, + **kwargs) + self.use_bn = norm_layer is not None + if self.use_bn: + self.bn0 = norm_layer(channels * radix) + self.relu = ReLU(inplace=True) + self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) + if self.use_bn: + self.bn1 = norm_layer(inter_channels) + self.fc2 = Conv2d( + inter_channels, channels * radix, 1, groups=self.cardinality) + if dropblock_prob > 0.0: + self.dropblock = DropBlock2D(dropblock_prob, 3) + self.rsoftmax = rSoftMax(radix, groups) + + def forward(self, x): + x = self.conv(x) + if self.use_bn: + x = self.bn0(x) + if self.dropblock_prob > 0.0: + x = self.dropblock(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splited = torch.split(x, rchannel // self.radix, dim=1) + gap = sum(splited) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + if self.use_bn: + gap = self.bn1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = torch.split(atten, rchannel // self.radix, dim=1) + out = sum([att * split for (att, split) in zip(attens, splited)]) + else: + out = atten * x + return out.contiguous() + + +class rSoftMax(nn.Module): + + def __init__(self, radix, cardinality): + super().__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x diff --git a/modelscope/models/cv/body_2d_keypoints/__init__.py b/modelscope/models/cv/body_2d_keypoints/__init__.py new file mode 100644 index 00000000..ddc00cb3 --- /dev/null +++ b/modelscope/models/cv/body_2d_keypoints/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .hrnet_v2 import PoseHighResolutionNetV2 + +else: + _import_structure = { + 'hrnet_v2': ['PoseHighResolutionNetV2'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/body_2d_keypoints/hrnet_basic_modules.py b/modelscope/models/cv/body_2d_keypoints/hrnet_basic_modules.py new file mode 100644 index 00000000..3b960688 --- /dev/null +++ b/modelscope/models/cv/body_2d_keypoints/hrnet_basic_modules.py @@ -0,0 +1,397 @@ +# The implementation is based on HRNET, available at https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation. + +import torch +import torch.nn as nn + +BN_MOMENTUM = 0.1 + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d( + planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + + def __init__(self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, + num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d( + num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + layers = [] + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), nn.BatchNorm2d(num_inchannels[i]), + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3), + nn.ReLU(True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): + result = nn.Sequential() + result.add_module( + 'conv', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False)) + result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) + return result + + +def upsample(scale, oup): + return nn.Sequential( + nn.Upsample(scale_factor=scale, mode='bilinear'), + nn.Conv2d( + in_channels=oup, + out_channels=oup, + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False), nn.BatchNorm2d(oup), nn.PReLU()) + + +class SE_Block(nn.Module): + + def __init__(self, c, r=16): + super().__init__() + self.squeeze = nn.AdaptiveAvgPool2d(1) + self.excitation = nn.Sequential( + nn.Linear(c, c // r, bias=False), nn.ReLU(inplace=True), + nn.Linear(c // r, c, bias=False), nn.Sigmoid()) + + def forward(self, x): + bs, c, _, _ = x.shape + y = self.squeeze(x).view(bs, c) + y = self.excitation(y).view(bs, c, 1, 1) + return x * y.expand_as(x) + + +class BasicBlockSE(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, r=64): + super(BasicBlockSE, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + self.se = SE_Block(planes, r) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BottleneckSE(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, r=64): + super(BottleneckSE, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d( + planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + self.se = SE_Block(planes * self.expansion, r) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck, + 'BASICSE': BasicBlockSE, + 'BOTTLENECKSE': BottleneckSE, +} diff --git a/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py b/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py new file mode 100644 index 00000000..ebd69adb --- /dev/null +++ b/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py @@ -0,0 +1,223 @@ +# The implementation is based on HRNET, available at https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation. + +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.body_2d_keypoints.hrnet_basic_modules import ( + BN_MOMENTUM, BasicBlock, Bottleneck, HighResolutionModule, blocks_dict) +from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15 +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + Tasks.body_2d_keypoints, module_name=Models.body_2d_keypoints) +class PoseHighResolutionNetV2(TorchModel): + + def __init__(self, cfg=None, **kwargs): + if cfg is None: + cfg = cfg_128x128_15 + self.inplanes = 64 + extra = cfg['MODEL']['EXTRA'] + super(PoseHighResolutionNetV2, self).__init__(**kwargs) + + # stem net + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + 64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(Bottleneck, 64, 4) + + self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + """final four layers""" + last_inp_channels = np.int(np.sum(pre_stage_channels)) + self.final_layer = nn.Sequential( + nn.Conv2d( + in_channels=last_inp_channels, + out_channels=last_inp_channels, + kernel_size=1, + stride=1, + padding=0), + nn.BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=False), + nn.Conv2d( + in_channels=last_inp_channels, + out_channels=cfg['MODEL']['NUM_JOINTS'], + kernel_size=extra['FINAL_CONV_KERNEL'], + stride=1, + padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0)) + + self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS'] + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i]), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[ + i] if j == i - num_branches_pre else inchannels + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, + layer_config, + num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, block, num_blocks, + num_inchannels, num_channels, fuse_method, + reset_multi_scale_output)) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + + y_list = self.stage4(x_list) + + y0_h, y0_w = y_list[0].size(2), y_list[0].size(3) + y1 = F.upsample(y_list[1], size=(y0_h, y0_w), mode='bilinear') + y2 = F.upsample(y_list[2], size=(y0_h, y0_w), mode='bilinear') + y3 = F.upsample(y_list[3], size=(y0_h, y0_w), mode='bilinear') + + y = torch.cat([y_list[0], y1, y2, y3], 1) + output = self.final_layer(y) + + return output diff --git a/modelscope/models/cv/body_2d_keypoints/w48.py b/modelscope/models/cv/body_2d_keypoints/w48.py new file mode 100644 index 00000000..e0317991 --- /dev/null +++ b/modelscope/models/cv/body_2d_keypoints/w48.py @@ -0,0 +1,53 @@ +# The implementation is based on HRNET, available at https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation. + +cfg_128x128_15 = { + 'DATASET': { + 'TYPE': 'DAMO', + 'PARENT_IDS': [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1], + 'LEFT_IDS': [2, 3, 4, 8, 9, 10], + 'RIGHT_IDS': [5, 6, 7, 11, 12, 13], + 'SPINE_IDS': [0, 1, 14] + }, + 'MODEL': { + 'INIT_WEIGHTS': True, + 'NAME': 'pose_hrnet', + 'NUM_JOINTS': 15, + 'PRETRAINED': '', + 'TARGET_TYPE': 'gaussian', + 'IMAGE_SIZE': [128, 128], + 'HEATMAP_SIZE': [32, 32], + 'SIGMA': 2.0, + 'EXTRA': { + 'PRETRAINED_LAYERS': [ + 'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1', + 'stage2', 'transition2', 'stage3', 'transition3', 'stage4' + ], + 'FINAL_CONV_KERNEL': + 1, + 'STAGE2': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 2, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': [4, 4], + 'NUM_CHANNELS': [48, 96], + 'FUSE_METHOD': 'SUM' + }, + 'STAGE3': { + 'NUM_MODULES': 4, + 'NUM_BRANCHES': 3, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': [4, 4, 4], + 'NUM_CHANNELS': [48, 96, 192], + 'FUSE_METHOD': 'SUM' + }, + 'STAGE4': { + 'NUM_MODULES': 3, + 'NUM_BRANCHES': 4, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': [4, 4, 4, 4], + 'NUM_CHANNELS': [48, 96, 192, 384], + 'FUSE_METHOD': 'SUM' + }, + } + } +} diff --git a/modelscope/models/cv/body_3d_keypoints/__init__.py b/modelscope/models/cv/body_3d_keypoints/__init__.py new file mode 100644 index 00000000..4bb83936 --- /dev/null +++ b/modelscope/models/cv/body_3d_keypoints/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .body_3d_pose import BodyKeypointsDetection3D + +else: + _import_structure = { + 'body_3d_pose': ['BodyKeypointsDetection3D'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py b/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py new file mode 100644 index 00000000..3e920d12 --- /dev/null +++ b/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py @@ -0,0 +1,248 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +import os.path as osp +from typing import Any, Dict, List, Union + +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.body_3d_keypoints.canonical_pose_modules import ( + TemporalModel, TransCan3Dkeys) +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['BodyKeypointsDetection3D'] + + +class KeypointsTypes(object): + POSES_CAMERA = 'poses_camera' + POSES_TRAJ = 'poses_traj' + + +@MODELS.register_module( + Tasks.body_3d_keypoints, module_name=Models.body_3d_keypoints) +class BodyKeypointsDetection3D(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + + super().__init__(model_dir, *args, **kwargs) + + self.model_dir = model_dir + model_path = osp.join(self.model_dir, ModelFile.TORCH_MODEL_FILE) + cfg_path = osp.join(self.model_dir, ModelFile.CONFIGURATION) + self.cfg = Config.from_file(cfg_path) + self._create_model() + + if not osp.exists(model_path): + raise IOError(f'{model_path} is not exists.') + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.pretrained_state_dict = torch.load( + model_path, map_location=self._device) + + self.load_pretrained() + self.to_device(self._device) + self.eval() + + def _create_model(self): + self.model_pos = TemporalModel( + self.cfg.model.MODEL.IN_NUM_JOINTS, + self.cfg.model.MODEL.IN_2D_FEATURE, + self.cfg.model.MODEL.OUT_NUM_JOINTS, + filter_widths=self.cfg.model.MODEL.FILTER_WIDTHS, + causal=self.cfg.model.MODEL.CAUSAL, + dropout=self.cfg.model.MODEL.DROPOUT, + channels=self.cfg.model.MODEL.CHANNELS, + dense=self.cfg.model.MODEL.DENSE) + + receptive_field = self.model_pos.receptive_field() + self.pad = (receptive_field - 1) // 2 + if self.cfg.model.MODEL.CAUSAL: + self.causal_shift = self.pad + else: + self.causal_shift = 0 + + self.model_traj = TransCan3Dkeys( + in_channels=self.cfg.model.MODEL.IN_NUM_JOINTS + * self.cfg.model.MODEL.IN_2D_FEATURE, + num_features=1024, + out_channels=self.cfg.model.MODEL.OUT_3D_FEATURE, + num_blocks=4, + time_window=receptive_field) + + def eval(self): + self.model_pos.eval() + self.model_traj.eval() + + def train(self): + self.model_pos.train() + self.model_traj.train() + + def to_device(self, device): + self.model_pos = self.model_pos.to(device) + self.model_traj = self.model_traj.to(device) + + def load_pretrained(self): + if 'model_pos' in self.pretrained_state_dict: + self.model_pos.load_state_dict( + self.pretrained_state_dict['model_pos'], strict=False) + else: + logging.error( + 'Not load model pos from pretrained_state_dict, not in pretrained_state_dict' + ) + + if 'model_traj' in self.pretrained_state_dict: + self.model_traj.load_state_dict( + self.pretrained_state_dict['model_traj'], strict=False) + else: + logging.error( + 'Not load model traj from pretrained_state_dict, not in pretrained_state_dict' + ) + logging.info('Load pretrained model done.') + + def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: + """Proprocess of 2D input joints. + + Args: + input (Dict[str, Any]): [NUM_FRAME, NUM_JOINTS, 2], input 2d human body keypoints. + + Returns: + Dict[str, Any]: canonical 2d points and root relative joints. + """ + if 'cuda' == input.device.type: + input = input.data.cpu().numpy() + elif 'cpu' == input.device.type: + input = input.data.numpy() + pose2d = input + + pose2d_canonical = self.canonicalize_2Ds( + pose2d, self.cfg.model.INPUT.FOCAL_LENGTH, + self.cfg.model.INPUT.CENTER) + pose2d_normalized = self.normalize_screen_coordinates( + pose2d, self.cfg.model.INPUT.RES_W, self.cfg.model.INPUT.RES_H) + pose2d_rr = pose2d_normalized + pose2d_rr[:, 1:] -= pose2d_rr[:, :1] + + # expand [NUM_FRAME, NUM_JOINTS, 2] to [1, NUM_FRAME, NUM_JOINTS, 2] + pose2d_rr = np.expand_dims( + np.pad( + pose2d_rr, + ((self.pad + self.causal_shift, self.pad - self.causal_shift), + (0, 0), (0, 0)), 'edge'), + axis=0) + pose2d_canonical = np.expand_dims( + np.pad( + pose2d_canonical, + ((self.pad + self.causal_shift, self.pad - self.causal_shift), + (0, 0), (0, 0)), 'edge'), + axis=0) + pose2d_rr = torch.from_numpy(pose2d_rr.astype(np.float32)) + pose2d_canonical = torch.from_numpy( + pose2d_canonical.astype(np.float32)) + + inputs_2d = pose2d_rr.clone() + if torch.cuda.is_available(): + inputs_2d = inputs_2d.cuda(non_blocking=True) + + # Positional model + if self.cfg.model.MODEL.USE_2D_OFFSETS: + inputs_2d[:, :, 0] = 0 + else: + inputs_2d[:, :, 1:] += inputs_2d[:, :, :1] + + return { + 'inputs_2d': inputs_2d, + 'pose2d_rr': pose2d_rr, + 'pose2d_canonical': pose2d_canonical + } + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """3D human pose estimation. + + Args: + input (Dict): + inputs_2d: [1, NUM_FRAME, NUM_JOINTS, 2] + pose2d_rr: [1, NUM_FRAME, NUM_JOINTS, 2] + pose2d_canonical: [1, NUM_FRAME, NUM_JOINTS, 2] + NUM_FRAME = max(receptive_filed + video_frame_number, video_frame_number) + + Returns: + Dict[str, Any]: + "camera_pose": Tensor, [1, NUM_FRAME, OUT_NUM_JOINTS, OUT_3D_FEATURE_DIM], + 3D human pose keypoints in camera frame. + "camera_traj": Tensor, [1, NUM_FRAME, 1, 3], + root keypoints coordinates in camere frame. + """ + inputs_2d = input['inputs_2d'] + pose2d_rr = input['pose2d_rr'] + pose2d_canonical = input['pose2d_canonical'] + with torch.no_grad(): + # predict 3D pose keypoints + predicted_3d_pos = self.model_pos(inputs_2d) + + # predict global trajectory + b1, w1, n1, d1 = inputs_2d.shape + + input_pose2d_abs = self.get_abs_2d_pts(w1, pose2d_rr, + pose2d_canonical) + b1, w1, n1, d1 = input_pose2d_abs.size() + b2, w2, n2, d2 = predicted_3d_pos.size() + + if torch.cuda.is_available(): + input_pose2d_abs = input_pose2d_abs.cuda(non_blocking=True) + + predicted_3d_traj = self.model_traj( + input_pose2d_abs.view(b1, w1, n1 * d1), + predicted_3d_pos.view(b2 * w2, n2 * d2)).view(b2, w2, -1, 3) + + predict_dict = { + KeypointsTypes.POSES_CAMERA: predicted_3d_pos, + KeypointsTypes.POSES_TRAJ: predicted_3d_traj + } + + return predict_dict + + def get_abs_2d_pts(self, input_video_frame_num, pose2d_rr, + pose2d_canonical): + pad = self.pad + w = input_video_frame_num - pad * 2 + + lst_pose2d_rr = [] + lst_pose2d_cannoical = [] + for i in range(pad, w + pad): + lst_pose2d_rr.append(pose2d_rr[:, i - pad:i + pad + 1]) + lst_pose2d_cannoical.append(pose2d_canonical[:, + i - pad:i + pad + 1]) + + input_pose2d_rr = torch.concat(lst_pose2d_cannoical, axis=0) + input_pose2d_cannoical = torch.concat(lst_pose2d_cannoical, axis=0) + + if self.cfg.model.MODEL.USE_CANONICAL_COORDS: + input_pose2d_abs = input_pose2d_cannoical.clone() + else: + input_pose2d_abs = input_pose2d_rr.clone() + input_pose2d_abs[:, :, 1:] += input_pose2d_abs[:, :, :1] + + return input_pose2d_abs + + def canonicalize_2Ds(self, pos2d, f, c): + cs = np.array([c[0], c[1]]).reshape(1, 1, 2) + fs = np.array([f[0], f[1]]).reshape(1, 1, 2) + canoical_2Ds = (pos2d - cs) / fs + return canoical_2Ds + + def normalize_screen_coordinates(self, X, w, h): + assert X.shape[-1] == 2 + + # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio + return X / w * 2 - [1, h / w] diff --git a/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py b/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py new file mode 100644 index 00000000..b7f0c4a3 --- /dev/null +++ b/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py @@ -0,0 +1,233 @@ +# The implementation is based on VideoPose3D, available at https://github.com/facebookresearch/VideoPose3D +import torch +import torch.nn as nn + + +class TemporalModelBase(nn.Module): + """ + Do not instantiate this class. + """ + + def __init__(self, num_joints_in, in_features, num_joints_out, + filter_widths, causal, dropout, channels): + super().__init__() + + # Validate input + for fw in filter_widths: + assert fw % 2 != 0, 'Only odd filter widths are supported' + + self.num_joints_in = num_joints_in + self.in_features = in_features + self.num_joints_out = num_joints_out + self.filter_widths = filter_widths + + self.drop = nn.Dropout(dropout) + self.relu = nn.ReLU(inplace=True) + + self.pad = [filter_widths[0] // 2] + self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1) + self.shrink = nn.Conv1d(channels, num_joints_out * 3, 1) + + def set_bn_momentum(self, momentum): + self.expand_bn.momentum = momentum + for bn in self.layers_bn: + bn.momentum = momentum + + def receptive_field(self): + """ + Return the total receptive field of this model as # of frames. + """ + frames = 0 + for f in self.pad: + frames += f + return 1 + 2 * frames + + def total_causal_shift(self): + """ + Return the asymmetric offset for sequence padding. + The returned value is typically 0 if causal convolutions are disabled, + otherwise it is half the receptive field. + """ + frames = self.causal_shift[0] + next_dilation = self.filter_widths[0] + for i in range(1, len(self.filter_widths)): + frames += self.causal_shift[i] * next_dilation + next_dilation *= self.filter_widths[i] + return frames + + def forward(self, x): + assert len(x.shape) == 4 + assert x.shape[-2] == self.num_joints_in + assert x.shape[-1] == self.in_features + + sz = x.shape[:3] + x = x.view(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) + + x = self._forward_blocks(x) + + x = x.permute(0, 2, 1) + x = x.view(sz[0], -1, self.num_joints_out, 3) + + return x + + +class TemporalModel(TemporalModelBase): + """ + Reference 3D pose estimation model with temporal convolutions. + This implementation can be used for all use-cases. + """ + + def __init__(self, + num_joints_in, + in_features, + num_joints_out, + filter_widths, + causal=False, + dropout=0.25, + channels=1024, + dense=False): + """ + Initialize this model. + + Arguments: + num_joints_in -- number of input joints (e.g. 17 for Human3.6M) + in_features -- number of input features for each joint (typically 2 for 2D input) + num_joints_out -- number of output joints (can be different than input) + filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field + causal -- use causal convolutions instead of symmetric convolutions (for real-time applications) + dropout -- dropout probability + channels -- number of convolution channels + dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment) + """ + super().__init__(num_joints_in, in_features, num_joints_out, + filter_widths, causal, dropout, channels) + + self.expand_conv = nn.Conv1d( + num_joints_in * in_features, + channels, + filter_widths[0], + bias=False) + + layers_conv = [] + layers_bn = [] + + self.causal_shift = [(filter_widths[0]) // 2 if causal else 0] + next_dilation = filter_widths[0] + for i in range(1, len(filter_widths)): + self.pad.append((filter_widths[i] - 1) * next_dilation // 2) + self.causal_shift.append((filter_widths[i] // 2 + * next_dilation) if causal else 0) + + layers_conv.append( + nn.Conv1d( + channels, + channels, + filter_widths[i] if not dense else (2 * self.pad[-1] + 1), + dilation=next_dilation if not dense else 1, + bias=False)) + layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) + layers_conv.append( + nn.Conv1d(channels, channels, 1, dilation=1, bias=False)) + layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) + + next_dilation *= filter_widths[i] + + self.layers_conv = nn.ModuleList(layers_conv) + self.layers_bn = nn.ModuleList(layers_bn) + + def _forward_blocks(self, x): + x = self.drop(self.relu(self.expand_bn(self.expand_conv(x)))) + for i in range(len(self.pad) - 1): + pad = self.pad[i + 1] + shift = self.causal_shift[i + 1] + res = x[:, :, pad + shift:x.shape[2] - pad + shift] + x = self.drop( + self.relu(self.layers_bn[2 * i](self.layers_conv[2 * i](x)))) + x = res + self.drop( + self.relu(self.layers_bn[2 * i + 1]( + self.layers_conv[2 * i + 1](x)))) + + x = self.shrink(x) + return x + + +# regression of the trajectory +class TransCan3Dkeys(nn.Module): + + def __init__(self, + in_channels=74, + num_features=256, + out_channels=44, + time_window=10, + num_blocks=2): + super().__init__() + self.in_channels = in_channels + self.num_features = num_features + self.out_channels = out_channels + self.num_blocks = num_blocks + self.time_window = time_window + + self.expand_bn = nn.BatchNorm1d(self.num_features, momentum=0.1) + self.conv1 = nn.Sequential( + nn.ReplicationPad1d(1), + nn.Conv1d( + self.in_channels, self.num_features, kernel_size=3, + bias=False), self.expand_bn, nn.ReLU(inplace=True), + nn.Dropout(p=0.25)) + self._make_blocks() + self.pad = nn.ReplicationPad1d(4) + self.relu = nn.ReLU(inplace=True) + self.drop = nn.Dropout(p=0.25) + self.reduce = nn.Conv1d( + self.num_features, self.num_features, kernel_size=self.time_window) + self.embedding_3d_1 = nn.Linear(in_channels // 2 * 3, 500) + self.embedding_3d_2 = nn.Linear(500, 500) + self.LReLU1 = nn.LeakyReLU() + self.LReLU2 = nn.LeakyReLU() + self.LReLU3 = nn.LeakyReLU() + self.out1 = nn.Linear(self.num_features + 500, self.num_features) + self.out2 = nn.Linear(self.num_features, self.out_channels) + + def _make_blocks(self): + layers_conv = [] + layers_bn = [] + for i in range(self.num_blocks): + layers_conv.append( + nn.Conv1d( + self.num_features, + self.num_features, + kernel_size=5, + bias=False, + dilation=2)) + layers_bn.append(nn.BatchNorm1d(self.num_features)) + self.layers_conv = nn.ModuleList(layers_conv) + self.layers_bn = nn.ModuleList(layers_bn) + + def set_bn_momentum(self, momentum): + self.expand_bn.momentum = momentum + for bn in self.layers_bn: + bn.momentum = momentum + + def forward(self, p2ds, p3d): + """ + Args: + x - (B x T x J x C) + """ + B, T, C = p2ds.shape + x = p2ds.permute((0, 2, 1)) + x = self.conv1(x) + for i in range(self.num_blocks): + pre = x + x = self.pad(x) + x = self.layers_conv[i](x) + x = self.layers_bn[i](x) + x = self.drop(self.relu(x)) + x = pre + x + x_2d = self.relu(self.reduce(x)) + x_2d = x_2d.view(B, -1) + x_3d = self.LReLU1(self.embedding_3d_1(p3d)) + x = torch.cat((x_2d, x_3d), 1) + x = self.LReLU3(self.out1(x)) + x = self.out2(x) + return x diff --git a/modelscope/models/cv/cartoon/__init__.py b/modelscope/models/cv/cartoon/__init__.py new file mode 100644 index 00000000..131f5cac --- /dev/null +++ b/modelscope/models/cv/cartoon/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .facelib.facer import FaceAna + from .mtcnn_pytorch.src.align_trans import (get_reference_facial_points, + warp_and_crop_face) + from .utils import (get_f5p, padTo16x, resize_size) + +else: + _import_structure = { + 'facelib.facer': ['FaceAna'], + 'mtcnn_pytorch.src.align_trans': + ['get_reference_facial_points', 'warp_and_crop_face'], + 'utils': ['get_f5p', 'padTo16x', 'resize_size'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/cartoon/facelib/LICENSE b/modelscope/models/cv/cartoon/facelib/LICENSE new file mode 100644 index 00000000..8e497ab8 --- /dev/null +++ b/modelscope/models/cv/cartoon/facelib/LICENSE @@ -0,0 +1,4 @@ + +Copyright (c) Peppa_Pig_Face_Engine + +https://github.com/610265158/Peppa_Pig_Face_Engine diff --git a/modelscope/models/cv/cartoon/facelib/LK/__init__.py b/modelscope/models/cv/cartoon/facelib/LK/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/cartoon/facelib/LK/lk.py b/modelscope/models/cv/cartoon/facelib/LK/lk.py new file mode 100644 index 00000000..6fd95ad6 --- /dev/null +++ b/modelscope/models/cv/cartoon/facelib/LK/lk.py @@ -0,0 +1,99 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + +import numpy as np + +from modelscope.models.cv.cartoon.facelib.config import config as cfg + + +class GroupTrack(): + + def __init__(self): + self.old_frame = None + self.previous_landmarks_set = None + self.with_landmark = True + self.thres = cfg.TRACE.pixel_thres + self.alpha = cfg.TRACE.smooth_landmark + self.iou_thres = cfg.TRACE.iou_thres + + def calculate(self, img, current_landmarks_set): + if self.previous_landmarks_set is None: + self.previous_landmarks_set = current_landmarks_set + result = current_landmarks_set + else: + previous_lm_num = self.previous_landmarks_set.shape[0] + if previous_lm_num == 0: + self.previous_landmarks_set = current_landmarks_set + result = current_landmarks_set + return result + else: + result = [] + for i in range(current_landmarks_set.shape[0]): + not_in_flag = True + for j in range(previous_lm_num): + if self.iou(current_landmarks_set[i], + self.previous_landmarks_set[j] + ) > self.iou_thres: + result.append( + self.smooth(current_landmarks_set[i], + self.previous_landmarks_set[j])) + not_in_flag = False + break + if not_in_flag: + result.append(current_landmarks_set[i]) + + result = np.array(result) + self.previous_landmarks_set = result + + return result + + def iou(self, p_set0, p_set1): + rec1 = [ + np.min(p_set0[:, 0]), + np.min(p_set0[:, 1]), + np.max(p_set0[:, 0]), + np.max(p_set0[:, 1]) + ] + rec2 = [ + np.min(p_set1[:, 0]), + np.min(p_set1[:, 1]), + np.max(p_set1[:, 0]), + np.max(p_set1[:, 1]) + ] + + # computing area of each rectangles + S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) + S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) + + # computing the sum_area + sum_area = S_rec1 + S_rec2 + + # find the each edge of intersect rectangle + x1 = max(rec1[0], rec2[0]) + y1 = max(rec1[1], rec2[1]) + x2 = min(rec1[2], rec2[2]) + y2 = min(rec1[3], rec2[3]) + + # judge if there is an intersect + intersect = max(0, x2 - x1) * max(0, y2 - y1) + + iou = intersect / (sum_area - intersect) + return iou + + def smooth(self, now_landmarks, previous_landmarks): + result = [] + for i in range(now_landmarks.shape[0]): + x = now_landmarks[i][0] - previous_landmarks[i][0] + y = now_landmarks[i][1] - previous_landmarks[i][1] + dis = np.sqrt(np.square(x) + np.square(y)) + if dis < self.thres: + result.append(previous_landmarks[i]) + else: + result.append( + self.do_moving_average(now_landmarks[i], + previous_landmarks[i])) + + return np.array(result) + + def do_moving_average(self, p_now, p_previous): + p = self.alpha * p_now + (1 - self.alpha) * p_previous + return p diff --git a/modelscope/models/cv/cartoon/facelib/__init__.py b/modelscope/models/cv/cartoon/facelib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/cartoon/facelib/config.py b/modelscope/models/cv/cartoon/facelib/config.py new file mode 100644 index 00000000..92b39db0 --- /dev/null +++ b/modelscope/models/cv/cartoon/facelib/config.py @@ -0,0 +1,25 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + +import os + +import numpy as np +from easydict import EasyDict as edict + +config = edict() +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +config.DETECT = edict() +config.DETECT.topk = 10 +config.DETECT.thres = 0.8 +config.DETECT.input_shape = (512, 512, 3) +config.KEYPOINTS = edict() +config.KEYPOINTS.p_num = 68 +config.KEYPOINTS.base_extend_range = [0.2, 0.3] +config.KEYPOINTS.input_shape = (160, 160, 3) +config.TRACE = edict() +config.TRACE.pixel_thres = 1 +config.TRACE.smooth_box = 0.3 +config.TRACE.smooth_landmark = 0.95 +config.TRACE.iou_thres = 0.5 +config.DATA = edict() +config.DATA.pixel_means = np.array([123., 116., 103.]) # RGB diff --git a/modelscope/models/cv/cartoon/facelib/face_detector.py b/modelscope/models/cv/cartoon/facelib/face_detector.py new file mode 100644 index 00000000..fa36d662 --- /dev/null +++ b/modelscope/models/cv/cartoon/facelib/face_detector.py @@ -0,0 +1,118 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + +import time + +import cv2 +import numpy as np +import tensorflow as tf + +from .config import config as cfg + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class FaceDetector: + + def __init__(self, dir): + + self.model_path = dir + '/detector.pb' + self.thres = cfg.DETECT.thres + self.input_shape = cfg.DETECT.input_shape + + self._graph = tf.Graph() + + with self._graph.as_default(): + self._graph, self._sess = self.init_model(self.model_path) + + self.input_image = tf.get_default_graph().get_tensor_by_name( + 'tower_0/images:0') + self.training = tf.get_default_graph().get_tensor_by_name( + 'training_flag:0') + self.output_ops = [ + tf.get_default_graph().get_tensor_by_name('tower_0/boxes:0'), + tf.get_default_graph().get_tensor_by_name('tower_0/scores:0'), + tf.get_default_graph().get_tensor_by_name( + 'tower_0/num_detections:0'), + ] + + def __call__(self, image): + + image, scale_x, scale_y = self.preprocess( + image, + target_width=self.input_shape[1], + target_height=self.input_shape[0]) + + image = np.expand_dims(image, 0) + + boxes, scores, num_boxes = self._sess.run( + self.output_ops, + feed_dict={ + self.input_image: image, + self.training: False + }) + + num_boxes = num_boxes[0] + boxes = boxes[0][:num_boxes] + + scores = scores[0][:num_boxes] + + to_keep = scores > self.thres + boxes = boxes[to_keep] + scores = scores[to_keep] + + y1 = self.input_shape[0] / scale_y + x1 = self.input_shape[1] / scale_x + y2 = self.input_shape[0] / scale_y + x2 = self.input_shape[1] / scale_x + scaler = np.array([y1, x1, y2, x2], dtype='float32') + boxes = boxes * scaler + + scores = np.expand_dims(scores, 0).reshape([-1, 1]) + + for i in range(boxes.shape[0]): + boxes[i] = np.array( + [boxes[i][1], boxes[i][0], boxes[i][3], boxes[i][2]]) + return np.concatenate([boxes, scores], axis=1) + + def preprocess(self, image, target_height, target_width, label=None): + + h, w, c = image.shape + + bimage = np.zeros( + shape=[target_height, target_width, c], + dtype=image.dtype) + np.array( + cfg.DATA.pixel_means, dtype=image.dtype) + long_side = max(h, w) + + scale_x = scale_y = target_height / long_side + + image = cv2.resize(image, None, fx=scale_x, fy=scale_y) + + h_, w_, _ = image.shape + bimage[:h_, :w_, :] = image + + return bimage, scale_x, scale_y + + def init_model(self, *args): + pb_path = args[0] + + def init_pb(model_path): + config = tf.ConfigProto() + config.gpu_options.per_process_gpu_memory_fraction = 0.2 + compute_graph = tf.Graph() + compute_graph.as_default() + sess = tf.Session(config=config) + with tf.gfile.GFile(model_path, 'rb') as fid: + graph_def = tf.GraphDef() + graph_def.ParseFromString(fid.read()) + tf.import_graph_def(graph_def, name='') + + return (compute_graph, sess) + + model = init_pb(pb_path) + + graph = model[0] + sess = model[1] + + return graph, sess diff --git a/modelscope/models/cv/cartoon/facelib/face_landmark.py b/modelscope/models/cv/cartoon/facelib/face_landmark.py new file mode 100644 index 00000000..3b7cc1b9 --- /dev/null +++ b/modelscope/models/cv/cartoon/facelib/face_landmark.py @@ -0,0 +1,156 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + +import cv2 +import numpy as np +import tensorflow as tf + +from .config import config as cfg + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class FaceLandmark: + + def __init__(self, dir): + self.model_path = dir + '/keypoints.pb' + self.min_face = 60 + self.keypoint_num = cfg.KEYPOINTS.p_num * 2 + + self._graph = tf.Graph() + + with self._graph.as_default(): + + self._graph, self._sess = self.init_model(self.model_path) + self.img_input = tf.get_default_graph().get_tensor_by_name( + 'tower_0/images:0') + self.embeddings = tf.get_default_graph().get_tensor_by_name( + 'tower_0/prediction:0') + self.training = tf.get_default_graph().get_tensor_by_name( + 'training_flag:0') + + self.landmark = self.embeddings[:, :self.keypoint_num] + self.headpose = self.embeddings[:, -7:-4] * 90. + self.state = tf.nn.sigmoid(self.embeddings[:, -4:]) + + def __call__(self, img, bboxes): + landmark_result = [] + state_result = [] + for i, bbox in enumerate(bboxes): + landmark, state = self._one_shot_run(img, bbox, i) + if landmark is not None: + landmark_result.append(landmark) + state_result.append(state) + return np.array(landmark_result), np.array(state_result) + + def simple_run(self, cropped_img): + with self._graph.as_default(): + + cropped_img = np.expand_dims(cropped_img, axis=0) + landmark, p, states = self._sess.run( + [self.landmark, self.headpose, self.state], + feed_dict={ + self.img_input: cropped_img, + self.training: False + }) + + return landmark, states + + def _one_shot_run(self, image, bbox, i): + + bbox_width = bbox[2] - bbox[0] + bbox_height = bbox[3] - bbox[1] + if (bbox_width <= self.min_face and bbox_height <= self.min_face): + return None, None + add = int(max(bbox_width, bbox_height)) + bimg = cv2.copyMakeBorder( + image, + add, + add, + add, + add, + borderType=cv2.BORDER_CONSTANT, + value=cfg.DATA.pixel_means) + bbox += add + + one_edge = (1 + 2 * cfg.KEYPOINTS.base_extend_range[0]) * bbox_width + center = [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2] + + bbox[0] = center[0] - one_edge // 2 + bbox[1] = center[1] - one_edge // 2 + bbox[2] = center[0] + one_edge // 2 + bbox[3] = center[1] + one_edge // 2 + + bbox = bbox.astype(np.int) + crop_image = bimg[bbox[1]:bbox[3], bbox[0]:bbox[2], :] + h, w, _ = crop_image.shape + crop_image = cv2.resize( + crop_image, + (cfg.KEYPOINTS.input_shape[1], cfg.KEYPOINTS.input_shape[0])) + crop_image = crop_image.astype(np.float32) + + keypoints, state = self.simple_run(crop_image) + + res = keypoints[0][:self.keypoint_num].reshape((-1, 2)) + res[:, 0] = res[:, 0] * w / cfg.KEYPOINTS.input_shape[1] + res[:, 1] = res[:, 1] * h / cfg.KEYPOINTS.input_shape[0] + + landmark = [] + for _index in range(res.shape[0]): + x_y = res[_index] + landmark.append([ + int(x_y[0] * cfg.KEYPOINTS.input_shape[0] + bbox[0] - add), + int(x_y[1] * cfg.KEYPOINTS.input_shape[1] + bbox[1] - add) + ]) + + landmark = np.array(landmark, np.float32) + + return landmark, state + + def init_model(self, *args): + + if len(args) == 1: + use_pb = True + pb_path = args[0] + else: + use_pb = False + meta_path = args[0] + restore_model_path = args[1] + + def ini_ckpt(): + graph = tf.Graph() + graph.as_default() + configProto = tf.ConfigProto() + configProto.gpu_options.allow_growth = True + sess = tf.Session(config=configProto) + # load_model(model_path, sess) + saver = tf.train.import_meta_graph(meta_path) + saver.restore(sess, restore_model_path) + + print('Model restred!') + return (graph, sess) + + def init_pb(model_path): + config = tf.ConfigProto() + config.gpu_options.per_process_gpu_memory_fraction = 0.2 + compute_graph = tf.Graph() + compute_graph.as_default() + sess = tf.Session(config=config) + with tf.gfile.GFile(model_path, 'rb') as fid: + graph_def = tf.GraphDef() + graph_def.ParseFromString(fid.read()) + tf.import_graph_def(graph_def, name='') + + # saver = tf.train.Saver(tf.global_variables()) + # saver.save(sess, save_path='./tmp.ckpt') + return (compute_graph, sess) + + if use_pb: + model = init_pb(pb_path) + else: + model = ini_ckpt() + + graph = model[0] + sess = model[1] + + return graph, sess diff --git a/modelscope/models/cv/cartoon/facelib/facer.py b/modelscope/models/cv/cartoon/facelib/facer.py new file mode 100644 index 00000000..c6f34e9c --- /dev/null +++ b/modelscope/models/cv/cartoon/facelib/facer.py @@ -0,0 +1,152 @@ +# The implementation is adopted from https://github.com/610265158/Peppa_Pig_Face_Engine + +import time + +import cv2 +import numpy as np + +from .config import config as cfg +from .face_detector import FaceDetector +from .face_landmark import FaceLandmark +from .LK.lk import GroupTrack + + +class FaceAna(): + ''' + by default the top3 facea sorted by area will be calculated for time reason + ''' + + def __init__(self, model_dir): + self.face_detector = FaceDetector(model_dir) + self.face_landmark = FaceLandmark(model_dir) + self.trace = GroupTrack() + + self.track_box = None + self.previous_image = None + self.previous_box = None + + self.diff_thres = 5 + self.top_k = cfg.DETECT.topk + self.iou_thres = cfg.TRACE.iou_thres + self.alpha = cfg.TRACE.smooth_box + + def run(self, image): + + boxes = self.face_detector(image) + + if boxes.shape[0] > self.top_k: + boxes = self.sort(boxes) + + boxes_return = np.array(boxes) + landmarks, states = self.face_landmark(image, boxes) + + if 1: + track = [] + for i in range(landmarks.shape[0]): + track.append([ + np.min(landmarks[i][:, 0]), + np.min(landmarks[i][:, 1]), + np.max(landmarks[i][:, 0]), + np.max(landmarks[i][:, 1]) + ]) + tmp_box = np.array(track) + + self.track_box = self.judge_boxs(boxes_return, tmp_box) + + self.track_box, landmarks = self.sort_res(self.track_box, landmarks) + return self.track_box, landmarks, states + + def sort_res(self, bboxes, points): + area = [] + for bbox in bboxes: + bbox_width = bbox[2] - bbox[0] + bbox_height = bbox[3] - bbox[1] + area.append(bbox_height * bbox_width) + + area = np.array(area) + picked = area.argsort()[::-1] + sorted_bboxes = [bboxes[x] for x in picked] + sorted_points = [points[x] for x in picked] + return np.array(sorted_bboxes), np.array(sorted_points) + + def diff_frames(self, previous_frame, image): + if previous_frame is None: + return True + else: + _diff = cv2.absdiff(previous_frame, image) + diff = np.sum( + _diff) / previous_frame.shape[0] / previous_frame.shape[1] / 3. + return diff > self.diff_thres + + def sort(self, bboxes): + if self.top_k > 100: + return bboxes + area = [] + for bbox in bboxes: + + bbox_width = bbox[2] - bbox[0] + bbox_height = bbox[3] - bbox[1] + area.append(bbox_height * bbox_width) + + area = np.array(area) + + picked = area.argsort()[-self.top_k:][::-1] + sorted_bboxes = [bboxes[x] for x in picked] + return np.array(sorted_bboxes) + + def judge_boxs(self, previuous_bboxs, now_bboxs): + + def iou(rec1, rec2): + + # computing area of each rectangles + S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) + S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) + + # computing the sum_area + sum_area = S_rec1 + S_rec2 + + # find the each edge of intersect rectangle + x1 = max(rec1[0], rec2[0]) + y1 = max(rec1[1], rec2[1]) + x2 = min(rec1[2], rec2[2]) + y2 = min(rec1[3], rec2[3]) + + # judge if there is an intersect + intersect = max(0, x2 - x1) * max(0, y2 - y1) + + return intersect / (sum_area - intersect) + + if previuous_bboxs is None: + return now_bboxs + + result = [] + + for i in range(now_bboxs.shape[0]): + contain = False + for j in range(previuous_bboxs.shape[0]): + if iou(now_bboxs[i], previuous_bboxs[j]) > self.iou_thres: + result.append( + self.smooth(now_bboxs[i], previuous_bboxs[j])) + contain = True + break + if not contain: + result.append(now_bboxs[i]) + + return np.array(result) + + def smooth(self, now_box, previous_box): + + return self.do_moving_average(now_box[:4], previous_box[:4]) + + def do_moving_average(self, p_now, p_previous): + p = self.alpha * p_now + (1 - self.alpha) * p_previous + return p + + def reset(self): + ''' + reset the previous info used foe tracking, + :return: + ''' + self.track_box = None + self.previous_image = None + self.previous_box = None diff --git a/modelscope/models/cv/cartoon/mtcnn_pytorch/LICENSE b/modelscope/models/cv/cartoon/mtcnn_pytorch/LICENSE new file mode 100644 index 00000000..9210f5b8 --- /dev/null +++ b/modelscope/models/cv/cartoon/mtcnn_pytorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Dan Antoshchenko + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/modelscope/models/cv/cartoon/mtcnn_pytorch/README.md b/modelscope/models/cv/cartoon/mtcnn_pytorch/README.md new file mode 100644 index 00000000..b748cf58 --- /dev/null +++ b/modelscope/models/cv/cartoon/mtcnn_pytorch/README.md @@ -0,0 +1,26 @@ +# MTCNN + +`pytorch` implementation of **inference stage** of face detection algorithm described in +[Joint Face Detection and Alignment using Multi-task Cascaded Convolutional Networks](https://arxiv.org/abs/1604.02878). + +## Example +![example of a face detection](images/example.png) + +## How to use it +Just download the repository and then do this +```python +from src import detect_faces +from PIL import Image + +image = Image.open('image.jpg') +bounding_boxes, landmarks = detect_faces(image) +``` +For examples see `test_on_images.ipynb`. + +## Requirements +* pytorch 0.2 +* Pillow, numpy + +## Credit +This implementation is heavily inspired by: +* [pangyupo/mxnet_mtcnn_face_detection](https://github.com/pangyupo/mxnet_mtcnn_face_detection) diff --git a/modelscope/models/cv/cartoon/mtcnn_pytorch/__init__.py b/modelscope/models/cv/cartoon/mtcnn_pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/cartoon/mtcnn_pytorch/src/__init__.py b/modelscope/models/cv/cartoon/mtcnn_pytorch/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/cartoon/mtcnn_pytorch/src/align_trans.py b/modelscope/models/cv/cartoon/mtcnn_pytorch/src/align_trans.py new file mode 100644 index 00000000..eb542042 --- /dev/null +++ b/modelscope/models/cv/cartoon/mtcnn_pytorch/src/align_trans.py @@ -0,0 +1,185 @@ +# The implementation is adopted from https://github.com/TreB1eN/InsightFace_Pytorch/tree/master/mtcnn_pytorch + +import cv2 +import numpy as np + +from .matlab_cp2tform import get_similarity_transform_for_cv2 + +# reference facial points, a list of coordinates (x,y) +dx = 1 +dy = 1 +REFERENCE_FACIAL_POINTS = [ + [30.29459953 + dx, 51.69630051 + dy], # left eye + [65.53179932 + dx, 51.50139999 + dy], # right eye + [48.02519989 + dx, 71.73660278 + dy], # nose + [33.54930115 + dx, 92.3655014 + dy], # left mouth + [62.72990036 + dx, 92.20410156 + dy] # right mouth +] + +DEFAULT_CROP_SIZE = (96, 112) + +global FACIAL_POINTS + + +class FaceWarpException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, + inner_padding_factor=0.0, + outer_padding=(0, 0), + default_square=False): + + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + h_crop = tmp_crop_size[0] + w_crop = tmp_crop_size[1] + if (output_size): + if (output_size[0] == h_crop and output_size[1] == w_crop): + return tmp_5pts + + if (inner_padding_factor == 0 and outer_padding == (0, 0)): + if output_size is None: + return tmp_5pts + else: + raise FaceWarpException( + 'No paddings to do, output_size must be None or {}'.format( + tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + factor = inner_padding_factor > 0 or outer_padding[0] > 0 + factor = factor or outer_padding[1] > 0 + if (factor and output_size is None): + output_size = tmp_crop_size * \ + (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + + cond1 = outer_padding[0] < output_size[0] + cond2 = outer_padding[1] < output_size[1] + if not (cond1 and cond2): + raise FaceWarpException('Not (outer_padding[0] < output_size[0]' + 'and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # 2) resize the padded inner region + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[ + 1] * tmp_crop_size[0]: + raise FaceWarpException( + 'Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + tmp_5pts = tmp_5pts * scale_factor + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], + [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def warp_and_crop_face(src_img, + facial_pts, + ratio=0.84, + reference_pts=None, + crop_size=(96, 112), + align_type='similarity' + '', + return_trans_inv=False): + + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points( + output_size, inner_padding_factor, outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + + factor = ratio + ref_pts = (ref_pts - 112 / 2) * factor + 112 / 2 + ref_pts *= crop_size[0] / 112. + + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException( + 'reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException( + 'facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException( + 'facial_pts and reference_pts must have the same shape') + + if align_type == 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts, ref_pts) + tfm_inv = cv2.getAffineTransform(ref_pts, src_pts) + + elif align_type == 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) + else: + tfm, tfm_inv = get_similarity_transform_for_cv2(src_pts, ref_pts) + + face_img = cv2.warpAffine( + src_img, + tfm, (crop_size[0], crop_size[1]), + borderValue=(255, 255, 255)) + + if return_trans_inv: + return face_img, tfm_inv + else: + return face_img diff --git a/modelscope/models/cv/cartoon/mtcnn_pytorch/src/matlab_cp2tform.py b/modelscope/models/cv/cartoon/mtcnn_pytorch/src/matlab_cp2tform.py new file mode 100644 index 00000000..ea9fbacf --- /dev/null +++ b/modelscope/models/cv/cartoon/mtcnn_pytorch/src/matlab_cp2tform.py @@ -0,0 +1,335 @@ +# The implementation is adopted from https://github.com/TreB1eN/InsightFace_Pytorch/tree/master/mtcnn_pytorch + +import numpy as np +from numpy.linalg import inv, lstsq +from numpy.linalg import matrix_rank as rank +from numpy.linalg import norm + + +class MatlabCp2tormException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def tformfwd(trans, uv): + """ + Function: + ---------- + apply affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of transformed coordinates (x, y) + """ + uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy = np.dot(uv, trans) + xy = xy[:, 0:-1] + return xy + + +def tforminv(trans, uv): + """ + Function: + ---------- + apply the inverse of affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of inverse-transformed coordinates (x, y) + """ + Tinv = inv(trans) + xy = tformfwd(Tinv, uv) + return xy + + +def findNonreflectiveSimilarity(uv, xy, options=None): + + options = {'K': 2} + + K = options['K'] + M = xy.shape[0] + x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + # print('--->x, y:\n', x, y + + tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) + tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) + X = np.vstack((tmp1, tmp2)) + # print('--->X.shape: ', X.shape + # print('X:\n', X + + u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + U = np.vstack((u, v)) + # print('--->U.shape: ', U.shape + # print('U:\n', U + + # We know that X * r = U + if rank(X) >= 2 * K: + r, _, _, _ = lstsq(X, U) + r = np.squeeze(r) + else: + raise Exception('cp2tform:twoUniquePointsReq') + + # print('--->r:\n', r + + sc = r[0] + ss = r[1] + tx = r[2] + ty = r[3] + + Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) + + # print('--->Tinv:\n', Tinv + + T = inv(Tinv) + # print('--->T:\n', T + + T[:, 2] = np.array([0, 0, 1]) + + return T, Tinv + + +def findSimilarity(uv, xy, options=None): + + options = {'K': 2} + + # uv = np.array(uv) + # xy = np.array(xy) + + # Solve for trans1 + trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) + + # Solve for trans2 + + # manually reflect the xy data across the Y-axis + xyR = xy + xyR[:, 0] = -1 * xyR[:, 0] + + trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) + + # manually reflect the tform to undo the reflection done on xyR + TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + trans2 = np.dot(trans2r, TreflectY) + + # Figure out if trans1 or trans2 is better + xy1 = tformfwd(trans1, uv) + norm1 = norm(xy1 - xy) + + xy2 = tformfwd(trans2, uv) + norm2 = norm(xy2 - xy) + + if norm1 <= norm2: + return trans1, trans1_inv + else: + trans2_inv = inv(trans2) + return trans2, trans2_inv + + +def get_similarity_transform(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'trans': + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y, 1] = [u, v, 1] * trans + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + @reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + trans_inv: 3x3 np.array + inverse of trans, transform matrix from xy to uv + """ + + if reflective: + trans, trans_inv = findSimilarity(src_pts, dst_pts) + else: + trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) + + return trans, trans_inv + + +def cvt_tform_mat_for_cv2(trans): + """ + Function: + ---------- + Convert Transform Matrix 'trans' into 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + cv2_trans = trans[:, 0:2].T + + return cv2_trans + + +def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) + cv2_trans = cvt_tform_mat_for_cv2(trans) + cv2_trans_inv = cvt_tform_mat_for_cv2(trans_inv) + + return cv2_trans, cv2_trans_inv + + +if __name__ == '__main__': + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + # In Matlab, run: + # + # uv = [u'; v']; + # xy = [x'; y']; + # tform_sim=cp2tform(uv,xy,'similarity'); + # + # trans = tform_sim.tdata.T + # ans = + # -0.0764 -1.6190 0 + # 1.6190 -0.0764 0 + # -3.2156 0.0290 1.0000 + # trans_inv = tform_sim.tdata.Tinv + # ans = + # + # -0.0291 0.6163 0 + # -0.6163 -0.0291 0 + # -0.0756 1.9826 1.0000 + # xy_m=tformfwd(tform_sim, u,v) + # + # xy_m = + # + # -3.2156 0.0290 + # 1.1833 -9.9143 + # 5.0323 2.8853 + # uv_m=tforminv(tform_sim, x,y) + # + # uv_m = + # + # 0.5698 1.3953 + # 6.0872 2.2733 + # -2.6570 4.3314 + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + uv = np.array((u, v)).T + xy = np.array((x, y)).T + + print('\n--->uv:') + print(uv) + print('\n--->xy:') + print(xy) + + trans, trans_inv = get_similarity_transform(uv, xy) + + print('\n--->trans matrix:') + print(trans) + + print('\n--->trans_inv matrix:') + print(trans_inv) + + print('\n---> apply transform to uv') + print('\nxy_m = uv_augmented * trans') + uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy_m = np.dot(uv_aug, trans) + print(xy_m) + + print('\nxy_m = tformfwd(trans, uv)') + xy_m = tformfwd(trans, uv) + print(xy_m) + + print('\n---> apply inverse transform to xy') + print('\nuv_m = xy_augmented * trans_inv') + xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) + uv_m = np.dot(xy_aug, trans_inv) + print(uv_m) + + print('\nuv_m = tformfwd(trans_inv, xy)') + uv_m = tformfwd(trans_inv, xy) + print(uv_m) + + uv_m = tforminv(trans, xy) + print('\nuv_m = tforminv(trans, xy)') + print(uv_m) diff --git a/modelscope/models/cv/cartoon/utils.py b/modelscope/models/cv/cartoon/utils.py new file mode 100644 index 00000000..59b4e879 --- /dev/null +++ b/modelscope/models/cv/cartoon/utils.py @@ -0,0 +1,93 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +import cv2 +import numpy as np + + +def resize_size(image, size=720): + h, w, c = np.shape(image) + if min(h, w) > size: + if h > w: + h, w = int(size * h / w), size + else: + h, w = size, int(size * w / h) + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) + return image + + +def padTo16x(image): + h, w, c = np.shape(image) + if h % 16 == 0 and w % 16 == 0: + return image, h, w + nh, nw = (h // 16 + 1) * 16, (w // 16 + 1) * 16 + img_new = np.ones((nh, nw, 3), np.uint8) * 255 + img_new[:h, :w, :] = image + + return img_new, h, w + + +def get_f5p(landmarks, np_img): + eye_left = find_pupil(landmarks[36:41], np_img) + eye_right = find_pupil(landmarks[42:47], np_img) + if eye_left is None or eye_right is None: + print('cannot find 5 points with find_puil, used mean instead.!') + eye_left = landmarks[36:41].mean(axis=0) + eye_right = landmarks[42:47].mean(axis=0) + nose = landmarks[30] + mouth_left = landmarks[48] + mouth_right = landmarks[54] + f5p = [[eye_left[0], eye_left[1]], [eye_right[0], eye_right[1]], + [nose[0], nose[1]], [mouth_left[0], mouth_left[1]], + [mouth_right[0], mouth_right[1]]] + return f5p + + +def find_pupil(landmarks, np_img): + h, w, _ = np_img.shape + xmax = int(landmarks[:, 0].max()) + xmin = int(landmarks[:, 0].min()) + ymax = int(landmarks[:, 1].max()) + ymin = int(landmarks[:, 1].min()) + + if ymin >= ymax or xmin >= xmax or ymin < 0 or xmin < 0 or ymax > h or xmax > w: + return None + eye_img_bgr = np_img[ymin:ymax, xmin:xmax, :] + eye_img = cv2.cvtColor(eye_img_bgr, cv2.COLOR_BGR2GRAY) + eye_img = cv2.equalizeHist(eye_img) + n_marks = landmarks - np.array([xmin, ymin]).reshape([1, 2]) + eye_mask = cv2.fillConvexPoly( + np.zeros_like(eye_img), n_marks.astype(np.int32), 1) + ret, thresh = cv2.threshold(eye_img, 100, 255, + cv2.THRESH_BINARY | cv2.THRESH_OTSU) + thresh = (1 - thresh / 255.) * eye_mask + cnt = 0 + xm = [] + ym = [] + for i in range(thresh.shape[0]): + for j in range(thresh.shape[1]): + if thresh[i, j] > 0.5: + xm.append(j) + ym.append(i) + cnt += 1 + if cnt != 0: + xm.sort() + ym.sort() + xm = xm[cnt // 2] + ym = ym[cnt // 2] + else: + xm = thresh.shape[1] / 2 + ym = thresh.shape[0] / 2 + + return xm + xmin, ym + ymin + + +def all_file(file_dir): + L = [] + for root, dirs, files in os.walk(file_dir): + for file in files: + extend = os.path.splitext(file)[1] + if extend == '.png' or extend == '.jpg' or extend == '.jpeg': + L.append(os.path.join(root, file)) + return L diff --git a/modelscope/models/cv/cmdssl_video_embedding/__init__.py b/modelscope/models/cv/cmdssl_video_embedding/__init__.py new file mode 100644 index 00000000..5bc67b63 --- /dev/null +++ b/modelscope/models/cv/cmdssl_video_embedding/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .c3d import C3D + from .resnet2p1d import resnet26_2p1d + from .resnet3d import resnet26_3d + +else: + _import_structure = { + 'c3d': ['C3D'], + 'resnet2p1d': ['resnet26_2p1d'], + 'resnet3d': ['resnet26_3d'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/cmdssl_video_embedding/c3d.py b/modelscope/models/cv/cmdssl_video_embedding/c3d.py new file mode 100644 index 00000000..53dd05a1 --- /dev/null +++ b/modelscope/models/cv/cmdssl_video_embedding/c3d.py @@ -0,0 +1,129 @@ +# Copyright 2022 Davide Abati. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +# The implementation here is modified based on c3d-pytorch, +# originally MIT License, Copyright (c) 2022 Davide Abati, +# and publicly available at https://github.com/DavideA/c3d-pytorch +""" C3D Model Architecture.""" + +import torch +import torch.nn as nn + + +def conv3x3x3(in_planes, out_planes, stride=1): + return nn.Conv3d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + + +class C3D(nn.Module): + + def __init__(self, + num_classes=1000, + dropout=0.5, + inplanes=3, + norm_layer=None, + last_pool=True): + super(C3D, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if not last_pool and num_classes is not None: + raise ValueError('num_classes should be None when last_pool=False') + + self.conv1 = conv3x3x3(inplanes, 64) + self.bn1 = norm_layer(64) + self.relu1 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) + + self.conv2 = conv3x3x3(64, 128) + self.bn2 = norm_layer(128) + self.relu2 = nn.ReLU(inplace=True) + self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) + + self.conv3a = conv3x3x3(128, 256) + self.bn3a = norm_layer(256) + self.relu3a = nn.ReLU(inplace=True) + + self.conv3b = conv3x3x3(256, 256) + self.bn3b = norm_layer(256) + self.relu3b = nn.ReLU(inplace=True) + self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) + + self.conv4a = conv3x3x3(256, 512) + self.bn4a = norm_layer(512) + self.relu4a = nn.ReLU(inplace=True) + + self.conv4b = conv3x3x3(512, 512) + self.bn4b = norm_layer(512) + self.relu4b = nn.ReLU(inplace=True) + self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) + + self.conv5a = conv3x3x3(512, 512) + self.bn5a = norm_layer(512) + self.relu5a = nn.ReLU(inplace=True) + + self.conv5b = conv3x3x3(512, 512) + self.bn5b = norm_layer(512) + self.relu5b = nn.ReLU(inplace=True) + self.pool5 = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None + + if num_classes is None: + self.dropout = None + self.fc = None + else: + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(512, num_classes) + self.out_planes = 512 + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.pool1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + x = self.pool2(x) + + x = self.conv3a(x) + x = self.bn3a(x) + x = self.relu3a(x) + + x = self.conv3b(x) + x = self.bn3b(x) + x = self.relu3b(x) + x = self.pool3(x) + + x = self.conv4a(x) + x = self.bn4a(x) + x = self.relu4a(x) + + x = self.conv4b(x) + x = self.bn4b(x) + x = self.relu4b(x) + x = self.pool4(x) + + x = self.conv5a(x) + x = self.bn5a(x) + x = self.relu5a(x) + + x = self.conv5b(x) + x = self.bn5b(x) + x = self.relu5b(x) + + if self.pool5: + x = self.pool5(x) + x = torch.flatten(x, 1) + if self.dropout and self.fc: + x = self.dropout(x) + x = self.fc(x) + + return x diff --git a/modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py b/modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py new file mode 100644 index 00000000..b49069d1 --- /dev/null +++ b/modelscope/models/cv/cmdssl_video_embedding/resnet2p1d.py @@ -0,0 +1,347 @@ +# Copyright (c) 2022 Kensho Hara. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +# The implementation here is modified based on 3D-ResNets-PyTorch, +# originally MIT License, Copyright (c) 2022 Kensho Hara, +# and publicly available at https://github.com/kenshohara/3D-ResNets-PyTorch/blob/master/models/resnet2p1d.py +""" ResNet2plus1d Model Architecture.""" + +import torch +import torch.nn as nn + + +def conv1x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + return nn.Conv3d( + in_planes, + out_planes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, dilation, dilation), + groups=groups, + bias=False, + dilation=(1, dilation, dilation)) + + +def conv3x1x1(in_planes, out_planes, stride=1, groups=1, dilation=1): + return nn.Conv3d( + in_planes, + out_planes, + kernel_size=(3, 1, 1), + stride=(stride, 1, 1), + padding=(dilation, 0, 0), + groups=groups, + bias=False, + dilation=(dilation, 1, 1)) + + +def conv1x1x1(in_planes, out_planes, stride=1): + return nn.Conv3d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + + midplanes1 = (inplanes * planes * 3 * 3 * 3) // ( + inplanes * 3 * 3 + planes * 3) + self.conv1_s = conv1x3x3(inplanes, midplanes1, stride) + self.bn1_s = norm_layer(midplanes1) + self.conv1_t = conv3x1x1(midplanes1, planes, stride) + self.bn1_t = norm_layer(planes) + + midplanes2 = (planes * planes * 3 * 3 * 3) // ( + planes * 3 * 3 + planes * 3) + self.conv2_s = conv1x3x3(planes, midplanes2) + self.bn2_s = norm_layer(midplanes2) + self.conv2_t = conv3x1x1(midplanes2, planes) + self.bn2_t = norm_layer(planes) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1_s(x) + out = self.bn1_s(out) + out = self.relu(out) + out = self.conv1_t(out) + out = self.bn1_t(out) + out = self.relu(out) + + out = self.conv2_s(out) + out = self.bn2_s(out) + out = self.relu(out) + out = self.conv2_t(out) + out = self.bn2_t(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + width = int(planes * (base_width / 64.)) * groups + + self.conv1 = conv1x1x1(inplanes, width) + self.bn1 = norm_layer(width) + + midplanes = (width * width * 3 * 3 * 3) // (width * 3 * 3 + width * 3) + self.conv2_s = conv1x3x3(width, midplanes, stride, groups, dilation) + self.bn2_s = norm_layer(midplanes) + self.conv2_t = conv3x1x1(midplanes, width, stride, groups, dilation) + self.bn2_t = norm_layer(width) + + self.conv3 = conv1x1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2_s(out) + out = self.bn2_s(out) + out = self.relu(out) + out = self.conv2_t(out) + out = self.bn2_t(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet2p1d(nn.Module): + + def __init__(self, + block, + layers, + num_classes=None, + zero_init_residual=True, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + dropout=0.5, + inplanes=3, + first_stride=2, + norm_layer=None, + last_pool=True): + super(ResNet2p1d, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if not last_pool and num_classes is not None: + raise ValueError('num_classes should be None when last_pool=False') + self._norm_layer = norm_layer + self.first_stride = first_stride + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + + midplanes = (3 * self.inplanes * 3 * 7 * 7) // (3 * 7 * 7 + + self.inplanes * 3) + self.conv1_s = nn.Conv3d( + inplanes, + midplanes, + kernel_size=(1, 7, 7), + stride=(1, first_stride, first_stride), + padding=(0, 3, 3), + bias=False) + self.bn1_s = norm_layer(midplanes) + self.conv1_t = nn.Conv3d( + midplanes, + self.inplanes, + kernel_size=(3, 1, 1), + stride=(1, 1, 1), + padding=(1, 0, 0), + bias=False) + self.bn1_t = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None + if num_classes is None: + self.dropout = None + self.fc = None + else: + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(512 * block.expansion, num_classes) + self.out_planes = 512 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2_t.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion)) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1_s(x) + x = self.bn1_s(x) + x = self.relu(x) + x = self.conv1_t(x) + x = self.bn1_t(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.avgpool: + x = self.avgpool(x) + x = torch.flatten(x, 1) + if self.dropout and self.fc: + x = self.dropout(x) + x = self.fc(x) + + return x + + +def resnet10_2p1d(**kwargs): + return ResNet2p1d(BasicBlock, [1, 1, 1, 1], **kwargs) + + +def resnet18_2p1d(**kwargs): + return ResNet2p1d(BasicBlock, [2, 2, 2, 2], **kwargs) + + +def resnet26_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [2, 2, 2, 2], **kwargs) + + +def resnet34_2p1d(**kwargs): + return ResNet2p1d(BasicBlock, [3, 4, 6, 3], **kwargs) + + +def resnet50_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [3, 4, 6, 3], **kwargs) + + +def resnet101_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [3, 4, 23, 3], **kwargs) + + +def resnet152_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [3, 8, 36, 3], **kwargs) + + +def resnet200_2p1d(**kwargs): + return ResNet2p1d(Bottleneck, [3, 24, 36, 3], **kwargs) diff --git a/modelscope/models/cv/cmdssl_video_embedding/resnet3d.py b/modelscope/models/cv/cmdssl_video_embedding/resnet3d.py new file mode 100644 index 00000000..dddba06f --- /dev/null +++ b/modelscope/models/cv/cmdssl_video_embedding/resnet3d.py @@ -0,0 +1,292 @@ +# Copyright (c) 2022 Kensho Hara. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +# The implementation here is modified based on 3D-ResNets-PyTorch, +# originally MIT License, Copyright (c) 2022 Kensho Hara, +# and publicly available at https://github.com/kenshohara/3D-ResNets-PyTorch/blob/master/models/resnet.py +""" ResNet3D Model Architecture.""" + +import torch +import torch.nn as nn + + +def conv3x3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + return nn.Conv3d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1x1(in_planes, out_planes, stride=1): + return nn.Conv3d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + self.conv1 = conv3x3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + width = int(planes * (base_width / 64.)) * groups + self.conv1 = conv1x1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet3d(nn.Module): + + def __init__(self, + block, + layers, + num_classes=1000, + zero_init_residual=True, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + dropout=0.5, + inplanes=3, + first_stride=2, + norm_layer=None, + last_pool=True): + super(ResNet3d, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if not last_pool and num_classes is not None: + raise ValueError('num_classes should be None when last_pool=False') + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv3d( + inplanes, + self.inplanes, + kernel_size=(3, 7, 7), + stride=(1, first_stride, first_stride), + padding=(1, 3, 3), + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) if last_pool else None + if num_classes is None: + self.dropout = None + self.fc = None + else: + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(512 * block.expansion, num_classes) + self.out_planes = 512 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion)) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.avgpool: + x = self.avgpool(x) + x = torch.flatten(x, 1) + if self.dropout and self.fc: + x = self.dropout(x) + x = self.fc(x) + + return x + + +def resnet10_3d(**kwargs): + return ResNet3d(BasicBlock, [1, 1, 1, 1], **kwargs) + + +def resnet18_3d(**kwargs): + return ResNet3d(BasicBlock, [2, 2, 2, 2], **kwargs) + + +def resnet26_3d(**kwargs): + return ResNet3d(Bottleneck, [2, 2, 2, 2], **kwargs) + + +def resnet34_3d(**kwargs): + return ResNet3d(BasicBlock, [3, 4, 6, 3], **kwargs) + + +def resnet50_3d(**kwargs): + return ResNet3d(Bottleneck, [3, 4, 6, 3], **kwargs) + + +def resnet101_3d(**kwargs): + return ResNet3d(Bottleneck, [3, 4, 23, 3], **kwargs) + + +def resnet152_3d(**kwargs): + return ResNet3d(Bottleneck, [3, 8, 36, 3], **kwargs) + + +def resnet200_3d(**kwargs): + return ResNet3d(Bottleneck, [3, 24, 36, 3], **kwargs) diff --git a/modelscope/models/cv/crowd_counting/__init__.py b/modelscope/models/cv/crowd_counting/__init__.py new file mode 100644 index 00000000..b5eeb937 --- /dev/null +++ b/modelscope/models/cv/crowd_counting/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .cc_model import HRNetCrowdCounting + +else: + _import_structure = { + 'cc_model': ['HRNetCrowdCounting'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/crowd_counting/cc_model.py b/modelscope/models/cv/crowd_counting/cc_model.py new file mode 100644 index 00000000..16fbc261 --- /dev/null +++ b/modelscope/models/cv/crowd_counting/cc_model.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + Tasks.crowd_counting, module_name=Models.crowd_counting) +class HRNetCrowdCounting(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + super().__init__(model_dir, **kwargs) + + from .hrnet_aspp_relu import HighResolutionNet as HRNet_aspp_relu + + domain_center_model = os.path.join( + model_dir, 'average_clip_domain_center_54.97.npz') + net = HRNet_aspp_relu( + attn_weight=1.0, + fix_domain=0, + domain_center_model=domain_center_model) + net.load_state_dict( + torch.load( + os.path.join(model_dir, 'DCANet_final.pth'), + map_location='cpu')) + self.model = net + + def forward(self, inputs): + return self.model(inputs) diff --git a/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py b/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py new file mode 100644 index 00000000..0d1bd3ca --- /dev/null +++ b/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py @@ -0,0 +1,638 @@ +""" +Copyright (c) Microsoft +Licensed under the MIT License. +Written by Bin Xiao (Bin.Xiao@microsoft.com) +Modified by Ke Sun (sunk@mail.ustc.edu.cn) +https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py +""" + +import functools +import logging +import os + +import numpy as np +import torch +import torch._utils +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger + +BN_MOMENTUM = 0.01 # 0.01 for seg +logger = get_logger() + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d( + planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + + def __init__(self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, + num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.info(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.info(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.info(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d( + num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + nn.BatchNorm2d( + num_inchannels[i], momentum=BN_MOMENTUM), + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d( + num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d( + num_outchannels_conv3x3, + momentum=BN_MOMENTUM), nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + +class HighResolutionNet(nn.Module): + + def __init__(self, + leaky_relu=False, + attn_weight=1, + fix_domain=1, + domain_center_model='', + **kwargs): + super(HighResolutionNet, self).__init__() + + self.criterion_attn = torch.nn.MSELoss(reduction='sum') + self.domain_center_model = domain_center_model + self.attn_weight = attn_weight + self.fix_domain = fix_domain + self.cosine = 1 + + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + 64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + num_channels = 64 + block = blocks_dict['BOTTLENECK'] + num_blocks = 4 + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + # -- stage 2 + self.stage2_cfg = {} + self.stage2_cfg['NUM_MODULES'] = 1 + self.stage2_cfg['NUM_BRANCHES'] = 2 + self.stage2_cfg['BLOCK'] = 'BASIC' + self.stage2_cfg['NUM_BLOCKS'] = [4, 4] + self.stage2_cfg['NUM_CHANNELS'] = [40, 80] + self.stage2_cfg['FUSE_METHOD'] = 'SUM' + + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer([stage1_out_channel], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # -- stage 3 + self.stage3_cfg = {} + self.stage3_cfg['NUM_MODULES'] = 4 + self.stage3_cfg['NUM_BRANCHES'] = 3 + self.stage3_cfg['BLOCK'] = 'BASIC' + self.stage3_cfg['NUM_BLOCKS'] = [4, 4, 4] + self.stage3_cfg['NUM_CHANNELS'] = [40, 80, 160] + self.stage3_cfg['FUSE_METHOD'] = 'SUM' + + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + last_inp_channels = np.int(np.sum(pre_stage_channels)) + 256 + self.redc_layer = nn.Sequential( + nn.Conv2d( + in_channels=last_inp_channels, + out_channels=128, + kernel_size=3, + stride=1, + padding=1), + nn.BatchNorm2d(128, momentum=BN_MOMENTUM), + nn.ReLU(True), + ) + + self.aspp = nn.ModuleList(aspp(in_channel=128)) + + # additional layers specfic for Phase 3 + self.pred_conv = nn.Conv2d(128, 512, 3, padding=1) + self.pred_bn = nn.BatchNorm2d(512) + self.GAP = nn.AdaptiveAvgPool2d(1) + + # Specially for hidden domain + # Set the domain for learnable parameters + domain_center_src = np.load(self.domain_center_model) + G_SHA = torch.from_numpy(domain_center_src['G_SHA']).view(1, -1, 1, 1) + G_SHB = torch.from_numpy(domain_center_src['G_SHB']).view(1, -1, 1, 1) + G_QNRF = torch.from_numpy(domain_center_src['G_QNRF']).view( + 1, -1, 1, 1) + + self.n_domain = 3 + + self.G_all = torch.cat( + [G_SHA.clone(), G_SHB.clone(), + G_QNRF.clone()], dim=0) + + self.G_all = nn.Parameter(self.G_all) + + self.last_layer = nn.Sequential( + nn.Conv2d( + in_channels=128, + out_channels=64, + kernel_size=3, + stride=1, + padding=1), + nn.BatchNorm2d(64, momentum=BN_MOMENTUM), + nn.ReLU(True), + nn.Conv2d( + in_channels=64, + out_channels=32, + kernel_size=3, + stride=1, + padding=1), + nn.BatchNorm2d(32, momentum=BN_MOMENTUM), + nn.ReLU(True), + nn.Conv2d( + in_channels=32, + out_channels=1, + kernel_size=1, + stride=1, + padding=0), + ) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d( + num_channels_cur_layer[i], + momentum=BN_MOMENTUM), nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, + layer_config, + num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, block, num_blocks, + num_inchannels, num_channels, fuse_method, + reset_multi_scale_output)) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + x_head_1 = x + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + + x = self.stage3(x_list) + + # Replace the classification heaeder with custom setting + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x = torch.cat([x[0], x1, x2, x_head_1], 1) + # first, reduce the channel down + x = self.redc_layer(x) + + pred_attn = self.GAP(F.relu_(self.pred_bn(self.pred_conv(x)))) + pred_attn = F.softmax(pred_attn, dim=1) + pred_attn_list = torch.chunk(pred_attn, 4, dim=1) + + aspp_out = [] + for k, v in enumerate(self.aspp): + if k % 2 == 0: + aspp_out.append(self.aspp[k + 1](v(x))) + else: + continue + # Using Aspp add, and relu inside + for i in range(4): + x = x + F.relu_(aspp_out[i] * 0.25) * pred_attn_list[i] + + bz = x.size(0) + # -- Besides, we also need to let the prediction attention be close to visable domain + # -- Calculate the domain distance and get the weights + # - First, detach domains + G_all_d = self.G_all.detach() # use detached G_all for calulcating + pred_attn_d = pred_attn.detach().view(bz, 512, 1, 1) + + if self.cosine == 1: + G_A, G_B, G_Q = torch.chunk(G_all_d, self.n_domain, dim=0) + + cos_dis_A = F.cosine_similarity(pred_attn_d, G_A, dim=1).view(-1) + cos_dis_B = F.cosine_similarity(pred_attn_d, G_B, dim=1).view(-1) + cos_dis_Q = F.cosine_similarity(pred_attn_d, G_Q, dim=1).view(-1) + + cos_dis_all = torch.stack([cos_dis_A, cos_dis_B, + cos_dis_Q]).view(bz, -1) # bz*3 + + cos_dis_all = F.softmax(cos_dis_all, dim=1) + + target_attn = cos_dis_all.view(bz, self.n_domain, 1, 1, 1).expand( + bz, self.n_domain, 512, 1, 1) * self.G_all.view( + 1, self.n_domain, 512, 1, 1).expand( + bz, self.n_domain, 512, 1, 1) + target_attn = torch.sum( + target_attn, dim=1, keepdim=False) # bz * 512 * 1 * 1 + + if self.fix_domain: + target_attn = target_attn.detach() + + else: + raise ValueError('Have not implemented not cosine distance yet') + + x = self.last_layer(x) + x = F.relu_(x) + + x = F.interpolate( + x, size=(x0_h * 2, x0_w * 2), mode='bilinear', align_corners=False) + + return x, pred_attn, target_attn + + def init_weights( + self, + pretrained='', + ): + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + logger.info(f'=> loading pretrained model {pretrained}') + model_dict = self.state_dict() + pretrained_dict = { + k: v + for k, v in pretrained_dict.items() if k in model_dict.keys() + } + for k, _ in pretrained_dict.items(): + logger.info(f'=> loading {k} pretrained model {pretrained}') + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + else: + assert 1 == 2 + + +def aspp(aspp_num=4, aspp_stride=2, in_channel=512, use_bn=True): + aspp_list = [] + for i in range(aspp_num): + pad = (i + 1) * aspp_stride + dilate = pad + conv_aspp = nn.Conv2d( + in_channel, in_channel, 3, padding=pad, dilation=dilate) + aspp_list.append(conv_aspp) + if use_bn: + aspp_list.append(nn.BatchNorm2d(in_channel)) + + return aspp_list diff --git a/modelscope/models/cv/easycv_base.py b/modelscope/models/cv/easycv_base.py new file mode 100644 index 00000000..7bc35e84 --- /dev/null +++ b/modelscope/models/cv/easycv_base.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.models.base import BaseModel +from easycv.utils.ms_utils import EasyCVMeta + +from modelscope.models.base import TorchModel + + +class EasyCVBaseModel(BaseModel, TorchModel): + """Base model for EasyCV.""" + + def __init__(self, model_dir=None, args=(), kwargs={}): + kwargs.pop(EasyCVMeta.ARCH, None) # pop useless keys + BaseModel.__init__(self) + TorchModel.__init__(self, model_dir=model_dir) + + def forward(self, img, mode='train', **kwargs): + if self.training: + losses = self.forward_train(img, **kwargs) + loss, log_vars = self._parse_losses(losses) + return dict(loss=loss, log_vars=log_vars) + else: + return self.forward_test(img, **kwargs) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) diff --git a/modelscope/models/cv/face_2d_keypoints/__init__.py b/modelscope/models/cv/face_2d_keypoints/__init__.py new file mode 100644 index 00000000..636ba0f4 --- /dev/null +++ b/modelscope/models/cv/face_2d_keypoints/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .face_2d_keypoints_align import Face2DKeypoints + +else: + _import_structure = {'face_2d_keypoints_align': ['Face2DKeypoints']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/face_2d_keypoints/face_2d_keypoints_align.py b/modelscope/models/cv/face_2d_keypoints/face_2d_keypoints_align.py new file mode 100644 index 00000000..468662a0 --- /dev/null +++ b/modelscope/models/cv/face_2d_keypoints/face_2d_keypoints_align.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.models.face.face_keypoint import FaceKeypoint + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.cv.easycv_base import EasyCVBaseModel +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + group_key=Tasks.face_2d_keypoints, module_name=Models.face_2d_keypoints) +class Face2DKeypoints(EasyCVBaseModel, FaceKeypoint): + + def __init__(self, model_dir=None, *args, **kwargs): + EasyCVBaseModel.__init__(self, model_dir, args, kwargs) + FaceKeypoint.__init__(self, *args, **kwargs) diff --git a/modelscope/models/cv/face_detection/__init__.py b/modelscope/models/cv/face_detection/__init__.py new file mode 100644 index 00000000..27d1bd4c --- /dev/null +++ b/modelscope/models/cv/face_detection/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .mogface import MogFaceDetector + from .mtcnn import MtcnnFaceDetector + from .retinaface import RetinaFaceDetection + from .ulfd_slim import UlfdFaceDetector + from .scrfd import ScrfdDetect +else: + _import_structure = { + 'ulfd_slim': ['UlfdFaceDetector'], + 'retinaface': ['RetinaFaceDetection'], + 'mtcnn': ['MtcnnFaceDetector'], + 'mogface': ['MogFaceDetector'], + 'scrfd': ['ScrfdDetect'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/face_detection/mogface/__init__.py b/modelscope/models/cv/face_detection/mogface/__init__.py new file mode 100644 index 00000000..a58268d0 --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .models.detectors import MogFaceDetector diff --git a/modelscope/models/cv/face_detection/mogface/models/__init__.py b/modelscope/models/cv/face_detection/mogface/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/mogface/models/detectors.py b/modelscope/models/cv/face_detection/mogface/models/detectors.py new file mode 100644 index 00000000..8c1d9150 --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/detectors.py @@ -0,0 +1,98 @@ +# The implementation is based on MogFace, available at +# https://github.com/damo-cv/MogFace +import os + +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from .mogface import MogFace +from .utils import MogPriorBox, mogdecode, py_cpu_nms + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.mogface) +class MogFaceDetector(TorchModel): + + def __init__(self, model_path, device='cuda'): + super().__init__(model_path) + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.device = device + self.net = MogFace() + self.load_model() + self.net = self.net.to(device) + + self.mean = np.array([[104, 117, 123]]) + + def load_model(self, load_to_cpu=False): + pretrained_dict = torch.load( + self.model_path, map_location=torch.device('cpu')) + self.net.load_state_dict(pretrained_dict, strict=False) + self.net.eval() + + def forward(self, input): + img_raw = input['img'] + img = np.array(img_raw.cpu().detach()) + img = img[:, :, ::-1] + + im_height, im_width = img.shape[:2] + ss = 1.0 + # tricky + if max(im_height, im_width) > 1500: + ss = 1000.0 / max(im_height, im_width) + img = cv2.resize(img, (0, 0), fx=ss, fy=ss) + im_height, im_width = img.shape[:2] + + scale = torch.Tensor( + [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + img -= np.array([[103.53, 116.28, 123.675]]) + img /= np.array([[57.375, 57.120003, 58.395]]) + img /= 255 + img = img[:, :, ::-1].copy() + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.to(self.device) + scale = scale.to(self.device) + + conf, loc = self.net(img) # forward pass + + confidence_threshold = 0.82 + nms_threshold = 0.4 + top_k = 5000 + keep_top_k = 750 + + priorbox = MogPriorBox(scale_list=[0.68]) + priors = priorbox(im_height, im_width) + priors = torch.tensor(priors).to(self.device) + prior_data = priors.data + + boxes = mogdecode(loc.data.squeeze(0), prior_data) + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 0] + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype( + np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + + return dets / ss diff --git a/modelscope/models/cv/face_detection/mogface/models/mogface.py b/modelscope/models/cv/face_detection/mogface/models/mogface.py new file mode 100644 index 00000000..294c2c6b --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/mogface.py @@ -0,0 +1,135 @@ +# -------------------------------------------------------- +# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on +# https://github.com/damo-cv/MogFace +# -------------------------------------------------------- +import torch.nn as nn +import torch.nn.functional as F + +from .mogprednet import MogPredNet +from .resnet import ResNet + + +class MogFace(nn.Module): + + def __init__(self): + super(MogFace, self).__init__() + self.backbone = ResNet(depth=101) + self.fpn = LFPN() + self.pred_net = MogPredNet() + + def forward(self, x): + feature_list = self.backbone(x) + fpn_list = self.fpn(feature_list) + pyramid_feature_list = fpn_list[0] + conf, loc = self.pred_net(pyramid_feature_list) + return conf, loc + + +class FeatureFusion(nn.Module): + + def __init__(self, lat_ch=256, **channels): + super(FeatureFusion, self).__init__() + self.main_conv = nn.Conv2d(channels['main'], lat_ch, kernel_size=1) + + def forward(self, up, main): + main = self.main_conv(main) + _, _, H, W = main.size() + res = F.upsample(up, scale_factor=2, mode='bilinear') + if res.size(2) != main.size(2) or res.size(3) != main.size(3): + res = res[:, :, 0:H, 0:W] + res = res + main + return res + + +class LFPN(nn.Module): + + def __init__(self, + c2_out_ch=256, + c3_out_ch=512, + c4_out_ch=1024, + c5_out_ch=2048, + c6_mid_ch=512, + c6_out_ch=512, + c7_mid_ch=128, + c7_out_ch=256, + out_dsfd_ft=True): + super(LFPN, self).__init__() + self.out_dsfd_ft = out_dsfd_ft + if self.out_dsfd_ft: + dsfd_module = [] + dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(512, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(1024, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(2048, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) + dsfd_module.append(nn.Conv2d(256, 256, kernel_size=3, padding=1)) + self.dsfd_modules = nn.ModuleList(dsfd_module) + + c6_input_ch = c5_out_ch + self.c6 = nn.Sequential(*[ + nn.Conv2d( + c6_input_ch, + c6_mid_ch, + kernel_size=1, + ), + nn.BatchNorm2d(c6_mid_ch), + nn.ReLU(inplace=True), + nn.Conv2d( + c6_mid_ch, c6_out_ch, kernel_size=3, padding=1, stride=2), + nn.BatchNorm2d(c6_out_ch), + nn.ReLU(inplace=True) + ]) + self.c7 = nn.Sequential(*[ + nn.Conv2d( + c6_out_ch, + c7_mid_ch, + kernel_size=1, + ), + nn.BatchNorm2d(c7_mid_ch), + nn.ReLU(inplace=True), + nn.Conv2d( + c7_mid_ch, c7_out_ch, kernel_size=3, padding=1, stride=2), + nn.BatchNorm2d(c7_out_ch), + nn.ReLU(inplace=True) + ]) + + self.p2_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.p3_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.p4_lat = nn.Conv2d(256, 256, kernel_size=3, padding=1) + + self.c5_lat = nn.Conv2d(c6_input_ch, 256, kernel_size=3, padding=1) + self.c6_lat = nn.Conv2d(c6_out_ch, 256, kernel_size=3, padding=1) + self.c7_lat = nn.Conv2d(c7_out_ch, 256, kernel_size=3, padding=1) + + self.ff_c5_c4 = FeatureFusion(main=c4_out_ch) + self.ff_c4_c3 = FeatureFusion(main=c3_out_ch) + self.ff_c3_c2 = FeatureFusion(main=c2_out_ch) + + def forward(self, feature_list): + c2, c3, c4, c5 = feature_list + c6 = self.c6(c5) + c7 = self.c7(c6) + + c5 = self.c5_lat(c5) + c6 = self.c6_lat(c6) + c7 = self.c7_lat(c7) + + if self.out_dsfd_ft: + dsfd_fts = [] + dsfd_fts.append(self.dsfd_modules[0](c2)) + dsfd_fts.append(self.dsfd_modules[1](c3)) + dsfd_fts.append(self.dsfd_modules[2](c4)) + dsfd_fts.append(self.dsfd_modules[3](feature_list[-1])) + dsfd_fts.append(self.dsfd_modules[4](c6)) + dsfd_fts.append(self.dsfd_modules[5](c7)) + + p4 = self.ff_c5_c4(c5, c4) + p3 = self.ff_c4_c3(p4, c3) + p2 = self.ff_c3_c2(p3, c2) + + p2 = self.p2_lat(p2) + p3 = self.p3_lat(p3) + p4 = self.p4_lat(p4) + + if self.out_dsfd_ft: + return ([p2, p3, p4, c5, c6, c7], dsfd_fts) diff --git a/modelscope/models/cv/face_detection/mogface/models/mogprednet.py b/modelscope/models/cv/face_detection/mogface/models/mogprednet.py new file mode 100644 index 00000000..31384976 --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/mogprednet.py @@ -0,0 +1,164 @@ +# -------------------------------------------------------- +# The implementation is also open-sourced by the authors as Yang Liu, and is available publicly on +# https://github.com/damo-cv/MogFace +# -------------------------------------------------------- +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class conv_bn(nn.Module): + """docstring for conv""" + + def __init__(self, in_plane, out_plane, kernel_size, stride, padding): + super(conv_bn, self).__init__() + self.conv1 = nn.Conv2d( + in_plane, + out_plane, + kernel_size=kernel_size, + stride=stride, + padding=padding) + self.bn1 = nn.BatchNorm2d(out_plane) + + def forward(self, x): + x = self.conv1(x) + return self.bn1(x) + + +class SSHContext(nn.Module): + + def __init__(self, channels, Xchannels=256): + super(SSHContext, self).__init__() + + self.conv1 = nn.Conv2d( + channels, Xchannels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d( + channels, + Xchannels // 2, + kernel_size=3, + dilation=2, + stride=1, + padding=2) + self.conv2_1 = nn.Conv2d( + Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d( + Xchannels // 2, + Xchannels // 2, + kernel_size=3, + dilation=2, + stride=1, + padding=2) + self.conv2_2_1 = nn.Conv2d( + Xchannels // 2, Xchannels // 2, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x1 = F.relu(self.conv1(x), inplace=True) + x2 = F.relu(self.conv2(x), inplace=True) + x2_1 = F.relu(self.conv2_1(x2), inplace=True) + x2_2 = F.relu(self.conv2_2(x2), inplace=True) + x2_2 = F.relu(self.conv2_2_1(x2_2), inplace=True) + + return torch.cat([x1, x2_1, x2_2], 1) + + +class DeepHead(nn.Module): + + def __init__(self, + in_channel=256, + out_channel=256, + use_gn=False, + num_conv=4): + super(DeepHead, self).__init__() + self.use_gn = use_gn + self.num_conv = num_conv + self.conv1 = nn.Conv2d(in_channel, out_channel, 3, 1, 1) + self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) + self.conv3 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) + self.conv4 = nn.Conv2d(out_channel, out_channel, 3, 1, 1) + if self.use_gn: + self.gn1 = nn.GroupNorm(16, out_channel) + self.gn2 = nn.GroupNorm(16, out_channel) + self.gn3 = nn.GroupNorm(16, out_channel) + self.gn4 = nn.GroupNorm(16, out_channel) + + def forward(self, x): + if self.use_gn: + x1 = F.relu(self.gn1(self.conv1(x)), inplace=True) + x2 = F.relu(self.gn2(self.conv1(x1)), inplace=True) + x3 = F.relu(self.gn3(self.conv1(x2)), inplace=True) + x4 = F.relu(self.gn4(self.conv1(x3)), inplace=True) + else: + x1 = F.relu(self.conv1(x), inplace=True) + x2 = F.relu(self.conv1(x1), inplace=True) + if self.num_conv == 2: + return x2 + x3 = F.relu(self.conv1(x2), inplace=True) + x4 = F.relu(self.conv1(x3), inplace=True) + + return x4 + + +class MogPredNet(nn.Module): + + def __init__(self, + num_anchor_per_pixel=1, + num_classes=1, + input_ch_list=[256, 256, 256, 256, 256, 256], + use_deep_head=True, + deep_head_with_gn=True, + use_ssh=True, + deep_head_ch=512): + super(MogPredNet, self).__init__() + self.num_classes = num_classes + self.use_deep_head = use_deep_head + self.deep_head_with_gn = deep_head_with_gn + + self.use_ssh = use_ssh + + self.deep_head_ch = deep_head_ch + + if self.use_ssh: + self.conv_SSH = SSHContext(input_ch_list[0], + self.deep_head_ch // 2) + + if self.use_deep_head: + if self.deep_head_with_gn: + self.deep_loc_head = DeepHead( + self.deep_head_ch, self.deep_head_ch, use_gn=True) + self.deep_cls_head = DeepHead( + self.deep_head_ch, self.deep_head_ch, use_gn=True) + + self.pred_cls = nn.Conv2d(self.deep_head_ch, + 1 * num_anchor_per_pixel, 3, 1, 1) + self.pred_loc = nn.Conv2d(self.deep_head_ch, + 4 * num_anchor_per_pixel, 3, 1, 1) + + self.sigmoid = nn.Sigmoid() + + def forward(self, pyramid_feature_list, dsfd_ft_list=None): + loc = [] + conf = [] + + if self.use_deep_head: + for x in pyramid_feature_list: + if self.use_ssh: + x = self.conv_SSH(x) + x_cls = self.deep_cls_head(x) + x_loc = self.deep_loc_head(x) + + conf.append( + self.pred_cls(x_cls).permute(0, 2, 3, 1).contiguous()) + loc.append( + self.pred_loc(x_loc).permute(0, 2, 3, 1).contiguous()) + + loc = torch.cat([o.view(o.size(0), -1, 4) for o in loc], 1) + conf = torch.cat( + [o.view(o.size(0), -1, self.num_classes) for o in conf], 1) + output = ( + self.sigmoid(conf.view(conf.size(0), -1, self.num_classes)), + loc.view(loc.size(0), -1, 4), + ) + + return output diff --git a/modelscope/models/cv/face_detection/mogface/models/resnet.py b/modelscope/models/cv/face_detection/mogface/models/resnet.py new file mode 100644 index 00000000..045f6fa3 --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/resnet.py @@ -0,0 +1,193 @@ +# The implementation is modified from original resent implementaiton, which is +# also open-sourced by the authors as Yang Liu, +# and is available publicly on https://github.com/damo-cv/MogFace + +import torch.nn as nn + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, + depth=50, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + inplanes=64, + shrink_ch_ratio=1): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + if depth == 50: + block = Bottleneck + layers = [3, 4, 6, 3] + elif depth == 101: + block = Bottleneck + layers = [3, 4, 23, 3] + elif depth == 152: + block = Bottleneck + layers = [3, 4, 36, 3] + elif depth == 18: + block = BasicBlock + layers = [2, 2, 2, 2] + else: + raise ValueError('only support depth in [18, 50, 101, 152]') + + shrink_input_ch = int(inplanes * shrink_ch_ratio) + self.inplanes = int(inplanes * shrink_ch_ratio) + if shrink_ch_ratio == 0.125: + layers = [2, 3, 3, 3] + + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, shrink_input_ch, layers[0]) + self.layer2 = self._make_layer( + block, + shrink_input_ch * 2, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer( + block, + shrink_input_ch * 4, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer( + block, + shrink_input_ch * 8, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + four_conv_layer = [] + x = self.layer1(x) + four_conv_layer.append(x) + x = self.layer2(x) + four_conv_layer.append(x) + x = self.layer3(x) + four_conv_layer.append(x) + x = self.layer4(x) + four_conv_layer.append(x) + + return four_conv_layer diff --git a/modelscope/models/cv/face_detection/mogface/models/utils.py b/modelscope/models/cv/face_detection/mogface/models/utils.py new file mode 100755 index 00000000..377ceb3d --- /dev/null +++ b/modelscope/models/cv/face_detection/mogface/models/utils.py @@ -0,0 +1,212 @@ +# Modified from https://github.com/biubug6/Pytorch_Retinaface + +import math +from itertools import product as product +from math import ceil + +import numpy as np +import torch + + +def transform_anchor(anchors): + """ + from [x0, x1, y0, y1] to [c_x, cy, w, h] + x1 = x0 + w - 1 + c_x = (x0 + x1) / 2 = (2x0 + w - 1) / 2 = x0 + (w - 1) / 2 + """ + return np.concatenate(((anchors[:, :2] + anchors[:, 2:]) / 2, + anchors[:, 2:] - anchors[:, :2] + 1), + axis=1) + + +def normalize_anchor(anchors): + """ + from [c_x, cy, w, h] to [x0, x1, y0, y1] + """ + item_1 = anchors[:, :2] - (anchors[:, 2:] - 1) / 2 + item_2 = anchors[:, :2] + (anchors[:, 2:] - 1) / 2 + return np.concatenate((item_1, item_2), axis=1) + + +class MogPriorBox(object): + """ + both for fpn and single layer, single layer need to test + return (np.array) [num_anchros, 4] [x0, y0, x1, y1] + """ + + def __init__(self, + scale_list=[1.], + aspect_ratio_list=[1.0], + stride_list=[4, 8, 16, 32, 64, 128], + anchor_size_list=[16, 32, 64, 128, 256, 512]): + self.scale_list = scale_list + self.aspect_ratio_list = aspect_ratio_list + self.stride_list = stride_list + self.anchor_size_list = anchor_size_list + + def __call__(self, img_height, img_width): + final_anchor_list = [] + + for idx, stride in enumerate(self.stride_list): + anchor_list = [] + cur_img_height = img_height + cur_img_width = img_width + tmp_stride = stride + + while tmp_stride != 1: + tmp_stride = tmp_stride // 2 + cur_img_height = (cur_img_height + 1) // 2 + cur_img_width = (cur_img_width + 1) // 2 + + for i in range(cur_img_height): + for j in range(cur_img_width): + for scale in self.scale_list: + cx = (j + 0.5) * stride + cy = (i + 0.5) * stride + side_x = self.anchor_size_list[idx] * scale + side_y = self.anchor_size_list[idx] * scale + for ratio in self.aspect_ratio_list: + anchor_list.append([ + cx, cy, side_x / math.sqrt(ratio), + side_y * math.sqrt(ratio) + ]) + + final_anchor_list.append(anchor_list) + final_anchor_arr = np.concatenate(final_anchor_list, axis=0) + normalized_anchor_arr = normalize_anchor(final_anchor_arr).astype( + 'float32') + transformed_anchor = transform_anchor(normalized_anchor_arr) + + return transformed_anchor + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ + ceil(self.image_size[0] / step), + ceil(self.image_size[1] / step) + ] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [ + x * self.steps[k] / self.image_size[1] + for x in [j + 0.5] + ] + dense_cy = [ + y * self.steps[k] / self.image_size[0] + for y in [i + 0.5] + ] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def mogdecode(loc, anchors): + """ + loc: torch.Tensor + anchors: 2-d, torch.Tensor (cx, cy, w, h) + boxes: 2-d, torch.Tensor (x0, y0, x1, y1) + """ + + boxes = torch.cat((anchors[:, :2] + loc[:, :2] * anchors[:, 2:], + anchors[:, 2:] * torch.exp(loc[:, 2:])), 1) + + boxes[:, 0] -= (boxes[:, 2] - 1) / 2 + boxes[:, 1] -= (boxes[:, 3] - 1) / 2 + boxes[:, 2] += boxes[:, 0] - 1 + boxes[:, 3] += boxes[:, 1] - 1 + + return boxes + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat( + (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] + b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] + c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] + d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] + e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] + landms = torch.cat((a, b, c, d, e), dim=1) + return landms diff --git a/modelscope/models/cv/face_detection/mtcnn/__init__.py b/modelscope/models/cv/face_detection/mtcnn/__init__.py new file mode 100644 index 00000000..9fddab9c --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .models.detector import MtcnnFaceDetector diff --git a/modelscope/models/cv/face_detection/mtcnn/models/__init__.py b/modelscope/models/cv/face_detection/mtcnn/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/mtcnn/models/box_utils.py b/modelscope/models/cv/face_detection/mtcnn/models/box_utils.py new file mode 100644 index 00000000..f6a27b05 --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/models/box_utils.py @@ -0,0 +1,240 @@ +# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch +import numpy as np +from PIL import Image + + +def nms(boxes, overlap_threshold=0.5, mode='union'): + """Non-maximum suppression. + + Arguments: + boxes: a float numpy array of shape [n, 5], + where each row is (xmin, ymin, xmax, ymax, score). + overlap_threshold: a float number. + mode: 'union' or 'min'. + + Returns: + list with indices of the selected boxes + """ + + # if there are no boxes, return the empty list + if len(boxes) == 0: + return [] + + # list of picked indices + pick = [] + + # grab the coordinates of the bounding boxes + x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] + + area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0) + ids = np.argsort(score) # in increasing order + + while len(ids) > 0: + + # grab index of the largest value + last = len(ids) - 1 + i = ids[last] + pick.append(i) + + # compute intersections + # of the box with the largest score + # with the rest of boxes + + # left top corner of intersection boxes + ix1 = np.maximum(x1[i], x1[ids[:last]]) + iy1 = np.maximum(y1[i], y1[ids[:last]]) + + # right bottom corner of intersection boxes + ix2 = np.minimum(x2[i], x2[ids[:last]]) + iy2 = np.minimum(y2[i], y2[ids[:last]]) + + # width and height of intersection boxes + w = np.maximum(0.0, ix2 - ix1 + 1.0) + h = np.maximum(0.0, iy2 - iy1 + 1.0) + + # intersections' areas + inter = w * h + if mode == 'min': + overlap = inter / np.minimum(area[i], area[ids[:last]]) + elif mode == 'union': + # intersection over union (IoU) + overlap = inter / (area[i] + area[ids[:last]] - inter) + + # delete all boxes where overlap is too big + ids = np.delete( + ids, + np.concatenate([[last], + np.where(overlap > overlap_threshold)[0]])) + + return pick + + +def convert_to_square(bboxes): + """Convert bounding boxes to a square form. + + Arguments: + bboxes: a float numpy array of shape [n, 5]. + + Returns: + a float numpy array of shape [n, 5], + squared bounding boxes. + """ + + square_bboxes = np.zeros_like(bboxes) + x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] + h = y2 - y1 + 1.0 + w = x2 - x1 + 1.0 + max_side = np.maximum(h, w) + square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5 + square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5 + square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 + square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 + return square_bboxes + + +def calibrate_box(bboxes, offsets): + """Transform bounding boxes to be more like true bounding boxes. + 'offsets' is one of the outputs of the nets. + + Arguments: + bboxes: a float numpy array of shape [n, 5]. + offsets: a float numpy array of shape [n, 4]. + + Returns: + a float numpy array of shape [n, 5]. + """ + x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] + w = x2 - x1 + 1.0 + h = y2 - y1 + 1.0 + w = np.expand_dims(w, 1) + h = np.expand_dims(h, 1) + + # this is what happening here: + # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] + # x1_true = x1 + tx1*w + # y1_true = y1 + ty1*h + # x2_true = x2 + tx2*w + # y2_true = y2 + ty2*h + # below is just more compact form of this + + # are offsets always such that + # x1 < x2 and y1 < y2 ? + + translation = np.hstack([w, h, w, h]) * offsets + bboxes[:, 0:4] = bboxes[:, 0:4] + translation + return bboxes + + +def get_image_boxes(bounding_boxes, img, size=24): + """Cut out boxes from the image. + + Arguments: + bounding_boxes: a float numpy array of shape [n, 5]. + img: an instance of PIL.Image. + size: an integer, size of cutouts. + + Returns: + a float numpy array of shape [n, 3, size, size]. + """ + + num_boxes = len(bounding_boxes) + width, height = img.size + + [dy, edy, dx, edx, y, ey, x, ex, w, + h] = correct_bboxes(bounding_boxes, width, height) + img_boxes = np.zeros((num_boxes, 3, size, size), 'float32') + + for i in range(num_boxes): + img_box = np.zeros((h[i], w[i], 3), 'uint8') + + img_array = np.asarray(img, 'uint8') + img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] =\ + img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :] + + # resize + img_box = Image.fromarray(img_box) + img_box = img_box.resize((size, size), Image.BILINEAR) + img_box = np.asarray(img_box, 'float32') + + img_boxes[i, :, :, :] = _preprocess(img_box) + + return img_boxes + + +def correct_bboxes(bboxes, width, height): + """Crop boxes that are too big and get coordinates + with respect to cutouts. + + Arguments: + bboxes: a float numpy array of shape [n, 5], + where each row is (xmin, ymin, xmax, ymax, score). + width: a float number. + height: a float number. + + Returns: + dy, dx, edy, edx: a int numpy arrays of shape [n], + coordinates of the boxes with respect to the cutouts. + y, x, ey, ex: a int numpy arrays of shape [n], + corrected ymin, xmin, ymax, xmax. + h, w: a int numpy arrays of shape [n], + just heights and widths of boxes. + + in the following order: + [dy, edy, dx, edx, y, ey, x, ex, w, h]. + """ + + x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] + w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 + num_boxes = bboxes.shape[0] + + # 'e' stands for end + # (x, y) -> (ex, ey) + x, y, ex, ey = x1, y1, x2, y2 + + # we need to cut out a box from the image. + # (x, y, ex, ey) are corrected coordinates of the box + # in the image. + # (dx, dy, edx, edy) are coordinates of the box in the cutout + # from the image. + dx, dy = np.zeros((num_boxes, )), np.zeros((num_boxes, )) + edx, edy = w.copy() - 1.0, h.copy() - 1.0 + + # if box's bottom right corner is too far right + ind = np.where(ex > width - 1.0)[0] + edx[ind] = w[ind] + width - 2.0 - ex[ind] + ex[ind] = width - 1.0 + + # if box's bottom right corner is too low + ind = np.where(ey > height - 1.0)[0] + edy[ind] = h[ind] + height - 2.0 - ey[ind] + ey[ind] = height - 1.0 + + # if box's top left corner is too far left + ind = np.where(x < 0.0)[0] + dx[ind] = 0.0 - x[ind] + x[ind] = 0.0 + + # if box's top left corner is too high + ind = np.where(y < 0.0)[0] + dy[ind] = 0.0 - y[ind] + y[ind] = 0.0 + + return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] + return_list = [i.astype('int32') for i in return_list] + + return return_list + + +def _preprocess(img): + """Preprocessing step before feeding the network. + + Arguments: + img: a float numpy array of shape [h, w, c]. + + Returns: + a float numpy array of shape [1, c, h, w]. + """ + img = img.transpose((2, 0, 1)) + img = np.expand_dims(img, 0) + img = (img - 127.5) * 0.0078125 + return img diff --git a/modelscope/models/cv/face_detection/mtcnn/models/detector.py b/modelscope/models/cv/face_detection/mtcnn/models/detector.py new file mode 100644 index 00000000..9c3aca3a --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/models/detector.py @@ -0,0 +1,149 @@ +# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch +import os + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from PIL import Image +from torch.autograd import Variable + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from .box_utils import calibrate_box, convert_to_square, get_image_boxes, nms +from .first_stage import run_first_stage +from .get_nets import ONet, PNet, RNet + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.mtcnn) +class MtcnnFaceDetector(TorchModel): + + def __init__(self, model_path, device='cuda'): + super().__init__(model_path) + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.device = device + + self.pnet = PNet(model_path=os.path.join(self.model_path, 'pnet.npy')) + self.rnet = RNet(model_path=os.path.join(self.model_path, 'rnet.npy')) + self.onet = ONet(model_path=os.path.join(self.model_path, 'onet.npy')) + + self.pnet = self.pnet.to(device) + self.rnet = self.rnet.to(device) + self.onet = self.onet.to(device) + + def forward(self, input): + image = Image.fromarray(np.uint8(input['img'].cpu().numpy())) + pnet = self.pnet + rnet = self.rnet + onet = self.onet + onet.eval() + + min_face_size = 20.0 + thresholds = [0.7, 0.8, 0.9] + nms_thresholds = [0.7, 0.7, 0.7] + + # BUILD AN IMAGE PYRAMID + width, height = image.size + min_length = min(height, width) + + min_detection_size = 12 + factor = 0.707 # sqrt(0.5) + + # scales for scaling the image + scales = [] + + m = min_detection_size / min_face_size + min_length *= m + + factor_count = 0 + while min_length > min_detection_size: + scales.append(m * factor**factor_count) + min_length *= factor + factor_count += 1 + + # STAGE 1 + + # it will be returned + bounding_boxes = [] + + # run P-Net on different scales + for s in scales: + boxes = run_first_stage( + image, + pnet, + scale=s, + threshold=thresholds[0], + device=self.device) + bounding_boxes.append(boxes) + + # collect boxes (and offsets, and scores) from different scales + bounding_boxes = [i for i in bounding_boxes if i is not None] + bounding_boxes = np.vstack(bounding_boxes) + + keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) + bounding_boxes = bounding_boxes[keep] + + # use offsets predicted by pnet to transform bounding boxes + bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], + bounding_boxes[:, 5:]) + # shape [n_boxes, 5] + + bounding_boxes = convert_to_square(bounding_boxes) + bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) + + # STAGE 2 + + img_boxes = get_image_boxes(bounding_boxes, image, size=24) + img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True) + output = rnet(img_boxes.to(self.device)) + offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] + probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] + + keep = np.where(probs[:, 1] > thresholds[1])[0] + bounding_boxes = bounding_boxes[keep] + bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, )) + offsets = offsets[keep] + + keep = nms(bounding_boxes, nms_thresholds[1]) + bounding_boxes = bounding_boxes[keep] + bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) + bounding_boxes = convert_to_square(bounding_boxes) + bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) + + # STAGE 3 + + img_boxes = get_image_boxes(bounding_boxes, image, size=48) + if len(img_boxes) == 0: + return [], [] + img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True) + output = onet(img_boxes.to(self.device)) + landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] + offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] + probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] + + keep = np.where(probs[:, 1] > thresholds[2])[0] + bounding_boxes = bounding_boxes[keep] + bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, )) + offsets = offsets[keep] + landmarks = landmarks[keep] + + # compute landmark points + width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 + height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 + xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] + landmarks[:, 0:5] = np.expand_dims( + xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] + landmarks[:, 5:10] = np.expand_dims( + ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] + + bounding_boxes = calibrate_box(bounding_boxes, offsets) + keep = nms(bounding_boxes, nms_thresholds[2], mode='min') + bounding_boxes = bounding_boxes[keep] + landmarks = landmarks[keep] + landmarks = landmarks.reshape(-1, 2, 5).transpose( + (0, 2, 1)).reshape(-1, 10) + + return bounding_boxes, landmarks diff --git a/modelscope/models/cv/face_detection/mtcnn/models/first_stage.py b/modelscope/models/cv/face_detection/mtcnn/models/first_stage.py new file mode 100644 index 00000000..e2aba47e --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/models/first_stage.py @@ -0,0 +1,100 @@ +# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch +import math + +import numpy as np +import torch +from PIL import Image +from torch.autograd import Variable + +from .box_utils import _preprocess, nms + + +def run_first_stage(image, net, scale, threshold, device='cuda'): + """Run P-Net, generate bounding boxes, and do NMS. + + Arguments: + image: an instance of PIL.Image. + net: an instance of pytorch's nn.Module, P-Net. + scale: a float number, + scale width and height of the image by this number. + threshold: a float number, + threshold on the probability of a face when generating + bounding boxes from predictions of the net. + + Returns: + a float numpy array of shape [n_boxes, 9], + bounding boxes with scores and offsets (4 + 1 + 4). + """ + + # scale the image and convert it to a float array + width, height = image.size + sw, sh = math.ceil(width * scale), math.ceil(height * scale) + img = image.resize((sw, sh), Image.BILINEAR) + img = np.asarray(img, 'float32') + + img = Variable( + torch.FloatTensor(_preprocess(img)), volatile=True).to(device) + output = net(img) + probs = output[1].cpu().data.numpy()[0, 1, :, :] + offsets = output[0].cpu().data.numpy() + # probs: probability of a face at each sliding window + # offsets: transformations to true bounding boxes + + boxes = _generate_bboxes(probs, offsets, scale, threshold) + if len(boxes) == 0: + return None + + keep = nms(boxes[:, 0:5], overlap_threshold=0.5) + return boxes[keep] + + +def _generate_bboxes(probs, offsets, scale, threshold): + """Generate bounding boxes at places + where there is probably a face. + + Arguments: + probs: a float numpy array of shape [n, m]. + offsets: a float numpy array of shape [1, 4, n, m]. + scale: a float number, + width and height of the image were scaled by this number. + threshold: a float number. + + Returns: + a float numpy array of shape [n_boxes, 9] + """ + + # applying P-Net is equivalent, in some sense, to + # moving 12x12 window with stride 2 + stride = 2 + cell_size = 12 + + # indices of boxes where there is probably a face + inds = np.where(probs > threshold) + + if inds[0].size == 0: + return np.array([]) + + # transformations of bounding boxes + tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] + # they are defined as: + # w = x2 - x1 + 1 + # h = y2 - y1 + 1 + # x1_true = x1 + tx1*w + # x2_true = x2 + tx2*w + # y1_true = y1 + ty1*h + # y2_true = y2 + ty2*h + + offsets = np.array([tx1, ty1, tx2, ty2]) + score = probs[inds[0], inds[1]] + + # P-Net is applied to scaled images + # so we need to rescale bounding boxes back + bounding_boxes = np.vstack([ + np.round((stride * inds[1] + 1.0) / scale), + np.round((stride * inds[0] + 1.0) / scale), + np.round((stride * inds[1] + 1.0 + cell_size) / scale), + np.round((stride * inds[0] + 1.0 + cell_size) / scale), score, offsets + ]) + # why one is added? + + return bounding_boxes.T diff --git a/modelscope/models/cv/face_detection/mtcnn/models/get_nets.py b/modelscope/models/cv/face_detection/mtcnn/models/get_nets.py new file mode 100644 index 00000000..5fbbd33b --- /dev/null +++ b/modelscope/models/cv/face_detection/mtcnn/models/get_nets.py @@ -0,0 +1,160 @@ +# The implementation is based on mtcnn, available at https://github.com/TropComplique/mtcnn-pytorch +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Flatten(nn.Module): + + def __init__(self): + super(Flatten, self).__init__() + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, c, h, w]. + Returns: + a float tensor with shape [batch_size, c*h*w]. + """ + + # without this pretrained model isn't working + x = x.transpose(3, 2).contiguous() + + return x.view(x.size(0), -1) + + +class PNet(nn.Module): + + def __init__(self, model_path=None): + + super(PNet, self).__init__() + + # suppose we have input with size HxW, then + # after first layer: H - 2, + # after pool: ceil((H - 2)/2), + # after second conv: ceil((H - 2)/2) - 2, + # after last conv: ceil((H - 2)/2) - 4, + # and the same for W + + self.features = nn.Sequential( + OrderedDict([('conv1', nn.Conv2d(3, 10, 3, 1)), + ('prelu1', nn.PReLU(10)), + ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), + ('conv2', nn.Conv2d(10, 16, 3, 1)), + ('prelu2', nn.PReLU(16)), + ('conv3', nn.Conv2d(16, 32, 3, 1)), + ('prelu3', nn.PReLU(32))])) + + self.conv4_1 = nn.Conv2d(32, 2, 1, 1) + self.conv4_2 = nn.Conv2d(32, 4, 1, 1) + + weights = np.load(model_path, allow_pickle=True)[()] + for n, p in self.named_parameters(): + p.data = torch.FloatTensor(weights[n]) + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, 3, h, w]. + Returns: + b: a float tensor with shape [batch_size, 4, h', w']. + a: a float tensor with shape [batch_size, 2, h', w']. + """ + x = self.features(x) + a = self.conv4_1(x) + b = self.conv4_2(x) + a = F.softmax(a) + return b, a + + +class RNet(nn.Module): + + def __init__(self, model_path=None): + + super(RNet, self).__init__() + + self.features = nn.Sequential( + OrderedDict([('conv1', nn.Conv2d(3, 28, 3, 1)), + ('prelu1', nn.PReLU(28)), + ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), + ('conv2', nn.Conv2d(28, 48, 3, 1)), + ('prelu2', nn.PReLU(48)), + ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), + ('conv3', nn.Conv2d(48, 64, 2, 1)), + ('prelu3', nn.PReLU(64)), ('flatten', Flatten()), + ('conv4', nn.Linear(576, 128)), + ('prelu4', nn.PReLU(128))])) + + self.conv5_1 = nn.Linear(128, 2) + self.conv5_2 = nn.Linear(128, 4) + + weights = np.load(model_path, allow_pickle=True)[()] + for n, p in self.named_parameters(): + p.data = torch.FloatTensor(weights[n]) + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, 3, h, w]. + Returns: + b: a float tensor with shape [batch_size, 4]. + a: a float tensor with shape [batch_size, 2]. + """ + x = self.features(x) + a = self.conv5_1(x) + b = self.conv5_2(x) + a = F.softmax(a) + return b, a + + +class ONet(nn.Module): + + def __init__(self, model_path=None): + + super(ONet, self).__init__() + + self.features = nn.Sequential( + OrderedDict([ + ('conv1', nn.Conv2d(3, 32, 3, 1)), + ('prelu1', nn.PReLU(32)), + ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), + ('conv2', nn.Conv2d(32, 64, 3, 1)), + ('prelu2', nn.PReLU(64)), + ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), + ('conv3', nn.Conv2d(64, 64, 3, 1)), + ('prelu3', nn.PReLU(64)), + ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), + ('conv4', nn.Conv2d(64, 128, 2, 1)), + ('prelu4', nn.PReLU(128)), + ('flatten', Flatten()), + ('conv5', nn.Linear(1152, 256)), + ('drop5', nn.Dropout(0.25)), + ('prelu5', nn.PReLU(256)), + ])) + + self.conv6_1 = nn.Linear(256, 2) + self.conv6_2 = nn.Linear(256, 4) + self.conv6_3 = nn.Linear(256, 10) + + weights = np.load(model_path, allow_pickle=True)[()] + for n, p in self.named_parameters(): + p.data = torch.FloatTensor(weights[n]) + + def forward(self, x): + """ + Arguments: + x: a float tensor with shape [batch_size, 3, h, w]. + Returns: + c: a float tensor with shape [batch_size, 10]. + b: a float tensor with shape [batch_size, 4]. + a: a float tensor with shape [batch_size, 2]. + """ + x = self.features(x) + a = self.conv6_1(x) + b = self.conv6_2(x) + c = self.conv6_3(x) + a = F.softmax(a) + return c, b, a diff --git a/modelscope/models/cv/face_detection/retinaface/__init__.py b/modelscope/models/cv/face_detection/retinaface/__init__.py new file mode 100644 index 00000000..e7b589a1 --- /dev/null +++ b/modelscope/models/cv/face_detection/retinaface/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .detection import RetinaFaceDetection diff --git a/modelscope/models/cv/face_detection/retinaface/detection.py b/modelscope/models/cv/face_detection/retinaface/detection.py new file mode 100755 index 00000000..3dd31659 --- /dev/null +++ b/modelscope/models/cv/face_detection/retinaface/detection.py @@ -0,0 +1,137 @@ +# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .models.retinaface import RetinaFace +from .utils import PriorBox, decode, decode_landm, py_cpu_nms + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.retinaface) +class RetinaFaceDetection(TorchModel): + + def __init__(self, model_path, device='cuda'): + super().__init__(model_path) + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.cfg = Config.from_file( + model_path.replace(ModelFile.TORCH_MODEL_FILE, + ModelFile.CONFIGURATION))['models'] + self.net = RetinaFace(cfg=self.cfg) + self.load_model() + self.device = device + self.net = self.net.to(self.device) + + self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device) + + def check_keys(self, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(self.net.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + assert len( + used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + + def remove_prefix(self, state_dict, prefix): + new_state_dict = dict() + for k, v in state_dict.items(): + if k.startswith(prefix): + new_state_dict[k[len(prefix):]] = v + else: + new_state_dict[k] = v + return new_state_dict + + def load_model(self, load_to_cpu=False): + pretrained_dict = torch.load( + self.model_path, map_location=torch.device('cpu')) + if 'state_dict' in pretrained_dict.keys(): + pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], + 'module.') + else: + pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') + self.check_keys(pretrained_dict) + self.net.load_state_dict(pretrained_dict, strict=False) + self.net.eval() + + def forward(self, input): + img_raw = input['img'].cpu().numpy() + img = np.float32(img_raw) + + im_height, im_width = img.shape[:2] + ss = 1.0 + # tricky + if max(im_height, im_width) > 1500: + ss = 1000.0 / max(im_height, im_width) + img = cv2.resize(img, (0, 0), fx=ss, fy=ss) + im_height, im_width = img.shape[:2] + + scale = torch.Tensor( + [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.to(self.device) + scale = scale.to(self.device) + + loc, conf, landms = self.net(img) # forward pass + del img + + confidence_threshold = 0.9 + nms_threshold = 0.4 + top_k = 5000 + keep_top_k = 750 + + priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) + priors = priorbox.forward() + priors = priors.to(self.device) + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) + boxes = boxes * scale + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = decode_landm( + landms.data.squeeze(0), prior_data, self.cfg['variance']) + scale1 = torch.Tensor([ + im_width, im_height, im_width, im_height, im_width, im_height, + im_width, im_height, im_width, im_height + ]) + scale1 = scale1.to(self.device) + landms = landms * scale1 + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype( + np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + landms = landms[:keep_top_k, :] + + landms = landms.reshape((-1, 5, 2)) + landms = landms.reshape( + -1, + 10, + ) + return dets / ss, landms / ss diff --git a/modelscope/models/cv/face_detection/retinaface/models/__init__.py b/modelscope/models/cv/face_detection/retinaface/models/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/retinaface/models/net.py b/modelscope/models/cv/face_detection/retinaface/models/net.py new file mode 100755 index 00000000..3be7c4b9 --- /dev/null +++ b/modelscope/models/cv/face_detection/retinaface/models/net.py @@ -0,0 +1,149 @@ +# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.models._utils as _utils +from torch.autograd import Variable + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn( + in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn( + out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1( + in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1( + in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1( + in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate( + output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate( + output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + x = x.view(-1, 256) + x = self.fc(x) + return x diff --git a/modelscope/models/cv/face_detection/retinaface/models/retinaface.py b/modelscope/models/cv/face_detection/retinaface/models/retinaface.py new file mode 100755 index 00000000..8d2001dd --- /dev/null +++ b/modelscope/models/cv/face_detection/retinaface/models/retinaface.py @@ -0,0 +1,145 @@ +# The implementation is based on resnet, available at https://github.com/biubug6/Pytorch_Retinaface +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.models._utils as _utils +import torchvision.models.detection.backbone_utils as backbone_utils + +from .net import FPN, SSH, MobileNetV1 + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d( + inchannels, + self.num_anchors * 2, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d( + inchannels, + num_anchors * 4, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d( + inchannels, + num_anchors * 10, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +class RetinaFace(nn.Module): + + def __init__(self, cfg=None): + """ + :param cfg: Network related settings. + """ + super(RetinaFace, self).__init__() + backbone = None + if cfg['name'] == 'Resnet50': + backbone = models.resnet50(pretrained=cfg['pretrain']) + else: + raise Exception('Invalid name') + + self.body = _utils.IntermediateLayerGetter(backbone, + cfg['return_layers']) + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head( + fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = self._make_bbox_head( + fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = self._make_landmark_head( + fpn_num=3, inchannels=cfg['out_channel']) + + def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead + + def forward(self, inputs): + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat( + [self.BboxHead[i](feature) for i, feature in enumerate(features)], + dim=1) + classifications = torch.cat( + [self.ClassHead[i](feature) for i, feature in enumerate(features)], + dim=1) + ldm_regressions = torch.cat( + [self.LandmarkHead[i](feat) for i, feat in enumerate(features)], + dim=1) + + output = (bbox_regressions, F.softmax(classifications, + dim=-1), ldm_regressions) + return output diff --git a/modelscope/models/cv/face_detection/retinaface/utils.py b/modelscope/models/cv/face_detection/retinaface/utils.py new file mode 100755 index 00000000..60c9e2dd --- /dev/null +++ b/modelscope/models/cv/face_detection/retinaface/utils.py @@ -0,0 +1,123 @@ +# -------------------------------------------------------- +# Modified from https://github.com/biubug6/Pytorch_Retinaface +# -------------------------------------------------------- + +from itertools import product as product +from math import ceil + +import numpy as np +import torch + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ + ceil(self.image_size[0] / step), + ceil(self.image_size[1] / step) + ] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [ + x * self.steps[k] / self.image_size[1] + for x in [j + 0.5] + ] + dense_cy = [ + y * self.steps[k] / self.image_size[0] + for y in [i + 0.5] + ] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat( + (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] + b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] + c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] + d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] + e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] + landms = torch.cat((a, b, c, d, e), dim=1) + return landms diff --git a/modelscope/models/cv/face_detection/scrfd/__init__.py b/modelscope/models/cv/face_detection/scrfd/__init__.py new file mode 100644 index 00000000..92f81f7a --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .scrfd_detect import ScrfdDetect diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/__init__.py new file mode 100755 index 00000000..5a895582 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/__init__.py @@ -0,0 +1,4 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet +""" diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/__init__.py new file mode 100644 index 00000000..cf1b7313 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/__init__.py @@ -0,0 +1,7 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/bbox +""" +from .transforms import bbox2result, distance2kps, kps2distance + +__all__ = ['bbox2result', 'distance2kps', 'kps2distance'] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/transforms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/transforms.py new file mode 100755 index 00000000..75e32d85 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/bbox/transforms.py @@ -0,0 +1,87 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/bbox/transforms.py +""" +import numpy as np +import torch + + +def bbox2result(bboxes, labels, num_classes, kps=None, num_kps=5): + """Convert detection results to a list of numpy arrays. + + Args: + bboxes (torch.Tensor | np.ndarray): shape (n, 5) + labels (torch.Tensor | np.ndarray): shape (n, ) + num_classes (int): class number, including background class + + Returns: + list(ndarray): bbox results of each class + """ + bbox_len = 5 if kps is None else 5 + num_kps * 2 # if has kps, add num_kps*2 into bbox + if bboxes.shape[0] == 0: + return [ + np.zeros((0, bbox_len), dtype=np.float32) + for i in range(num_classes) + ] + else: + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + if kps is None: + return [bboxes[labels == i, :] for i in range(num_classes)] + else: # with kps + if isinstance(kps, torch.Tensor): + kps = kps.detach().cpu().numpy() + return [ + np.hstack([bboxes[labels == i, :], kps[labels == i, :]]) + for i in range(num_classes) + ] + + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded kps. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i % 2] + distance[:, i] + py = points[:, i % 2 + 1] + distance[:, i + 1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return torch.stack(preds, -1) + + +def kps2distance(points, kps, max_dis=None, eps=0.1): + """Decode bounding box based on distances. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + kps (Tensor): Shape (n, K), "xyxy" format + max_dis (float): Upper bound of the distance. + eps (float): a small value to ensure target < max_dis, instead <= + + Returns: + Tensor: Decoded distances. + """ + + preds = [] + for i in range(0, kps.shape[1], 2): + px = kps[:, i] - points[:, i % 2] + py = kps[:, i + 1] - points[:, i % 2 + 1] + if max_dis is not None: + px = px.clamp(min=0, max=max_dis - eps) + py = py.clamp(min=0, max=max_dis - eps) + preds.append(px) + preds.append(py) + return torch.stack(preds, -1) diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/__init__.py new file mode 100755 index 00000000..61602fd3 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/__init__.py @@ -0,0 +1,7 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/post_processing/bbox_nms.py +""" +from .bbox_nms import multiclass_nms + +__all__ = ['multiclass_nms'] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py new file mode 100644 index 00000000..697b7338 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/core/post_processing/bbox_nms.py @@ -0,0 +1,89 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/core/post_processing/bbox_nms.py +""" +import torch + + +def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, + nms_cfg, + max_num=-1, + score_factors=None, + return_inds=False, + multi_kps=None): + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_kps (Tensor): shape (n, #class*num_kps*2) or (n, num_kps*2) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_thr (float): NMS IoU threshold + max_num (int, optional): if there are more than max_num bboxes after + NMS, only top max_num will be kept. Default to -1. + score_factors (Tensor, optional): The factors multiplied to scores + before applying NMS. Default to None. + return_inds (bool, optional): Whether return the indices of kept + bboxes. Default to False. + + Returns: + tuple: (bboxes, labels, indices (optional)), tensors of shape (k, 5), + (k), and (k). Labels are 0-based. + """ + num_classes = multi_scores.size(1) - 1 + # exclude background category + kps = None + if multi_kps is not None: + num_kps = int((multi_kps.shape[1] / num_classes) / 2) + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + if multi_kps is not None: + kps = multi_kps.view(multi_scores.size(0), -1, num_kps * 2) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) + if multi_kps is not None: + kps = multi_kps[:, None].expand( + multi_scores.size(0), num_classes, num_kps * 2) + + scores = multi_scores[:, :-1] + if score_factors is not None: + scores = scores * score_factors[:, None] + + labels = torch.arange(num_classes, dtype=torch.long) + labels = labels.view(1, -1).expand_as(scores) + + bboxes = bboxes.reshape(-1, 4) + if kps is not None: + kps = kps.reshape(-1, num_kps * 2) + scores = scores.reshape(-1) + labels = labels.reshape(-1) + + # remove low scoring boxes + valid_mask = scores > score_thr + inds = valid_mask.nonzero(as_tuple=False).squeeze(1) + bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] + if kps is not None: + kps = kps[inds] + if inds.numel() == 0: + if torch.onnx.is_in_onnx_export(): + raise RuntimeError('[ONNX Error] Can not record NMS ' + 'as it has not been executed this time') + return bboxes, labels, kps + + # TODO: add size check before feed into batched_nms + from mmcv.ops.nms import batched_nms + dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) + + if max_num > 0: + dets = dets[:max_num] + keep = keep[:max_num] + + if return_inds: + return dets, labels[keep], kps[keep], keep + else: + return dets, labels[keep], kps[keep] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/__init__.py new file mode 100644 index 00000000..cea179b0 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/__init__.py @@ -0,0 +1,7 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets +""" +from .retinaface import RetinaFaceDataset + +__all__ = ['RetinaFaceDataset'] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/__init__.py new file mode 100755 index 00000000..a2cafd1a --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/__init__.py @@ -0,0 +1,13 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines +""" +from .auto_augment import RotateV2 +from .formating import DefaultFormatBundleV2 +from .loading import LoadAnnotationsV2 +from .transforms import RandomSquareCrop + +__all__ = [ + 'RandomSquareCrop', 'LoadAnnotationsV2', 'RotateV2', + 'DefaultFormatBundleV2' +] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/auto_augment.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/auto_augment.py new file mode 100644 index 00000000..ee60c2e0 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/auto_augment.py @@ -0,0 +1,271 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/auto_augment.py +""" +import copy + +import cv2 +import mmcv +import numpy as np +from mmdet.datasets.builder import PIPELINES + +_MAX_LEVEL = 10 + + +def level_to_value(level, max_value): + """Map from level to values based on max_value.""" + return (level / _MAX_LEVEL) * max_value + + +def random_negative(value, random_negative_prob): + """Randomly negate value based on random_negative_prob.""" + return -value if np.random.rand() < random_negative_prob else value + + +def bbox2fields(): + """The key correspondence from bboxes to labels, masks and + segmentations.""" + bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + bbox2seg = { + 'gt_bboxes': 'gt_semantic_seg', + } + return bbox2label, bbox2mask, bbox2seg + + +@PIPELINES.register_module() +class RotateV2(object): + """Apply Rotate Transformation to image (and its corresponding bbox, mask, + segmentation). + + Args: + level (int | float): The level should be in range (0,_MAX_LEVEL]. + scale (int | float): Isotropic scale factor. Same in + ``mmcv.imrotate``. + center (int | float | tuple[float]): Center point (w, h) of the + rotation in the source image. If None, the center of the + image will be used. Same in ``mmcv.imrotate``. + img_fill_val (int | float | tuple): The fill value for image border. + If float, the same value will be used for all the three + channels of image. If tuple, the should be 3 elements (e.g. + equals the number of channels for image). + seg_ignore_label (int): The fill value used for segmentation map. + Note this value must equals ``ignore_label`` in ``semantic_head`` + of the corresponding config. Default 255. + prob (float): The probability for perform transformation and + should be in range 0 to 1. + max_rotate_angle (int | float): The maximum angles for rotate + transformation. + random_negative_prob (float): The probability that turns the + offset negative. + """ + + def __init__(self, + level, + scale=1, + center=None, + img_fill_val=128, + seg_ignore_label=255, + prob=0.5, + max_rotate_angle=30, + random_negative_prob=0.5): + assert isinstance(level, (int, float)), \ + f'The level must be type int or float. got {type(level)}.' + assert 0 <= level <= _MAX_LEVEL, \ + f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.' + assert isinstance(scale, (int, float)), \ + f'The scale must be type int or float. got type {type(scale)}.' + if isinstance(center, (int, float)): + center = (center, center) + elif isinstance(center, tuple): + assert len(center) == 2, 'center with type tuple must have '\ + f'2 elements. got {len(center)} elements.' + else: + assert center is None, 'center must be None or type int, '\ + f'float or tuple, got type {type(center)}.' + if isinstance(img_fill_val, (float, int)): + img_fill_val = tuple([float(img_fill_val)] * 3) + elif isinstance(img_fill_val, tuple): + assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\ + f'have 3 elements. got {len(img_fill_val)}.' + img_fill_val = tuple([float(val) for val in img_fill_val]) + else: + raise ValueError( + 'img_fill_val must be float or tuple with 3 elements.') + assert np.all([0 <= val <= 255 for val in img_fill_val]), \ + 'all elements of img_fill_val should between range [0,255]. '\ + f'got {img_fill_val}.' + assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\ + f'got {prob}.' + assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\ + f'should be type int or float. got type {type(max_rotate_angle)}.' + self.level = level + self.scale = scale + # Rotation angle in degrees. Positive values mean + # clockwise rotation. + self.angle = level_to_value(level, max_rotate_angle) + self.center = center + self.img_fill_val = img_fill_val + self.seg_ignore_label = seg_ignore_label + self.prob = prob + self.max_rotate_angle = max_rotate_angle + self.random_negative_prob = random_negative_prob + + def _rotate_img(self, results, angle, center=None, scale=1.0): + """Rotate the image. + + Args: + results (dict): Result dict from loading pipeline. + angle (float): Rotation angle in degrees, positive values + mean clockwise rotation. Same in ``mmcv.imrotate``. + center (tuple[float], optional): Center point (w, h) of the + rotation. Same in ``mmcv.imrotate``. + scale (int | float): Isotropic scale factor. Same in + ``mmcv.imrotate``. + """ + for key in results.get('img_fields', ['img']): + img = results[key].copy() + img_rotated = mmcv.imrotate( + img, angle, center, scale, border_value=self.img_fill_val) + results[key] = img_rotated.astype(img.dtype) + results['img_shape'] = results[key].shape + + def _rotate_bboxes(self, results, rotate_matrix): + """Rotate the bboxes.""" + h, w, c = results['img_shape'] + for key in results.get('bbox_fields', []): + min_x, min_y, max_x, max_y = np.split( + results[key], results[key].shape[-1], axis=-1) + coordinates = np.stack([[min_x, min_y], [max_x, min_y], + [min_x, max_y], + [max_x, max_y]]) # [4, 2, nb_bbox, 1] + # pad 1 to convert from format [x, y] to homogeneous + # coordinates format [x, y, 1] + coordinates = np.concatenate( + (coordinates, + np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)), + axis=1) # [4, 3, nb_bbox, 1] + coordinates = coordinates.transpose( + (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] + rotated_coords = np.matmul(rotate_matrix, + coordinates) # [nb_bbox, 4, 2, 1] + rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] + min_x, min_y = np.min( + rotated_coords[:, :, 0], axis=1), np.min( + rotated_coords[:, :, 1], axis=1) + max_x, max_y = np.max( + rotated_coords[:, :, 0], axis=1), np.max( + rotated_coords[:, :, 1], axis=1) + results[key] = np.stack([min_x, min_y, max_x, max_y], + axis=-1).astype(results[key].dtype) + + def _rotate_keypoints90(self, results, angle): + """Rotate the keypoints, only valid when angle in [-90,90,-180,180]""" + if angle not in [-90, 90, 180, -180 + ] or self.scale != 1 or self.center is not None: + return + for key in results.get('keypoints_fields', []): + k = results[key] + if angle == 90: + w, h, c = results['img'].shape + new = np.stack([h - k[..., 1], k[..., 0], k[..., 2]], axis=-1) + elif angle == -90: + w, h, c = results['img'].shape + new = np.stack([k[..., 1], w - k[..., 0], k[..., 2]], axis=-1) + else: + h, w, c = results['img'].shape + new = np.stack([w - k[..., 0], h - k[..., 1], k[..., 2]], + axis=-1) + # a kps is invalid if thrid value is -1 + kps_invalid = new[..., -1][:, -1] == -1 + new[kps_invalid] = np.zeros(new.shape[1:]) - 1 + results[key] = new + + def _rotate_masks(self, + results, + angle, + center=None, + scale=1.0, + fill_val=0): + """Rotate the masks.""" + h, w, c = results['img_shape'] + for key in results.get('mask_fields', []): + masks = results[key] + results[key] = masks.rotate((h, w), angle, center, scale, fill_val) + + def _rotate_seg(self, + results, + angle, + center=None, + scale=1.0, + fill_val=255): + """Rotate the segmentation map.""" + for key in results.get('seg_fields', []): + seg = results[key].copy() + results[key] = mmcv.imrotate( + seg, angle, center, scale, + border_value=fill_val).astype(seg.dtype) + + def _filter_invalid(self, results, min_bbox_size=0): + """Filter bboxes and corresponding masks too small after rotate + augmentation.""" + bbox2label, bbox2mask, _ = bbox2fields() + for key in results.get('bbox_fields', []): + bbox_w = results[key][:, 2] - results[key][:, 0] + bbox_h = results[key][:, 3] - results[key][:, 1] + valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size) + valid_inds = np.nonzero(valid_inds)[0] + results[key] = results[key][valid_inds] + # label fields. e.g. gt_labels and gt_labels_ignore + label_key = bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][valid_inds] + # mask fields, e.g. gt_masks and gt_masks_ignore + mask_key = bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][valid_inds] + + def __call__(self, results): + """Call function to rotate images, bounding boxes, masks and semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + if np.random.rand() > self.prob: + return results + h, w = results['img'].shape[:2] + center = self.center + if center is None: + center = ((w - 1) * 0.5, (h - 1) * 0.5) + angle = random_negative(self.angle, self.random_negative_prob) + self._rotate_img(results, angle, center, self.scale) + rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) + self._rotate_bboxes(results, rotate_matrix) + self._rotate_keypoints90(results, angle) + self._rotate_masks(results, angle, center, self.scale, fill_val=0) + self._rotate_seg( + results, angle, center, self.scale, fill_val=self.seg_ignore_label) + self._filter_invalid(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(level={self.level}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'center={self.center}, ' + repr_str += f'img_fill_val={self.img_fill_val}, ' + repr_str += f'seg_ignore_label={self.seg_ignore_label}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'max_rotate_angle={self.max_rotate_angle}, ' + repr_str += f'random_negative_prob={self.random_negative_prob})' + return repr_str diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/formating.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/formating.py new file mode 100644 index 00000000..bd2394a8 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/formating.py @@ -0,0 +1,113 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/formating.py +""" +import numpy as np +import torch +from mmcv.parallel import DataContainer as DC +from mmdet.datasets.builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PIPELINES.register_module() +class DefaultFormatBundleV2(object): + """Default formatting bundle. + + It simplifies the pipeline of formatting common fields, including "img", + "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". + These fields are formatted as follows. + + - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - proposals: (1)to tensor, (2)to DataContainer + - gt_bboxes: (1)to tensor, (2)to DataContainer + - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer + - gt_labels: (1)to tensor, (2)to DataContainer + - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) + - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ + (3)to DataContainer (stack=True) + """ + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with \ + default bundle. + """ + + if 'img' in results: + img = results['img'] + # add default meta keys + results = self._add_default_meta_keys(results) + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC(to_tensor(img), stack=True) + for key in [ + 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_keypointss', + 'gt_labels' + ]: + if key not in results: + continue + results[key] = DC(to_tensor(results[key])) + if 'gt_masks' in results: + results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) + if 'gt_semantic_seg' in results: + results['gt_semantic_seg'] = DC( + to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) + return results + + def _add_default_meta_keys(self, results): + """Add default meta keys. + + We set default meta keys including `pad_shape`, `scale_factor` and + `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and + `Pad` are implemented during the whole pipeline. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + results (dict): Updated result dict contains the data to convert. + """ + img = results['img'] + results.setdefault('pad_shape', img.shape) + results.setdefault('scale_factor', 1.0) + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results.setdefault( + 'img_norm_cfg', + dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False)) + return results + + def __repr__(self): + return self.__class__.__name__ diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/loading.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/loading.py new file mode 100644 index 00000000..b4c2a385 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/loading.py @@ -0,0 +1,225 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/loading.py +""" +import os.path as osp + +import numpy as np +import pycocotools.mask as maskUtils +from mmdet.core import BitmapMasks, PolygonMasks +from mmdet.datasets.builder import PIPELINES + + +@PIPELINES.register_module() +class LoadAnnotationsV2(object): + """Load mutiple types of annotations. + + Args: + with_bbox (bool): Whether to parse and load the bbox annotation. + Default: True. + with_label (bool): Whether to parse and load the label annotation. + Default: True. + with_keypoints (bool): Whether to parse and load the keypoints annotation. + Default: False. + with_mask (bool): Whether to parse and load the mask annotation. + Default: False. + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Default: False. + poly2mask (bool): Whether to convert the instance masks from polygons + to bitmaps. Default: True. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, + with_bbox=True, + with_label=True, + with_keypoints=False, + with_mask=False, + with_seg=False, + poly2mask=True, + file_client_args=dict(backend='disk')): + self.with_bbox = with_bbox + self.with_label = with_label + self.with_keypoints = with_keypoints + self.with_mask = with_mask + self.with_seg = with_seg + self.poly2mask = poly2mask + self.file_client_args = file_client_args.copy() + self.file_client = None + + def _load_bboxes(self, results): + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + + ann_info = results['ann_info'] + results['gt_bboxes'] = ann_info['bboxes'].copy() + + gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) + if gt_bboxes_ignore is not None: + results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() + results['bbox_fields'].append('gt_bboxes_ignore') + results['bbox_fields'].append('gt_bboxes') + return results + + def _load_keypoints(self, results): + """Private function to load bounding box annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + + ann_info = results['ann_info'] + results['gt_keypointss'] = ann_info['keypointss'].copy() + + results['keypoints_fields'] = ['gt_keypointss'] + return results + + def _load_labels(self, results): + """Private function to load label annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded label annotations. + """ + + results['gt_labels'] = results['ann_info']['labels'].copy() + return results + + def _poly2mask(self, mask_ann, img_h, img_w): + """Private function to convert masks represented with polygon to + bitmaps. + + Args: + mask_ann (list | dict): Polygon mask annotation input. + img_h (int): The height of output mask. + img_w (int): The width of output mask. + + Returns: + numpy.ndarray: The decode bitmap mask of shape (img_h, img_w). + """ + + if isinstance(mask_ann, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) + rle = maskUtils.merge(rles) + elif isinstance(mask_ann['counts'], list): + # uncompressed RLE + rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) + else: + # rle + rle = mask_ann + mask = maskUtils.decode(rle) + return mask + + def process_polygons(self, polygons): + """Convert polygons to list of ndarray and filter invalid polygons. + + Args: + polygons (list[list]): Polygons of one instance. + + Returns: + list[numpy.ndarray]: Processed polygons. + """ + + polygons = [np.array(p) for p in polygons] + valid_polygons = [] + for polygon in polygons: + if len(polygon) % 2 == 0 and len(polygon) >= 6: + valid_polygons.append(polygon) + return valid_polygons + + def _load_masks(self, results): + """Private function to load mask annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded mask annotations. + If ``self.poly2mask`` is set ``True``, `gt_mask` will contain + :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used. + """ + + h, w = results['img_info']['height'], results['img_info']['width'] + gt_masks = results['ann_info']['masks'] + if self.poly2mask: + gt_masks = BitmapMasks( + [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) + else: + gt_masks = PolygonMasks( + [self.process_polygons(polygons) for polygons in gt_masks], h, + w) + results['gt_masks'] = gt_masks + results['mask_fields'].append('gt_masks') + return results + + def _load_semantic_seg(self, results): + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + import mmcv + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + filename = osp.join(results['seg_prefix'], + results['ann_info']['seg_map']) + img_bytes = self.file_client.get(filename) + results['gt_semantic_seg'] = mmcv.imfrombytes( + img_bytes, flag='unchanged').squeeze() + results['seg_fields'].append('gt_semantic_seg') + return results + + def __call__(self, results): + """Call function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded bounding box, label, mask and + semantic segmentation annotations. + """ + + if self.with_bbox: + results = self._load_bboxes(results) + if results is None: + return None + if self.with_label: + results = self._load_labels(results) + if self.with_keypoints: + results = self._load_keypoints(results) + if self.with_mask: + results = self._load_masks(results) + if self.with_seg: + results = self._load_semantic_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(with_bbox={self.with_bbox}, ' + repr_str += f'with_label={self.with_label}, ' + repr_str += f'with_keypoints={self.with_keypoints}, ' + repr_str += f'with_mask={self.with_mask}, ' + repr_str += f'with_seg={self.with_seg})' + repr_str += f'poly2mask={self.poly2mask})' + repr_str += f'poly2mask={self.file_client_args})' + return repr_str diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/transforms.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/transforms.py new file mode 100755 index 00000000..270c34da --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/pipelines/transforms.py @@ -0,0 +1,737 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/pipelines/transforms.py +""" +import mmcv +import numpy as np +from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps +from mmdet.datasets.builder import PIPELINES +from numpy import random + + +@PIPELINES.register_module() +class ResizeV2(object): + """Resize images & bbox & mask &kps. + + This transform resizes the input image to some scale. Bboxes and masks are + then resized with the same scale factor. If the input dict contains the key + "scale", then the scale in the input dict is used, otherwise the specified + scale in the init method is used. If the input dict contains the key + "scale_factor" (if MultiScaleFlipAug does not give img_scale but + scale_factor), the actual scale will be computed by image shape and + scale_factor. + + `img_scale` can either be a tuple (single-scale) or a list of tuple + (multi-scale). There are 3 multiscale modes: + + - ``ratio_range is not None``: randomly sample a ratio from the ratio \ + range and multiply it with the image scale. + - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ + sample a scale from the multiscale range. + - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ + sample a scale from multiple scales. + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + override (bool, optional): Whether to override `scale` and + `scale_factor` so as to call resize twice. Default False. If True, + after the first resizing, the existed `scale` and `scale_factor` + will be ignored so the second resizing can be allowed. + This option is a work-around for multiple times of resize in DETR. + Defaults to False. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True, + bbox_clip_border=True, + backend='cv2', + override=False): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given a scale and a range of image ratio + assert len(self.img_scale) == 1 + else: + # mode 2: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.backend = backend + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + # TODO: refactor the override option in Resize + self.override = override + self.bbox_clip_border = bbox_clip_border + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ + where ``img_scale`` is the selected image scale and \ + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and uper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where \ + ``img_scale`` is sampled scale and None is just a placeholder \ + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where \ + ``scale`` is sampled ratio multiplied with ``img_scale`` and \ + None is just a placeholder to be consistent with \ + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into \ + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + for key in results.get('img_fields', ['img']): + if self.keep_ratio: + img, scale_factor = mmcv.imrescale( + results[key], + results['scale'], + return_scale=True, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results[key].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + results[key], + results['scale'], + return_scale=True, + backend=self.backend) + results[key] = img + + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img_shape'] = img.shape + # in case that there is no padding + results['pad_shape'] = img.shape + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_bboxes(self, results): + """Resize bounding boxes with ``results['scale_factor']``.""" + for key in results.get('bbox_fields', []): + bboxes = results[key] * results['scale_factor'] + if self.bbox_clip_border: + img_shape = results['img_shape'] + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) + results[key] = bboxes + + def _resize_keypoints(self, results): + """Resize keypoints with ``results['scale_factor']``.""" + for key in results.get('keypoints_fields', []): + keypointss = results[key].copy() + factors = results['scale_factor'] + assert factors[0] == factors[2] + assert factors[1] == factors[3] + keypointss[:, :, 0] *= factors[0] + keypointss[:, :, 1] *= factors[1] + if self.bbox_clip_border: + img_shape = results['img_shape'] + keypointss[:, :, 0] = np.clip(keypointss[:, :, 0], 0, + img_shape[1]) + keypointss[:, :, 1] = np.clip(keypointss[:, :, 1], 0, + img_shape[0]) + results[key] = keypointss + + def _resize_masks(self, results): + """Resize masks with ``results['scale']``""" + for key in results.get('mask_fields', []): + if results[key] is None: + continue + if self.keep_ratio: + results[key] = results[key].rescale(results['scale']) + else: + results[key] = results[key].resize(results['img_shape'][:2]) + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + else: + gt_seg = mmcv.imresize( + results[key], + results['scale'], + interpolation='nearest', + backend=self.backend) + results['gt_semantic_seg'] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \ + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + if 'scale_factor' in results: + img_shape = results['img'].shape[:2] + scale_factor = results['scale_factor'] + assert isinstance(scale_factor, float) + results['scale'] = tuple( + [int(x * scale_factor) for x in img_shape][::-1]) + else: + self._random_scale(results) + else: + if not self.override: + assert 'scale_factor' not in results, ( + 'scale and scale_factor cannot be both set.') + else: + results.pop('scale') + if 'scale_factor' in results: + results.pop('scale_factor') + self._random_scale(results) + + self._resize_img(results) + self._resize_bboxes(results) + self._resize_keypoints(results) + self._resize_masks(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(img_scale={self.img_scale}, ' + repr_str += f'multiscale_mode={self.multiscale_mode}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'keep_ratio={self.keep_ratio})' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +@PIPELINES.register_module() +class RandomFlipV2(object): + """Flip the image & bbox & mask & kps. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + When random flip is enabled, ``flip_ratio``/``direction`` can either be a + float/string or tuple of float/string. There are 3 flip modes: + + - ``flip_ratio`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``flip_ratio`` . + E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + - ``flip_ratio`` is float, ``direction`` is list of string: the image wil + be ``direction[i]``ly flipped with probability of + ``flip_ratio/len(direction)``. + E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + - ``flip_ratio`` is list of float, ``direction`` is list of string: + given ``len(flip_ratio) == len(direction)``, the image wil + be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. + E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with probability + of 0.3, vertically with probability of 0.5 + + Args: + flip_ratio (float | list[float], optional): The flipping probability. + Default: None. + direction(str | list[str], optional): The flipping direction. Options + are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. + If input is a list, the length must equal ``flip_ratio``. Each + element in ``flip_ratio`` indicates the flip probability of + corresponding direction. + """ + + def __init__(self, flip_ratio=None, direction='horizontal'): + if isinstance(flip_ratio, list): + assert mmcv.is_list_of(flip_ratio, float) + assert 0 <= sum(flip_ratio) <= 1 + elif isinstance(flip_ratio, float): + assert 0 <= flip_ratio <= 1 + elif flip_ratio is None: + pass + else: + raise ValueError('flip_ratios must be None, float, ' + 'or list of float') + self.flip_ratio = flip_ratio + + valid_directions = ['horizontal', 'vertical', 'diagonal'] + if isinstance(direction, str): + assert direction in valid_directions + elif isinstance(direction, list): + assert mmcv.is_list_of(direction, str) + assert set(direction).issubset(set(valid_directions)) + else: + raise ValueError('direction must be either str or list of str') + self.direction = direction + + if isinstance(flip_ratio, list): + assert len(self.flip_ratio) == len(self.direction) + self.count = 0 + + def bbox_flip(self, bboxes, img_shape, direction): + """Flip bboxes horizontally. + + Args: + bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) + img_shape (tuple[int]): Image shape (height, width) + direction (str): Flip direction. Options are 'horizontal', + 'vertical'. + + Returns: + numpy.ndarray: Flipped bounding boxes. + """ + + assert bboxes.shape[-1] % 4 == 0 + flipped = bboxes.copy() + if direction == 'horizontal': + w = img_shape[1] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + elif direction == 'vertical': + h = img_shape[0] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + elif direction == 'diagonal': + w = img_shape[1] + h = img_shape[0] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + else: + raise ValueError(f"Invalid flipping direction '{direction}'") + return flipped + + def keypoints_flip(self, keypointss, img_shape, direction): + """Flip keypoints horizontally.""" + + assert direction == 'horizontal' + assert keypointss.shape[-1] == 3 + num_kps = keypointss.shape[1] + assert num_kps in [4, 5], f'Only Support num_kps=4 or 5, got:{num_kps}' + assert keypointss.ndim == 3 + flipped = keypointss.copy() + if num_kps == 5: + flip_order = [1, 0, 2, 4, 3] + elif num_kps == 4: + flip_order = [3, 2, 1, 0] + for idx, a in enumerate(flip_order): + flipped[:, idx, :] = keypointss[:, a, :] + w = img_shape[1] + flipped[..., 0] = w - flipped[..., 0] + return flipped + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction' keys are added \ + into result dict. + """ + if 'flip' not in results: + if isinstance(self.direction, list): + # None means non-flip + direction_list = self.direction + [None] + else: + # None means non-flip + direction_list = [self.direction, None] + + if isinstance(self.flip_ratio, list): + non_flip_ratio = 1 - sum(self.flip_ratio) + flip_ratio_list = self.flip_ratio + [non_flip_ratio] + else: + non_flip_ratio = 1 - self.flip_ratio + # exclude non-flip + single_ratio = self.flip_ratio / (len(direction_list) - 1) + flip_ratio_list = [single_ratio] * (len(direction_list) + - 1) + [non_flip_ratio] + + cur_dir = np.random.choice(direction_list, p=flip_ratio_list) + + results['flip'] = cur_dir is not None + if 'flip_direction' not in results: + results['flip_direction'] = cur_dir + if results['flip']: + # flip image + for key in results.get('img_fields', ['img']): + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']) + # flip bboxes + for key in results.get('bbox_fields', []): + results[key] = self.bbox_flip(results[key], + results['img_shape'], + results['flip_direction']) + # flip kps + for key in results.get('keypoints_fields', []): + results[key] = self.keypoints_flip(results[key], + results['img_shape'], + results['flip_direction']) + # flip masks + for key in results.get('mask_fields', []): + results[key] = results[key].flip(results['flip_direction']) + + # flip segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' + + +@PIPELINES.register_module() +class RandomSquareCrop(object): + """Random crop the image & bboxes, the cropped patches have minimum IoU + requirement with original image & bboxes, the IoU threshold is randomly + selected from min_ious. + + Args: + min_ious (tuple): minimum IoU threshold for all intersections with + bounding boxes + min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, + where a >= min_crop_size). + + Note: + The keys for bboxes, labels and masks should be paired. That is, \ + `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \ + `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`. + """ + + def __init__(self, + crop_ratio_range=None, + crop_choice=None, + bbox_clip_border=True, + big_face_ratio=0, + big_face_crop_choice=None): + + self.crop_ratio_range = crop_ratio_range + self.crop_choice = crop_choice + self.big_face_crop_choice = big_face_crop_choice + self.bbox_clip_border = bbox_clip_border + + assert (self.crop_ratio_range is None) ^ (self.crop_choice is None) + if self.crop_ratio_range is not None: + self.crop_ratio_min, self.crop_ratio_max = self.crop_ratio_range + + self.bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + self.bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + assert big_face_ratio >= 0 and big_face_ratio <= 1.0 + self.big_face_ratio = big_face_ratio + + def __call__(self, results): + """Call function to crop images and bounding boxes with minimum IoU + constraint. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images and bounding boxes cropped, \ + 'img_shape' key is updated. + """ + + if 'img_fields' in results: + assert results['img_fields'] == ['img'], \ + 'Only single img_fields is allowed' + img = results['img'] + assert 'bbox_fields' in results + assert 'gt_bboxes' in results + # try augment big face images + find_bigface = False + if np.random.random() < self.big_face_ratio: + min_size = 100 # h and w + expand_ratio = 0.3 # expand ratio of croped face alongwith both w and h + bbox = results['gt_bboxes'].copy() + lmks = results['gt_keypointss'].copy() + label = results['gt_labels'].copy() + # filter small faces + size_mask = ((bbox[:, 2] - bbox[:, 0]) > min_size) * ( + (bbox[:, 3] - bbox[:, 1]) > min_size) + bbox = bbox[size_mask] + lmks = lmks[size_mask] + label = label[size_mask] + # randomly choose a face that has no overlap with others + if len(bbox) > 0: + overlaps = bbox_overlaps(bbox, bbox) + overlaps -= np.eye(overlaps.shape[0]) + iou_mask = np.sum(overlaps, axis=1) == 0 + bbox = bbox[iou_mask] + lmks = lmks[iou_mask] + label = label[iou_mask] + if len(bbox) > 0: + choice = np.random.randint(len(bbox)) + bbox = bbox[choice] + lmks = lmks[choice] + label = [label[choice]] + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x1 = bbox[0] - w * expand_ratio + x2 = bbox[2] + w * expand_ratio + y1 = bbox[1] - h * expand_ratio + y2 = bbox[3] + h * expand_ratio + x1, x2 = np.clip([x1, x2], 0, img.shape[1]) + y1, y2 = np.clip([y1, y2], 0, img.shape[0]) + bbox -= np.tile([x1, y1], 2) + lmks -= (x1, y1, 0) + + find_bigface = True + img = img[int(y1):int(y2), int(x1):int(x2), :] + results['gt_bboxes'] = np.expand_dims(bbox, axis=0) + results['gt_keypointss'] = np.expand_dims(lmks, axis=0) + results['gt_labels'] = np.array(label) + results['img'] = img + + boxes = results['gt_bboxes'] + h, w, c = img.shape + + if self.crop_ratio_range is not None: + max_scale = self.crop_ratio_max + else: + max_scale = np.amax(self.crop_choice) + scale_retry = 0 + while True: + scale_retry += 1 + if scale_retry == 1 or max_scale > 1.0: + if self.crop_ratio_range is not None: + scale = np.random.uniform(self.crop_ratio_min, + self.crop_ratio_max) + elif self.crop_choice is not None: + scale = np.random.choice(self.crop_choice) + else: + scale = scale * 1.2 + + if find_bigface: + # select a scale from big_face_crop_choice if in big_face mode + scale = np.random.choice(self.big_face_crop_choice) + + for i in range(250): + long_side = max(w, h) + cw = int(scale * long_side) + ch = cw + + # TODO +1 + if w == cw: + left = 0 + elif w > cw: + left = random.randint(0, w - cw) + else: + left = random.randint(w - cw, 0) + if h == ch: + top = 0 + elif h > ch: + top = random.randint(0, h - ch) + else: + top = random.randint(h - ch, 0) + + patch = np.array( + (int(left), int(top), int(left + cw), int(top + ch)), + dtype=np.int32) + + # center of boxes should inside the crop img + # only adjust boxes and instance masks when the gt is not empty + # adjust boxes + def is_center_of_bboxes_in_patch(boxes, patch): + # TODO >= + center = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask = \ + ((center[:, 0] > patch[0]) + * (center[:, 1] > patch[1]) + * (center[:, 0] < patch[2]) + * (center[:, 1] < patch[3])) + return mask + + mask = is_center_of_bboxes_in_patch(boxes, patch) + if not mask.any(): + continue + for key in results.get('bbox_fields', []): + boxes = results[key].copy() + mask = is_center_of_bboxes_in_patch(boxes, patch) + boxes = boxes[mask] + if self.bbox_clip_border: + boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) + boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) + boxes -= np.tile(patch[:2], 2) + + results[key] = boxes + # labels + label_key = self.bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][mask] + + # keypoints field + if key == 'gt_bboxes': + for kps_key in results.get('keypoints_fields', []): + keypointss = results[kps_key].copy() + keypointss = keypointss[mask, :, :] + if self.bbox_clip_border: + keypointss[:, :, : + 2] = keypointss[:, :, :2].clip( + max=patch[2:]) + keypointss[:, :, : + 2] = keypointss[:, :, :2].clip( + min=patch[:2]) + keypointss[:, :, 0] -= patch[0] + keypointss[:, :, 1] -= patch[1] + results[kps_key] = keypointss + + # mask fields + mask_key = self.bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][mask.nonzero() + [0]].crop(patch) + + # adjust the img no matter whether the gt is empty before crop + rimg = np.ones((ch, cw, 3), dtype=img.dtype) * 128 + patch_from = patch.copy() + patch_from[0] = max(0, patch_from[0]) + patch_from[1] = max(0, patch_from[1]) + patch_from[2] = min(img.shape[1], patch_from[2]) + patch_from[3] = min(img.shape[0], patch_from[3]) + patch_to = patch.copy() + patch_to[0] = max(0, patch_to[0] * -1) + patch_to[1] = max(0, patch_to[1] * -1) + patch_to[2] = patch_to[0] + (patch_from[2] - patch_from[0]) + patch_to[3] = patch_to[1] + (patch_from[3] - patch_from[1]) + rimg[patch_to[1]:patch_to[3], + patch_to[0]:patch_to[2], :] = img[ + patch_from[1]:patch_from[3], + patch_from[0]:patch_from[2], :] + img = rimg + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_ious={self.min_iou}, ' + repr_str += f'crop_size={self.crop_size})' + return repr_str diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/retinaface.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/retinaface.py new file mode 100755 index 00000000..40c440b9 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/datasets/retinaface.py @@ -0,0 +1,153 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/datasets/retinaface.py +""" +import numpy as np +from mmdet.datasets.builder import DATASETS +from mmdet.datasets.custom import CustomDataset + + +@DATASETS.register_module() +class RetinaFaceDataset(CustomDataset): + + CLASSES = ('FG', ) + + def __init__(self, min_size=None, **kwargs): + self.NK = kwargs.pop('num_kps', 5) + self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} + self.min_size = min_size + self.gt_path = kwargs.get('gt_path') + super(RetinaFaceDataset, self).__init__(**kwargs) + + def _parse_ann_line(self, line): + values = [float(x) for x in line.strip().split()] + bbox = np.array(values[0:4], dtype=np.float32) + kps = np.zeros((self.NK, 3), dtype=np.float32) + ignore = False + if self.min_size is not None: + assert not self.test_mode + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + if w < self.min_size or h < self.min_size: + ignore = True + if len(values) > 4: + if len(values) > 5: + kps = np.array( + values[4:4 + self.NK * 3], dtype=np.float32).reshape( + (self.NK, 3)) + for li in range(kps.shape[0]): + if (kps[li, :] == -1).all(): + kps[li][2] = 0.0 # weight = 0, ignore + else: + assert kps[li][2] >= 0 + kps[li][2] = 1.0 # weight + else: # len(values)==5 + if not ignore: + ignore = (values[4] == 1) + else: + assert self.test_mode + + return dict(bbox=bbox, kps=kps, ignore=ignore, cat='FG') + + def load_annotations(self, ann_file): + """Load annotation from COCO style annotation file. + + Args: + ann_file (str): Path of annotation file. + 20220711@tyx: ann_file is list of img paths is supported + + Returns: + list[dict]: Annotation info from COCO api. + """ + if isinstance(ann_file, list): + data_infos = [] + for line in ann_file: + name = line + objs = [0, 0, 0, 0] + data_infos.append( + dict(filename=name, width=0, height=0, objs=objs)) + else: + name = None + bbox_map = {} + for line in open(ann_file, 'r'): + line = line.strip() + if line.startswith('#'): + value = line[1:].strip().split() + name = value[0] + width = int(value[1]) + height = int(value[2]) + + bbox_map[name] = dict(width=width, height=height, objs=[]) + continue + assert name is not None + assert name in bbox_map + bbox_map[name]['objs'].append(line) + print('origin image size', len(bbox_map)) + data_infos = [] + for name in bbox_map: + item = bbox_map[name] + width = item['width'] + height = item['height'] + vals = item['objs'] + objs = [] + for line in vals: + data = self._parse_ann_line(line) + if data is None: + continue + objs.append(data) # data is (bbox, kps, cat) + if len(objs) == 0 and not self.test_mode: + continue + data_infos.append( + dict(filename=name, width=width, height=height, objs=objs)) + return data_infos + + def get_ann_info(self, idx): + """Get COCO annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + data_info = self.data_infos[idx] + + bboxes = [] + keypointss = [] + labels = [] + bboxes_ignore = [] + labels_ignore = [] + for obj in data_info['objs']: + label = self.cat2label[obj['cat']] + bbox = obj['bbox'] + keypoints = obj['kps'] + ignore = obj['ignore'] + if ignore: + bboxes_ignore.append(bbox) + labels_ignore.append(label) + else: + bboxes.append(bbox) + labels.append(label) + keypointss.append(keypoints) + if not bboxes: + bboxes = np.zeros((0, 4)) + labels = np.zeros((0, )) + keypointss = np.zeros((0, self.NK, 3)) + else: + # bboxes = np.array(bboxes, ndmin=2) - 1 + bboxes = np.array(bboxes, ndmin=2) + labels = np.array(labels) + keypointss = np.array(keypointss, ndmin=3) + if not bboxes_ignore: + bboxes_ignore = np.zeros((0, 4)) + labels_ignore = np.zeros((0, )) + else: + bboxes_ignore = np.array(bboxes_ignore, ndmin=2) + labels_ignore = np.array(labels_ignore) + ann = dict( + bboxes=bboxes.astype(np.float32), + labels=labels.astype(np.int64), + keypointss=keypointss.astype(np.float32), + bboxes_ignore=bboxes_ignore.astype(np.float32), + labels_ignore=labels_ignore.astype(np.int64)) + return ann diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/__init__.py new file mode 100755 index 00000000..bd5d5f5f --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/__init__.py @@ -0,0 +1,6 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models +""" +from .dense_heads import * # noqa: F401,F403 +from .detectors import * # noqa: F401,F403 diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py new file mode 100755 index 00000000..5c3b190e --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/__init__.py @@ -0,0 +1,7 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones +""" +from .resnet import ResNetV1e + +__all__ = ['ResNetV1e'] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/resnet.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/resnet.py new file mode 100644 index 00000000..a5862a58 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/backbones/resnet.py @@ -0,0 +1,413 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/backbones/resnet.py +""" +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer, + constant_init, kaiming_init) +from mmcv.runner import load_checkpoint +from mmdet.models.backbones.resnet import BasicBlock, Bottleneck +from mmdet.models.builder import BACKBONES +from mmdet.models.utils import ResLayer +from mmdet.utils import get_root_logger +from torch.nn.modules.batchnorm import _BatchNorm + + +class ResNet(nn.Module): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + stem_channels (int | None): Number of stem channels. If not specified, + it will be the same as `base_channels`. Default: None. + base_channels (int): Number of base channels of res layer. Default: 64. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + - position (str, required): Position inside block to insert + plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmdet.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 0: (BasicBlock, (2, 2, 2, 2)), + 18: (BasicBlock, (2, 2, 2, 2)), + 19: (BasicBlock, (2, 4, 4, 1)), + 20: (BasicBlock, (2, 3, 2, 2)), + 22: (BasicBlock, (2, 4, 3, 1)), + 24: (BasicBlock, (2, 4, 4, 1)), + 26: (BasicBlock, (2, 4, 4, 2)), + 28: (BasicBlock, (2, 5, 4, 2)), + 29: (BasicBlock, (2, 6, 3, 2)), + 30: (BasicBlock, (2, 5, 5, 2)), + 32: (BasicBlock, (2, 6, 5, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 35: (BasicBlock, (3, 6, 4, 3)), + 38: (BasicBlock, (3, 8, 4, 3)), + 40: (BasicBlock, (3, 8, 5, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 56: (Bottleneck, (3, 8, 4, 3)), + 68: (Bottleneck, (3, 10, 6, 3)), + 74: (Bottleneck, (3, 12, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=None, + base_channels=64, + num_stages=4, + block_cfg=None, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + no_pool33=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + with_cp=False, + zero_init_residual=True): + super(ResNet, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + if stem_channels is None: + stem_channels = base_channels + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.no_pool33 = no_pool33 + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.zero_init_residual = zero_init_residual + if block_cfg is None: + self.block, stage_blocks = self.arch_settings[depth] + else: + self.block = BasicBlock if block_cfg[ + 'block'] == 'BasicBlock' else Bottleneck + stage_blocks = block_cfg['stage_blocks'] + assert len(stage_blocks) >= num_stages + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + if block_cfg is not None and 'stage_planes' in block_cfg: + stage_planes = block_cfg['stage_planes'] + else: + stage_planes = [base_channels * 2**i for i in range(num_stages)] + + # print('resnet cfg:', stage_blocks, stage_planes) + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + planes = stage_planes[i] + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """Make plugins for ResNet ``stage_idx`` th stage. + + Currently we support to insert ``context_block``, + ``empirical_attention_block``, ``nonlocal_block`` into the backbone + like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be: + + Examples: + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose ``stage_idx=0``, the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->conv3->yyy->zzz1->zzz2 + + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + if self.no_pool33: + assert self.deep_stem + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + + if self.dcn is not None: + for m in self.modules(): + if isinstance(m, Bottleneck) and hasattr( + m.conv2, 'conv_offset'): + constant_init(m.conv2.conv_offset, 0) + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@BACKBONES.register_module() +class ResNetV1e(ResNet): + r"""ResNetV1d variant described in `Bag of Tricks + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + + Compared with ResNetV1d, ResNetV1e change maxpooling from 3x3 to 2x2 pad=1 + """ + + def __init__(self, **kwargs): + super(ResNetV1e, self).__init__( + deep_stem=True, avg_down=True, no_pool33=True, **kwargs) diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/__init__.py new file mode 100755 index 00000000..9ba63b68 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/__init__.py @@ -0,0 +1,7 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/dense_heads +""" +from .scrfd_head import SCRFDHead + +__all__ = ['SCRFDHead'] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/scrfd_head.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/scrfd_head.py new file mode 100755 index 00000000..77ec99cf --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/dense_heads/scrfd_head.py @@ -0,0 +1,1070 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/dense_heads/scrfd_head.py +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, Scale, + bias_init_with_prob, constant_init, kaiming_init, + normal_init) +from mmcv.runner import force_fp32 +from mmdet.core import (anchor_inside_flags, bbox2distance, bbox_overlaps, + build_assigner, build_sampler, distance2bbox, + images_to_levels, multi_apply, reduce_mean, unmap) +from mmdet.models.builder import HEADS, build_loss +from mmdet.models.dense_heads.anchor_head import AnchorHead + +from ....mmdet_patch.core.bbox import distance2kps, kps2distance +from ....mmdet_patch.core.post_processing import multiclass_nms + + +class Integral(nn.Module): + """A fixed layer for calculating integral result from distribution. + + This layer calculates the target location by :math: `sum{P(y_i) * y_i}`, + P(y_i) denotes the softmax vector that represents the discrete distribution + y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max} + + Args: + reg_max (int): The maximal value of the discrete set. Default: 16. You + may want to reset it according to your new dataset or related + settings. + """ + + def __init__(self, reg_max=16): + super(Integral, self).__init__() + self.reg_max = reg_max + self.register_buffer('project', + torch.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + + Args: + x (Tensor): Features of the regression head, shape (N, 4*(n+1)), + n is self.reg_max. + + Returns: + x (Tensor): Integral result of box locations, i.e., distance + offsets from the box center in four directions, shape (N, 4). + """ + x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1) + x = F.linear(x, self.project.type_as(x)).reshape(-1, 4) + return x + + +@HEADS.register_module() +class SCRFDHead(AnchorHead): + """Generalized Focal Loss: Learning Qualified and Distributed Bounding + Boxes for Dense Object Detection. + + GFL head structure is similar with ATSS, however GFL uses + 1) joint representation for classification and localization quality, and + 2) flexible General distribution for bounding box locations, + which are supervised by + Quality Focal Loss (QFL) and Distribution Focal Loss (DFL), respectively + + https://arxiv.org/abs/2006.04388 + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + stacked_convs (int): Number of conv layers in cls and reg tower. + Default: 4. + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='GN', num_groups=32, requires_grad=True). + loss_qfl (dict): Config of Quality Focal Loss (QFL). + reg_max (int): Max value of integral set :math: `{0, ..., reg_max}` + in QFL setting. Default: 16. + Example: + >>> self = GFLHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_quality_score, bbox_pred = self.forward(feats) + >>> assert len(cls_quality_score) == len(self.scales) + """ + + def __init__(self, + num_classes, + in_channels, + stacked_convs=4, + feat_mults=None, + conv_cfg=None, + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), + loss_dfl=None, + reg_max=8, + cls_reg_share=False, + strides_share=True, + scale_mode=1, + dw_conv=False, + use_kps=False, + num_kps=5, + loss_kps=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.1), + **kwargs): + self.stacked_convs = stacked_convs + self.feat_mults = feat_mults + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.reg_max = reg_max + self.cls_reg_share = cls_reg_share + self.strides_share = strides_share + self.scale_mode = scale_mode + self.use_dfl = True + self.dw_conv = dw_conv + self.NK = num_kps + self.extra_flops = 0.0 + if loss_dfl is None or not loss_dfl: + self.use_dfl = False + self.use_scale = False + self.use_kps = use_kps + if self.scale_mode > 0 and (self.strides_share + or self.scale_mode == 2): + self.use_scale = True + super(SCRFDHead, self).__init__(num_classes, in_channels, **kwargs) + + self.sampling = False + if self.train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + # SSD sampling=False so use PseudoSampler + sampler_cfg = dict(type='PseudoSampler') + self.sampler = build_sampler(sampler_cfg, context=self) + + self.integral = Integral(self.reg_max) + if self.use_dfl: + self.loss_dfl = build_loss(loss_dfl) + self.loss_kps = build_loss(loss_kps) + self.loss_kps_std = 1.0 + self.train_step = 0 + self.pos_count = {} + self.gtgroup_count = {} + for stride in self.anchor_generator.strides: + self.pos_count[stride[0]] = 0 + + def _get_conv_module(self, in_channel, out_channel): + if not self.dw_conv: + conv = ConvModule( + in_channel, + out_channel, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg) + else: + conv = DepthwiseSeparableConvModule( + in_channel, + out_channel, + 3, + stride=1, + padding=1, + pw_norm_cfg=self.norm_cfg, + dw_norm_cfg=self.norm_cfg) + return conv + + def _init_layers(self): + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + conv_strides = [0] if self.strides_share else \ + self.anchor_generator.strides + self.cls_stride_convs = nn.ModuleDict() + self.reg_stride_convs = nn.ModuleDict() + self.stride_cls = nn.ModuleDict() + self.stride_reg = nn.ModuleDict() + if self.use_kps: + self.stride_kps = nn.ModuleDict() + for stride_idx, conv_stride in enumerate(conv_strides): + key = str(conv_stride) + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + stacked_convs = self.stacked_convs[stride_idx] if \ + isinstance(self.stacked_convs, (list, tuple)) else \ + self.stacked_convs + feat_mult = self.feat_mults[stride_idx] if \ + self.feat_mults is not None else 1 + feat_ch = int(self.feat_channels * feat_mult) + last_feat_ch = 0 + for i in range(stacked_convs): + chn = self.in_channels if i == 0 else last_feat_ch + cls_convs.append(self._get_conv_module(chn, feat_ch)) + if not self.cls_reg_share: + reg_convs.append(self._get_conv_module(chn, feat_ch)) + last_feat_ch = feat_ch + self.cls_stride_convs[key] = cls_convs + self.reg_stride_convs[key] = reg_convs + self.stride_cls[key] = nn.Conv2d( + feat_ch, + self.cls_out_channels * self.num_anchors, + 3, + padding=1) + if not self.use_dfl: + self.stride_reg[key] = nn.Conv2d( + feat_ch, 4 * self.num_anchors, 3, padding=1) + else: + self.stride_reg[key] = nn.Conv2d( + feat_ch, + 4 * (self.reg_max + 1) * self.num_anchors, + 3, + padding=1) + if self.use_kps: + self.stride_kps[key] = nn.Conv2d( + feat_ch, self.NK * 2 * self.num_anchors, 3, padding=1) + if self.use_scale: + self.scales = nn.ModuleList( + [Scale(1.0) for _ in self.anchor_generator.strides]) + else: + self.scales = [None for _ in self.anchor_generator.strides] + + def init_weights(self): + """Initialize weights of the head.""" + for stride, cls_convs in self.cls_stride_convs.items(): + for m in cls_convs: + if not self.dw_conv: + try: + normal_init(m.conv, std=0.01) + except Exception: + pass + else: + normal_init(m.depthwise_conv.conv, std=0.01) + normal_init(m.pointwise_conv.conv, std=0.01) + for stride, reg_convs in self.reg_stride_convs.items(): + for m in reg_convs: + if not self.dw_conv: + normal_init(m.conv, std=0.01) + else: + normal_init(m.depthwise_conv.conv, std=0.01) + normal_init(m.pointwise_conv.conv, std=0.01) + bias_cls = -4.595 + for stride, conv in self.stride_cls.items(): + normal_init(conv, std=0.01, bias=bias_cls) + for stride, conv in self.stride_reg.items(): + normal_init(conv, std=0.01) + if self.use_kps: + for stride, conv in self.stride_kps.items(): + normal_init(conv, std=0.01) + + def forward(self, feats): + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + cls_scores (list[Tensor]): Classification and quality (IoU) + joint scores for all scale levels, each is a 4D-tensor, + the channel number is num_classes. + bbox_preds (list[Tensor]): Box distribution logits for all + scale levels, each is a 4D-tensor, the channel number is + 4*(n+1), n is max value of integral set. + """ + return multi_apply(self.forward_single, feats, self.scales, + self.anchor_generator.strides) + + def forward_single(self, x, scale, stride): + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + + Returns: + tuple: + cls_score (Tensor): Cls and quality joint scores for a single + scale level the channel number is num_classes. + bbox_pred (Tensor): Box distribution logits for a single scale + level, the channel number is 4*(n+1), n is max value of + integral set. + """ + cls_feat = x + reg_feat = x + cls_convs = self.cls_stride_convs[ + '0'] if self.strides_share else self.cls_stride_convs[str(stride)] + for cls_conv in cls_convs: + cls_feat = cls_conv(cls_feat) + if not self.cls_reg_share: + reg_convs = self.reg_stride_convs[ + '0'] if self.strides_share else self.reg_stride_convs[str( + stride)] + for reg_conv in reg_convs: + reg_feat = reg_conv(reg_feat) + else: + reg_feat = cls_feat + cls_pred_module = self.stride_cls[ + '0'] if self.strides_share else self.stride_cls[str(stride)] + cls_score = cls_pred_module(cls_feat) + reg_pred_module = self.stride_reg[ + '0'] if self.strides_share else self.stride_reg[str(stride)] + _bbox_pred = reg_pred_module(reg_feat) + if self.use_scale: + bbox_pred = scale(_bbox_pred) + else: + bbox_pred = _bbox_pred + if self.use_kps: + kps_pred_module = self.stride_kps[ + '0'] if self.strides_share else self.stride_kps[str(stride)] + kps_pred = kps_pred_module(reg_feat) + else: + kps_pred = bbox_pred.new_zeros( + (bbox_pred.shape[0], self.NK * 2, bbox_pred.shape[2], + bbox_pred.shape[3])) + if torch.onnx.is_in_onnx_export(): + assert not self.use_dfl + print('in-onnx-export', cls_score.shape, bbox_pred.shape) + # Add output batch dim, based on pull request #1593 + batch_size = cls_score.shape[0] + cls_score = cls_score.permute(0, 2, 3, 1).reshape( + batch_size, -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(batch_size, -1, 4) + kps_pred = kps_pred.permute(0, 2, 3, + 1).reshape(batch_size, -1, self.NK * 2) + return cls_score, bbox_pred, kps_pred + + def forward_train(self, + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_keypointss=None, + gt_bboxes_ignore=None, + proposal_cfg=None, + **kwargs): + """ + Args: + x (list[Tensor]): Features from FPN. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + proposal_cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used + + Returns: + tuple: + losses: (dict[str, Tensor]): A dictionary of loss components. + proposal_list (list[Tensor]): Proposals of each image. + """ + outs = self(x) + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, gt_keypointss, + img_metas) + losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + if proposal_cfg is None: + return losses + else: + proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) + return losses, proposal_list + + def get_anchors(self, featmap_sizes, img_metas, device='cuda'): + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + device (torch.device | str): Device for returned tensors + + Returns: + tuple: + anchor_list (list[Tensor]): Anchors of each image. + valid_flag_list (list[Tensor]): Valid flags of each image. + """ + num_imgs = len(img_metas) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + multi_level_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, device) + anchor_list = [multi_level_anchors for _ in range(num_imgs)] + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for img_id, img_meta in enumerate(img_metas): + multi_level_flags = self.anchor_generator.valid_flags( + featmap_sizes, img_meta['pad_shape'], device) + valid_flag_list.append(multi_level_flags) + + return anchor_list, valid_flag_list + + def anchor_center(self, anchors): + """Get anchor centers from anchors. + + Args: + anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format. + + Returns: + Tensor: Anchor centers with shape (N, 2), "xy" format. + """ + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + return torch.stack([anchors_cx, anchors_cy], dim=-1) + + def loss_single(self, anchors, cls_score, bbox_pred, kps_pred, labels, + label_weights, bbox_targets, kps_targets, kps_weights, + stride, num_total_samples): + """Compute loss of a single scale level. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Cls and quality joint scores for each scale + level has shape (N, num_classes, H, W). + bbox_pred (Tensor): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor wight + shape (N, num_total_anchors, 4). + stride (tuple): Stride in this scale level. + num_total_samples (int): Number of positive samples that is + reduced over all GPUs. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert stride[0] == stride[1], 'h stride is not equal to w stride!' + use_qscore = True + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + if not self.use_dfl: + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + else: + bbox_pred = bbox_pred.permute(0, 2, 3, 1) + bbox_pred = bbox_pred.reshape(-1, 4 * (self.reg_max + 1)) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + if self.use_kps: + kps_pred = kps_pred.permute(0, 2, 3, 1).reshape(-1, self.NK * 2) + kps_targets = kps_targets.reshape((-1, self.NK * 2)) + kps_weights = kps_weights.reshape((-1, self.NK * 2)) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + score = label_weights.new_zeros(labels.shape) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0] + + weight_targets = cls_score.detach().sigmoid() + weight_targets = weight_targets.max(dim=1)[0][pos_inds] + pos_decode_bbox_targets = pos_bbox_targets / stride[0] + + if self.use_dfl: + pos_bbox_pred_corners = self.integral(pos_bbox_pred) + pos_decode_bbox_pred = distance2bbox(pos_anchor_centers, + pos_bbox_pred_corners) + else: + pos_decode_bbox_pred = distance2bbox(pos_anchor_centers, + pos_bbox_pred) + if self.use_kps: + pos_kps_targets = kps_targets[pos_inds] + pos_kps_pred = kps_pred[pos_inds] + pos_kps_weights = kps_weights.max( + dim=1)[0][pos_inds] * weight_targets + pos_kps_weights = pos_kps_weights.reshape((-1, 1)) + pos_decode_kps_targets = kps2distance( + pos_anchor_centers, pos_kps_targets / stride[0]) + pos_decode_kps_pred = pos_kps_pred + if use_qscore: + score[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + else: + score[pos_inds] = 1.0 + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=weight_targets, + avg_factor=1.0) + + if self.use_kps: + loss_kps = self.loss_kps( + pos_decode_kps_pred * self.loss_kps_std, + pos_decode_kps_targets * self.loss_kps_std, + weight=pos_kps_weights, + avg_factor=1.0) + else: + loss_kps = kps_pred.sum() * 0 + + # dfl loss + if self.use_dfl: + pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1) + target_corners = bbox2distance(pos_anchor_centers, + pos_decode_bbox_targets, + self.reg_max).reshape(-1) + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets[:, None].expand(-1, 4).reshape(-1), + avg_factor=4.0) + else: + loss_dfl = bbox_pred.sum() * 0 + else: + loss_bbox = bbox_pred.sum() * 0 + loss_dfl = bbox_pred.sum() * 0 + loss_kps = kps_pred.sum() * 0 + weight_targets = torch.tensor(0).cuda() + + loss_cls = self.loss_cls( + cls_score, (labels, score), + weight=label_weights, + avg_factor=num_total_samples) + return loss_cls, loss_bbox, loss_dfl, loss_kps, weight_targets.sum() + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def loss(self, + cls_scores, + bbox_preds, + kps_preds, + gt_bboxes, + gt_labels, + gt_keypointss, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Cls and quality scores for each scale + level has shape (N, num_classes, H, W). + bbox_preds (list[Tensor]): Box distribution logits for each scale + level with shape (N, 4*(n+1), H, W), n is max value of integral + set. + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.anchor_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + gt_bboxes, + gt_keypointss, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels) + if cls_reg_targets is None: + return None + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, keypoints_targets_list, keypoints_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + + num_total_samples = reduce_mean( + torch.tensor(num_total_pos, dtype=torch.float, + device=device)).item() + num_total_samples = max(num_total_samples, 1.0) + + losses_cls, losses_bbox, losses_dfl, losses_kps,\ + avg_factor = multi_apply( + self.loss_single, + anchor_list, + cls_scores, + bbox_preds, + kps_preds, + labels_list, + label_weights_list, + bbox_targets_list, + keypoints_targets_list, + keypoints_weights_list, + self.anchor_generator.strides, + num_total_samples=num_total_samples) + + avg_factor = sum(avg_factor) + avg_factor = reduce_mean(avg_factor).item() + losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox)) + losses = dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + if self.use_kps: + losses_kps = list(map(lambda x: x / avg_factor, losses_kps)) + losses['loss_kps'] = losses_kps + if self.use_dfl: + losses_dfl = list(map(lambda x: x / avg_factor, losses_dfl)) + losses['loss_dfl'] = losses_dfl + return losses + + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'kps_preds')) + def get_bboxes(self, + cls_scores, + bbox_preds, + kps_preds, + img_metas, + cfg=None, + rescale=False, + with_nms=True): + """Transform network output for a batch into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class labelof the + corresponding box. + + Example: + >>> import mmcv + >>> self = AnchorHead( + >>> num_classes=9, + >>> in_channels=1, + >>> anchor_generator=dict( + >>> type='AnchorGenerator', + >>> scales=[8], + >>> ratios=[0.5, 1.0, 2.0], + >>> strides=[4,])) + >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}] + >>> cfg = mmcv.Config(dict( + >>> score_thr=0.00, + >>> nms=dict(type='nms', iou_thr=1.0), + >>> max_per_img=10)) + >>> feat = torch.rand(1, 1, 3, 3) + >>> cls_score, bbox_pred = self.forward_single(feat) + >>> # note the input lists are over different levels, not images + >>> cls_scores, bbox_preds = [cls_score], [bbox_pred] + >>> result_list = self.get_bboxes(cls_scores, bbox_preds, + >>> img_metas, cfg) + >>> det_bboxes, det_labels = result_list[0] + >>> assert len(result_list) == 1 + >>> assert det_bboxes.shape[1] == 5 + >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img + """ + assert len(cls_scores) == len(bbox_preds) + num_levels = len(cls_scores) + + device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, device=device) + + result_list = [] + # bbox_preds and kps_preds are list of 3 tensor, each tensor is NCHW + # corresponding to a stage, C is 8 for bbox and 20 for kps + for img_id in range(len(img_metas)): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_pred_list = [ + bbox_preds[i][img_id].detach() for i in range(num_levels) + ] + if self.use_kps: + kps_pred_list = [ + kps_preds[i][img_id].detach() for i in range(num_levels) + ] + else: + kps_pred_list = [None for i in range(num_levels)] + img_shape = img_metas[img_id]['img_shape'] + scale_factor = img_metas[img_id]['scale_factor'] + if with_nms: + # some heads don't support with_nms argument + proposals = self._get_bboxes_single(cls_score_list, + bbox_pred_list, + kps_pred_list, + mlvl_anchors, img_shape, + scale_factor, cfg, rescale) + else: + proposals = self._get_bboxes_single(cls_score_list, + bbox_pred_list, + kps_pred_list, + mlvl_anchors, img_shape, + scale_factor, cfg, rescale, + with_nms) + result_list.append(proposals) + return result_list + + def _get_bboxes_single(self, + cls_scores, + bbox_preds, + kps_preds, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=False, + with_nms=True): + """Transform outputs for a single batch item into labeled boxes. + + Args: + cls_scores (list[Tensor]): Box scores for a single scale level + has shape (num_classes, H, W). + bbox_preds (list[Tensor]): Box distribution logits for a single + scale level with shape (4*(n+1), H, W), n is max value of + integral set. + mlvl_anchors (list[Tensor]): Box reference for a single scale level + with shape (num_total_anchors, 4). + img_shape (tuple[int]): Shape of the input image, + (height, width, 3). + scale_factor (ndarray): Scale factor of the image arange as + (w_scale, h_scale, w_scale, h_scale). + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple(Tensor): + det_bboxes (Tensor): Bbox predictions in shape (N, 5), where + the first 4 columns are bounding box positions + (tl_x, tl_y, br_x, br_y) and the 5-th column is a score + between 0 and 1. + det_labels (Tensor): A (N,) tensor where each item is the + predicted class label of the corresponding box. + """ + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_kps = [] + for cls_score, bbox_pred, kps_pred, stride, anchors in zip( + cls_scores, bbox_preds, kps_preds, + self.anchor_generator.strides, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + assert stride[0] == stride[1] + + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(1, 2, 0) + if self.use_dfl: + bbox_pred = self.integral(bbox_pred) * stride[0] + else: + bbox_pred = bbox_pred.reshape((-1, 4)) * stride[0] + if kps_pred is not None: + kps_pred = kps_pred.permute(1, 2, 0) + if self.use_dfl: + kps_pred = self.integral(kps_pred) * stride[0] + else: + kps_pred = kps_pred.reshape((-1, self.NK * 2)) * stride[0] + + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + max_scores, _ = scores.max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + anchors = anchors[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + if kps_pred is not None: + kps_pred = kps_pred[topk_inds, :] + + bboxes = distance2bbox( + self.anchor_center(anchors), bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + if kps_pred is not None: + kps = distance2kps(self.anchor_center(anchors), kps_pred) + mlvl_kps.append(kps) + + mlvl_bboxes = torch.cat(mlvl_bboxes) + if mlvl_kps is not None: + mlvl_kps = torch.cat(mlvl_kps) + if rescale: + mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + if mlvl_kps is not None: + scale_factor2 = torch.tensor( + [scale_factor[0], scale_factor[1]] * self.NK) + mlvl_kps /= scale_factor2.to(mlvl_kps.device) + + mlvl_scores = torch.cat(mlvl_scores) + # Add a dummy background class to the backend when using sigmoid + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + + if with_nms: + det_bboxes, det_labels, det_kps = multiclass_nms( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + multi_kps=mlvl_kps) + if det_kps is not None: + return det_bboxes, det_labels, det_kps + else: + return det_bboxes, det_labels + else: + if mlvl_kps is not None: + return mlvl_bboxes, mlvl_scores, mlvl_kps + else: + return mlvl_bboxes, mlvl_scores + + def get_targets(self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + gt_keypointss_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True): + """Get targets for GFL head. + + This method is almost the same as `AnchorHead.get_targets()`. Besides + returning the targets as the parent method does, it also returns the + anchors as the first element of the returned tuple. + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + if gt_keypointss_list is None: + gt_keypointss_list = [None for _ in range(num_imgs)] + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, all_keypoints_targets, all_keypoints_weights, + pos_inds_list, neg_inds_list) = multi_apply( + self._get_target_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + gt_keypointss_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors) + keypoints_targets_list = images_to_levels(all_keypoints_targets, + num_level_anchors) + keypoints_weights_list = images_to_levels(all_keypoints_weights, + num_level_anchors) + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, bbox_weights_list, keypoints_targets_list, + keypoints_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, + flat_anchors, + valid_flags, + num_level_anchors, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + gt_keypointss, + img_meta, + label_channels=1, + unmap_outputs=True): + """Compute regression, classification targets for anchors in a single + image. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors, 4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + num_level_anchors Tensor): Number of anchors of each scale level. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + img_meta (dict): Meta info of the image. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: N is the number of total anchors in the image. + anchors (Tensor): All anchors in the image with shape (N, 4). + labels (Tensor): Labels of all anchors in the image with shape + (N,). + label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + bbox_weights (Tensor): BBox weights of all anchors in the + image with shape (N, 4). + pos_inds (Tensor): Indices of postive anchor with shape + (num_pos,). + neg_inds (Tensor): Indices of negative anchor with shape + (num_neg,). + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 7 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + num_level_anchors_inside = self.get_num_level_anchors_inside( + num_level_anchors, inside_flags) + if self.assigner.__class__.__name__ == 'ATSSAssigner': + assign_result = self.assigner.assign(anchors, + num_level_anchors_inside, + gt_bboxes, gt_bboxes_ignore, + gt_labels) + else: + assign_result = self.assigner.assign(anchors, gt_bboxes, + gt_bboxes_ignore, gt_labels) + + sampling_result = self.sampler.sample(assign_result, anchors, + gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + kps_targets = anchors.new_zeros(size=(anchors.shape[0], self.NK * 2)) + kps_weights = anchors.new_zeros(size=(anchors.shape[0], self.NK * 2)) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + if self.use_kps: + pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds + kps_targets[pos_inds, :] = gt_keypointss[ + pos_assigned_gt_inds, :, :2].reshape((-1, self.NK * 2)) + kps_weights[pos_inds, :] = torch.mean( + gt_keypointss[pos_assigned_gt_inds, :, 2], + dim=1, + keepdims=True) + if gt_labels is None: + # Only rpn gives gt_labels as None + # Foreground is the first class + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + if self.use_kps: + kps_targets = unmap(kps_targets, num_total_anchors, + inside_flags) + kps_weights = unmap(kps_weights, num_total_anchors, + inside_flags) + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, + kps_targets, kps_weights, pos_inds, neg_inds) + + def get_num_level_anchors_inside(self, num_level_anchors, inside_flags): + split_inside_flags = torch.split(inside_flags, num_level_anchors) + num_level_anchors_inside = [ + int(flags.sum()) for flags in split_inside_flags + ] + return num_level_anchors_inside + + def aug_test(self, feats, img_metas, rescale=False): + """Test function with test time augmentation. + + Args: + feats (list[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains features for all images in the batch. + img_metas (list[list[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. each dict has image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[ndarray]: bbox results of each class + """ + return self.aug_test_bboxes(feats, img_metas, rescale=rescale) diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py new file mode 100755 index 00000000..7935606a --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/__init__.py @@ -0,0 +1,7 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors +""" +from .scrfd import SCRFD + +__all__ = ['SCRFD'] diff --git a/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/scrfd.py b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/scrfd.py new file mode 100755 index 00000000..18b46be1 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/mmdet_patch/models/detectors/scrfd.py @@ -0,0 +1,150 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/models/detectors/scrfd.py +""" +import torch +from mmdet.models.builder import DETECTORS +from mmdet.models.detectors.single_stage import SingleStageDetector + +from ....mmdet_patch.core.bbox import bbox2result + + +@DETECTORS.register_module() +class SCRFD(SingleStageDetector): + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(SCRFD, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained) + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_keypointss=None, + gt_bboxes_ignore=None): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + gt_bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): Class indices corresponding to each box + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + super(SingleStageDetector, self).forward_train(img, img_metas) + x = self.extract_feat(img) + losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, + gt_labels, gt_keypointss, + gt_bboxes_ignore) + return losses + + def simple_test(self, + img, + img_metas, + rescale=False, + repeat_head=1, + output_kps_var=0, + output_results=1): + """Test function without test time augmentation. + + Args: + imgs (list[torch.Tensor]): List of multiple images + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + repeat_head (int): repeat inference times in head + output_kps_var (int): whether output kps var to calculate quality + output_results (int): 0: nothing 1: bbox 2: both bbox and kps + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + x = self.extract_feat(img) + assert repeat_head >= 1 + kps_out0 = [] + kps_out1 = [] + kps_out2 = [] + for i in range(repeat_head): + outs = self.bbox_head(x) + kps_out0 += [outs[2][0].detach().cpu().numpy()] + kps_out1 += [outs[2][1].detach().cpu().numpy()] + kps_out2 += [outs[2][2].detach().cpu().numpy()] + if output_kps_var: + var0 = np.var(np.vstack(kps_out0), axis=0).mean() + var1 = np.var(np.vstack(kps_out1), axis=0).mean() + var2 = np.var(np.vstack(kps_out2), axis=0).mean() + var = np.mean([var0, var1, var2]) + else: + var = None + + if output_results > 0: + if torch.onnx.is_in_onnx_export(): + print('single_stage.py in-onnx-export') + print(outs.__class__) + cls_score, bbox_pred, kps_pred = outs + for c in cls_score: + print(c.shape) + for c in bbox_pred: + print(c.shape) + if self.bbox_head.use_kps: + for c in kps_pred: + print(c.shape) + return (cls_score, bbox_pred, kps_pred) + else: + return (cls_score, bbox_pred) + bbox_list = self.bbox_head.get_bboxes( + *outs, img_metas, rescale=rescale) + + # return kps if use_kps + if len(bbox_list[0]) == 2: + bbox_results = [ + bbox2result(det_bboxes, det_labels, + self.bbox_head.num_classes) + for det_bboxes, det_labels in bbox_list + ] + elif len(bbox_list[0]) == 3: + if output_results == 2: + bbox_results = [ + bbox2result( + det_bboxes, + det_labels, + self.bbox_head.num_classes, + kps=det_kps, + num_kps=self.bbox_head.NK) + for det_bboxes, det_labels, det_kps in bbox_list + ] + elif output_results == 1: + bbox_results = [ + bbox2result(det_bboxes, det_labels, + self.bbox_head.num_classes) + for det_bboxes, det_labels, _ in bbox_list + ] + else: + bbox_results = None + if var is not None: + return bbox_results, var + else: + return bbox_results + + def feature_test(self, img): + x = self.extract_feat(img) + outs = self.bbox_head(x) + return outs diff --git a/modelscope/models/cv/face_detection/scrfd/scrfd_detect.py b/modelscope/models/cv/face_detection/scrfd/scrfd_detect.py new file mode 100644 index 00000000..59611604 --- /dev/null +++ b/modelscope/models/cv/face_detection/scrfd/scrfd_detect.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from copy import deepcopy +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ScrfdDetect'] + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.scrfd) +class ScrfdDetect(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the face detection model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + from mmcv import Config + from mmcv.parallel import MMDataParallel + from mmcv.runner import load_checkpoint + from mmdet.models import build_detector + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD + cfg = Config.fromfile(osp.join(model_dir, 'mmcv_scrfd.py')) + ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + cfg.model.test_cfg.score_thr = kwargs.get('score_thr', 0.3) + detector = build_detector(cfg.model) + logger.info(f'loading model from {ckpt_path}') + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + load_checkpoint(detector, ckpt_path, map_location=device) + detector = MMDataParallel(detector, device_ids=[0]) + detector.eval() + self.detector = detector + logger.info('load model done') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector( + return_loss=False, + rescale=True, + img=[input['img'][0].unsqueeze(0)], + img_metas=[[dict(input['img_metas'][0].data)]], + output_results=2) + assert result is not None + result = result[0][0] + bboxes = result[:, :4].tolist() + kpss = result[:, 5:].tolist() + scores = result[:, 4].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: kpss + } + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + return input diff --git a/modelscope/models/cv/face_detection/ulfd_slim/__init__.py b/modelscope/models/cv/face_detection/ulfd_slim/__init__.py new file mode 100644 index 00000000..af1e7b42 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .detection import UlfdFaceDetector diff --git a/modelscope/models/cv/face_detection/ulfd_slim/detection.py b/modelscope/models/cv/face_detection/ulfd_slim/detection.py new file mode 100755 index 00000000..c0e2da6e --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/detection.py @@ -0,0 +1,44 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import os + +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .vision.ssd.fd_config import define_img_size +from .vision.ssd.mb_tiny_fd import (create_mb_tiny_fd, + create_mb_tiny_fd_predictor) + +define_img_size(640) + + +@MODELS.register_module(Tasks.face_detection, module_name=Models.ulfd) +class UlfdFaceDetector(TorchModel): + + def __init__(self, model_path, device='cuda'): + super().__init__(model_path) + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.device = device + self.net = create_mb_tiny_fd(2, is_test=True, device=device) + self.predictor = create_mb_tiny_fd_predictor( + self.net, candidate_size=1500, device=device) + self.net.load(model_path) + self.net = self.net.to(device) + + def forward(self, input): + img_raw = input['img'] + img = np.array(img_raw.cpu().detach()) + img = img[:, :, ::-1] + prob_th = 0.85 + keep_top_k = 750 + boxes, labels, probs = self.predictor.predict(img, keep_top_k, prob_th) + return boxes, probs diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/__init__.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/box_utils.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/box_utils.py new file mode 100644 index 00000000..46d3b890 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/box_utils.py @@ -0,0 +1,124 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import math + +import torch + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + _, indexes = scores.sort(descending=True) + indexes = indexes[:candidate_size] + while len(indexes) > 0: + current = indexes[0] + picked.append(current.item()) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[1:] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + current_box.unsqueeze(0), + ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def nms(box_scores, + nms_method=None, + score_threshold=None, + iou_threshold=None, + sigma=0.5, + top_k=-1, + candidate_size=200): + return hard_nms( + box_scores, iou_threshold, top_k, candidate_size=candidate_size) + + +def generate_priors(feature_map_list, + shrinkage_list, + image_size, + min_boxes, + clamp=True) -> torch.Tensor: + priors = [] + for index in range(0, len(feature_map_list[0])): + scale_w = image_size[0] / shrinkage_list[0][index] + scale_h = image_size[1] / shrinkage_list[1][index] + for j in range(0, feature_map_list[1][index]): + for i in range(0, feature_map_list[0][index]): + x_center = (i + 0.5) / scale_w + y_center = (j + 0.5) / scale_h + + for min_box in min_boxes[index]: + w = min_box / image_size[0] + h = min_box / image_size[1] + priors.append([x_center, y_center, w, h]) + priors = torch.tensor(priors) + if clamp: + torch.clamp(priors, 0.0, 1.0, out=priors) + return priors + + +def convert_locations_to_boxes(locations, priors, center_variance, + size_variance): + # priors can have one dimension less. + if priors.dim() + 1 == locations.dim(): + priors = priors.unsqueeze(0) + a = locations[..., :2] * center_variance * priors[..., + 2:] + priors[..., :2] + b = torch.exp(locations[..., 2:] * size_variance) * priors[..., 2:] + + return torch.cat([a, b], dim=locations.dim() - 1) + + +def center_form_to_corner_form(locations): + a = locations[..., :2] - locations[..., 2:] / 2 + b = locations[..., :2] + locations[..., 2:] / 2 + return torch.cat([a, b], locations.dim() - 1) + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = torch.max(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = torch.min(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def area_of(left_top, right_bottom) -> torch.Tensor: + """Compute the areas of rectangles given two corners. + + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + + Returns: + area (N): return the area. + """ + hw = torch.clamp(right_bottom - left_top, min=0.0) + return hw[..., 0] * hw[..., 1] diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/mb_tiny.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/mb_tiny.py new file mode 100644 index 00000000..8bbcef41 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/mb_tiny.py @@ -0,0 +1,49 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import torch.nn as nn +import torch.nn.functional as F + + +class Mb_Tiny(nn.Module): + + def __init__(self, num_classes=2): + super(Mb_Tiny, self).__init__() + self.base_channel = 8 * 2 + + def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), nn.ReLU(inplace=True)) + + def conv_dw(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True), + ) + + self.model = nn.Sequential( + conv_bn(3, self.base_channel, 2), # 160*120 + conv_dw(self.base_channel, self.base_channel * 2, 1), + conv_dw(self.base_channel * 2, self.base_channel * 2, 2), # 80*60 + conv_dw(self.base_channel * 2, self.base_channel * 2, 1), + conv_dw(self.base_channel * 2, self.base_channel * 4, 2), # 40*30 + conv_dw(self.base_channel * 4, self.base_channel * 4, 1), + conv_dw(self.base_channel * 4, self.base_channel * 4, 1), + conv_dw(self.base_channel * 4, self.base_channel * 4, 1), + conv_dw(self.base_channel * 4, self.base_channel * 8, 2), # 20*15 + conv_dw(self.base_channel * 8, self.base_channel * 8, 1), + conv_dw(self.base_channel * 8, self.base_channel * 8, 1), + conv_dw(self.base_channel * 8, self.base_channel * 16, 2), # 10*8 + conv_dw(self.base_channel * 16, self.base_channel * 16, 1)) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + x = self.model(x) + x = F.avg_pool2d(x, 7) + x = x.view(-1, 1024) + x = self.fc(x) + return x diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/__init__.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/data_preprocessing.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/data_preprocessing.py new file mode 100644 index 00000000..9251d67f --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/data_preprocessing.py @@ -0,0 +1,18 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +from ..transforms import Compose, Resize, SubtractMeans, ToTensor + + +class PredictionTransform: + + def __init__(self, size, mean=0.0, std=1.0): + self.transform = Compose([ + Resize(size), + SubtractMeans(mean), lambda img, boxes=None, labels=None: + (img / std, boxes, labels), + ToTensor() + ]) + + def __call__(self, image): + image, _, _ = self.transform(image) + return image diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/fd_config.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/fd_config.py new file mode 100644 index 00000000..495a2fcd --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/fd_config.py @@ -0,0 +1,49 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import numpy as np + +from ..box_utils import generate_priors + +image_mean_test = image_mean = np.array([127, 127, 127]) +image_std = 128.0 +iou_threshold = 0.3 +center_variance = 0.1 +size_variance = 0.2 + +min_boxes = [[10, 16, 24], [32, 48], [64, 96], [128, 192, 256]] +shrinkage_list = [] +image_size = [320, 240] # default input size 320*240 +feature_map_w_h_list = [[40, 20, 10, 5], [30, 15, 8, + 4]] # default feature map size +priors = [] + + +def define_img_size(size): + global image_size, feature_map_w_h_list, priors + img_size_dict = { + 128: [128, 96], + 160: [160, 120], + 320: [320, 240], + 480: [480, 360], + 640: [640, 480], + 1280: [1280, 960] + } + image_size = img_size_dict[size] + + feature_map_w_h_list_dict = { + 128: [[16, 8, 4, 2], [12, 6, 3, 2]], + 160: [[20, 10, 5, 3], [15, 8, 4, 2]], + 320: [[40, 20, 10, 5], [30, 15, 8, 4]], + 480: [[60, 30, 15, 8], [45, 23, 12, 6]], + 640: [[80, 40, 20, 10], [60, 30, 15, 8]], + 1280: [[160, 80, 40, 20], [120, 60, 30, 15]] + } + feature_map_w_h_list = feature_map_w_h_list_dict[size] + + for i in range(0, len(image_size)): + item_list = [] + for k in range(0, len(feature_map_w_h_list[i])): + item_list.append(image_size[i] / feature_map_w_h_list[i][k]) + shrinkage_list.append(item_list) + priors = generate_priors(feature_map_w_h_list, shrinkage_list, image_size, + min_boxes) diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/mb_tiny_fd.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/mb_tiny_fd.py new file mode 100644 index 00000000..91ed268d --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/mb_tiny_fd.py @@ -0,0 +1,124 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +from torch.nn import Conv2d, ModuleList, ReLU, Sequential + +from ..mb_tiny import Mb_Tiny +from . import fd_config as config +from .predictor import Predictor +from .ssd import SSD + + +def SeperableConv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0): + """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d. + """ + return Sequential( + Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + groups=in_channels, + stride=stride, + padding=padding), + ReLU(), + Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1), + ) + + +def create_mb_tiny_fd(num_classes, is_test=False, device='cuda'): + base_net = Mb_Tiny(2) + base_net_model = base_net.model # disable dropout layer + + source_layer_indexes = [8, 11, 13] + extras = ModuleList([ + Sequential( + Conv2d( + in_channels=base_net.base_channel * 16, + out_channels=base_net.base_channel * 4, + kernel_size=1), ReLU(), + SeperableConv2d( + in_channels=base_net.base_channel * 4, + out_channels=base_net.base_channel * 16, + kernel_size=3, + stride=2, + padding=1), ReLU()) + ]) + + regression_headers = ModuleList([ + SeperableConv2d( + in_channels=base_net.base_channel * 4, + out_channels=3 * 4, + kernel_size=3, + padding=1), + SeperableConv2d( + in_channels=base_net.base_channel * 8, + out_channels=2 * 4, + kernel_size=3, + padding=1), + SeperableConv2d( + in_channels=base_net.base_channel * 16, + out_channels=2 * 4, + kernel_size=3, + padding=1), + Conv2d( + in_channels=base_net.base_channel * 16, + out_channels=3 * 4, + kernel_size=3, + padding=1) + ]) + + classification_headers = ModuleList([ + SeperableConv2d( + in_channels=base_net.base_channel * 4, + out_channels=3 * num_classes, + kernel_size=3, + padding=1), + SeperableConv2d( + in_channels=base_net.base_channel * 8, + out_channels=2 * num_classes, + kernel_size=3, + padding=1), + SeperableConv2d( + in_channels=base_net.base_channel * 16, + out_channels=2 * num_classes, + kernel_size=3, + padding=1), + Conv2d( + in_channels=base_net.base_channel * 16, + out_channels=3 * num_classes, + kernel_size=3, + padding=1) + ]) + + return SSD( + num_classes, + base_net_model, + source_layer_indexes, + extras, + classification_headers, + regression_headers, + is_test=is_test, + config=config, + device=device) + + +def create_mb_tiny_fd_predictor(net, + candidate_size=200, + nms_method=None, + sigma=0.5, + device=None): + predictor = Predictor( + net, + config.image_size, + config.image_mean_test, + config.image_std, + nms_method=nms_method, + iou_threshold=config.iou_threshold, + candidate_size=candidate_size, + sigma=sigma, + device=device) + return predictor diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/predictor.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/predictor.py new file mode 100644 index 00000000..f71820a5 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/predictor.py @@ -0,0 +1,80 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import torch + +from .. import box_utils +from .data_preprocessing import PredictionTransform + + +class Predictor: + + def __init__(self, + net, + size, + mean=0.0, + std=1.0, + nms_method=None, + iou_threshold=0.3, + filter_threshold=0.85, + candidate_size=200, + sigma=0.5, + device=None): + self.net = net + self.transform = PredictionTransform(size, mean, std) + self.iou_threshold = iou_threshold + self.filter_threshold = filter_threshold + self.candidate_size = candidate_size + self.nms_method = nms_method + + self.sigma = sigma + if device: + self.device = device + else: + self.device = torch.device( + 'cuda:0' if torch.cuda.is_available() else 'cpu') + + self.net.to(self.device) + self.net.eval() + + def predict(self, image, top_k=-1, prob_threshold=None): + height, width, _ = image.shape + image = self.transform(image) + images = image.unsqueeze(0) + images = images.to(self.device) + with torch.no_grad(): + for i in range(1): + scores, boxes = self.net.forward(images) + boxes = boxes[0] + scores = scores[0] + if not prob_threshold: + prob_threshold = self.filter_threshold + # this version of nms is slower on GPU, so we move data to CPU. + picked_box_probs = [] + picked_labels = [] + for class_index in range(1, scores.size(1)): + probs = scores[:, class_index] + mask = probs > prob_threshold + probs = probs[mask] + if probs.size(0) == 0: + continue + subset_boxes = boxes[mask, :] + box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1) + box_probs = box_utils.nms( + box_probs, + self.nms_method, + score_threshold=prob_threshold, + iou_threshold=self.iou_threshold, + sigma=self.sigma, + top_k=top_k, + candidate_size=self.candidate_size) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.size(0)) + if not picked_box_probs: + return torch.tensor([]), torch.tensor([]), torch.tensor([]) + picked_box_probs = torch.cat(picked_box_probs) + picked_box_probs[:, 0] *= width + picked_box_probs[:, 1] *= height + picked_box_probs[:, 2] *= width + picked_box_probs[:, 3] *= height + return picked_box_probs[:, :4], torch.tensor( + picked_labels), picked_box_probs[:, 4] diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/ssd.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/ssd.py new file mode 100644 index 00000000..08ff93a4 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/ssd/ssd.py @@ -0,0 +1,129 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +from collections import namedtuple +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .. import box_utils + +GraphPath = namedtuple('GraphPath', ['s0', 'name', 's1']) + + +class SSD(nn.Module): + + def __init__(self, + num_classes: int, + base_net: nn.ModuleList, + source_layer_indexes: List[int], + extras: nn.ModuleList, + classification_headers: nn.ModuleList, + regression_headers: nn.ModuleList, + is_test=False, + config=None, + device=None): + """Compose a SSD model using the given components. + """ + super(SSD, self).__init__() + + self.num_classes = num_classes + self.base_net = base_net + self.source_layer_indexes = source_layer_indexes + self.extras = extras + self.classification_headers = classification_headers + self.regression_headers = regression_headers + self.is_test = is_test + self.config = config + + # register layers in source_layer_indexes by adding them to a module list + self.source_layer_add_ons = nn.ModuleList([ + t[1] for t in source_layer_indexes + if isinstance(t, tuple) and not isinstance(t, GraphPath) + ]) + if device: + self.device = device + else: + self.device = torch.device( + 'cuda:0' if torch.cuda.is_available() else 'cpu') + if is_test: + self.config = config + self.priors = config.priors.to(self.device) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + confidences = [] + locations = [] + start_layer_index = 0 + header_index = 0 + end_layer_index = 0 + for end_layer_index in self.source_layer_indexes: + if isinstance(end_layer_index, GraphPath): + path = end_layer_index + end_layer_index = end_layer_index.s0 + added_layer = None + elif isinstance(end_layer_index, tuple): + added_layer = end_layer_index[1] + end_layer_index = end_layer_index[0] + path = None + else: + added_layer = None + path = None + for layer in self.base_net[start_layer_index:end_layer_index]: + x = layer(x) + if added_layer: + y = added_layer(x) + else: + y = x + if path: + sub = getattr(self.base_net[end_layer_index], path.name) + for layer in sub[:path.s1]: + x = layer(x) + y = x + for layer in sub[path.s1:]: + x = layer(x) + end_layer_index += 1 + start_layer_index = end_layer_index + confidence, location = self.compute_header(header_index, y) + header_index += 1 + confidences.append(confidence) + locations.append(location) + + for layer in self.base_net[end_layer_index:]: + x = layer(x) + + for layer in self.extras: + x = layer(x) + confidence, location = self.compute_header(header_index, x) + header_index += 1 + confidences.append(confidence) + locations.append(location) + + confidences = torch.cat(confidences, 1) + locations = torch.cat(locations, 1) + + if self.is_test: + confidences = F.softmax(confidences, dim=2) + boxes = box_utils.convert_locations_to_boxes( + locations, self.priors, self.config.center_variance, + self.config.size_variance) + boxes = box_utils.center_form_to_corner_form(boxes) + return confidences, boxes + else: + return confidences, locations + + def compute_header(self, i, x): + confidence = self.classification_headers[i](x) + confidence = confidence.permute(0, 2, 3, 1).contiguous() + confidence = confidence.view(confidence.size(0), -1, self.num_classes) + + location = self.regression_headers[i](x) + location = location.permute(0, 2, 3, 1).contiguous() + location = location.view(location.size(0), -1, 4) + + return confidence, location + + def load(self, model): + self.load_state_dict( + torch.load(model, map_location=lambda storage, loc: storage)) diff --git a/modelscope/models/cv/face_detection/ulfd_slim/vision/transforms.py b/modelscope/models/cv/face_detection/ulfd_slim/vision/transforms.py new file mode 100644 index 00000000..7c5331f1 --- /dev/null +++ b/modelscope/models/cv/face_detection/ulfd_slim/vision/transforms.py @@ -0,0 +1,56 @@ +# The implementation is based on ULFD, available at +# https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB +import types + +import cv2 +import numpy as np +import torch +from numpy import random + + +class Compose(object): + """Composes several augmentations together. + Args: + transforms (List[Transform]): list of transforms to compose. + Example: + >>> augmentations.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, boxes=None, labels=None): + for t in self.transforms: + img, boxes, labels = t(img, boxes, labels) + return img, boxes, labels + + +class SubtractMeans(object): + + def __init__(self, mean): + self.mean = np.array(mean, dtype=np.float32) + + def __call__(self, image, boxes=None, labels=None): + image = image.astype(np.float32) + image -= self.mean + return image.astype(np.float32), boxes, labels + + +class Resize(object): + + def __init__(self, size=(300, 300)): + self.size = size + + def __call__(self, image, boxes=None, labels=None): + image = cv2.resize(image, (self.size[0], self.size[1])) + return image, boxes, labels + + +class ToTensor(object): + + def __call__(self, cvimage, boxes=None, labels=None): + return torch.from_numpy(cvimage.astype(np.float32)).permute( + 2, 0, 1), boxes, labels diff --git a/modelscope/models/cv/face_emotion/__init__.py b/modelscope/models/cv/face_emotion/__init__.py new file mode 100644 index 00000000..2a13ea42 --- /dev/null +++ b/modelscope/models/cv/face_emotion/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .emotion_model import EfficientNetForFaceEmotion + +else: + _import_structure = {'emotion_model': ['EfficientNetForFaceEmotion']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/face_emotion/efficient/__init__.py b/modelscope/models/cv/face_emotion/efficient/__init__.py new file mode 100644 index 00000000..e8fc91a4 --- /dev/null +++ b/modelscope/models/cv/face_emotion/efficient/__init__.py @@ -0,0 +1,6 @@ +# The implementation here is modified based on EfficientNet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/lukemelas/EfficientNet-PyTorch + +from .model import VALID_MODELS, EfficientNet +from .utils import (BlockArgs, BlockDecoder, GlobalParams, efficientnet, + get_model_params) diff --git a/modelscope/models/cv/face_emotion/efficient/model.py b/modelscope/models/cv/face_emotion/efficient/model.py new file mode 100644 index 00000000..db303016 --- /dev/null +++ b/modelscope/models/cv/face_emotion/efficient/model.py @@ -0,0 +1,380 @@ +# The implementation here is modified based on EfficientNet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/lukemelas/EfficientNet-PyTorch + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import (MemoryEfficientSwish, Swish, calculate_output_image_size, + drop_connect, efficientnet_params, get_model_params, + get_same_padding_conv2d, load_pretrained_weights, + round_filters, round_repeats) + +VALID_MODELS = ('efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', + 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', + 'efficientnet-b6', 'efficientnet-b7', 'efficientnet-b8', + 'efficientnet-l2') + + +class MBConvBlock(nn.Module): + + def __init__(self, block_args, global_params, image_size=None): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio + is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip + + inp = self._block_args.input_filters + oup = self._block_args.input_filters * self._block_args.expand_ratio + if self._block_args.expand_ratio != 1: + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._expand_conv = Conv2d( + in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = nn.BatchNorm2d( + num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + + k = self._block_args.kernel_size + s = self._block_args.stride + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._depthwise_conv = Conv2d( + in_channels=oup, + out_channels=oup, + groups=oup, + kernel_size=k, + stride=s, + bias=False) + self._bn1 = nn.BatchNorm2d( + num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + image_size = calculate_output_image_size(image_size, s) + + if self.has_se: + Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + num_squeezed_channels = max( + 1, + int(self._block_args.input_filters + * self._block_args.se_ratio)) + self._se_reduce = Conv2d( + in_channels=oup, + out_channels=num_squeezed_channels, + kernel_size=1) + self._se_expand = Conv2d( + in_channels=num_squeezed_channels, + out_channels=oup, + kernel_size=1) + + final_oup = self._block_args.output_filters + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._project_conv = Conv2d( + in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = nn.BatchNorm2d( + num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """MBConvBlock's forward function. + Args: + inputs (tensor): Input tensor. + drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). + Returns: + Output of this block after processing. + """ + + x = inputs + if self._block_args.expand_ratio != 1: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + x = self._project_conv(x) + x = self._bn2(x) + + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + x = drop_connect( + x, p=drop_connect_rate, training=self.training) + x = x + inputs + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + """EfficientNet model. + Most easily loaded with the .from_name or .from_pretrained methods. + Args: + blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. + global_params (namedtuple): A set of GlobalParams shared between blocks. + References: + [1] https://arxiv.org/abs/1905.11946 (EfficientNet) + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> model.eval() + >>> outputs = model(inputs) + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + image_size = global_params.image_size + Conv2d = get_same_padding_conv2d(image_size=image_size) + + in_channels = 3 + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d( + num_features=out_channels, momentum=bn_mom, eps=bn_eps) + image_size = calculate_output_image_size(image_size, 2) + + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, + self._global_params), + output_filters=round_filters(block_args.output_filters, + self._global_params), + num_repeat=round_repeats(block_args.num_repeat, + self._global_params)) + + self._blocks.append( + MBConvBlock( + block_args, self._global_params, image_size=image_size)) + image_size = calculate_output_image_size(image_size, + block_args.stride) + if block_args.num_repeat > 1: + block_args = block_args._replace( + input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append( + MBConvBlock( + block_args, self._global_params, + image_size=image_size)) + + in_channels = block_args.output_filters + out_channels = round_filters(1280, self._global_params) + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._conv_head = Conv2d( + in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d( + num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + if self._global_params.include_top: + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_endpoints(self, inputs): + """Use convolution layer to extract features + from reduction levels i in [1, 2, 3, 4, 5]. + Args: + inputs (tensor): Input tensor. + Returns: + Dictionary of last intermediate features + with reduction levels i in [1, 2, 3, 4, 5]. + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> endpoints = model.extract_endpoints(inputs) + >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) + >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) + >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) + >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) + >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) + >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) + """ + endpoints = dict() + + x = self._swish(self._bn0(self._conv_stem(inputs))) + prev_x = x + + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len( + self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + if prev_x.size(2) > x.size(2): + endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x + elif idx == len(self._blocks) - 1: + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + prev_x = x + + x = self._swish(self._bn1(self._conv_head(x))) + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + + return endpoints + + def extract_features(self, inputs): + """use convolution layer to extract feature . + Args: + inputs (tensor): Input tensor. + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + x = self._swish(self._bn0(self._conv_stem(inputs))) + + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """EfficientNet's forward function. + Calls extract_features to extract features, applies final linear layer, and returns logits. + Args: + inputs (tensor): Input tensor. + Returns: + Output of this model after processing. + """ + x = self.extract_features(inputs) + x = self._avg_pooling(x) + if self._global_params.include_top: + x = x.flatten(start_dim=1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, in_channels=3, **override_params): + """Create an efficientnet model according to name. + Args: + model_name (str): Name for efficientnet. + in_channels (int): Input data's channel number. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + Returns: + An efficientnet model. + """ + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, + override_params) + model = cls(blocks_args, global_params) + model._change_in_channels(in_channels) + return model + + @classmethod + def from_pretrained(cls, + model_name, + weights_path=None, + advprop=False, + in_channels=3, + num_classes=1000, + **override_params): + """Create an efficientnet model according to name. + Args: + model_name (str): Name for efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + advprop (bool): + Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + in_channels (int): Input data's channel number. + num_classes (int): + Number of categories for classification. + It controls the output size for final linear layer. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + Returns: + A pretrained efficientnet model. + """ + model = cls.from_name( + model_name, num_classes=num_classes, **override_params) + model._change_in_channels(in_channels) + return model + + @classmethod + def get_image_size(cls, model_name): + """Get the input image size for a given efficientnet model. + Args: + model_name (str): Name for efficientnet. + Returns: + Input image size (resolution). + """ + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """Validates model name. + Args: + model_name (str): Name for efficientnet. + Returns: + bool: Is a valid name or not. + """ + if model_name not in VALID_MODELS: + raise ValueError('model_name should be one of: ' + + ', '.join(VALID_MODELS)) + + def _change_in_channels(self, in_channels): + """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. + Args: + in_channels (int): Input data's channel number. + """ + if in_channels != 3: + Conv2d = get_same_padding_conv2d( + image_size=self._global_params.image_size) + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=False) diff --git a/modelscope/models/cv/face_emotion/efficient/utils.py b/modelscope/models/cv/face_emotion/efficient/utils.py new file mode 100644 index 00000000..6cae70fc --- /dev/null +++ b/modelscope/models/cv/face_emotion/efficient/utils.py @@ -0,0 +1,559 @@ +# The implementation here is modified based on EfficientNet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/lukemelas/EfficientNet-PyTorch + +import collections +import math +import re +from functools import partial + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo + +GlobalParams = collections.namedtuple('GlobalParams', [ + 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', + 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top' +]) + +BlockArgs = collections.namedtuple('BlockArgs', [ + 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', 'input_filters', + 'output_filters', 'se_ratio', 'id_skip' +]) + +GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields) + +if hasattr(nn, 'SiLU'): + Swish = nn.SiLU +else: + + class Swish(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(x) + + +class SwishImplementation(torch.autograd.Function): + + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_tensors[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + + def forward(self, x): + return SwishImplementation.apply(x) + + +def round_filters(filters, global_params): + """Calculate and round number of filters based on width multiplier. + Use width_coefficient, depth_divisor and min_depth of global_params. + Args: + filters (int): Filters number to be calculated. + global_params (namedtuple): Global params of the model. + Returns: + new_filters: New filters number after calculating. + """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, + int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """Calculate module's repeat number of a block based on depth multiplier. + Use depth_coefficient of global_params. + Args: + repeats (int): num_repeat to be calculated. + global_params (namedtuple): Global params of the model. + Returns: + new repeat: New repeat number after calculating. + """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """Drop connect. + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, 'p must be in range of [0,1]' + + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], + dtype=inputs.dtype, + device=inputs.device) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + +def get_width_and_height_from_size(x): + """Obtain height and width from x. + Args: + x (int, tuple or list): Data size. + Returns: + size: A tuple or list (H,W). + """ + if isinstance(x, int): + return x, x + if isinstance(x, list) or isinstance(x, tuple): + return x + else: + raise TypeError() + + +def calculate_output_image_size(input_image_size, stride): + """Calculates the output image size when using Conv2dSamePadding with a stride. + Necessary for static padding. Thanks to mannatsingh for pointing this out. + Args: + input_image_size (int, tuple or list): Size of input image. + stride (int, tuple or list): Conv2d operation's stride. + Returns: + output_image_size: A list [H,W]. + """ + if input_image_size is None: + return None + image_height, image_width = get_width_and_height_from_size( + input_image_size) + stride = stride if isinstance(stride, int) else stride[0] + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) + return [image_height, image_width] + + +def get_same_padding_conv2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: + image_size (int or tuple): Size of the image. + Returns: + Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, + dilation, groups, bias) + self.stride = self.stride if len( + self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + a1 = (oh - 1) * self.stride[0] + pad_h = max(a1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + a2 = (ow - 1) * self.stride[1] + pad_w = max(a2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + image_size=None, + **kwargs): + super().__init__(in_channels, out_channels, kernel_size, stride, + **kwargs) + self.stride = self.stride if len( + self.stride) == 2 else [self.stride[0]] * 2 + + assert image_size is not None + ih, iw = (image_size, + image_size) if isinstance(image_size, int) else image_size + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + b1 = (oh - 1) * self.stride[0] + pad_h = max(b1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + b2 = (ow - 1) * self.stride[1] + pad_w = max(b2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + return x + + +def get_same_padding_maxPool2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: + image_size (int or tuple): Size of the image. + Returns: + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + """ + if image_size is None: + return MaxPool2dDynamicSamePadding + else: + return partial(MaxPool2dStaticSamePadding, image_size=image_size) + + +class MaxPool2dDynamicSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, + kernel_size, + stride, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, + return_indices, ceil_mode) + self.stride = [self.stride] * 2 if isinstance(self.stride, + int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance( + self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance( + self.dilation, int) else self.dilation + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + c1 = (oh - 1) * self.stride[0] + pad_h = max(c1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + c2 = (ow - 1) * self.stride[1] + pad_w = max(c2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + + +class MaxPool2dStaticSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, kernel_size, stride, image_size=None, **kwargs): + super().__init__(kernel_size, stride, **kwargs) + self.stride = [self.stride] * 2 if isinstance(self.stride, + int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance( + self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance( + self.dilation, int) else self.dilation + + assert image_size is not None + ih, iw = (image_size, + image_size) if isinstance(image_size, int) else image_size + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + d1 = (oh - 1) * self.stride[0] + pad_h = max(d1 + (kh - 1) * self.dilation[0] + 1 - ih, 0) + d2 = (ow - 1) * self.stride[1] + pad_w = max(d2 + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + return x + + +class BlockDecoder(object): + """Block Decoder for readability, + straight from the official TensorFlow repository. + """ + + @staticmethod + def _decode_block_string(block_string): + """Get a block through a string notation of arguments. + Args: + block_string (str): A string notation of arguments. + Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + Returns: + BlockArgs: The namedtuple defined at the top of this file. + """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) + or (len(options['s']) == 2 + and options['s'][0] == options['s'][1])) + + return BlockArgs( + num_repeat=int(options['r']), + kernel_size=int(options['k']), + stride=[int(options['s'][0])], + expand_ratio=int(options['e']), + input_filters=int(options['i']), + output_filters=int(options['o']), + se_ratio=float(options['se']) if 'se' in options else None, + id_skip=('noskip' not in block_string)) + + @staticmethod + def _encode_block_string(block): + """Encode a block to a string. + Args: + block (namedtuple): A BlockArgs type argument. + Returns: + block_string: A String form of BlockArgs. + """ + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """Decode a list of string notations to specify blocks inside the network. + Args: + string_list (list[str]): A list of strings, each string is a notation of block. + Returns: + blocks_args: A list of BlockArgs namedtuples of block args. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """Encode a list of BlockArgs to a list of strings. + Args: + blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + Returns: + block_strings: A list of strings, each string is a notation of block. + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet_params(model_name): + """Map EfficientNet model name to parameter coefficients. + Args: + model_name (str): Model name to be queried. + Returns: + params_dict[model_name]: A (width,depth,res,dropout) tuple. + """ + params_dict = { + 'efficientnet-b0': (1.0, 1.0, 112, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +def efficientnet(width_coefficient=None, + depth_coefficient=None, + image_size=None, + dropout_rate=0.2, + drop_connect_rate=0.2, + num_classes=1000, + include_top=True): + """Create BlockArgs and GlobalParams for efficientnet model. + Args: + width_coefficient (float) + depth_coefficient (float) + image_size (int) + dropout_rate (float) + drop_connect_rate (float) + num_classes (int) + Meaning as the name suggests. + Returns: + blocks_args, global_params. + """ + + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', + 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', + 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', + 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + image_size=image_size, + dropout_rate=dropout_rate, + num_classes=num_classes, + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + drop_connect_rate=drop_connect_rate, + depth_divisor=8, + min_depth=None, + include_top=include_top, + ) + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """Get the block args and global params for a given model name. + Args: + model_name (str): Model's name. + override_params (dict): A dict to modify global_params. + Returns: + blocks_args, global_params + """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + blocks_args, global_params = efficientnet( + width_coefficient=w, + depth_coefficient=d, + dropout_rate=p, + image_size=s) + else: + raise NotImplementedError( + 'model name is not pre-defined: {}'.format(model_name)) + if override_params: + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +def load_pretrained_weights(model, + model_name, + weights_path=None, + load_fc=True, + advprop=False, + verbose=True): + """Loads pretrained weights from weights path or download using url. + Args: + model (Module): The whole model of efficientnet. + model_name (str): Model name of efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. + advprop (bool): Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + """ + if isinstance(weights_path, str): + state_dict = torch.load(weights_path) + else: + url_map_ = url_map_advprop if advprop else url_map + state_dict = model_zoo.load_url(url_map_[model_name]) + + if load_fc: + ret = model.load_state_dict(state_dict, strict=False) + assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format( + ret.missing_keys) + else: + state_dict.pop('_fc.weight') + state_dict.pop('_fc.bias') + ret = model.load_state_dict(state_dict, strict=False) + assert set(ret.missing_keys) == set([ + '_fc.weight', '_fc.bias' + ]), 'Missing keys when loading pretrained weights: {}'.format( + ret.missing_keys) + assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format( + ret.unexpected_keys) + + if verbose: + print('Loaded pretrained weights for {}'.format(model_name)) diff --git a/modelscope/models/cv/face_emotion/emotion_infer.py b/modelscope/models/cv/face_emotion/emotion_infer.py new file mode 100644 index 00000000..618822ff --- /dev/null +++ b/modelscope/models/cv/face_emotion/emotion_infer.py @@ -0,0 +1,67 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import torch +from PIL import Image +from torch import nn +from torchvision import transforms + +from modelscope.utils.logger import get_logger +from .face_alignment.face_align import face_detection_PIL_v2 + +logger = get_logger() + + +def transform_PIL(img_pil): + val_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + return val_transforms(img_pil) + + +index2AU = [1, 2, 4, 6, 7, 10, 12, 15, 23, 24, 25, 26] +emotion_list = [ + 'Neutral', 'Anger', 'Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise' +] + + +def inference(image, model, face_model, score_thre=0.5, GPU=0): + image = image.cpu().numpy() + image = Image.fromarray(image) + face, bbox = face_detection_PIL_v2(image, face_model) + if bbox is None: + logger.warn('no face detected!') + result = {'emotion_result': None, 'box': None} + return result + + face = transform_PIL(face) + face = face.unsqueeze(0) + if torch.cuda.is_available(): + face = face.cuda(GPU) + logits_AU, logits_emotion = model(face) + logits_AU = torch.sigmoid(logits_AU) + logits_emotion = nn.functional.softmax(logits_emotion, 1) + + _, index_list = logits_emotion.max(1) + emotion_index = index_list[0].data.item() + prob = logits_emotion[0][emotion_index] + if prob > score_thre and emotion_index != 3: + cur_emotion = emotion_list[emotion_index] + else: + cur_emotion = 'Neutral' + + logits_AU = logits_AU[0] + au_ouput = torch.zeros_like(logits_AU) + au_ouput[logits_AU >= score_thre] = 1 + au_ouput[logits_AU < score_thre] = 0 + + au_ouput = au_ouput.int() + + cur_au_list = [] + for idx in range(au_ouput.shape[0]): + if au_ouput[idx] == 1: + au = index2AU[idx] + cur_au_list.append(au) + cur_au_list.sort() + result = (cur_emotion, bbox) + return result diff --git a/modelscope/models/cv/face_emotion/emotion_model.py b/modelscope/models/cv/face_emotion/emotion_model.py new file mode 100644 index 00000000..f8df9c37 --- /dev/null +++ b/modelscope/models/cv/face_emotion/emotion_model.py @@ -0,0 +1,96 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import sys + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.face_emotion.efficient import EfficientNet +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@MODELS.register_module(Tasks.face_emotion, module_name=Models.face_emotion) +class EfficientNetForFaceEmotion(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.model = FaceEmotionModel( + name='efficientnet-b0', num_embed=512, num_au=12, num_emotion=7) + + if torch.cuda.is_available(): + self.device = 'cuda' + logger.info('Use GPU') + else: + self.device = 'cpu' + logger.info('Use CPU') + pretrained_params = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=self.device) + + state_dict = pretrained_params['model'] + new_state = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] + new_state[k] = v + + self.model.load_state_dict(new_state) + self.model.eval() + self.model.to(self.device) + + def forward(self, x): + logits_au, logits_emotion = self.model(x) + return logits_au, logits_emotion + + +class FaceEmotionModel(nn.Module): + + def __init__(self, + name='efficientnet-b0', + num_embed=512, + num_au=12, + num_emotion=7): + super(FaceEmotionModel, self).__init__() + self.backbone = EfficientNet.from_pretrained( + name, weights_path=None, advprop=True) + self.average_pool = nn.AdaptiveAvgPool2d(1) + self.embed = nn.Linear(self.backbone._fc.weight.data.shape[1], + num_embed) + self.features = nn.BatchNorm1d(num_embed) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + self.fc_au = nn.Sequential( + nn.Dropout(0.6), + nn.Linear(num_embed, num_au), + ) + self.fc_emotion = nn.Sequential( + nn.Dropout(0.6), + nn.Linear(num_embed, num_emotion), + ) + + def feat_single_img(self, x): + x = self.backbone.extract_features(x) + x = self.average_pool(x) + x = x.flatten(1) + x = self.embed(x) + x = self.features(x) + return x + + def forward(self, x): + x = self.feat_single_img(x) + logits_au = self.fc_au(x) + att_au = torch.sigmoid(logits_au).unsqueeze(-1) + x = x.unsqueeze(1) + emotion_vec_list = torch.matmul(att_au, x) + emotion_vec = emotion_vec_list.sum(1) + logits_emotion = self.fc_emotion(emotion_vec) + return logits_au, logits_emotion diff --git a/modelscope/models/cv/face_emotion/face_alignment/__init__.py b/modelscope/models/cv/face_emotion/face_alignment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_emotion/face_alignment/face.py b/modelscope/models/cv/face_emotion/face_alignment/face.py new file mode 100644 index 00000000..a362bddc --- /dev/null +++ b/modelscope/models/cv/face_emotion/face_alignment/face.py @@ -0,0 +1,79 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os + +import cv2 +import numpy as np +import tensorflow as tf + + +def init(mod): + PATH_TO_CKPT = mod + net = tf.Graph() + with net.as_default(): + od_graph_def = tf.GraphDef() + config = tf.ConfigProto() + config.gpu_options.per_process_gpu_memory_fraction = 0.6 + with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: + serialized_graph = fid.read() + od_graph_def.ParseFromString(serialized_graph) + tf.import_graph_def(od_graph_def, name='') + sess = tf.Session(graph=net, config=config) + return sess, net + + +def filter_bboxes_confs(shape, + imgsBboxes, + imgsConfs, + single=False, + thresh=0.5): + [w, h] = shape + if single: + bboxes, confs = [], [] + for y in range(len(imgsBboxes)): + if imgsConfs[y] >= thresh: + [x1, y1, x2, y2] = list(imgsBboxes[y]) + x1, y1, x2, y2 = int(w * x1), int(h * y1), int(w * x2), int( + h * y2) + bboxes.append([y1, x1, y2, x2]) + confs.append(imgsConfs[y]) + return bboxes, confs + else: + retImgsBboxes, retImgsConfs = [], [] + for x in range(len(imgsBboxes)): + bboxes, confs = [], [] + for y in range(len(imgsBboxes[x])): + if imgsConfs[x][y] >= thresh: + [x1, y1, x2, y2] = list(imgsBboxes[x][y]) + x1, y1, x2, y2 = int(w * x1), int(h * y1), int( + w * x2), int(h * y2) + bboxes.append([y1, x1, y2, x2]) + confs.append(imgsConfs[x][y]) + retImgsBboxes.append(bboxes) + retImgsConfs.append(confs) + return retImgsBboxes, retImgsConfs + + +def detect(im, sess, net): + image_np = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + image_np_expanded = np.expand_dims(image_np, axis=0) + image_tensor = net.get_tensor_by_name('image_tensor:0') + bboxes = net.get_tensor_by_name('detection_boxes:0') + dConfs = net.get_tensor_by_name('detection_scores:0') + classes = net.get_tensor_by_name('detection_classes:0') + num_detections = net.get_tensor_by_name('num_detections:0') + (bboxes, dConfs, classes, + num_detections) = sess.run([bboxes, dConfs, classes, num_detections], + feed_dict={image_tensor: image_np_expanded}) + w, h, _ = im.shape + bboxes, confs = filter_bboxes_confs([w, h], bboxes[0], dConfs[0], True) + return bboxes, confs + + +class FaceDetector: + + def __init__(self, mod): + self.sess, self.net = init(mod) + + def do_detect(self, im): + bboxes, confs = detect(im, self.sess, self.net) + return bboxes, confs diff --git a/modelscope/models/cv/face_emotion/face_alignment/face_align.py b/modelscope/models/cv/face_emotion/face_alignment/face_align.py new file mode 100644 index 00000000..71282b12 --- /dev/null +++ b/modelscope/models/cv/face_emotion/face_alignment/face_align.py @@ -0,0 +1,59 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import sys + +import cv2 +import numpy as np +from PIL import Image, ImageFile + +from .face import FaceDetector + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def adjust_bx_v2(box, w, h): + x1, y1, x2, y2 = box[0], box[1], box[2], box[3] + box_w = x2 - x1 + box_h = y2 - y1 + delta = abs(box_w - box_h) + if box_w > box_h: + if y1 >= delta: + y1 = y1 - delta + else: + delta_y1 = y1 + y1 = 0 + delta_y2 = delta - delta_y1 + y2 = y2 + delta_y2 if y2 < h - delta_y2 else h - 1 + else: + if x1 >= delta / 2 and x2 <= w - delta / 2: + x1 = x1 - delta / 2 + x2 = x2 + delta / 2 + elif x1 < delta / 2 and x2 <= w - delta / 2: + delta_x1 = x1 + x1 = 0 + delta_x2 = delta - delta_x1 + x2 = x2 + delta_x2 if x2 < w - delta_x2 else w - 1 + elif x1 >= delta / 2 and x2 > w - delta / 2: + delta_x2 = w - x2 + x2 = w - 1 + delta_x1 = delta - x1 + x1 = x1 - delta_x1 if x1 >= delta_x1 else 0 + + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + return [x1, y1, x2, y2] + + +def face_detection_PIL_v2(image, face_model): + crop_size = 112 + face_detector = FaceDetector(face_model) + img = np.array(image) + h, w = img.shape[0:2] + bxs, conf = face_detector.do_detect(img) + bx = bxs[0] + bx = adjust_bx_v2(bx, w, h) + x1, y1, x2, y2 = bx + image = img[y1:y2, x1:x2, :] + img = Image.fromarray(image) + img = img.resize((crop_size, crop_size)) + bx = tuple(bx) + return img, bx diff --git a/modelscope/models/cv/face_generation/__init__.py b/modelscope/models/cv/face_generation/__init__.py new file mode 100644 index 00000000..35a63f1c --- /dev/null +++ b/modelscope/models/cv/face_generation/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .stylegan2 import Generator + +else: + _import_structure = { + 'stylegan2': ['Generator'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/face_generation/op/__init__.py b/modelscope/models/cv/face_generation/op/__init__.py new file mode 100755 index 00000000..d0918d92 --- /dev/null +++ b/modelscope/models/cv/face_generation/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/modelscope/models/cv/face_generation/op/conv2d_gradfix.py b/modelscope/models/cv/face_generation/op/conv2d_gradfix.py new file mode 100755 index 00000000..a3aba91f --- /dev/null +++ b/modelscope/models/cv/face_generation/op/conv2d_gradfix.py @@ -0,0 +1,228 @@ +# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License +# at https://github.com/rosinality/stylegan2-pytorch/blob/master/op/conv2d_gradfix.py +import contextlib +import warnings + +import torch +from torch import autograd +from torch.nn import functional as F + +enabled = True +weight_gradients_disabled = False + + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + + +def conv2d(input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1): + if could_use_op(input): + return conv2d_gradfix( + transpose=False, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=0, + dilation=dilation, + groups=groups, + ).apply(input, weight, bias) + + return F.conv2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + +def conv_transpose2d( + input, + weight, + bias=None, + stride=1, + padding=0, + output_padding=0, + groups=1, + dilation=1, +): + if could_use_op(input): + return conv2d_gradfix( + transpose=True, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ).apply(input, weight, bias) + + return F.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + ) + + +def could_use_op(input): + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + + if input.device.type != 'cuda': + return False + + warnings.warn( + f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().' + ) + + return False + + +def ensure_tuple(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim + + return xs + + +conv2d_gradfix_cache = dict() + + +def conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, + dilation, groups): + ndim = 2 + weight_shape = tuple(weight_shape) + stride = ensure_tuple(stride, ndim) + padding = ensure_tuple(padding, ndim) + output_padding = ensure_tuple(output_padding, ndim) + dilation = ensure_tuple(dilation, ndim) + + key = (transpose, weight_shape, stride, padding, output_padding, dilation, + groups) + if key in conv2d_gradfix_cache: + return conv2d_gradfix_cache[key] + + common_kwargs = dict( + stride=stride, padding=padding, dilation=dilation, groups=groups) + + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + + a = input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] + return [ + a - (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + class Conv2d(autograd.Function): + + @staticmethod + def forward(ctx, input, weight, bias): + if not transpose: + out = F.conv2d( + input=input, weight=weight, bias=bias, **common_kwargs) + + else: + out = F.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + output_padding=output_padding, + **common_kwargs, + ) + + ctx.save_for_backward(input, weight) + + return out + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input, grad_weight, grad_bias = None, None, None + + if ctx.needs_input_grad[0]: + p = calc_output_padding( + input_shape=input.shape, output_shape=grad_output.shape) + grad_input = conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs, + ).apply(grad_output, weight, None) + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum((0, 2, 3)) + + return grad_input, grad_weight, grad_bias + + class Conv2dGradWeight(autograd.Function): + + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation( + 'aten::cudnn_convolution_backward_weight' if not transpose else + 'aten::cudnn_convolution_transpose_backward_weight') + flags = [ + torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, + torch.backends.cudnn.allow_tf32, + ] + grad_weight = op( + weight_shape, + grad_output, + input, + padding, + stride, + dilation, + groups, + *flags, + ) + ctx.save_for_backward(grad_output, input) + + return grad_weight + + @staticmethod + def backward(ctx, grad_grad_weight): + grad_output, input = ctx.saved_tensors + grad_grad_output, grad_grad_input = None, None + + if ctx.needs_input_grad[0]: + grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) + + if ctx.needs_input_grad[1]: + p = calc_output_padding( + input_shape=input.shape, output_shape=grad_output.shape) + grad_grad_input = conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs, + ).apply(grad_output, grad_grad_weight, None) + + return grad_grad_output, grad_grad_input + + conv2d_gradfix_cache[key] = Conv2d + + return Conv2d diff --git a/modelscope/models/cv/face_generation/op/fused_act.py b/modelscope/models/cv/face_generation/op/fused_act.py new file mode 100755 index 00000000..a24f5972 --- /dev/null +++ b/modelscope/models/cv/face_generation/op/fused_act.py @@ -0,0 +1,115 @@ +# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License +# t https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.nn import functional as F + +def_lib = False + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, bias, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act(grad_output.contiguous(), empty, out, + 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + if bias: + grad_bias = grad_input.sum(dim).detach() + + else: + grad_bias = empty + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input.contiguous(), + gradgrad_bias, + out, + 3, + 1, + ctx.negative_slope, + ctx.scale, + ) + + return gradgrad_out, None, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + + ctx.bias = bias is not None + + if bias is None: + bias = empty + + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, + scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale) + + if not ctx.bias: + grad_bias = None + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5): + super().__init__() + + if bias: + self.bias = nn.Parameter(torch.zeros(channel)) + + else: + self.bias = None + + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, + self.scale) + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5): + if not def_lib: + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return (F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), + negative_slope=0.2) * scale) + + else: + return F.leaky_relu(input, negative_slope=0.2) * scale diff --git a/modelscope/models/cv/face_generation/op/upfirdn2d.py b/modelscope/models/cv/face_generation/op/upfirdn2d.py new file mode 100755 index 00000000..95c987af --- /dev/null +++ b/modelscope/models/cv/face_generation/op/upfirdn2d.py @@ -0,0 +1,199 @@ +# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License +# at https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py +import os +from collections import abc + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +def_lib = False + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, + in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], + in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], + ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], + ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, + pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if not isinstance(up, abc.Iterable): + up = (up, up) + + if not isinstance(down, abc.Iterable): + down = (down, down) + + if len(pad) == 2: + pad = (pad[0], pad[1], pad[0], pad[1]) + + if not def_lib: + out = upfirdn2d_native(input, kernel, *up, *down, *pad) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, + pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, + [0, 0, + max(pad_x0, 0), + max(pad_x1, 0), + max(pad_y0, 0), + max(pad_y1, 0)]) + out = out[:, + max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0)] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x + + return out.view(-1, channel, out_h, out_w) diff --git a/modelscope/models/cv/face_generation/stylegan2.py b/modelscope/models/cv/face_generation/stylegan2.py new file mode 100755 index 00000000..4c650f54 --- /dev/null +++ b/modelscope/models/cv/face_generation/stylegan2.py @@ -0,0 +1,733 @@ +# The implementation is adopted from stylegan2-pytorch, +# made public available under the MIT License at https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py +import functools +import math +import operator +import random + +import torch +from torch import nn +from torch.autograd import Function +from torch.nn import functional as F + +from .op import FusedLeakyReLU, conv2d_gradfix, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt( + torch.mean(input**2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor**2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d( + input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d( + input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor**2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + bias=True): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size**2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = conv2d_gradfix.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + + def __init__(self, + in_dim, + out_dim, + bias=True, + bias_init=0, + lr_mul=1, + activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ModulatedConv2d(nn.Module): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + fused=True, + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur( + blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size**2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + self.fused = fused + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})') + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + if not self.fused: + weight = self.scale * self.weight.squeeze(0) + style = self.modulation(style) + + if self.demodulate: + w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, + 1) + dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() + + input = input * style.reshape(batch, in_channel, 1, 1) + + if self.upsample: + weight = weight.transpose(0, 1) + out = conv2d_gradfix.conv_transpose2d( + input, weight, padding=0, stride=2) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) + + else: + out = conv2d_gradfix.conv2d( + input, weight, padding=self.padding) + + if self.demodulate: + out = out * dcoefs.view(batch, -1, 1, 1) + + return out + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view(batch * self.out_channel, in_channel, + self.kernel_size, self.kernel_size) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view(batch, self.out_channel, in_channel, + self.kernel_size, self.kernel_size) + weight = weight.transpose(1, 2).reshape(batch * in_channel, + self.out_channel, + self.kernel_size, + self.kernel_size) + out = conv2d_gradfix.conv_transpose2d( + input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = conv2d_gradfix.conv2d( + input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = conv2d_gradfix.conv2d( + input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + + def __init__(self, + in_channel, + style_dim, + upsample=True, + blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d( + in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, + style_dim, + lr_mul=lr_mlp, + activation='fused_lrelu')) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], + self.channels[4], + 3, + style_dim, + blur_kernel=blur_kernel) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2**res, 2**res] + self.noises.register_buffer(f'noise_{layer_idx}', + torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2**i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + )) + + self.convs.append( + StyledConv( + out_channel, + out_channel, + 3, + style_dim, + blur_kernel=blur_kernel)) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') + for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append(truncation_latent + + truncation * (style - truncation_latent)) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat( + 1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + )) + + if activate: + layers.append(FusedLeakyReLU(out_channel, bias=bias)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, + out_channel, + 1, + downsample=True, + activate=False, + bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2**(i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear( + channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view(group, -1, self.stddev_feat, + channel // self.stddev_feat, height, width) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/modelscope/models/cv/face_human_hand_detection/__init__.py b/modelscope/models/cv/face_human_hand_detection/__init__.py new file mode 100644 index 00000000..33a5fd2f --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .det_infer import NanoDetForFaceHumanHandDetection + +else: + _import_structure = {'det_infer': ['NanoDetForFaceHumanHandDetection']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/face_human_hand_detection/det_infer.py b/modelscope/models/cv/face_human_hand_detection/det_infer.py new file mode 100644 index 00000000..6822bd9f --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/det_infer.py @@ -0,0 +1,138 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .one_stage_detector import OneStageDetector + +logger = get_logger() + + +def load_model_weight(model_dir, device): + checkpoint = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device) + state_dict = checkpoint['state_dict'].copy() + for k in checkpoint['state_dict']: + if k.startswith('avg_model.'): + v = state_dict.pop(k) + state_dict[k[4:]] = v + + return state_dict + + +@MODELS.register_module( + Tasks.face_human_hand_detection, + module_name=Models.face_human_hand_detection) +class NanoDetForFaceHumanHandDetection(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = OneStageDetector() + if torch.cuda.is_available(): + self.device = 'cuda' + logger.info('Use GPU ') + else: + self.device = 'cpu' + logger.info('Use CPU') + + self.state_dict = load_model_weight(model_dir, self.device) + self.model.load_state_dict(self.state_dict, strict=False) + self.model.eval() + self.model.to(self.device) + + def forward(self, x): + pred_result = self.model.inference(x) + return pred_result + + +def naive_collate(batch): + elem = batch[0] + if isinstance(elem, dict): + return {key: naive_collate([d[key] for d in batch]) for key in elem} + else: + return batch + + +def get_resize_matrix(raw_shape, dst_shape): + + r_w, r_h = raw_shape + d_w, d_h = dst_shape + Rs = np.eye(3) + + Rs[0, 0] *= d_w / r_w + Rs[1, 1] *= d_h / r_h + return Rs + + +def color_aug_and_norm(meta, mean, std): + img = meta['img'].astype(np.float32) / 255 + mean = np.array(mean, dtype=np.float32).reshape(1, 1, 3) / 255 + std = np.array(std, dtype=np.float32).reshape(1, 1, 3) / 255 + img = (img - mean) / std + meta['img'] = img + return meta + + +def img_process(meta, mean, std): + raw_img = meta['img'] + height = raw_img.shape[0] + width = raw_img.shape[1] + dst_shape = [320, 320] + M = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) + ResizeM = get_resize_matrix((width, height), dst_shape) + M = ResizeM @ M + img = cv2.warpPerspective(raw_img, M, dsize=tuple(dst_shape)) + meta['img'] = img + meta['warp_matrix'] = M + meta = color_aug_and_norm(meta, mean, std) + return meta + + +def overlay_bbox_cv(dets, class_names, score_thresh): + all_box = [] + for label in dets: + for bbox in dets[label]: + score = bbox[-1] + if score > score_thresh: + x0, y0, x1, y1 = [int(i) for i in bbox[:4]] + all_box.append([label, x0, y0, x1, y1, score]) + all_box.sort(key=lambda v: v[5]) + return all_box + + +mean = [103.53, 116.28, 123.675] +std = [57.375, 57.12, 58.395] +class_names = ['person', 'face', 'hand'] + + +def inference(model, device, img): + img = img.cpu().numpy() + img_info = {'id': 0} + height, width = img.shape[:2] + img_info['height'] = height + img_info['width'] = width + meta = dict(img_info=img_info, raw_img=img, img=img) + + meta = img_process(meta, mean, std) + meta['img'] = torch.from_numpy(meta['img'].transpose(2, 0, 1)).to(device) + meta = naive_collate([meta]) + meta['img'] = (meta['img'][0]).reshape(1, 3, 320, 320) + with torch.no_grad(): + res = model(meta) + result = overlay_bbox_cv(res[0], class_names, score_thresh=0.35) + cls_list, bbox_list, score_list = [], [], [] + for pred in result: + cls_list.append(pred[0]) + bbox_list.append([pred[1], pred[2], pred[3], pred[4]]) + score_list.append(pred[5]) + return cls_list, bbox_list, score_list diff --git a/modelscope/models/cv/face_human_hand_detection/ghost_pan.py b/modelscope/models/cv/face_human_hand_detection/ghost_pan.py new file mode 100644 index 00000000..e00de407 --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/ghost_pan.py @@ -0,0 +1,395 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import math + +import torch +import torch.nn as nn + +from .utils import ConvModule, DepthwiseConvModule, act_layers + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0) + else: + return F.relu6(x + 3.0) / 6.0 + + +class SqueezeExcite(nn.Module): + + def __init__(self, + in_chs, + se_ratio=0.25, + reduced_base_chs=None, + activation='ReLU', + gate_fn=hard_sigmoid, + divisor=4, + **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, + divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layers(activation) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class GhostModule(nn.Module): + + def __init__(self, + inp, + oup, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + activation='ReLU'): + super(GhostModule, self).__init__() + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + + self.primary_conv = nn.Sequential( + nn.Conv2d( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False), + nn.BatchNorm2d(init_channels), + act_layers(activation) if activation else nn.Sequential(), + ) + + self.cheap_operation = nn.Sequential( + nn.Conv2d( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ), + nn.BatchNorm2d(new_channels), + act_layers(activation) if activation else nn.Sequential(), + ) + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out + + +class GhostBottleneck(nn.Module): + """Ghost bottleneck w/ optional SE""" + + def __init__( + self, + in_chs, + mid_chs, + out_chs, + dw_kernel_size=3, + stride=1, + activation='ReLU', + se_ratio=0.0, + ): + super(GhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0.0 + self.stride = stride + + # Point-wise expansion + self.ghost1 = GhostModule(in_chs, mid_chs, activation=activation) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d( + mid_chs, + mid_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=mid_chs, + bias=False, + ) + self.bn_dw = nn.BatchNorm2d(mid_chs) + + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + self.ghost2 = GhostModule(mid_chs, out_chs, activation=None) + + if in_chs == out_chs and self.stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_chs, + in_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=in_chs, + bias=False, + ), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + residual = x + + x = self.ghost1(x) + + if self.stride > 1: + x = self.conv_dw(x) + x = self.bn_dw(x) + + if self.se is not None: + x = self.se(x) + + x = self.ghost2(x) + + x += self.shortcut(residual) + return x + + +class GhostBlocks(nn.Module): + """Stack of GhostBottleneck used in GhostPAN. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expand (int): Expand ratio of GhostBottleneck. Default: 1. + kernel_size (int): Kernel size of depthwise convolution. Default: 5. + num_blocks (int): Number of GhostBottlecneck blocks. Default: 1. + use_res (bool): Whether to use residual connection. Default: False. + activation (str): Name of activation function. Default: LeakyReLU. + """ + + def __init__( + self, + in_channels, + out_channels, + expand=1, + kernel_size=5, + num_blocks=1, + use_res=False, + activation='LeakyReLU', + ): + super(GhostBlocks, self).__init__() + self.use_res = use_res + if use_res: + self.reduce_conv = ConvModule( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + activation=activation, + ) + blocks = [] + for _ in range(num_blocks): + blocks.append( + GhostBottleneck( + in_channels, + int(out_channels * expand), + out_channels, + dw_kernel_size=kernel_size, + activation=activation, + )) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + out = self.blocks(x) + if self.use_res: + out = out + self.reduce_conv(x) + return out + + +class GhostPAN(nn.Module): + """Path Aggregation Network with Ghost block. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 3 + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: False + kernel_size (int): Kernel size of depthwise convolution. Default: 5. + expand (int): Expand ratio of GhostBottleneck. Default: 1. + num_blocks (int): Number of GhostBottlecneck blocks. Default: 1. + use_res (bool): Whether to use residual connection. Default: False. + num_extra_level (int): Number of extra conv layers for more feature levels. + Default: 0. + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(scale_factor=2, mode='nearest')` + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN') + activation (str): Activation layer name. + Default: LeakyReLU. + """ + + def __init__( + self, + in_channels, + out_channels, + use_depthwise=False, + kernel_size=5, + expand=1, + num_blocks=1, + use_res=False, + num_extra_level=0, + upsample_cfg=dict(scale_factor=2, mode='bilinear'), + norm_cfg=dict(type='BN'), + activation='LeakyReLU', + ): + super(GhostPAN, self).__init__() + assert num_extra_level >= 0 + assert num_blocks >= 1 + self.in_channels = in_channels + self.out_channels = out_channels + + conv = DepthwiseConvModule if use_depthwise else ConvModule + + # build top-down blocks + self.upsample = nn.Upsample(**upsample_cfg) + self.reduce_layers = nn.ModuleList() + for idx in range(len(in_channels)): + self.reduce_layers.append( + ConvModule( + in_channels[idx], + out_channels, + 1, + norm_cfg=norm_cfg, + activation=activation, + )) + self.top_down_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1, 0, -1): + self.top_down_blocks.append( + GhostBlocks( + out_channels * 2, + out_channels, + expand, + kernel_size=kernel_size, + num_blocks=num_blocks, + use_res=use_res, + activation=activation, + )) + + # build bottom-up blocks + self.downsamples = nn.ModuleList() + self.bottom_up_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + conv( + out_channels, + out_channels, + kernel_size, + stride=2, + padding=kernel_size // 2, + norm_cfg=norm_cfg, + activation=activation, + )) + self.bottom_up_blocks.append( + GhostBlocks( + out_channels * 2, + out_channels, + expand, + kernel_size=kernel_size, + num_blocks=num_blocks, + use_res=use_res, + activation=activation, + )) + + # extra layers + self.extra_lvl_in_conv = nn.ModuleList() + self.extra_lvl_out_conv = nn.ModuleList() + for i in range(num_extra_level): + self.extra_lvl_in_conv.append( + conv( + out_channels, + out_channels, + kernel_size, + stride=2, + padding=kernel_size // 2, + norm_cfg=norm_cfg, + activation=activation, + )) + self.extra_lvl_out_conv.append( + conv( + out_channels, + out_channels, + kernel_size, + stride=2, + padding=kernel_size // 2, + norm_cfg=norm_cfg, + activation=activation, + )) + + def forward(self, inputs): + """ + Args: + inputs (tuple[Tensor]): input features. + Returns: + tuple[Tensor]: multi level features. + """ + assert len(inputs) == len(self.in_channels) + inputs = [ + reduce(input_x) + for input_x, reduce in zip(inputs, self.reduce_layers) + ] + # top-down path + inner_outs = [inputs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = inputs[idx - 1] + + inner_outs[0] = feat_heigh + + upsample_feat = self.upsample(feat_heigh) + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + torch.cat([upsample_feat, feat_low], 1)) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx]( + torch.cat([downsample_feat, feat_height], 1)) + outs.append(out) + + # extra layers + for extra_in_layer, extra_out_layer in zip(self.extra_lvl_in_conv, + self.extra_lvl_out_conv): + outs.append(extra_in_layer(inputs[-1]) + extra_out_layer(outs[-1])) + + return tuple(outs) diff --git a/modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py b/modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py new file mode 100644 index 00000000..7f5b50ec --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/nanodet_plus_head.py @@ -0,0 +1,427 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import nms + +from .utils import ConvModule, DepthwiseConvModule + + +class Integral(nn.Module): + """A fixed layer for calculating integral result from distribution. + This layer calculates the target location by :math: `sum{P(y_i) * y_i}`, + P(y_i) denotes the softmax vector that represents the discrete distribution + y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max} + Args: + reg_max (int): The maximal value of the discrete set. Default: 16. You + may want to reset it according to your new dataset or related + settings. + """ + + def __init__(self, reg_max=16): + super(Integral, self).__init__() + self.reg_max = reg_max + self.register_buffer('project', + torch.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + Args: + x (Tensor): Features of the regression head, shape (N, 4*(n+1)), + n is self.reg_max. + Returns: + x (Tensor): Integral result of box locations, i.e., distance + offsets from the box center in four directions, shape (N, 4). + """ + shape = x.size() + x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1) + x = F.linear(x, self.project.type_as(x)).reshape(*shape[:-1], 4) + return x + + +def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False): + """Performs non-maximum suppression in a batched fashion. + Modified from https://github.com/pytorch/vision/blob + /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39. + In order to perform NMS independently per class, we add an offset to all + the boxes. The offset is dependent only on the class idx, and is large + enough so that boxes from different classes do not overlap. + Arguments: + boxes (torch.Tensor): boxes in shape (N, 4). + scores (torch.Tensor): scores in shape (N, ). + idxs (torch.Tensor): each index value correspond to a bbox cluster, + and NMS will not be applied between elements of different idxs, + shape (N, ). + nms_cfg (dict): specify nms type and other parameters like iou_thr. + Possible keys includes the following. + - iou_thr (float): IoU threshold used for NMS. + - split_thr (float): threshold number of boxes. In some cases the + number of boxes is large (e.g., 200k). To avoid OOM during + training, the users could set `split_thr` to a small value. + If the number of boxes is greater than the threshold, it will + perform NMS on each group of boxes separately and sequentially. + Defaults to 10000. + class_agnostic (bool): if true, nms is class agnostic, + i.e. IoU thresholding happens over all boxes, + regardless of the predicted class. + Returns: + tuple: kept dets and indice. + """ + nms_cfg_ = nms_cfg.copy() + class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic) + if class_agnostic: + boxes_for_nms = boxes + else: + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None] + nms_cfg_.pop('type', 'nms') + split_thr = nms_cfg_.pop('split_thr', 10000) + if len(boxes_for_nms) < split_thr: + keep = nms(boxes_for_nms, scores, **nms_cfg_) + boxes = boxes[keep] + scores = scores[keep] + else: + total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) + for id in torch.unique(idxs): + mask = (idxs == id).nonzero(as_tuple=False).view(-1) + keep = nms(boxes_for_nms[mask], scores[mask], **nms_cfg_) + total_mask[mask[keep]] = True + + keep = total_mask.nonzero(as_tuple=False).view(-1) + keep = keep[scores[keep].argsort(descending=True)] + boxes = boxes[keep] + scores = scores[keep] + + return torch.cat([boxes, scores[:, None]], -1), keep + + +def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, + nms_cfg, + max_num=-1, + score_factors=None): + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_thr (float): NMS IoU threshold + max_num (int): if there are more than max_num bboxes after NMS, + only top max_num will be kept. + score_factors (Tensor): The factors multiplied to scores before + applying NMS + + Returns: + tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ + are 0-based. + """ + num_classes = multi_scores.size(1) - 1 + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) + scores = multi_scores[:, :-1] + + valid_mask = scores > score_thr + + bboxes = torch.masked_select( + bboxes, + torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), + -1)).view(-1, 4) + if score_factors is not None: + scores = scores * score_factors[:, None] + scores = torch.masked_select(scores, valid_mask) + labels = valid_mask.nonzero(as_tuple=False)[:, 1] + + if bboxes.numel() == 0: + bboxes = multi_bboxes.new_zeros((0, 5)) + labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) + + if torch.onnx.is_in_onnx_export(): + raise RuntimeError('[ONNX Error] Can not record NMS ' + 'as it has not been executed this time') + return bboxes, labels + + dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) + + if max_num > 0: + dets = dets[:max_num] + keep = keep[:max_num] + + return dets, labels[keep] + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return torch.stack([x1, y1, x2, y2], -1) + + +def warp_boxes(boxes, M, width, height): + n = len(boxes) + if n: + xy = np.ones((n * 4, 3)) + xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) + xy = xy @ M.T + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) + xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) + return xy.astype(np.float32) + else: + return boxes + + +class NanoDetPlusHead(nn.Module): + """Detection head used in NanoDet-Plus. + + Args: + num_classes (int): Number of categories excluding the background + category. + loss (dict): Loss config. + input_channel (int): Number of channels of the input feature. + feat_channels (int): Number of channels of the feature. + Default: 96. + stacked_convs (int): Number of conv layers in the stacked convs. + Default: 2. + kernel_size (int): Size of the convolving kernel. Default: 5. + strides (list[int]): Strides of input multi-level feature maps. + Default: [8, 16, 32]. + conv_type (str): Type of the convolution. + Default: "DWConv". + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + reg_max (int): The maximal value of the discrete set. Default: 7. + activation (str): Type of activation function. Default: "LeakyReLU". + assigner_cfg (dict): Config dict of the assigner. Default: dict(topk=13). + """ + + def __init__(self, + num_classes, + input_channel, + feat_channels=96, + stacked_convs=2, + kernel_size=5, + strides=[8, 16, 32], + conv_type='DWConv', + norm_cfg=dict(type='BN'), + reg_max=7, + activation='LeakyReLU', + assigner_cfg=dict(topk=13), + **kwargs): + super(NanoDetPlusHead, self).__init__() + self.num_classes = num_classes + self.in_channels = input_channel + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.kernel_size = kernel_size + self.strides = strides + self.reg_max = reg_max + self.activation = activation + self.ConvModule = ConvModule if conv_type == 'Conv' else DepthwiseConvModule + + self.norm_cfg = norm_cfg + self.distribution_project = Integral(self.reg_max) + + self._init_layers() + + def _init_layers(self): + self.cls_convs = nn.ModuleList() + for _ in self.strides: + cls_convs = self._buid_not_shared_head() + self.cls_convs.append(cls_convs) + + self.gfl_cls = nn.ModuleList([ + nn.Conv2d( + self.feat_channels, + self.num_classes + 4 * (self.reg_max + 1), + 1, + padding=0, + ) for _ in self.strides + ]) + + def _buid_not_shared_head(self): + cls_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + cls_convs.append( + self.ConvModule( + chn, + self.feat_channels, + self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + bias=self.norm_cfg is None, + activation=self.activation, + )) + return cls_convs + + def forward(self, feats): + if torch.onnx.is_in_onnx_export(): + return self._forward_onnx(feats) + outputs = [] + for feat, cls_convs, gfl_cls in zip( + feats, + self.cls_convs, + self.gfl_cls, + ): + for conv in cls_convs: + feat = conv(feat) + output = gfl_cls(feat) + outputs.append(output.flatten(start_dim=2)) + outputs = torch.cat(outputs, dim=2).permute(0, 2, 1) + return outputs + + def post_process(self, preds, meta): + """Prediction results post processing. Decode bboxes and rescale + to original image size. + Args: + preds (Tensor): Prediction output. + meta (dict): Meta info. + """ + cls_scores, bbox_preds = preds.split( + [self.num_classes, 4 * (self.reg_max + 1)], dim=-1) + result_list = self.get_bboxes(cls_scores, bbox_preds, meta) + det_results = {} + warp_matrixes = ( + meta['warp_matrix'] + if isinstance(meta['warp_matrix'], list) else meta['warp_matrix']) + img_heights = ( + meta['img_info']['height'].cpu().numpy() if isinstance( + meta['img_info']['height'], torch.Tensor) else + meta['img_info']['height']) + img_widths = ( + meta['img_info']['width'].cpu().numpy() if isinstance( + meta['img_info']['width'], torch.Tensor) else + meta['img_info']['width']) + img_ids = ( + meta['img_info']['id'].cpu().numpy() if isinstance( + meta['img_info']['id'], torch.Tensor) else + meta['img_info']['id']) + + for result, img_width, img_height, img_id, warp_matrix in zip( + result_list, img_widths, img_heights, img_ids, warp_matrixes): + det_result = {} + det_bboxes, det_labels = result + det_bboxes = det_bboxes.detach().cpu().numpy() + det_bboxes[:, :4] = warp_boxes(det_bboxes[:, :4], + np.linalg.inv(warp_matrix), + img_width, img_height) + classes = det_labels.detach().cpu().numpy() + for i in range(self.num_classes): + inds = classes == i + det_result[i] = np.concatenate( + [ + det_bboxes[inds, :4].astype(np.float32), + det_bboxes[inds, 4:5].astype(np.float32), + ], + axis=1, + ).tolist() + det_results[img_id] = det_result + return det_results + + def get_bboxes(self, cls_preds, reg_preds, img_metas): + """Decode the outputs to bboxes. + Args: + cls_preds (Tensor): Shape (num_imgs, num_points, num_classes). + reg_preds (Tensor): Shape (num_imgs, num_points, 4 * (regmax + 1)). + img_metas (dict): Dict of image info. + + Returns: + results_list (list[tuple]): List of detection bboxes and labels. + """ + device = cls_preds.device + b = cls_preds.shape[0] + input_height, input_width = img_metas['img'].shape[2:] + input_shape = (input_height, input_width) + + featmap_sizes = [(math.ceil(input_height / stride), + math.ceil(input_width) / stride) + for stride in self.strides] + mlvl_center_priors = [ + self.get_single_level_center_priors( + b, + featmap_sizes[i], + stride, + dtype=torch.float32, + device=device, + ) for i, stride in enumerate(self.strides) + ] + center_priors = torch.cat(mlvl_center_priors, dim=1) + dis_preds = self.distribution_project(reg_preds) * center_priors[..., + 2, + None] + bboxes = distance2bbox( + center_priors[..., :2], dis_preds, max_shape=input_shape) + scores = cls_preds.sigmoid() + result_list = [] + for i in range(b): + score, bbox = scores[i], bboxes[i] + padding = score.new_zeros(score.shape[0], 1) + score = torch.cat([score, padding], dim=1) + results = multiclass_nms( + bbox, + score, + score_thr=0.05, + nms_cfg=dict(type='nms', iou_threshold=0.6), + max_num=100, + ) + result_list.append(results) + return result_list + + def get_single_level_center_priors(self, batch_size, featmap_size, stride, + dtype, device): + """Generate centers of a single stage feature map. + Args: + batch_size (int): Number of images in one batch. + featmap_size (tuple[int]): height and width of the feature map + stride (int): down sample stride of the feature map + dtype (obj:`torch.dtype`): data type of the tensors + device (obj:`torch.device`): device of the tensors + Return: + priors (Tensor): center priors of a single level feature map. + """ + h, w = featmap_size + x_range = (torch.arange(w, dtype=dtype, device=device)) * stride + y_range = (torch.arange(h, dtype=dtype, device=device)) * stride + y, x = torch.meshgrid(y_range, x_range) + y = y.flatten() + x = x.flatten() + strides = x.new_full((x.shape[0], ), stride) + proiors = torch.stack([x, y, strides, strides], dim=-1) + return proiors.unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py b/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py new file mode 100644 index 00000000..0d1cd15d --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/one_stage_detector.py @@ -0,0 +1,61 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import torch +import torch.nn as nn + +from .ghost_pan import GhostPAN +from .nanodet_plus_head import NanoDetPlusHead +from .shufflenetv2 import ShuffleNetV2 + + +class OneStageDetector(nn.Module): + + def __init__(self): + super(OneStageDetector, self).__init__() + self.backbone = ShuffleNetV2( + model_size='1.0x', + out_stages=(2, 3, 4), + with_last_conv=False, + kernal_size=3, + activation='LeakyReLU', + pretrain=False) + self.fpn = GhostPAN( + in_channels=[116, 232, 464], + out_channels=96, + use_depthwise=True, + kernel_size=5, + expand=1, + num_blocks=1, + use_res=False, + num_extra_level=1, + upsample_cfg=dict(scale_factor=2, mode='bilinear'), + norm_cfg=dict(type='BN'), + activation='LeakyReLU') + self.head = NanoDetPlusHead( + num_classes=3, + input_channel=96, + feat_channels=96, + stacked_convs=2, + kernel_size=5, + strides=[8, 16, 32, 64], + conv_type='DWConv', + norm_cfg=dict(type='BN'), + reg_max=7, + activation='LeakyReLU', + assigner_cfg=dict(topk=13)) + self.epoch = 0 + + def forward(self, x): + x = self.backbone(x) + if hasattr(self, 'fpn'): + x = self.fpn(x) + if hasattr(self, 'head'): + x = self.head(x) + return x + + def inference(self, meta): + with torch.no_grad(): + preds = self(meta['img']) + results = self.head.post_process(preds, meta) + return results diff --git a/modelscope/models/cv/face_human_hand_detection/shufflenetv2.py b/modelscope/models/cv/face_human_hand_detection/shufflenetv2.py new file mode 100644 index 00000000..7f4dfc2a --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/shufflenetv2.py @@ -0,0 +1,182 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import torch +import torch.nn as nn + +from .utils import act_layers + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = num_channels // groups + + x = x.view(batchsize, groups, channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + x = x.view(batchsize, -1, height, width) + + return x + + +class ShuffleV2Block(nn.Module): + + def __init__(self, inp, oup, stride, activation='ReLU'): + super(ShuffleV2Block, self).__init__() + + if not (1 <= stride <= 3): + raise ValueError('illegal stride value') + self.stride = stride + + branch_features = oup // 2 + assert (self.stride != 1) or (inp == branch_features << 1) + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv( + inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(inp), + nn.Conv2d( + inp, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False), + nn.BatchNorm2d(branch_features), + act_layers(activation), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + act_layers(activation), + self.depthwise_conv( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + ), + nn.BatchNorm2d(branch_features), + nn.Conv2d( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + act_layers(activation), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d( + i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + +class ShuffleNetV2(nn.Module): + + def __init__( + self, + model_size='1.5x', + out_stages=(2, 3, 4), + with_last_conv=False, + kernal_size=3, + activation='ReLU', + pretrain=True, + ): + super(ShuffleNetV2, self).__init__() + assert set(out_stages).issubset((2, 3, 4)) + + print('model size is ', model_size) + + self.stage_repeats = [4, 8, 4] + self.model_size = model_size + self.out_stages = out_stages + self.with_last_conv = with_last_conv + self.kernal_size = kernal_size + self.activation = activation + if model_size == '0.5x': + self._stage_out_channels = [24, 48, 96, 192, 1024] + elif model_size == '1.0x': + self._stage_out_channels = [24, 116, 232, 464, 1024] + elif model_size == '1.5x': + self._stage_out_channels = [24, 176, 352, 704, 1024] + elif model_size == '2.0x': + self._stage_out_channels = [24, 244, 488, 976, 2048] + else: + raise NotImplementedError + + # building first layer + input_channels = 3 + output_channels = self._stage_out_channels[0] + self.conv1 = nn.Sequential( + nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), + nn.BatchNorm2d(output_channels), + act_layers(activation), + ) + input_channels = output_channels + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] + for name, repeats, output_channels in zip( + stage_names, self.stage_repeats, self._stage_out_channels[1:]): + seq = [ + ShuffleV2Block( + input_channels, output_channels, 2, activation=activation) + ] + for i in range(repeats - 1): + seq.append( + ShuffleV2Block( + output_channels, + output_channels, + 1, + activation=activation)) + setattr(self, name, nn.Sequential(*seq)) + input_channels = output_channels + output_channels = self._stage_out_channels[-1] + if self.with_last_conv: + conv5 = nn.Sequential( + nn.Conv2d( + input_channels, output_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(output_channels), + act_layers(activation), + ) + self.stage4.add_module('conv5', conv5) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + output = [] + + for i in range(2, 5): + stage = getattr(self, 'stage{}'.format(i)) + x = stage(x) + if i in self.out_stages: + output.append(x) + return tuple(output) diff --git a/modelscope/models/cv/face_human_hand_detection/utils.py b/modelscope/models/cv/face_human_hand_detection/utils.py new file mode 100644 index 00000000..f989c164 --- /dev/null +++ b/modelscope/models/cv/face_human_hand_detection/utils.py @@ -0,0 +1,277 @@ +# The implementation here is modified based on nanodet, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/RangiLyu/nanodet + +import torch +import torch.nn as nn + +activations = { + 'ReLU': nn.ReLU, + 'LeakyReLU': nn.LeakyReLU, + 'ReLU6': nn.ReLU6, + 'SELU': nn.SELU, + 'ELU': nn.ELU, + 'GELU': nn.GELU, + 'PReLU': nn.PReLU, + 'SiLU': nn.SiLU, + 'HardSwish': nn.Hardswish, + 'Hardswish': nn.Hardswish, + None: nn.Identity, +} + + +def act_layers(name): + assert name in activations.keys() + if name == 'LeakyReLU': + return nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif name == 'GELU': + return nn.GELU() + elif name == 'PReLU': + return nn.PReLU() + else: + return activations[name](inplace=True) + + +norm_cfg = { + 'BN': ('bn', nn.BatchNorm2d), + 'SyncBN': ('bn', nn.SyncBatchNorm), + 'GN': ('gn', nn.GroupNorm), +} + + +def build_norm_layer(cfg, num_features, postfix=''): + """Build normalization layer + + Args: + cfg (dict): cfg should contain: + type (str): identify norm layer type. + layer args: args needed to instantiate a norm layer. + requires_grad (bool): [optional] whether stop gradient updates + num_features (int): number of channels from input. + postfix (int, str): appended into norm abbreviation to + create named layer. + + Returns: + name (str): abbreviation + postfix + layer (nn.Module): created norm layer + """ + assert isinstance(cfg, dict) and 'type' in cfg + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type not in norm_cfg: + raise KeyError('Unrecognized norm type {}'.format(layer_type)) + else: + abbr, norm_layer = norm_cfg[layer_type] + if norm_layer is None: + raise NotImplementedError + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + requires_grad = cfg_.pop('requires_grad', True) + cfg_.setdefault('eps', 1e-5) + if layer_type != 'GN': + layer = norm_layer(num_features, **cfg_) + if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): + layer._specify_ddp_gpu_num(1) + else: + assert 'num_groups' in cfg_ + layer = norm_layer(num_channels=num_features, **cfg_) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return name, layer + + +class ConvModule(nn.Module): + """A conv block that contains conv/norm/activation layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + conv_cfg (dict): Config dict for convolution layer. + norm_cfg (dict): Config dict for normalization layer. + activation (str): activation layer, "ReLU" by default. + inplace (bool): Whether to use inplace mode for activation. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias='auto', + conv_cfg=None, + norm_cfg=None, + activation='ReLU', + inplace=True, + order=('conv', 'norm', 'act'), + ): + super(ConvModule, self).__init__() + assert conv_cfg is None or isinstance(conv_cfg, dict) + assert norm_cfg is None or isinstance(norm_cfg, dict) + assert activation is None or isinstance(activation, str) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.activation = activation + self.inplace = inplace + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == {'conv', 'norm', 'act'} + + self.with_norm = norm_cfg is not None + if bias == 'auto': + bias = False if self.with_norm else True + self.with_bias = bias + + if self.with_norm and self.with_bias: + warnings.warn('ConvModule has norm and bias at the same time') + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = self.conv.padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_norm: + if order.index('norm') > order.index('conv'): + norm_channels = out_channels + else: + norm_channels = in_channels + self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) + self.add_module(self.norm_name, norm) + else: + self.norm_name = None + + if self.activation: + self.act = act_layers(self.activation) + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def forward(self, x, norm=True): + for layer in self.order: + if layer == 'conv': + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: + x = self.norm(x) + elif layer == 'act' and self.activation: + x = self.act(x) + return x + + +class DepthwiseConvModule(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias='auto', + norm_cfg=dict(type='BN'), + activation='ReLU', + inplace=True, + order=('depthwise', 'dwnorm', 'act', 'pointwise', 'pwnorm', 'act'), + ): + super(DepthwiseConvModule, self).__init__() + assert activation is None or isinstance(activation, str) + self.activation = activation + self.inplace = inplace + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 6 + assert set(order) == { + 'depthwise', + 'dwnorm', + 'act', + 'pointwise', + 'pwnorm', + 'act', + } + + self.with_norm = norm_cfg is not None + if bias == 'auto': + bias = False if self.with_norm else True + self.with_bias = bias + + if self.with_norm and self.with_bias: + warnings.warn('ConvModule has norm and bias at the same time') + + self.depthwise = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + ) + self.pointwise = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias) + + self.in_channels = self.depthwise.in_channels + self.out_channels = self.pointwise.out_channels + self.kernel_size = self.depthwise.kernel_size + self.stride = self.depthwise.stride + self.padding = self.depthwise.padding + self.dilation = self.depthwise.dilation + self.transposed = self.depthwise.transposed + self.output_padding = self.depthwise.output_padding + + if self.with_norm: + _, self.dwnorm = build_norm_layer(norm_cfg, in_channels) + _, self.pwnorm = build_norm_layer(norm_cfg, out_channels) + + if self.activation: + self.act = act_layers(self.activation) + + def forward(self, x, norm=True): + for layer_name in self.order: + if layer_name != 'act': + layer = self.__getattr__(layer_name) + x = layer(x) + elif layer_name == 'act' and self.activation: + x = self.act(x) + return x diff --git a/modelscope/models/cv/face_recognition/__init__.py b/modelscope/models/cv/face_recognition/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_recognition/align_face.py b/modelscope/models/cv/face_recognition/align_face.py new file mode 100644 index 00000000..0477375a --- /dev/null +++ b/modelscope/models/cv/face_recognition/align_face.py @@ -0,0 +1,54 @@ +""" +The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at +https://github.com/deepinsight/insightface/blob/master/python-package/insightface/utils/face_align.py +""" +import cv2 +import numpy as np +from skimage import transform as trans + + +def align_face(image, size, lmks): + dst_w = size[1] + dst_h = size[0] + # landmark calculation of dst images + base_w = 96 + base_h = 112 + assert (dst_w >= base_w) + assert (dst_h >= base_h) + base_lmk = [ + 30.2946, 51.6963, 65.5318, 51.5014, 48.0252, 71.7366, 33.5493, 92.3655, + 62.7299, 92.2041 + ] + + dst_lmk = np.array(base_lmk).reshape((5, 2)).astype(np.float32) + if dst_w != base_w: + slide = (dst_w - base_w) / 2 + dst_lmk[:, 0] += slide + + if dst_h != base_h: + slide = (dst_h - base_h) / 2 + dst_lmk[:, 1] += slide + + src_lmk = lmks + # using skimage method + tform = trans.SimilarityTransform() + tform.estimate(src_lmk, dst_lmk) + t = tform.params[0:2, :] + + assert (image.shape[2] == 3) + + dst_image = cv2.warpAffine(image.copy(), t, (dst_w, dst_h)) + dst_pts = GetAffinePoints(src_lmk, t) + return dst_image, dst_pts + + +def GetAffinePoints(pts_in, trans): + pts_out = pts_in.copy() + assert (pts_in.shape[1] == 2) + + for k in range(pts_in.shape[0]): + pts_out[k, 0] = pts_in[k, 0] * trans[0, 0] + pts_in[k, 1] * trans[ + 0, 1] + trans[0, 2] + pts_out[k, 1] = pts_in[k, 0] * trans[1, 0] + pts_in[k, 1] * trans[ + 1, 1] + trans[1, 2] + return pts_out diff --git a/modelscope/models/cv/face_recognition/torchkit/__init__.py b/modelscope/models/cv/face_recognition/torchkit/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py b/modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py new file mode 100755 index 00000000..afe89963 --- /dev/null +++ b/modelscope/models/cv/face_recognition/torchkit/backbone/__init__.py @@ -0,0 +1,33 @@ +# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at +# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone +from .model_irse import (IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, + IR_SE_101, IR_SE_152, IR_SE_200) +from .model_resnet import ResNet_50, ResNet_101, ResNet_152 + +_model_dict = { + 'ResNet_50': ResNet_50, + 'ResNet_101': ResNet_101, + 'ResNet_152': ResNet_152, + 'IR_18': IR_18, + 'IR_34': IR_34, + 'IR_50': IR_50, + 'IR_101': IR_101, + 'IR_152': IR_152, + 'IR_200': IR_200, + 'IR_SE_50': IR_SE_50, + 'IR_SE_101': IR_SE_101, + 'IR_SE_152': IR_SE_152, + 'IR_SE_200': IR_SE_200 +} + + +def get_model(key): + """ Get different backbone network by key, + support ResNet50, ResNet_101, ResNet_152 + IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, + IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200. + """ + if key in _model_dict.keys(): + return _model_dict[key] + else: + raise KeyError('not support model {}'.format(key)) diff --git a/modelscope/models/cv/face_recognition/torchkit/backbone/common.py b/modelscope/models/cv/face_recognition/torchkit/backbone/common.py new file mode 100755 index 00000000..a1683225 --- /dev/null +++ b/modelscope/models/cv/face_recognition/torchkit/backbone/common.py @@ -0,0 +1,70 @@ +# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at +# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/common.py +import torch +import torch.nn as nn +from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Linear, Module, ReLU, + Sigmoid) + + +def initialize_weights(modules): + """ Weight initilize, conv2d and linear is initialized with kaiming_normal + """ + for m in modules: + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + + +class Flatten(Module): + """ Flat tensor + """ + + def forward(self, input): + return input.view(input.size(0), -1) + + +class SEModule(Module): + """ SE block + """ + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = Conv2d( + channels, + channels // reduction, + kernel_size=1, + padding=0, + bias=False) + + nn.init.xavier_uniform_(self.fc1.weight.data) + + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d( + channels // reduction, + channels, + kernel_size=1, + padding=0, + bias=False) + + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + + return module_input * x diff --git a/modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py b/modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py new file mode 100755 index 00000000..1982ca05 --- /dev/null +++ b/modelscope/models/cv/face_recognition/torchkit/backbone/model_irse.py @@ -0,0 +1,279 @@ +# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at +# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py +from collections import namedtuple + +from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, + MaxPool2d, Module, PReLU, Sequential) + +from .common import Flatten, SEModule, initialize_weights + + +class BasicBlockIR(Module): + """ BasicBlock for IRNet + """ + + def __init__(self, in_channel, depth, stride): + super(BasicBlockIR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + BatchNorm2d(depth), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BottleneckIR(Module): + """ BasicBlock with bottleneck for IRNet + """ + + def __init__(self, in_channel, depth, stride): + super(BottleneckIR, self).__init__() + reduction_channel = depth // 4 + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d( + in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), + BatchNorm2d(reduction_channel), PReLU(reduction_channel), + Conv2d( + reduction_channel, + reduction_channel, (3, 3), (1, 1), + 1, + bias=False), BatchNorm2d(reduction_channel), + PReLU(reduction_channel), + Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BasicBlockIRSE(BasicBlockIR): + + def __init__(self, in_channel, depth, stride): + super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_module('se_block', SEModule(depth, 16)) + + +class BottleneckIRSE(BottleneckIR): + + def __init__(self, in_channel, depth, stride): + super(BottleneckIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_module('se_block', SEModule(depth, 16)) + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + '''A named tuple describing a ResNet block.''' + + +def get_block(in_channel, depth, num_units, stride=2): + + return [Bottleneck(in_channel, depth, stride)] +\ + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 18: + blocks = [ + get_block(in_channel=64, depth=64, num_units=2), + get_block(in_channel=64, depth=128, num_units=2), + get_block(in_channel=128, depth=256, num_units=2), + get_block(in_channel=256, depth=512, num_units=2) + ] + elif num_layers == 34: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=6), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=256, num_units=3), + get_block(in_channel=256, depth=512, num_units=8), + get_block(in_channel=512, depth=1024, num_units=36), + get_block(in_channel=1024, depth=2048, num_units=3) + ] + elif num_layers == 200: + blocks = [ + get_block(in_channel=64, depth=256, num_units=3), + get_block(in_channel=256, depth=512, num_units=24), + get_block(in_channel=512, depth=1024, num_units=36), + get_block(in_channel=1024, depth=2048, num_units=3) + ] + + return blocks + + +class Backbone(Module): + + def __init__(self, input_size, num_layers, mode='ir'): + """ Args: + input_size: input_size of backbone + num_layers: num_layers of backbone + mode: support ir or irse + """ + super(Backbone, self).__init__() + assert input_size[0] in [112, 224], \ + 'input_size should be [112, 112] or [224, 224]' + assert num_layers in [18, 34, 50, 100, 152, 200], \ + 'num_layers should be 18, 34, 50, 100 or 152' + assert mode in ['ir', 'ir_se'], \ + 'mode should be ir or ir_se' + self.input_layer = Sequential( + Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), + PReLU(64)) + blocks = get_blocks(num_layers) + if num_layers <= 100: + if mode == 'ir': + unit_module = BasicBlockIR + elif mode == 'ir_se': + unit_module = BasicBlockIRSE + output_channel = 512 + else: + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + output_channel = 2048 + + if input_size[0] == 112: + self.output_layer = Sequential( + BatchNorm2d(output_channel), Dropout(0.4), Flatten(), + Linear(output_channel * 7 * 7, 512), + BatchNorm1d(512, affine=False)) + else: + self.output_layer = Sequential( + BatchNorm2d(output_channel), Dropout(0.4), Flatten(), + Linear(output_channel * 14 * 14, 512), + BatchNorm1d(512, affine=False)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + initialize_weights(self.modules()) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return x + + +def IR_18(input_size): + """ Constructs a ir-18 model. + """ + model = Backbone(input_size, 18, 'ir') + + return model + + +def IR_34(input_size): + """ Constructs a ir-34 model. + """ + model = Backbone(input_size, 34, 'ir') + + return model + + +def IR_50(input_size): + """ Constructs a ir-50 model. + """ + model = Backbone(input_size, 50, 'ir') + + return model + + +def IR_101(input_size): + """ Constructs a ir-101 model. + """ + model = Backbone(input_size, 100, 'ir') + + return model + + +def IR_152(input_size): + """ Constructs a ir-152 model. + """ + model = Backbone(input_size, 152, 'ir') + + return model + + +def IR_200(input_size): + """ Constructs a ir-200 model. + """ + model = Backbone(input_size, 200, 'ir') + + return model + + +def IR_SE_50(input_size): + """ Constructs a ir_se-50 model. + """ + model = Backbone(input_size, 50, 'ir_se') + + return model + + +def IR_SE_101(input_size): + """ Constructs a ir_se-101 model. + """ + model = Backbone(input_size, 100, 'ir_se') + + return model + + +def IR_SE_152(input_size): + """ Constructs a ir_se-152 model. + """ + model = Backbone(input_size, 152, 'ir_se') + + return model + + +def IR_SE_200(input_size): + """ Constructs a ir_se-200 model. + """ + model = Backbone(input_size, 200, 'ir_se') + + return model diff --git a/modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py b/modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py new file mode 100755 index 00000000..568e24ff --- /dev/null +++ b/modelscope/models/cv/face_recognition/torchkit/backbone/model_resnet.py @@ -0,0 +1,162 @@ +# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at +# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_resnet.py +import torch.nn as nn +from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, + MaxPool2d, Module, ReLU, Sequential) + +from .common import initialize_weights + + +def conv3x3(in_planes, out_planes, stride=1): + """ 3x3 convolution with padding + """ + return Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """ 1x1 convolution + """ + return Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class Bottleneck(Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn2 = BatchNorm2d(planes) + self.conv3 = conv1x1(planes, planes * self.expansion) + self.bn3 = BatchNorm2d(planes * self.expansion) + self.relu = ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(Module): + """ ResNet backbone + """ + + def __init__(self, input_size, block, layers, zero_init_residual=True): + """ Args: + input_size: input_size of backbone + block: block function + layers: layers in each block + """ + super(ResNet, self).__init__() + assert input_size[0] in [112, 224],\ + 'input_size should be [112, 112] or [224, 224]' + self.inplanes = 64 + self.conv1 = Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = BatchNorm2d(64) + self.relu = ReLU(inplace=True) + self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + self.bn_o1 = BatchNorm2d(2048) + self.dropout = Dropout() + if input_size[0] == 112: + self.fc = Linear(2048 * 4 * 4, 512) + else: + self.fc = Linear(2048 * 7 * 7, 512) + self.bn_o2 = BatchNorm1d(512) + + initialize_weights(self.modules) + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.bn_o1(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + x = self.bn_o2(x) + + return x + + +def ResNet_50(input_size, **kwargs): + """ Constructs a ResNet-50 model. + """ + model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs) + + return model + + +def ResNet_101(input_size, **kwargs): + """ Constructs a ResNet-101 model. + """ + model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs) + + return model + + +def ResNet_152(input_size, **kwargs): + """ Constructs a ResNet-152 model. + """ + model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs) + + return model diff --git a/modelscope/models/cv/facial_expression_recognition/__init__.py b/modelscope/models/cv/facial_expression_recognition/__init__.py new file mode 100644 index 00000000..35a15d18 --- /dev/null +++ b/modelscope/models/cv/facial_expression_recognition/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .fer import FacialExpressionRecognition + +else: + _import_structure = {'fer': ['FacialExpressionRecognition']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/facial_expression_recognition/fer/__init__.py b/modelscope/models/cv/facial_expression_recognition/fer/__init__.py new file mode 100644 index 00000000..2546035b --- /dev/null +++ b/modelscope/models/cv/facial_expression_recognition/fer/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .facial_expression_recognition import FacialExpressionRecognition diff --git a/modelscope/models/cv/facial_expression_recognition/fer/facial_expression_recognition.py b/modelscope/models/cv/facial_expression_recognition/fer/facial_expression_recognition.py new file mode 100644 index 00000000..c5eb71a1 --- /dev/null +++ b/modelscope/models/cv/facial_expression_recognition/fer/facial_expression_recognition.py @@ -0,0 +1,72 @@ +# The implementation is based on Facial-Expression-Recognition, available at +# https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch +import os + +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +from PIL import Image +from torch.autograd import Variable + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from . import transforms +from .vgg import VGG + + +@MODELS.register_module( + Tasks.facial_expression_recognition, module_name=Models.fer) +class FacialExpressionRecognition(TorchModel): + + def __init__(self, model_path, device='cuda'): + super().__init__(model_path) + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.device = device + self.cfg_path = model_path.replace(ModelFile.TORCH_MODEL_FILE, + ModelFile.CONFIGURATION) + self.net = VGG('VGG19', cfg_path=self.cfg_path) + self.load_model() + self.net = self.net.to(device) + self.transform_test = transforms.Compose([ + transforms.TenCrop(44), + transforms.Lambda(lambda crops: torch.stack( + [transforms.ToTensor()(crop) for crop in crops])), + ]) + + self.mean = np.array([[104, 117, 123]]) + + def load_model(self, load_to_cpu=False): + pretrained_dict = torch.load( + self.model_path, map_location=torch.device('cpu')) + self.net.load_state_dict(pretrained_dict['net'], strict=True) + self.net.eval() + + def forward(self, input): + img = input['img'] + img = cv2.cvtColor(img.cpu().numpy(), cv2.COLOR_BGR2GRAY) + img = cv2.resize(img, (48, 48)) + img = img[:, :, np.newaxis] + img = np.concatenate((img, img, img), axis=2) + + img = Image.fromarray(np.uint8(img)) + inputs = self.transform_test(img) + + ncrops, c, h, w = inputs.shape + + inputs = inputs.view(-1, c, h, w) + inputs = inputs.to(self.device) + inputs = Variable(inputs, volatile=True) + outputs = self.net(inputs) + + outputs_avg = outputs.view(ncrops, -1).mean(0) # avg over crops + + score = F.softmax(outputs_avg) + _, predicted = torch.max(outputs_avg.data, 0) + + return score, predicted diff --git a/modelscope/models/cv/facial_expression_recognition/fer/transforms.py b/modelscope/models/cv/facial_expression_recognition/fer/transforms.py new file mode 100644 index 00000000..a1448c49 --- /dev/null +++ b/modelscope/models/cv/facial_expression_recognition/fer/transforms.py @@ -0,0 +1,118 @@ +# The implementation is based on Facial-Expression-Recognition, available at +# https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch +import numbers +import types + +import numpy as np +import torch +from PIL import Image + + +def to_tensor(pic): + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float().div(255) + else: + return img + + +def center_crop(img, output_size): + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + w, h = img.size + th, tw = output_size + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + return img.crop((j, i, j + tw, i + th)) + + +def five_crop(img, size): + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len( + size) == 2, 'Please provide only two dimensions (h, w) for size.' + + w, h = img.size + crop_h, crop_w = size + if crop_w > w or crop_h > h: + raise ValueError( + 'Requested crop size {} is bigger than input size {}'.format( + size, (h, w))) + tl = img.crop((0, 0, crop_w, crop_h)) + tr = img.crop((w - crop_w, 0, w, crop_h)) + bl = img.crop((0, h - crop_h, crop_w, h)) + br = img.crop((w - crop_w, h - crop_h, w, h)) + center = center_crop(img, (crop_h, crop_w)) + return (tl, tr, bl, br, center) + + +class TenCrop(object): + + def __init__(self, size, vertical_flip=False): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len( + size + ) == 2, 'Please provide only two dimensions (h, w) for size.' + self.size = size + self.vertical_flip = vertical_flip + + def __call__(self, img): + first_five = five_crop(img, self.size) + + if self.vertical_flip: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + else: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + second_five = five_crop(img, self.size) + + return first_five + second_five + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + + +class ToTensor(object): + + def __call__(self, pic): + return to_tensor(pic) + + +class Lambda(object): + + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) diff --git a/modelscope/models/cv/facial_expression_recognition/fer/vgg.py b/modelscope/models/cv/facial_expression_recognition/fer/vgg.py new file mode 100644 index 00000000..8120b6cc --- /dev/null +++ b/modelscope/models/cv/facial_expression_recognition/fer/vgg.py @@ -0,0 +1,40 @@ +# The implementation is based on Facial-Expression-Recognition, available at +# https://github.com/WuJie1010/Facial-Expression-Recognition.Pytorch +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + +from modelscope.utils.config import Config + + +class VGG(nn.Module): + + def __init__(self, vgg_name, cfg_path): + super(VGG, self).__init__() + model_cfg = Config.from_file(cfg_path)['models'] + self.features = self._make_layers(model_cfg[vgg_name]) + self.classifier = nn.Linear(512, 7) + + def forward(self, x): + out = self.features(x) + out = out.view(out.size(0), -1) + out = F.dropout(out, p=0.5, training=self.training) + out = self.classifier(out) + return out + + def _make_layers(self, cfg): + layers = [] + in_channels = 3 + for x in cfg: + if x == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + layers += [ + nn.Conv2d(in_channels, x, kernel_size=3, padding=1), + nn.BatchNorm2d(x), + nn.ReLU(inplace=True) + ] + in_channels = x + layers += [nn.AvgPool2d(kernel_size=1, stride=1)] + return nn.Sequential(*layers) diff --git a/modelscope/models/cv/hand_2d_keypoints/__init__.py b/modelscope/models/cv/hand_2d_keypoints/__init__.py new file mode 100644 index 00000000..2b06f19a --- /dev/null +++ b/modelscope/models/cv/hand_2d_keypoints/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .hand_2d_keypoints import Hand2dKeyPoints + +else: + _import_structure = {'hand_2d_keypoints': ['Hand2dKeyPoints']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/hand_2d_keypoints/hand_2d_keypoints.py b/modelscope/models/cv/hand_2d_keypoints/hand_2d_keypoints.py new file mode 100644 index 00000000..15a97c30 --- /dev/null +++ b/modelscope/models/cv/hand_2d_keypoints/hand_2d_keypoints.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.models.pose import TopDown + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.cv.easycv_base import EasyCVBaseModel +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + group_key=Tasks.hand_2d_keypoints, module_name=Models.hand_2d_keypoints) +class Hand2dKeyPoints(EasyCVBaseModel, TopDown): + + def __init__(self, model_dir=None, *args, **kwargs): + EasyCVBaseModel.__init__(self, model_dir, args, kwargs) + TopDown.__init__(self, *args, **kwargs) diff --git a/modelscope/models/cv/hand_static/__init__.py b/modelscope/models/cv/hand_static/__init__.py new file mode 100644 index 00000000..654d2acb --- /dev/null +++ b/modelscope/models/cv/hand_static/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .hand_model import HandStatic + +else: + _import_structure = {'hand_model': ['HandStatic']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/hand_static/hand_model.py b/modelscope/models/cv/hand_static/hand_model.py new file mode 100644 index 00000000..7a8a323e --- /dev/null +++ b/modelscope/models/cv/hand_static/hand_model.py @@ -0,0 +1,93 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import sys + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch import nn +from torchvision import transforms + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .networks import StaticGestureNet + +logger = get_logger() + +map_idx = { + 0: 'unrecog', + 1: 'one', + 2: 'two', + 3: 'bixin', + 4: 'yaogun', + 5: 'zan', + 6: 'fist', + 7: 'ok', + 8: 'tuoju', + 9: 'd_bixin', + 10: 'd_fist_left', + 11: 'd_fist_right', + 12: 'd_hand', + 13: 'fashe', + 14: 'five', + 15: 'nohand' +} + +img_size = [112, 112] + +spatial_transform = transforms.Compose([ + transforms.Resize(img_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +]) + + +@MODELS.register_module(Tasks.hand_static, module_name=Models.hand_static) +class HandStatic(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = StaticGestureNet() + if torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + self.params = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=self.device) + + self.model.load_state_dict(self.params) + self.model.to(self.device) + self.model.eval() + self.device_id = device_id + if self.device_id >= 0 and self.device == 'cuda': + self.model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device_id = -1 + logger.info('Use CPU for inference') + + def forward(self, x): + pred_result = self.model(x) + return pred_result + + +def infer(img, model, device): + img = img.cpu().numpy() + img = Image.fromarray(img) + clip = spatial_transform(img) + clip = clip.unsqueeze(0).to(device).float() + outputs = model(clip) + predicted = int(outputs.max(1)[1]) + pred_result = map_idx.get(predicted) + logger.info('pred result: {}'.format(pred_result)) + + return pred_result diff --git a/modelscope/models/cv/hand_static/networks.py b/modelscope/models/cv/hand_static/networks.py new file mode 100644 index 00000000..6cf46f5d --- /dev/null +++ b/modelscope/models/cv/hand_static/networks.py @@ -0,0 +1,358 @@ +""" HandStatic +The implementation here is modified based on MobileFaceNet, +originally Apache 2.0 License and publicly avaialbe at https://github.com/xuexingyu24/MobileFaceNet_Tutorial_Pytorch +""" + +import os + +import torch +import torch.nn as nn +import torchvision +import torchvision.models as models +from torch.nn import (AdaptiveAvgPool2d, BatchNorm1d, BatchNorm2d, Conv2d, + Dropout, Linear, MaxPool2d, Module, PReLU, ReLU, + Sequential, Sigmoid) + + +class StaticGestureNet(torch.nn.Module): + + def __init__(self, train=True): + super().__init__() + + model = MobileFaceNet(512) + self.feature_extractor = model + self.fc_layer = torch.nn.Sequential( + nn.Linear(512, 128), nn.Softplus(), nn.Linear(128, 15)) + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs): + out = self.feature_extractor(inputs) + out = self.fc_layer(out) + out = self.sigmoid(out) + return out + + +class Flatten(Module): + + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class SEModule(Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d( + channels, + channels // reduction, + kernel_size=1, + padding=0, + bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d( + channels // reduction, + channels, + kernel_size=1, + padding=0, + bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class BottleneckIR(Module): + + def __init__(self, in_channel, depth, stride): + super(BottleneckIR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class BottleneckIRSE(Module): + + def __init__(self, in_channel, depth, stride): + super(BottleneckIRSE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), SEModule(depth, 16)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride) + ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + return blocks + + +class Backbone(Module): + + def __init__(self, num_layers, drop_ratio, mode='ir'): + super(Backbone, self).__init__() + assert num_layers in [50, 100, + 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + self.input_layer = Sequential( + Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), + PReLU(64)) + self.output_layer = Sequential( + BatchNorm2d(512), Dropout(drop_ratio), Flatten(), + Linear(512 * 7 * 7, 512), BatchNorm1d(512)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +class ConvBlock(Module): + + def __init__(self, + in_c, + out_c, + kernel=(1, 1), + stride=(1, 1), + padding=(0, 0), + groups=1): + super(ConvBlock, self).__init__() + self.conv = Conv2d( + in_c, + out_channels=out_c, + kernel_size=kernel, + groups=groups, + stride=stride, + padding=padding, + bias=False) + self.bn = BatchNorm2d(out_c) + self.prelu = PReLU(out_c) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.prelu(x) + return x + + +class LinearBlock(Module): + + def __init__(self, + in_c, + out_c, + kernel=(1, 1), + stride=(1, 1), + padding=(0, 0), + groups=1): + super(LinearBlock, self).__init__() + self.conv = Conv2d( + in_c, + out_channels=out_c, + kernel_size=kernel, + groups=groups, + stride=stride, + padding=padding, + bias=False) + self.bn = BatchNorm2d(out_c) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class DepthWise(Module): + + def __init__(self, + in_c, + out_c, + residual=False, + kernel=(3, 3), + stride=(2, 2), + padding=(1, 1), + groups=1): + super(DepthWise, self).__init__() + self.conv = ConvBlock( + in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) + self.conv_dw = ConvBlock( + groups, + groups, + groups=groups, + kernel=kernel, + padding=padding, + stride=stride) + self.project = LinearBlock( + groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) + self.residual = residual + + def forward(self, x): + if self.residual: + short_cut = x + x = self.conv(x) + x = self.conv_dw(x) + x = self.project(x) + if self.residual: + output = short_cut + x + else: + output = x + return output + + +class Residual(Module): + + def __init__(self, + c, + num_block, + groups, + kernel=(3, 3), + stride=(1, 1), + padding=(1, 1)): + super(Residual, self).__init__() + modules = [] + for _ in range(num_block): + modules.append( + DepthWise( + c, + c, + residual=True, + kernel=kernel, + padding=padding, + stride=stride, + groups=groups)) + self.model = Sequential(*modules) + + def forward(self, x): + return self.model(x) + + +class MobileFaceNet(Module): + + def __init__(self, embedding_size): + super(MobileFaceNet, self).__init__() + self.conv1 = ConvBlock( + 3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) + self.conv2_dw = ConvBlock( + 64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) + self.conv_23 = DepthWise( + 64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128) + self.conv_3 = Residual( + 64, + num_block=4, + groups=128, + kernel=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.conv_34 = DepthWise( + 64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256) + self.conv_4 = Residual( + 128, + num_block=6, + groups=256, + kernel=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.conv_45 = DepthWise( + 128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512) + self.conv_5 = Residual( + 128, + num_block=2, + groups=256, + kernel=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.conv_6_sep = ConvBlock( + 128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) + self.conv_6_dw = LinearBlock( + 512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)) + self.conv_6_flatten = Flatten() + self.linear = Linear(512, embedding_size, bias=False) + self.bn = BatchNorm1d(embedding_size) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2_dw(out) + out = self.conv_23(out) + out = self.conv_3(out) + out = self.conv_34(out) + out = self.conv_4(out) + out = self.conv_45(out) + out = self.conv_5(out) + out = self.conv_6_sep(out) + out = self.conv_6_dw(out) + out = self.conv_6_flatten(out) + out = self.linear(out) + return l2_norm(out) diff --git a/modelscope/models/cv/human_wholebody_keypoint/__init__.py b/modelscope/models/cv/human_wholebody_keypoint/__init__.py new file mode 100644 index 00000000..30e23457 --- /dev/null +++ b/modelscope/models/cv/human_wholebody_keypoint/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .human_wholebody_keypoint import HumanWholeBodyKeypoint + +else: + _import_structure = { + 'human_wholebody_keypoint': ['HumanWholeBodyKeypoint'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py b/modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py new file mode 100644 index 00000000..dd3c0290 --- /dev/null +++ b/modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py @@ -0,0 +1,17 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.models.pose.top_down import TopDown + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.cv.easycv_base import EasyCVBaseModel +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + group_key=Tasks.human_wholebody_keypoint, + module_name=Models.human_wholebody_keypoint) +class HumanWholeBodyKeypoint(EasyCVBaseModel, TopDown): + + def __init__(self, model_dir=None, *args, **kwargs): + EasyCVBaseModel.__init__(self, model_dir, args, kwargs) + TopDown.__init__(self, *args, **kwargs) diff --git a/modelscope/models/cv/image_body_reshaping/__init__.py b/modelscope/models/cv/image_body_reshaping/__init__.py new file mode 100644 index 00000000..a04f110d --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_body_reshaping import ImageBodyReshaping + +else: + _import_structure = {'image_body_reshaping': ['ImageBodyReshaping']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_body_reshaping/image_body_reshaping.py b/modelscope/models/cv/image_body_reshaping/image_body_reshaping.py new file mode 100644 index 00000000..4aed8d98 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/image_body_reshaping.py @@ -0,0 +1,128 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .model import FlowGenerator +from .person_info import PersonInfo +from .pose_estimator.body import Body +from .slim_utils import image_warp_grid1, resize_on_long_side + +logger = get_logger() + +__all__ = ['ImageBodyReshaping'] + + +@MODELS.register_module( + Tasks.image_body_reshaping, module_name=Models.image_body_reshaping) +class ImageBodyReshaping(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the image body reshaping model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + + self.degree = 1.0 + self.reshape_model = FlowGenerator(n_channels=16).to(self.device) + model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + checkpoints = torch.load(model_path, map_location=torch.device('cpu')) + self.reshape_model.load_state_dict( + checkpoints['state_dict'], strict=True) + self.reshape_model.eval() + logger.info('load body reshaping model done') + + pose_model_ckpt = os.path.join(model_dir, 'body_pose_model.pth') + self.pose_esti = Body(pose_model_ckpt, self.device) + logger.info('load pose model done') + + def pred_joints(self, img): + if img is None: + return None + small_src, resize_scale = resize_on_long_side(img, 300) + body_joints = self.pose_esti(small_src) + + if body_joints.shape[0] >= 1: + body_joints[:, :, :2] = body_joints[:, :, :2] / resize_scale + + return body_joints + + def pred_flow(self, img): + + body_joints = self.pred_joints(img) + small_size = 1200 + + if img.shape[0] > small_size or img.shape[1] > small_size: + _img, _scale = resize_on_long_side(img, small_size) + body_joints[:, :, :2] = body_joints[:, :, :2] * _scale + else: + _img = img + + # We only reshape one person + if body_joints.shape[0] < 1 or body_joints.shape[0] > 1: + return None + + person = PersonInfo(body_joints[0]) + + with torch.no_grad(): + person_pred = person.pred_flow(_img, self.reshape_model, + self.device) + + flow = np.dstack((person_pred['rDx'], person_pred['rDy'])) + + scale = img.shape[0] * 1.0 / flow.shape[0] + + flow = cv2.resize(flow, (img.shape[1], img.shape[0])) + flow *= scale + + return flow + + def warp(self, src_img, flow): + + X_flow = flow[..., 0] + Y_flow = flow[..., 1] + + X_flow = np.ascontiguousarray(X_flow) + Y_flow = np.ascontiguousarray(Y_flow) + + pred = image_warp_grid1(X_flow, Y_flow, src_img, 1.0, 0, 0) + return pred + + def inference(self, img): + img = img.cpu().numpy() + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + flow = self.pred_flow(img) + + if flow is None: + return img + + assert flow.shape[:2] == img.shape[:2] + + mag, ang = cv2.cartToPolar(flow[..., 0] + 1e-8, flow[..., 1] + 1e-8) + mag -= 3 + mag[mag <= 0] = 0 + + x, y = cv2.polarToCart(mag, ang, angleInDegrees=False) + flow = np.dstack((x, y)) + + flow *= self.degree + pred = self.warp(img, flow) + out_img = np.clip(pred, 0, 255) + logger.info('model inference done') + + return out_img.astype(np.uint8) diff --git a/modelscope/models/cv/image_body_reshaping/model.py b/modelscope/models/cv/image_body_reshaping/model.py new file mode 100644 index 00000000..174428a1 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/model.py @@ -0,0 +1,189 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvLayer(nn.Module): + + def __init__(self, in_ch, out_ch): + super(ConvLayer, self).__init__() + + self.conv = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=0), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) + + def forward(self, x): + x = self.conv(x) + return x + + +class SASA(nn.Module): + + def __init__(self, in_dim): + super(SASA, self).__init__() + self.chanel_in = in_dim + + self.query_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.key_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.value_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.mag_conv = nn.Conv2d( + in_channels=5, out_channels=in_dim // 32, kernel_size=1) + + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + self.sigmoid = nn.Sigmoid() + + def structure_encoder(self, paf_mag, target_height, target_width): + torso_mask = torch.sum(paf_mag[:, 1:3, :, :], dim=1, keepdim=True) + torso_mask = torch.clamp(torso_mask, 0, 1) + + arms_mask = torch.sum(paf_mag[:, 4:8, :, :], dim=1, keepdim=True) + arms_mask = torch.clamp(arms_mask, 0, 1) + + legs_mask = torch.sum(paf_mag[:, 8:12, :, :], dim=1, keepdim=True) + legs_mask = torch.clamp(legs_mask, 0, 1) + + fg_mask = paf_mag[:, 12, :, :].unsqueeze(1) + bg_mask = 1 - fg_mask + Y = torch.cat((arms_mask, torso_mask, legs_mask, fg_mask, bg_mask), + dim=1) + Y = F.interpolate(Y, size=(target_height, target_width), mode='area') + return Y + + def forward(self, X, PAF_mag): + """extract self-attention features. + Args: + X : input feature maps( B x C x H x W) + PAF_mag : ( B x C x H x W), 1 denotes connectivity, 0 denotes non-connectivity + + Returns: + out : self attention value + input feature + Y: B X N X N (N is Width*Height) + """ + + m_batchsize, C, height, width = X.size() + + Y = self.structure_encoder(PAF_mag, height, width) + + connectivity_mask_vec = self.mag_conv(Y).view(m_batchsize, -1, + width * height) + affinity = torch.bmm( + connectivity_mask_vec.permute(0, 2, 1), connectivity_mask_vec) + affinity_centered = affinity - torch.mean(affinity) + affinity_sigmoid = self.sigmoid(affinity_centered) + + proj_query = self.query_conv(X).view(m_batchsize, -1, + width * height).permute(0, 2, 1) + proj_key = self.key_conv(X).view(m_batchsize, -1, width * height) + selfatten_map = torch.bmm(proj_query, proj_key) + selfatten_centered = selfatten_map - torch.mean( + selfatten_map) # centering + selfatten_sigmoid = self.sigmoid(selfatten_centered) + + SASA_map = selfatten_sigmoid * affinity_sigmoid + + proj_value = self.value_conv(X).view(m_batchsize, -1, width * height) + + out = torch.bmm(proj_value, SASA_map.permute(0, 2, 1)) + out = out.view(m_batchsize, C, height, width) + + out = self.gamma * out + X + return out, Y + + +class FlowGenerator(nn.Module): + + def __init__(self, n_channels, deep_supervision=False): + super(FlowGenerator, self).__init__() + self.deep_supervision = deep_supervision + + self.Encoder = nn.Sequential( + ConvLayer(n_channels, 64), + ConvLayer(64, 64), + nn.MaxPool2d(2), + ConvLayer(64, 128), + ConvLayer(128, 128), + nn.MaxPool2d(2), + ConvLayer(128, 256), + ConvLayer(256, 256), + nn.MaxPool2d(2), + ConvLayer(256, 512), + ConvLayer(512, 512), + nn.MaxPool2d(2), + ConvLayer(512, 1024), + ConvLayer(1024, 1024), + ConvLayer(1024, 1024), + ConvLayer(1024, 1024), + ConvLayer(1024, 1024), + ) + + self.SASA = SASA(in_dim=1024) + + self.Decoder = nn.Sequential( + ConvLayer(1024, 1024), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ConvLayer(1024, 512), + ConvLayer(512, 512), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ConvLayer(512, 256), + ConvLayer(256, 256), + ConvLayer(256, 128), + ConvLayer(128, 64), + ConvLayer(64, 32), + nn.Conv2d(32, 2, kernel_size=1, padding=0), + nn.Tanh(), + nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True), + ) + + dilation_ksize = 17 + self.dilation = torch.nn.MaxPool2d( + kernel_size=dilation_ksize, + stride=1, + padding=int((dilation_ksize - 1) / 2)) + + def warp(self, x, flow, mode='bilinear', padding_mode='zeros', coff=0.2): + n, c, h, w = x.size() + yv, xv = torch.meshgrid([torch.arange(h), torch.arange(w)]) + xv = xv.float() / (w - 1) * 2.0 - 1 + yv = yv.float() / (h - 1) * 2.0 - 1 + grid = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), -1).unsqueeze(0) + grid = grid.to(flow.device) + grid_x = grid + 2 * flow * coff + warp_x = F.grid_sample(x, grid_x, mode=mode, padding_mode=padding_mode) + return warp_x + + def forward(self, img, skeleton_map, coef=0.2): + """extract self-attention features. + Args: + img : input numpy image + skeleton_map : skeleton map of input image + coef: warp degree + + Returns: + warp_x : warped image + flow: predicted flow + """ + + img_concat = torch.cat((img, skeleton_map), dim=1) + X = self.Encoder(img_concat) + + _, _, height, width = X.size() + + # directly get PAF magnitude from skeleton maps via dilation + PAF_mag = self.dilation((skeleton_map + 1.0) * 0.5) + + out, Y = self.SASA(X, PAF_mag) + flow = self.Decoder(out) + + flow = flow.permute(0, 2, 3, 1) # [n, 2, h, w] ==> [n, h, w, 2] + + warp_x = self.warp(img, flow, coff=coef) + warp_x = torch.clamp(warp_x, min=-1.0, max=1.0) + + return warp_x, flow diff --git a/modelscope/models/cv/image_body_reshaping/person_info.py b/modelscope/models/cv/image_body_reshaping/person_info.py new file mode 100644 index 00000000..509a2ce3 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/person_info.py @@ -0,0 +1,339 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy + +import cv2 +import numpy as np +import torch + +from .slim_utils import (enlarge_box_tblr, gen_skeleton_map, + get_map_fusion_map_cuda, get_mask_bbox, + resize_on_long_side) + + +class PersonInfo(object): + + def __init__(self, joints): + self.joints = joints + self.flow = None + self.pad_boder = False + self.height_expand = 0 + self.width_expand = 0 + self.coeff = 0.2 + self.network_input_W = 256 + self.network_input_H = 256 + self.divider = 20 + self.flow_scales = ['upper_2'] + + def update_attribute(self, pad_boder, height_expand, width_expand): + self.pad_boder = pad_boder + self.height_expand = height_expand + self.width_expand = width_expand + if pad_boder: + self.joints[:, 0] += width_expand + self.joints[:, 1] += height_expand + + def pred_flow(self, img, flow_net, device): + with torch.no_grad(): + if img is None: + print('image is none') + self.flow = None + + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if self.pad_boder: + height_expand = self.height_expand + width_expand = self.width_expand + pad_img = cv2.copyMakeBorder( + img, + height_expand, + height_expand, + width_expand, + width_expand, + cv2.BORDER_CONSTANT, + value=(127, 127, 127)) + + else: + height_expand = 0 + width_expand = 0 + pad_img = img.copy() + + canvas = np.zeros( + shape=(pad_img.shape[0], pad_img.shape[1]), dtype=np.float32) + + self.human_joint_box = self.__joint_to_body_box() + + self.human_box = enlarge_box_tblr( + self.human_joint_box, pad_img, ratio=0.25) + human_box_height = self.human_box[1] - self.human_box[0] + human_box_width = self.human_box[3] - self.human_box[2] + + self.leg_joint_box = self.__joint_to_leg_box() + self.leg_box = enlarge_box_tblr( + self.leg_joint_box, pad_img, ratio=0.25) + + self.arm_joint_box = self.__joint_to_arm_box() + self.arm_box = enlarge_box_tblr( + self.arm_joint_box, pad_img, ratio=0.1) + + x_flows = [] + y_flows = [] + multi_bbox = [] + + for scale in self.flow_scales: # better for metric + scale_value = float(scale.split('_')[-1]) + + arm_box = copy.deepcopy(self.arm_box) + + if arm_box[0] is None: + arm_box = self.human_box + + arm_box_height = arm_box[1] - arm_box[0] + arm_box_width = arm_box[3] - arm_box[2] + + roi_bbox = None + + if arm_box_width < human_box_width * 0.1 or arm_box_height < human_box_height * 0.1: + roi_bbox = self.human_box + else: + arm_box = enlarge_box_tblr( + arm_box, pad_img, ratio=scale_value) + if scale == 'upper_0.2': + arm_box[0] = min(arm_box[0], int(self.joints[0][1])) + if scale.startswith('upper'): + roi_bbox = [ + max(self.human_box[0], arm_box[0]), + min(self.human_box[1], arm_box[1]), + max(self.human_box[2], arm_box[2]), + min(self.human_box[3], arm_box[3]) + ] + if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ + 3] - roi_bbox[2] < 1: + continue + + elif scale.startswith('lower'): + roi_bbox = [ + max(self.human_box[0], self.leg_box[0]), + min(self.human_box[1], self.leg_box[1]), + max(self.human_box[2], self.leg_box[2]), + min(self.human_box[3], self.leg_box[3]) + ] + + if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ + 3] - roi_bbox[2] < 1: + continue + + skel_map, roi_bbox = gen_skeleton_map( + self.joints, 'depth', input_roi_box=roi_bbox) + + if roi_bbox is None: + continue + + if skel_map.dtype != np.float32: + skel_map = skel_map.astype(np.float32) + + skel_map -= 1.0 # [0,2] ->[-1,1] + + multi_bbox.append(roi_bbox) + + roi_bbox_height = roi_bbox[1] - roi_bbox[0] + roi_bbox_width = roi_bbox[3] - roi_bbox[2] + + assert skel_map.shape[0] == roi_bbox_height + assert skel_map.shape[1] == roi_bbox_width + roi_height_pad = roi_bbox_height // self.divider + roi_width_pad = roi_bbox_width // self.divider + paded_roi_h = roi_bbox_height + 2 * roi_height_pad + paded_roi_w = roi_bbox_width + 2 * roi_width_pad + + roi_height_pad_joint = skel_map.shape[0] // self.divider + roi_width_pad_joint = skel_map.shape[1] // self.divider + skel_map = np.pad( + skel_map, + ((roi_height_pad_joint, roi_height_pad_joint), + (roi_width_pad_joint, roi_width_pad_joint), (0, 0)), + 'constant', + constant_values=-1) + + skel_map_resized = cv2.resize( + skel_map, (self.network_input_W, self.network_input_H)) + + skel_map_resized[skel_map_resized < 0] = -1.0 + skel_map_resized[skel_map_resized > -0.5] = 1.0 + skel_map_transformed = torch.from_numpy( + skel_map_resized.transpose((2, 0, 1))) + + roi_npy = pad_img[roi_bbox[0]:roi_bbox[1], + roi_bbox[2]:roi_bbox[3], :].copy() + if roi_npy.dtype != np.float32: + roi_npy = roi_npy.astype(np.float32) + + roi_npy = np.pad(roi_npy, + ((roi_height_pad, roi_height_pad), + (roi_width_pad, roi_width_pad), (0, 0)), + 'edge') + + roi_npy = roi_npy[:, :, ::-1] + + roi_npy = cv2.resize( + roi_npy, (self.network_input_W, self.network_input_H)) + + roi_npy *= 1.0 / 255 + roi_npy -= 0.5 + roi_npy *= 2 + + rgb_tensor = torch.from_numpy(roi_npy.transpose((2, 0, 1))) + + rgb_tensor = rgb_tensor.unsqueeze(0).to(device) + skel_map_tensor = skel_map_transformed.unsqueeze(0).to(device) + warped_img_val, flow_field_val = flow_net( + rgb_tensor, skel_map_tensor + ) # inference, connectivity_mask [1,12,16,16] + flow_field_val = flow_field_val.detach().squeeze().cpu().numpy( + ) + + flow_field_val = cv2.resize( + flow_field_val, (paded_roi_w, paded_roi_h), + interpolation=cv2.INTER_LINEAR) + flow_field_val[..., 0] = flow_field_val[ + ..., 0] * paded_roi_w * 0.5 * 2 * self.coeff + flow_field_val[..., 1] = flow_field_val[ + ..., 1] * paded_roi_h * 0.5 * 2 * self.coeff + + # remove pad areas + flow_field_val = flow_field_val[ + roi_height_pad:flow_field_val.shape[0] - roi_height_pad, + roi_width_pad:flow_field_val.shape[1] - roi_width_pad, :] + + diffuse_width = max(roi_bbox_width // 3, 1) + diffuse_height = max(roi_bbox_height // 3, 1) + assert roi_bbox_width == flow_field_val.shape[1] + assert roi_bbox_height == flow_field_val.shape[0] + + origin_flow = np.zeros( + (pad_img.shape[0] + 2 * diffuse_height, + pad_img.shape[1] + 2 * diffuse_width, 2), + dtype=np.float32) + + flow_field_val = np.pad(flow_field_val, + ((diffuse_height, diffuse_height), + (diffuse_width, diffuse_width), + (0, 0)), 'linear_ramp') + + origin_flow[roi_bbox[0]:roi_bbox[1] + 2 * diffuse_height, + roi_bbox[2]:roi_bbox[3] + + 2 * diffuse_width] = flow_field_val + + origin_flow = origin_flow[diffuse_height:-diffuse_height, + diffuse_width:-diffuse_width, :] + + x_flows.append(origin_flow[..., 0]) + y_flows.append(origin_flow[..., 1]) + + if len(x_flows) == 0: + return { + 'rDx': np.zeros(canvas.shape[:2], dtype=np.float32), + 'rDy': np.zeros(canvas.shape[:2], dtype=np.float32), + 'multi_bbox': multi_bbox, + 'x_fusion_map': + np.ones(canvas.shape[:2], dtype=np.float32), + 'y_fusion_map': + np.ones(canvas.shape[:2], dtype=np.float32) + } + else: + origin_rDx, origin_rDy, x_fusion_map, y_fusion_map = self.blend_multiscale_flow( + x_flows, y_flows, device=device) + + return { + 'rDx': origin_rDx, + 'rDy': origin_rDy, + 'multi_bbox': multi_bbox, + 'x_fusion_map': x_fusion_map, + 'y_fusion_map': y_fusion_map + } + + @staticmethod + def blend_multiscale_flow(x_flows, y_flows, device=None): + scale_num = len(x_flows) + if scale_num == 1: + return x_flows[0], y_flows[0], np.ones_like( + x_flows[0]), np.ones_like(x_flows[0]) + + origin_rDx = np.zeros((x_flows[0].shape[0], x_flows[0].shape[1]), + dtype=np.float32) + origin_rDy = np.zeros((y_flows[0].shape[0], y_flows[0].shape[1]), + dtype=np.float32) + + x_fusion_map, x_acc_map = get_map_fusion_map_cuda( + x_flows, 1, device=device) + y_fusion_map, y_acc_map = get_map_fusion_map_cuda( + y_flows, 1, device=device) + + x_flow_map = 1.0 / x_fusion_map + y_flow_map = 1.0 / y_fusion_map + + all_acc_map = x_acc_map + y_acc_map + all_acc_map = all_acc_map.astype(np.uint8) + roi_box = get_mask_bbox(all_acc_map, threshold=1) + + if roi_box[0] is None or roi_box[1] - roi_box[0] <= 0 or roi_box[ + 3] - roi_box[2] <= 0: + roi_box = [0, x_flow_map.shape[0], 0, x_flow_map.shape[1]] + + roi_x_flow_map = x_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] + roi_y_flow_map = y_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] + + roi_width = roi_x_flow_map.shape[1] + roi_height = roi_x_flow_map.shape[0] + + roi_x_flow_map, scale = resize_on_long_side(roi_x_flow_map, 320) + roi_y_flow_map, scale = resize_on_long_side(roi_y_flow_map, 320) + + roi_x_flow_map = cv2.blur(roi_x_flow_map, (55, 55)) + roi_y_flow_map = cv2.blur(roi_y_flow_map, (55, 55)) + + roi_x_flow_map = cv2.resize(roi_x_flow_map, (roi_width, roi_height)) + roi_y_flow_map = cv2.resize(roi_y_flow_map, (roi_width, roi_height)) + + x_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] = roi_x_flow_map + y_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] = roi_y_flow_map + + for i in range(scale_num): + origin_rDx += x_flows[i] + origin_rDy += y_flows[i] + + origin_rDx *= x_flow_map + origin_rDy *= y_flow_map + + return origin_rDx, origin_rDy, x_flow_map, y_flow_map + + def __joint_to_body_box(self): + joint_left = int(np.min(self.joints, axis=0)[0]) + joint_right = int(np.max(self.joints, axis=0)[0]) + joint_top = int(np.min(self.joints, axis=0)[1]) + joint_bottom = int(np.max(self.joints, axis=0)[1]) + return [joint_top, joint_bottom, joint_left, joint_right] + + def __joint_to_leg_box(self): + leg_joints = self.joints[8:, :] + if np.max(leg_joints, axis=0)[2] < 0.05: + return [0, 0, 0, 0] + joint_left = int(np.min(leg_joints, axis=0)[0]) + joint_right = int(np.max(leg_joints, axis=0)[0]) + joint_top = int(np.min(leg_joints, axis=0)[1]) + joint_bottom = int(np.max(leg_joints, axis=0)[1]) + return [joint_top, joint_bottom, joint_left, joint_right] + + def __joint_to_arm_box(self): + arm_joints = self.joints[2:8, :] + if np.max(arm_joints, axis=0)[2] < 0.05: + return [0, 0, 0, 0] + joint_left = int(np.min(arm_joints, axis=0)[0]) + joint_right = int(np.max(arm_joints, axis=0)[0]) + joint_top = int(np.min(arm_joints, axis=0)[1]) + joint_bottom = int(np.max(arm_joints, axis=0)[1]) + return [joint_top, joint_bottom, joint_left, joint_right] diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/__init__.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/body.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/body.py new file mode 100644 index 00000000..45b02724 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/pose_estimator/body.py @@ -0,0 +1,272 @@ +# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. + +import math + +import cv2 +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter + +from .model import BodyposeModel +from .util import pad_rightdown_corner, transfer + + +class Body(object): + + def __init__(self, model_path, device): + self.model = BodyposeModel().to(device) + model_dict = transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def __call__(self, oriImg): + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + bodyparts = 18 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = cv2.resize( + oriImg, (0, 0), + fx=scale, + fy=scale, + interpolation=cv2.INTER_CUBIC) + imageToTest_padded, pad = pad_rightdown_corner( + imageToTest, stride, padValue) + im = np.transpose( + np.float32(imageToTest_padded[:, :, :, np.newaxis]), + (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + if torch.cuda.is_available(): + data = data.cuda() + with torch.no_grad(): + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), + (1, 2, 0)) # output 1 is heatmaps + heatmap = cv2.resize( + heatmap, (0, 0), + fx=stride, + fy=stride, + interpolation=cv2.INTER_CUBIC) + heatmap = heatmap[:imageToTest_padded.shape[0] + - pad[2], :imageToTest_padded.shape[1] + - pad[3], :] + heatmap = cv2.resize( + heatmap, (oriImg.shape[1], oriImg.shape[0]), + interpolation=cv2.INTER_CUBIC) + + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), + (1, 2, 0)) # output 0 is PAFs + paf = cv2.resize( + paf, (0, 0), + fx=stride, + fy=stride, + interpolation=cv2.INTER_CUBIC) + paf = paf[:imageToTest_padded.shape[0] + - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = cv2.resize( + paf, (oriImg.shape[1], oriImg.shape[0]), + interpolation=cv2.INTER_CUBIC) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += +paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(bodyparts): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, + one_heatmap >= map_up, one_heatmap >= map_down, + one_heatmap > thre1)) + peaks = list( + zip(np.nonzero(peaks_binary)[1], + np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [ + peaks_with_score[i] + (peak_id[i], ) + for i in range(len(peak_id)) + ] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], + [9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], + [1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], + [19, 20], [21, 22], [23, 24], [25, 26], [27, 28], [29, 30], + [47, 48], [49, 50], [53, 54], [51, 52], [55, 56], [37, 38], + [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list( + zip( + np.linspace( + candA[i][0], candB[j][0], num=mid_num), + np.linspace( + candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([ + score_mid[int(round(startend[item][1])), + int(round(startend[item][0])), 0] + for item in range(len(startend)) + ]) + vec_y = np.array([ + score_mid[int(round(startend[item][1])), + int(round(startend[item][0])), 1] + for item in range(len(startend)) + ]) + + score_midpts = np.multiply( + vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + temp1 = sum(score_midpts) / len(score_midpts) + temp2 = min(0.5 * oriImg.shape[0] / norm - 1, 0) + score_with_dist_prior = temp1 + temp2 + criterion1 = len(np.nonzero( + score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append([ + i, j, score_with_dist_prior, + score_with_dist_prior + candA[i][2] + + candB[j][2] + ]) + + connection_candidate = sorted( + connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] + and j not in connection[:, 4]): + connection = np.vstack( + [connection, [candA[i][3], candB[j][3], s, i, j]]) + if (len(connection) >= min(nA, nB)): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array( + [item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][ + indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[ + partBs[i].astype(int), + 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + tmp1 = (subset[j1] >= 0).astype(int) + tmp2 = (subset[j2] >= 0).astype(int) + membership = (tmp1 + tmp2)[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[ + partBs[i].astype(int), + 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum( + candidate[connection_all[k][i, :2].astype(int), + 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + count = subset.shape[0] + joints = np.zeros(shape=(count, bodyparts, 3)) + + for i in range(count): + for j in range(bodyparts): + joints[i, j, :3] = candidate[int(subset[i, j]), :3] + confidence = 1.0 if subset[i, j] >= 0 else 0.0 + joints[i, j, 2] *= confidence + return joints diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/model.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/model.py new file mode 100644 index 00000000..12f6e84d --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/pose_estimator/model.py @@ -0,0 +1,141 @@ +# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. + +from collections import OrderedDict + +import torch +import torch.nn as nn + + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d( + in_channels=v[0], + out_channels=v[1], + kernel_size=v[2], + stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_' + layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + + +class BodyposeModel(nn.Module): + + def __init__(self): + super(BodyposeModel, self).__init__() + + # these layers have no relu layer + no_relu_layers = [ + 'conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1', + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2', + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1', + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1' + ] + blocks = {} + block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1])]) + + # Stage 1 + block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])]) + + block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/util.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/util.py new file mode 100644 index 00000000..13a42074 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/pose_estimator/util.py @@ -0,0 +1,33 @@ +# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. +import numpy as np + + +def pad_rightdown_corner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join( + weights_name.split('.')[1:])] + return transfered_model_weights diff --git a/modelscope/models/cv/image_body_reshaping/slim_utils.py b/modelscope/models/cv/image_body_reshaping/slim_utils.py new file mode 100644 index 00000000..23d5a741 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/slim_utils.py @@ -0,0 +1,507 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +import random + +import cv2 +import numba +import numpy as np +import torch + + +def resize_on_long_side(img, long_side=800): + src_height = img.shape[0] + src_width = img.shape[1] + + if src_height > src_width: + scale = long_side * 1.0 / src_height + _img = cv2.resize( + img, (int(src_width * scale), long_side), + interpolation=cv2.INTER_LINEAR) + else: + scale = long_side * 1.0 / src_width + _img = cv2.resize( + img, (long_side, int(src_height * scale)), + interpolation=cv2.INTER_LINEAR) + + return _img, scale + + +def point_in_box(pt, box): + pt_x = pt[0] + pt_y = pt[1] + + if pt_x >= box[0] and pt_x <= box[0] + box[2] and pt_y >= box[ + 1] and pt_y <= box[1] + box[3]: + return True + else: + return False + + +def enlarge_box_tblr(roi_bbox, mask, ratio=0.4, use_long_side=True): + if roi_bbox is None or None in roi_bbox: + return [None, None, None, None] + + top = roi_bbox[0] + bottom = roi_bbox[1] + left = roi_bbox[2] + right = roi_bbox[3] + + roi_width = roi_bbox[3] - roi_bbox[2] + roi_height = roi_bbox[1] - roi_bbox[0] + right = left + roi_width + bottom = top + roi_height + + long_side = roi_width if roi_width > roi_height else roi_height + + if use_long_side: + new_left = left - int(long_side * ratio) + else: + new_left = left - int(roi_width * ratio) + new_left = 1 if new_left < 0 else new_left + + if use_long_side: + new_top = top - int(long_side * ratio) + else: + new_top = top - int(roi_height * ratio) + new_top = 1 if new_top < 0 else new_top + + if use_long_side: + new_right = right + int(long_side * ratio) + else: + new_right = right + int(roi_width * ratio) + new_right = mask.shape[1] - 2 if new_right > mask.shape[1] else new_right + + if use_long_side: + new_bottom = bottom + int(long_side * ratio) + else: + new_bottom = bottom + int(roi_height * ratio) + new_bottom = mask.shape[0] - 2 if new_bottom > mask.shape[0] else new_bottom + + bbox = [new_top, new_bottom, new_left, new_right] + return bbox + + +def gen_PAF(image, joints): + + assert joints.shape[0] == 18 + assert joints.shape[1] == 3 + + org_h = image.shape[0] + org_w = image.shape[1] + small_image, resize_scale = resize_on_long_side(image, 120) + + joints[:, :2] = joints[:, :2] * resize_scale + + joint_left = int(np.min(joints, axis=0)[0]) + joint_right = int(np.max(joints, axis=0)[0]) + joint_top = int(np.min(joints, axis=0)[1]) + joint_bottom = int(np.max(joints, axis=0)[1]) + + limb_width = min( + abs(joint_right - joint_left), abs(joint_bottom - joint_top)) // 6 + + if limb_width % 2 == 0: + limb_width += 1 + kernel_size = limb_width + + part_orders = [(5, 11), (2, 8), (5, 6), (6, 7), (2, 3), (3, 4), (11, 12), + (12, 13), (8, 9), (9, 10)] + + map_list = [] + mask_list = [] + PAF_all = np.zeros( + shape=(small_image.shape[0], small_image.shape[1], 2), + dtype=np.float32) + for c, pair in enumerate(part_orders): + idx_a_name = pair[0] + idx_b_name = pair[1] + + jointa = joints[idx_a_name] + jointb = joints[idx_b_name] + + confidence_threshold = 0.05 + if jointa[2] > confidence_threshold and jointb[ + 2] > confidence_threshold: + canvas = np.zeros( + shape=(small_image.shape[0], small_image.shape[1]), + dtype=np.uint8) + + canvas = cv2.line(canvas, (int(jointa[0]), int(jointa[1])), + (int(jointb[0]), int(jointb[1])), + (255, 255, 255), 5) + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (kernel_size, kernel_size)) + + canvas = cv2.dilate(canvas, kernel, 1) + canvas = cv2.GaussianBlur(canvas, (kernel_size, kernel_size), 0) + canvas = canvas.astype(np.float32) / 255 + PAF = np.zeros( + shape=(small_image.shape[0], small_image.shape[1], 2), + dtype=np.float32) + PAF[..., 0] = jointb[0] - jointa[0] + PAF[..., 1] = jointb[1] - jointa[1] + mag, ang = cv2.cartToPolar(PAF[..., 0], PAF[..., 1]) + PAF /= (np.dstack((mag, mag)) + 1e-5) + + single_PAF = PAF * np.dstack((canvas, canvas)) + map_list.append( + cv2.GaussianBlur(single_PAF, + (kernel_size * 3, kernel_size * 3), 0)) + + mask_list.append( + cv2.GaussianBlur(canvas.copy(), + (kernel_size * 3, kernel_size * 3), 0)) + PAF_all = PAF_all * (1.0 - np.dstack( + (canvas, canvas))) + single_PAF + + PAF_all = cv2.GaussianBlur(PAF_all, (kernel_size * 3, kernel_size * 3), 0) + PAF_all = cv2.resize( + PAF_all, (org_w, org_h), interpolation=cv2.INTER_LINEAR) + map_list.append(PAF_all) + return PAF_all, map_list, mask_list + + +def gen_skeleton_map(joints, stack_mode='column', input_roi_box=None): + if type(joints) == list: + joints = np.array(joints) + assert stack_mode == 'column' or stack_mode == 'depth' + + part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), + (3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] + + def link(img, a, b, color, line_width, scale=1.0, x_offset=0, y_offset=0): + jointa = joints[a] + jointb = joints[b] + + temp1 = int((jointa[0] - x_offset) * scale) + temp2 = int((jointa[1] - y_offset) * scale) + temp3 = int((jointb[0] - x_offset) * scale) + temp4 = int((jointb[1] - y_offset) * scale) + + cv2.line(img, (temp1, temp2), (temp3, temp4), color, line_width) + + roi_box = input_roi_box + + roi_box_width = roi_box[3] - roi_box[2] + roi_box_height = roi_box[1] - roi_box[0] + short_side_length = min(roi_box_width, roi_box_height) + line_width = short_side_length // 30 + + line_width = max(line_width, 2) + + map_cube = np.zeros( + shape=(roi_box_height, roi_box_width, len(part_orders) + 1), + dtype=np.float32) + + use_line_width = min(5, line_width) + fx = use_line_width * 1.0 / line_width # fx 最大值为1 + + if fx < 0.99: + map_cube = cv2.resize(map_cube, (0, 0), fx=fx, fy=fx) + + for c, pair in enumerate(part_orders): + tmp = map_cube[..., c].copy() + link( + tmp, + pair[0], + pair[1], (2.0, 2.0, 2.0), + use_line_width, + scale=fx, + x_offset=roi_box[2], + y_offset=roi_box[0]) + map_cube[..., c] = tmp + + tmp = map_cube[..., -1].copy() + link( + tmp, + pair[0], + pair[1], (2.0, 2.0, 2.0), + use_line_width, + scale=fx, + x_offset=roi_box[2], + y_offset=roi_box[0]) + map_cube[..., -1] = tmp + + map_cube = cv2.resize(map_cube, (roi_box_width, roi_box_height)) + + if stack_mode == 'depth': + return map_cube, roi_box + elif stack_mode == 'column': + joint_maps = [] + for c in range(len(part_orders) + 1): + joint_maps.append(map_cube[..., c]) + joint_map = np.column_stack(joint_maps) + + return joint_map, roi_box + + +def plot_one_box(x, img, color=None, label=None, line_thickness=None): + tl = line_thickness or round( + 0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + if label: + tf = max(tl - 1, 1) # font thickness + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + img, + label, (c1[0], c1[1] - 2), + 0, + tl / 3, [225, 255, 255], + thickness=tf, + lineType=cv2.LINE_AA) + + +def draw_line(im, points, color, stroke_size=2, closed=False): + points = points.astype(np.int32) + for i in range(len(points) - 1): + cv2.line(im, tuple(points[i]), tuple(points[i + 1]), color, + stroke_size) + if closed: + cv2.line(im, tuple(points[0]), tuple(points[-1]), color, stroke_size) + + +def enlarged_bbox(bbox, img_width, img_height, enlarge_ratio=0.2): + left = bbox[0] + top = bbox[1] + + right = bbox[2] + bottom = bbox[3] + + roi_width = right - left + roi_height = bottom - top + + new_left = left - int(roi_width * enlarge_ratio) + new_left = 0 if new_left < 0 else new_left + + new_top = top - int(roi_height * enlarge_ratio) + new_top = 0 if new_top < 0 else new_top + + new_right = right + int(roi_width * enlarge_ratio) + new_right = img_width if new_right > img_width else new_right + + new_bottom = bottom + int(roi_height * enlarge_ratio) + new_bottom = img_height if new_bottom > img_height else new_bottom + + bbox = [new_left, new_top, new_right, new_bottom] + + bbox = [int(x) for x in bbox] + + return bbox + + +def get_map_fusion_map_cuda(map_list, threshold=1, device=torch.device('cpu')): + map_list_cuda = [torch.from_numpy(x).to(device) for x in map_list] + map_concat = torch.stack(tuple(map_list_cuda), dim=-1) + + map_concat = torch.abs(map_concat) + + map_concat[map_concat < threshold] = 0 + map_concat[map_concat > 1e-5] = 1.0 + + sum_map = torch.sum(map_concat, dim=2) + a = torch.ones_like(sum_map) + acc_map = torch.where(sum_map > 0, a * 2.0, torch.zeros_like(sum_map)) + + fusion_map = torch.where(sum_map < 0.5, a * 1.5, sum_map) + + fusion_map = fusion_map.float() + acc_map = acc_map.float() + + fusion_map = fusion_map.cpu().numpy().astype(np.float32) + acc_map = acc_map.cpu().numpy().astype(np.float32) + + return fusion_map, acc_map + + +def gen_border_shade(height, width, height_band, width_band): + height_ratio = height_band * 1.0 / height + width_ratio = width_band * 1.0 / width + + _height_band = int(256 * height_ratio) + _width_band = int(256 * width_ratio) + + canvas = np.zeros((256, 256), dtype=np.float32) + + canvas[_height_band // 2:-_height_band // 2, + _width_band // 2:-_width_band // 2] = 1.0 + + canvas = cv2.blur(canvas, (_height_band, _width_band)) + + canvas = cv2.resize(canvas, (width, height)) + + return canvas + + +def get_mask_bbox(mask, threshold=127): + ret, mask = cv2.threshold(mask, threshold, 1, 0) + + if cv2.countNonZero(mask) == 0: + return [None, None, None, None] + + col_acc = np.sum(mask, 0) + row_acc = np.sum(mask, 1) + + col_acc = col_acc.tolist() + row_acc = row_acc.tolist() + + for x in range(len(col_acc)): + if col_acc[x] > 0: + left = x + break + + for x in range(1, len(col_acc)): + if col_acc[-x] > 0: + right = len(col_acc) - x + break + + for x in range(len(row_acc)): + if row_acc[x] > 0: + top = x + break + + for x in range(1, len(row_acc)): + if row_acc[-x] > 0: + bottom = len(row_acc[::-1]) - x + break + return [top, bottom, left, right] + + +def visualize_flow(flow): + h, w = flow.shape[:2] + hsv = np.zeros((h, w, 3), np.uint8) + mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + hsv[..., 2] = 255 + bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + bgr = bgr * 1.0 / 255 + return bgr.astype(np.float32) + + +def vis_joints(image, joints, color, show_text=True, confidence_threshold=0.1): + + part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), + (3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] + + abandon_idxs = [0, 1, 14, 15, 16, 17] + # draw joints + for i, joint in enumerate(joints): + if i in abandon_idxs: + continue + if joint[-1] > confidence_threshold: + + cv2.circle(image, (int(joint[0]), int(joint[1])), 1, color, 2) + if show_text: + cv2.putText(image, + str(i) + '[{:.2f}]'.format(joint[-1]), + (int(joint[0]), int(joint[1])), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + # draw link + for pair in part_orders: + if joints[pair[0]][-1] > confidence_threshold and joints[ + pair[1]][-1] > confidence_threshold: + cv2.line(image, (int(joints[pair[0]][0]), int(joints[pair[0]][1])), + (int(joints[pair[1]][0]), int(joints[pair[1]][1])), color, + 2) + return image + + +def get_heatmap_cv(img, magn, max_flow_mag): + min_flow_mag = .5 + cv_magn = np.clip( + 255 * (magn - min_flow_mag) / (max_flow_mag - min_flow_mag + 1e-7), + a_min=0, + a_max=255).astype(np.uint8) + if img.dtype != np.uint8: + img = (255 * img).astype(np.uint8) + + heatmap_img = cv2.applyColorMap(cv_magn, cv2.COLORMAP_JET) + heatmap_img = heatmap_img[..., ::-1] + + h, w = magn.shape + img_alpha = np.ones((h, w), dtype=np.double)[:, :, None] + heatmap_alpha = np.clip( + magn / (max_flow_mag + 1e-7), a_min=1e-7, a_max=1)[:, :, None]**.7 + heatmap_alpha[heatmap_alpha < .2]**.5 + pm_hm = heatmap_img * heatmap_alpha + pm_img = img * img_alpha + cv_out = pm_hm + pm_img * (1 - heatmap_alpha) + cv_out = np.clip(cv_out, a_min=0, a_max=255).astype(np.uint8) + + return cv_out + + +def save_heatmap_cv(img, flow, supression=2): + + flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2) + flow_magn -= supression + flow_magn[flow_magn <= 0] = 0 + cv_out = get_heatmap_cv(img, flow_magn, np.max(flow_magn) * 1.3) + return cv_out + + +@numba.jit(nopython=True, parallel=False) +def bilinear_interp(x, y, v11, v12, v21, v22): + temp1 = (v11 * (1 - y) + v12 * y) * (1 - x) + temp2 = (v21 * (1 - y) + v22 * y) * x + result = temp1 + temp2 + return result + + +@numba.jit(nopython=True, parallel=False) +def image_warp_grid1(rDx, rDy, oriImg, transRatio, width_expand, + height_expand): + srcW = oriImg.shape[1] + srcH = oriImg.shape[0] + + newImg = oriImg.copy() + + for i in range(srcH): + for j in range(srcW): + _i = i + _j = j + + deltaX = rDx[_i, _j] + deltaY = rDy[_i, _j] + + nx = _j + deltaX * transRatio + ny = _i + deltaY * transRatio + + if nx >= srcW - width_expand - 1: + if nx > srcW - 1: + nx = srcW - 1 + + if ny >= srcH - height_expand - 1: + if ny > srcH - 1: + ny = srcH - 1 + + if nx < width_expand: + if nx < 0: + nx = 0 + + if ny < height_expand: + if ny < 0: + ny = 0 + + nxi = int(math.floor(nx)) + nyi = int(math.floor(ny)) + nxi1 = int(math.ceil(nx)) + nyi1 = int(math.ceil(ny)) + + for ll in range(3): + newImg[_i, _j, + ll] = bilinear_interp(ny - nyi, nx - nxi, + oriImg[nyi, nxi, + ll], oriImg[nyi, nxi1, ll], + oriImg[nyi1, nxi, + ll], oriImg[nyi1, nxi1, + ll]) + return newImg diff --git a/modelscope/models/cv/image_classification/__init__.py b/modelscope/models/cv/image_classification/__init__.py new file mode 100644 index 00000000..7afe44bb --- /dev/null +++ b/modelscope/models/cv/image_classification/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .mmcls_model import ClassificationModel + +else: + _import_structure = { + 'mmcls_model': ['ClassificationModel'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_classification/mmcls_model.py b/modelscope/models/cv/image_classification/mmcls_model.py new file mode 100644 index 00000000..a6789d0b --- /dev/null +++ b/modelscope/models/cv/image_classification/mmcls_model.py @@ -0,0 +1,42 @@ +import os + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + Tasks.image_classification, module_name=Models.classification_model) +class ClassificationModel(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + import mmcv + from mmcls.models import build_classifier + + super().__init__(model_dir) + + config = os.path.join(model_dir, 'config.py') + + cfg = mmcv.Config.fromfile(config) + cfg.model.pretrained = None + self.cls_model = build_classifier(cfg.model) + + self.cfg = cfg + self.ms_model_dir = model_dir + + self.load_pretrained_checkpoint() + + def forward(self, inputs): + + return self.cls_model(**inputs) + + def load_pretrained_checkpoint(self): + import mmcv + checkpoint_path = os.path.join(self.ms_model_dir, 'checkpoints.pth') + if os.path.exists(checkpoint_path): + checkpoint = mmcv.runner.load_checkpoint( + self.cls_model, checkpoint_path, map_location='cpu') + if 'CLASSES' in checkpoint.get('meta', {}): + self.cls_model.CLASSES = checkpoint['meta']['CLASSES'] + self.CLASSES = self.cls_model.CLASSES diff --git a/modelscope/models/cv/image_color_enhance/__init__.py b/modelscope/models/cv/image_color_enhance/__init__.py new file mode 100644 index 00000000..72f26b52 --- /dev/null +++ b/modelscope/models/cv/image_color_enhance/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_color_enhance import ImageColorEnhance + +else: + _import_structure = { + 'image_color_enhance': ['ImageColorEnhance'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_color_enhance/csrnet.py b/modelscope/models/cv/image_color_enhance/csrnet.py new file mode 100644 index 00000000..502abf88 --- /dev/null +++ b/modelscope/models/cv/image_color_enhance/csrnet.py @@ -0,0 +1,113 @@ +# The implementation is adopted from Jingwen He, +# made publicly available at https://github.com/hejingwenhejingwen/CSRNet + +import functools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Condition(nn.Module): + + def __init__(self, in_nc=3, nf=32): + super(Condition, self).__init__() + stride = 2 + pad = 0 + self.pad = nn.ZeroPad2d(1) + self.conv1 = nn.Conv2d(in_nc, nf, 7, stride, pad, bias=True) + self.conv2 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) + self.conv3 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + conv1_out = self.act(self.conv1(self.pad(x))) + conv2_out = self.act(self.conv2(self.pad(conv1_out))) + conv3_out = self.act(self.conv3(self.pad(conv2_out))) + out = torch.mean(conv3_out, dim=[2, 3], keepdim=False) + + return out + + +# 3layers with control +class CSRNet(nn.Module): + + def __init__(self, in_nc=3, out_nc=3, base_nf=64, cond_nf=32): + super(CSRNet, self).__init__() + + self.base_nf = base_nf + self.out_nc = out_nc + + self.cond_net = Condition(in_nc=in_nc, nf=cond_nf) + + self.cond_scale1 = nn.Linear(cond_nf, base_nf, bias=True) + self.cond_scale2 = nn.Linear(cond_nf, base_nf, bias=True) + self.cond_scale3 = nn.Linear(cond_nf, 3, bias=True) + + self.cond_shift1 = nn.Linear(cond_nf, base_nf, bias=True) + self.cond_shift2 = nn.Linear(cond_nf, base_nf, bias=True) + self.cond_shift3 = nn.Linear(cond_nf, 3, bias=True) + + self.conv1 = nn.Conv2d(in_nc, base_nf, 1, 1, bias=True) + self.conv2 = nn.Conv2d(base_nf, base_nf, 1, 1, bias=True) + self.conv3 = nn.Conv2d(base_nf, out_nc, 1, 1, bias=True) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + cond = self.cond_net(x) + + scale1 = self.cond_scale1(cond) + shift1 = self.cond_shift1(cond) + + scale2 = self.cond_scale2(cond) + shift2 = self.cond_shift2(cond) + + scale3 = self.cond_scale3(cond) + shift3 = self.cond_shift3(cond) + + out = self.conv1(x) + out = out * scale1.view(-1, self.base_nf, 1, 1) + shift1.view( + -1, self.base_nf, 1, 1) + out + out = self.act(out) + + out = self.conv2(out) + out = out * scale2.view(-1, self.base_nf, 1, 1) + shift2.view( + -1, self.base_nf, 1, 1) + out + out = self.act(out) + + out = self.conv3(out) + out = out * scale3.view(-1, self.out_nc, 1, 1) + shift3.view( + -1, self.out_nc, 1, 1) + out + return out + + +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError( + f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' + ) + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * F.l1_loss( + pred, target, reduction=self.reduction) diff --git a/modelscope/models/cv/image_color_enhance/image_color_enhance.py b/modelscope/models/cv/image_color_enhance/image_color_enhance.py new file mode 100644 index 00000000..0bd74197 --- /dev/null +++ b/modelscope/models/cv/image_color_enhance/image_color_enhance.py @@ -0,0 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from copy import deepcopy +from typing import Dict, Union + +import torch +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .csrnet import CSRNet, L1Loss + +logger = get_logger() + +__all__ = ['ImageColorEnhance'] + + +@MODELS.register_module( + Tasks.image_color_enhancement, module_name=Models.csrnet) +class ImageColorEnhance(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the image color enhance model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + + self.loss = L1Loss() + self.model = CSRNet() + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.model = self.model.to(self._device) + + self.model = self.load_pretrained(self.model, model_path) + + if self.training: + self.model.train() + else: + self.model.eval() + + def load_pretrained(self, net, load_path, strict=True, param_key='params'): + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + load_net = torch.load( + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info( + f'Loading: {param_key} does not exist, use params.') + if param_key in load_net: + load_net = load_net[param_key] + logger.info( + f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' + ) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + net.load_state_dict(load_net, strict=strict) + logger.info('load model done.') + return net + + def _evaluate_postprocess(self, src: Tensor, + target: Tensor) -> Dict[str, list]: + preds = self.model(src) + preds = list(torch.split(preds, 1, 0)) + targets = list(torch.split(target, 1, 0)) + + preds = [(pred.data * 255.).squeeze(0).type(torch.uint8).permute( + 1, 2, 0).cpu().numpy() for pred in preds] + targets = [(target.data * 255.).squeeze(0).type(torch.uint8).permute( + 1, 2, 0).cpu().numpy() for target in targets] + + return {'pred': preds, 'target': targets} + + def _train_forward(self, src: Tensor, target: Tensor) -> Dict[str, Tensor]: + preds = self.model(src) + return {'loss': self.loss(preds, target)} + + def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: + return {'outputs': self.model(src).clamp(0, 1)} + + def forward(self, input: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Union[list, Tensor]]: results + """ + for key, value in input.items(): + input[key] = input[key].to(self._device) + if self.training: + return self._train_forward(**input) + elif 'target' in input: + return self._evaluate_postprocess(**input) + else: + return self._inference_forward(**input) diff --git a/modelscope/models/cv/image_colorization/__init__.py b/modelscope/models/cv/image_colorization/__init__.py new file mode 100644 index 00000000..9dbb07a5 --- /dev/null +++ b/modelscope/models/cv/image_colorization/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .unet import DynamicUnetWide, DynamicUnetDeep + from .utils import NormType + +else: + _import_structure = { + 'unet': ['DynamicUnetWide', 'DynamicUnetDeep'], + 'utils': ['NormType'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_colorization/unet.py b/modelscope/models/cv/image_colorization/unet.py new file mode 100644 index 00000000..19f6ab62 --- /dev/null +++ b/modelscope/models/cv/image_colorization/unet.py @@ -0,0 +1,302 @@ +# The implementation here is modified based on DeOldify, originally MIT License +# and publicly available at https://github.com/jantic/DeOldify/blob/master/deoldify/unet.py +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .utils import (MergeLayer, NormType, PixelShuffle_ICNR, SelfAttention, + SequentialEx, SigmoidRange, dummy_eval, hook_outputs, + in_channels, model_sizes, relu, res_block) + +__all__ = ['DynamicUnetDeep', 'DynamicUnetWide'] + + +def custom_conv_layer( + ni, + nf, + ks=3, + stride=1, + padding=None, + bias=None, + is_1d=False, + norm_type=NormType.Batch, + use_activ=True, + leaky=None, + transpose=False, + init=nn.init.kaiming_normal_, + self_attention=False, + extra_bn=False, +): + 'Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.' + if padding is None: + padding = (ks - 1) // 2 if not transpose else 0 + bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn is True + if bias is None: + bias = not bn + conv_func = nn.ConvTranspose2d if transpose is True else nn.Conv1d + conv_func = conv_func if is_1d else nn.Conv2d + conv = conv_func( + ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding) + if norm_type == NormType.Weight: + conv = weight_norm(conv) + elif norm_type == NormType.Spectral: + conv = spectral_norm(conv) + + layers = [conv] + if use_activ: + layers.append(relu(True, leaky=leaky)) + if bn: + layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf)) + if self_attention: + layers.append(SelfAttention(nf)) + return nn.Sequential(*layers) + + +def _get_sfs_idxs(sizes): + 'Get the indexes of the layers where the size of the activation changes.' + feature_szs = [size[-1] for size in sizes] + sfs_idxs = list( + np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]) + if feature_szs[0] != feature_szs[1]: + sfs_idxs = [0] + sfs_idxs + return sfs_idxs + + +class CustomPixelShuffle_ICNR(nn.Module): + 'Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, and `weight_norm`.' + + def __init__(self, ni, nf=None, scale=2, blur=False, leaky=None, **kwargs): + super().__init__() + nf = ni if nf is None else nf + self.conv = custom_conv_layer( + ni, nf * (scale**2), ks=1, use_activ=False, **kwargs) + self.shuf = nn.PixelShuffle(scale) + # Blurring over (h*w) kernel + # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" + # - https://arxiv.org/abs/1806.02658 + self.pad = nn.ReplicationPad2d((1, 0, 1, 0)) + self.blur = nn.AvgPool2d(2, stride=1) + self.relu = relu(True, leaky=leaky) + + def forward(self, x): + x = self.shuf(self.relu(self.conv(x))) + return self.blur(self.pad(x)) if self.blur else x + + +class UnetBlockDeep(nn.Module): + 'A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.' + + def __init__(self, + up_in_c, + x_in_c, + hook, + final_div=True, + blur=False, + leaky=None, + self_attention=False, + nf_factor=1.0, + **kwargs): + super().__init__() + self.hook = hook + self.shuf = CustomPixelShuffle_ICNR( + up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs) + self.bn = nn.BatchNorm2d(x_in_c) + ni = up_in_c // 2 + x_in_c + nf = int((ni if final_div else ni // 2) * nf_factor) + self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs) + self.conv2 = custom_conv_layer( + nf, nf, leaky=leaky, self_attention=self_attention, **kwargs) + self.relu = relu(leaky=leaky) + + def forward(self, up_in): + s = self.hook.stored + up_out = self.shuf(up_in) + ssh = s.shape[-2:] + if ssh != up_out.shape[-2:]: + up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') + cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) + return self.conv2(self.conv1(cat_x)) + + +class DynamicUnetDeep(SequentialEx): + 'Create a U-Net from a given architecture.' + + def __init__(self, + encoder, + n_classes, + blur=False, + blur_final=True, + self_attention=False, + y_range=None, + last_cross=True, + bottle=False, + norm_type=NormType.Batch, + nf_factor=1.0, + **kwargs): + extra_bn = norm_type == NormType.Spectral + imsize = (256, 256) + sfs_szs = model_sizes(encoder, size=imsize) + sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) + self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False) + x = dummy_eval(encoder, imsize).detach() + + ni = sfs_szs[-1][1] + middle_conv = nn.Sequential( + custom_conv_layer( + ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs), + custom_conv_layer( + ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs), + ).eval() + x = middle_conv(x) + layers = [encoder, nn.BatchNorm2d(ni), nn.ReLU(), middle_conv] + + for i, idx in enumerate(sfs_idxs): + not_final = i != len(sfs_idxs) - 1 + up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) + sa = self_attention and (i == len(sfs_idxs) - 3) + unet_block = UnetBlockDeep( + up_in_c, + x_in_c, + self.sfs[i], + final_div=not_final, + blur=blur, + self_attention=sa, + norm_type=norm_type, + extra_bn=extra_bn, + nf_factor=nf_factor, + **kwargs).eval() + layers.append(unet_block) + x = unet_block(x) + + ni = x.shape[1] + if imsize != sfs_szs[0][-2:]: + layers.append(PixelShuffle_ICNR(ni, **kwargs)) + if last_cross: + layers.append(MergeLayer(dense=True)) + ni += in_channels(encoder) + layers.append( + res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) + layers += [ + custom_conv_layer( + ni, n_classes, ks=1, use_activ=False, norm_type=norm_type) + ] + if y_range is not None: + layers.append(SigmoidRange(*y_range)) + super().__init__(*layers) + + def __del__(self): + if hasattr(self, 'sfs'): + self.sfs.remove() + + +# ------------------------------------------------------ +class UnetBlockWide(nn.Module): + 'A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.' + + def __init__(self, + up_in_c, + x_in_c, + n_out, + hook, + final_div=True, + blur=False, + leaky=None, + self_attention=False, + **kwargs): + super().__init__() + self.hook = hook + up_out = x_out = n_out // 2 + self.shuf = CustomPixelShuffle_ICNR( + up_in_c, up_out, blur=blur, leaky=leaky, **kwargs) + self.bn = nn.BatchNorm2d(x_in_c) + ni = up_out + x_in_c + self.conv = custom_conv_layer( + ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs) + self.relu = relu(leaky=leaky) + + def forward(self, up_in): + s = self.hook.stored + up_out = self.shuf(up_in) + ssh = s.shape[-2:] + if ssh != up_out.shape[-2:]: + up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') + cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) + return self.conv(cat_x) + + +class DynamicUnetWide(SequentialEx): + 'Create a U-Net from a given architecture.' + + def __init__(self, + encoder, + n_classes, + blur=False, + blur_final=True, + self_attention=False, + y_range=None, + last_cross=True, + bottle=False, + norm_type=NormType.Batch, + nf_factor=1, + **kwargs): + + nf = 512 * nf_factor + extra_bn = norm_type == NormType.Spectral + imsize = (256, 256) + sfs_szs = model_sizes(encoder, size=imsize) + sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) + self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False) + x = dummy_eval(encoder, imsize).detach() + + ni = sfs_szs[-1][1] + middle_conv = nn.Sequential( + custom_conv_layer( + ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs), + custom_conv_layer( + ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs), + ).eval() + x = middle_conv(x) + layers = [encoder, nn.BatchNorm2d(ni), nn.ReLU(), middle_conv] + + for i, idx in enumerate(sfs_idxs): + not_final = i != len(sfs_idxs) - 1 + up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) + sa = self_attention and (i == len(sfs_idxs) - 3) + + n_out = nf if not_final else nf // 2 + + unet_block = UnetBlockWide( + up_in_c, + x_in_c, + n_out, + self.sfs[i], + final_div=not_final, + blur=blur, + self_attention=sa, + norm_type=norm_type, + extra_bn=extra_bn, + **kwargs).eval() + layers.append(unet_block) + x = unet_block(x) + + ni = x.shape[1] + if imsize != sfs_szs[0][-2:]: + layers.append(PixelShuffle_ICNR(ni, **kwargs)) + if last_cross: + layers.append(MergeLayer(dense=True)) + ni += in_channels(encoder) + layers.append( + res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) + layers += [ + custom_conv_layer( + ni, n_classes, ks=1, use_activ=False, norm_type=norm_type) + ] + if y_range is not None: + layers.append(SigmoidRange(*y_range)) + super().__init__(*layers) + + def __del__(self): + if hasattr(self, 'sfs'): + self.sfs.remove() diff --git a/modelscope/models/cv/image_colorization/utils.py b/modelscope/models/cv/image_colorization/utils.py new file mode 100644 index 00000000..b8968aa0 --- /dev/null +++ b/modelscope/models/cv/image_colorization/utils.py @@ -0,0 +1,350 @@ +# The implementation here is modified based on DeOldify, originally MIT License and +# publicly available at https://github.com/jantic/DeOldify/blob/master/fastai/callbacks/hooks.py +import functools +from enum import Enum + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import spectral_norm, weight_norm + +NormType = Enum('NormType', + 'Batch BatchZero Weight Spectral Group Instance SpectralGN') + + +def is_listy(x): + return isinstance(x, (tuple, list)) + + +class Hook(): + 'Create a hook on `m` with `hook_func`.' + + def __init__(self, m, hook_func, is_forward=True, detach=True): + self.hook_func, self.detach, self.stored = hook_func, detach, None + f = m.register_forward_hook if is_forward else m.register_backward_hook + self.hook = f(self.hook_fn) + self.removed = False + + def hook_fn(self, module, input, output): + 'Applies `hook_func` to `module`, `input`, `output`.' + if self.detach: + input = (o.detach() + for o in input) if is_listy(input) else input.detach() + output = ( + o.detach() + for o in output) if is_listy(output) else output.detach() + self.stored = self.hook_func(module, input, output) + + def remove(self): + 'Remove the hook from the model.' + if not self.removed: + self.hook.remove() + self.removed = True + + def __enter__(self, *args): + return self + + def __exit__(self, *args): + self.remove() + + +class Hooks(): + 'Create several hooks on the modules in `ms` with `hook_func`.' + + def __init__(self, ms, hook_func, is_forward=True, detach=True): + self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms] + + def __getitem__(self, i): + return self.hooks[i] + + def __len__(self): + return len(self.hooks) + + def __iter__(self): + return iter(self.hooks) + + @property + def stored(self): + return [o.stored for o in self] + + def remove(self): + 'Remove the hooks from the model.' + for h in self.hooks: + h.remove() + + def __enter__(self, *args): + return self + + def __exit__(self, *args): + self.remove() + + +def _hook_inner(m, i, o): + return o if isinstance(o, torch.Tensor) else o if is_listy(o) else list(o) + + +def hook_outputs(modules, detach=True, grad=False): + 'Return `Hooks` that store activations of all `modules` in `self.stored`' + return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad) + + +def one_param(m): + 'Return the first parameter of `m`.' + return next(m.parameters()) + + +def dummy_batch(m, size=(64, 64)): + 'Create a dummy batch to go through `m` with `size`.' + ch_in = in_channels(m) + return one_param(m).new(1, ch_in, + *size).requires_grad_(False).uniform_(-1., 1.) + + +def dummy_eval(m, size=(64, 64)): + 'Pass a `dummy_batch` in evaluation mode in `m` with `size`.' + return m.eval()(dummy_batch(m, size)) + + +def model_sizes(m, size=(64, 64)): + 'Pass a dummy input through the model `m` to get the various sizes of activations.' + with hook_outputs(m) as hooks: + dummy_eval(m, size) + return [o.stored.shape for o in hooks] + + +class PrePostInitMeta(type): + 'A metaclass that calls optional `__pre_init__` and `__post_init__` methods' + + def __new__(cls, name, bases, dct): + x = super().__new__(cls, name, bases, dct) + old_init = x.__init__ + + def _pass(self): + pass + + @functools.wraps(old_init) + def _init(self, *args, **kwargs): + self.__pre_init__() + old_init(self, *args, **kwargs) + self.__post_init__() + + x.__init__ = _init + if not hasattr(x, '__pre_init__'): + x.__pre_init__ = _pass + if not hasattr(x, '__post_init__'): + x.__post_init__ = _pass + return x + + +class Module(nn.Module, metaclass=PrePostInitMeta): + 'Same as `nn.Module`, but no need for subclasses to call `super().__init__`' + + def __pre_init__(self): + super().__init__() + + def __init__(self): + pass + + +def children(m): + 'Get children of `m`.' + return list(m.children()) + + +def num_children(m): + 'Get number of children modules in `m`.' + return len(children(m)) + + +def children_and_parameters(m: nn.Module): + 'Return the children of `m` and its direct parameters not registered in modules.' + children = list(m.children()) + children_p = sum([[id(p) for p in c.parameters()] for c in m.children()], + []) + for p in m.parameters(): + if id(p) not in children_p: + children.append(ParameterModule(p)) + return children + + +def flatten_model(m): + if num_children(m): + mapped = map(flatten_model, children_and_parameters(m)) + return sum(mapped, []) + else: + return [m] + + +def in_channels(m): + 'Return the shape of the first weight layer in `m`.' + for layer in flatten_model(m): + if hasattr(layer, 'weight'): + return layer.weight.shape[1] + raise Exception('No weight layer') + + +def relu(inplace: bool = False, leaky: float = None): + 'Return a relu activation, maybe `leaky` and `inplace`.' + return nn.LeakyReLU( + inplace=inplace, + negative_slope=leaky) if leaky is not None else nn.ReLU( + inplace=inplace) + + +def conv_layer(ni, + nf, + ks=3, + stride=1, + padding=None, + bias=None, + is_1d=False, + norm_type=NormType.Batch, + use_activ=True, + leaky=None, + transpose=False, + init=nn.init.kaiming_normal_, + self_attention=False): + 'Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.' + if padding is None: + padding = (ks - 1) // 2 if not transpose else 0 + bn = norm_type in (NormType.Batch, NormType.BatchZero) + if bias is None: + bias = not bn + conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d + conv = conv_func( + ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding) + if norm_type == NormType.Weight: + conv = weight_norm(conv) + elif norm_type == NormType.Spectral: + conv = spectral_norm(conv) + layers = [conv] + if use_activ: + layers.append(relu(True, leaky=leaky)) + if bn: + layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf)) + if self_attention: + layers.append(SelfAttention(nf)) + return nn.Sequential(*layers) + + +def res_block(nf, + dense=False, + norm_type=NormType.Batch, + bottle=False, + **conv_kwargs): + 'Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`.' + norm2 = norm_type + if not dense and (norm_type == NormType.Batch): + norm2 = NormType.BatchZero + nf_inner = nf // 2 if bottle else nf + return SequentialEx( + conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs), + conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs), + MergeLayer(dense)) + + +def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False): + 'Create and initialize a `nn.Conv1d` layer with spectral normalization.' + conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) + nn.init.kaiming_normal_(conv.weight) + if bias: + conv.bias.data.zero_() + return spectral_norm(conv) + + +class SelfAttention(Module): + 'Self attention layer for nd.' + + def __init__(self, n_channels): + self.query = conv1d(n_channels, n_channels // 8) + self.key = conv1d(n_channels, n_channels // 8) + self.value = conv1d(n_channels, n_channels) + self.gamma = nn.Parameter(torch.tensor([0.])) + + def forward(self, x): + 'Notation from https://arxiv.org/pdf/1805.08318.pdf' + size = x.size() + x = x.view(*size[:2], -1) + f, g, h = self.query(x), self.key(x), self.value(x) + beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1) + o = self.gamma * torch.bmm(h, beta) + x + return o.view(*size).contiguous() + + +def sigmoid_range(x, low, high): + 'Sigmoid function with range `(low, high)`' + return torch.sigmoid(x) * (high - low) + low + + +class SigmoidRange(Module): + 'Sigmoid module with range `(low,x_max)`' + + def __init__(self, low, high): + self.low, self.high = low, high + + def forward(self, x): + return sigmoid_range(x, self.low, self.high) + + +class SequentialEx(Module): + 'Like `nn.Sequential`, but with ModuleList semantics, and can access module input' + + def __init__(self, *layers): + self.layers = nn.ModuleList(layers) + + def forward(self, x): + res = x + for layer in self.layers: + res.orig = x + nres = layer(res) + res.orig = None + res = nres + return res + + def __getitem__(self, i): + return self.layers[i] + + def append(self, layer): + return self.layers.append(layer) + + def extend(self, layer): + return self.layers.extend(layer) + + def insert(self, i, layer): + return self.layers.insert(i, layer) + + +class MergeLayer(Module): + 'Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`.' + + def __init__(self, dense: bool = False): + self.dense = dense + + def forward(self, x): + return torch.cat([x, x.orig], dim=1) if self.dense else (x + x.orig) + + +class PixelShuffle_ICNR(Module): + 'Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, and `weight_norm`.' + + def __init__(self, + ni: int, + nf: int = None, + scale: int = 2, + blur: bool = False, + norm_type=NormType.Weight, + leaky: float = None): + nf = ni if nf is None else nf + self.conv = conv_layer( + ni, nf * (scale**2), ks=1, norm_type=norm_type, use_activ=False) + self.shuf = nn.PixelShuffle(scale) + # Blurring over (h*w) kernel + # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" + # - https://arxiv.org/abs/1806.02658 + self.pad = nn.ReplicationPad2d((1, 0, 1, 0)) + self.blur = nn.AvgPool2d(2, stride=1) + self.relu = relu(True, leaky=leaky) + + def forward(self, x): + x = self.shuf(self.relu(self.conv(x))) + return self.blur(self.pad(x)) if self.blur else x diff --git a/modelscope/models/cv/image_denoise/__init__.py b/modelscope/models/cv/image_denoise/__init__.py new file mode 100644 index 00000000..aa925daf --- /dev/null +++ b/modelscope/models/cv/image_denoise/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .nafnet_for_image_denoise import NAFNetForImageDenoise + +else: + _import_structure = {'nafnet_for_image_denoise': ['NAFNetForImageDenoise']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py b/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py new file mode 100644 index 00000000..c4de0729 --- /dev/null +++ b/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py @@ -0,0 +1,238 @@ +# ------------------------------------------------------------------------ +# Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/NAFNet_arch.py +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .arch_util import LayerNorm2d + + +class SimpleGate(nn.Module): + + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class NAFBlock(nn.Module): + + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=dw_channel // 2, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter( + torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class NAFNet(nn.Module): + + def __init__(self, + img_channel=3, + width=16, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[]): + super().__init__() + + self.intro = nn.Conv2d( + in_channels=img_channel, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=img_channel, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) + self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2)) + chan = chan * 2 + + self.middle_blks = \ + nn.Sequential( + *[NAFBlock(chan) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2))) + chan = chan // 2 + self.decoders.append( + nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) + + self.padder_size = 2**len(self.encoders) + + def forward(self, inp): + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size + - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size + - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + +class PSNRLoss(nn.Module): + + def __init__(self, loss_weight=1.0, reduction='mean', toY=False): + super(PSNRLoss, self).__init__() + assert reduction == 'mean' + self.loss_weight = loss_weight + self.scale = 10 / np.log(10) + self.toY = toY + self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) + self.first = True + + def forward(self, pred, target): + assert len(pred.size()) == 4 + if self.toY: + if self.first: + self.coef = self.coef.to(pred.device) + self.first = False + + pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. + target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. + + pred, target = pred / 255., target / 255. + pass + assert len(pred.size()) == 4 + + return self.loss_weight * self.scale * torch.log(( + (pred - target)**2).mean(dim=(1, 2, 3)) + 1e-8).mean() diff --git a/modelscope/models/cv/image_denoise/nafnet/__init__.py b/modelscope/models/cv/image_denoise/nafnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_denoise/nafnet/arch_util.py b/modelscope/models/cv/image_denoise/nafnet/arch_util.py new file mode 100644 index 00000000..2d406141 --- /dev/null +++ b/modelscope/models/cv/image_denoise/nafnet/arch_util.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + +import torch +import torch.nn as nn + + +class LayerNormFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum( + dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None + + +class LayerNorm2d(nn.Module): + + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) diff --git a/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py b/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py new file mode 100644 index 00000000..4e8fc0ed --- /dev/null +++ b/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py @@ -0,0 +1,100 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from copy import deepcopy +from typing import Any, Dict, Union + +import torch.cuda +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .nafnet.NAFNet_arch import NAFNet, PSNRLoss + +logger = get_logger() +__all__ = ['NAFNetForImageDenoise'] + + +@MODELS.register_module(Tasks.image_denoising, module_name=Models.nafnet) +class NAFNetForImageDenoise(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the image denoise model from the `model_dir` path. + + Args: + model_dir (str): the model path. + + """ + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + self.model = NAFNet(**self.config.model.network_g) + self.loss = PSNRLoss() + self.model = self._load_pretrained(self.model, model_path) + + def _load_pretrained(self, + net, + load_path, + strict=True, + param_key='params'): + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + load_net = torch.load( + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info( + f'Loading: {param_key} does not exist, use params.') + if param_key in load_net: + load_net = load_net[param_key] + logger.info( + f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' + ) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + net.load_state_dict(load_net, strict=strict) + logger.info('load model done.') + return net + + def _train_forward(self, input: Tensor, + target: Tensor) -> Dict[str, Tensor]: + preds = self.model(input) + return {'loss': self.loss(preds, target)} + + def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: + return {'outputs': self.model(input).clamp(0, 1)} + + def _evaluate_postprocess(self, input: Tensor, + target: Tensor) -> Dict[str, list]: + preds = self.model(input) + preds = list(torch.split(preds.clamp(0, 1), 1, 0)) + targets = list(torch.split(target.clamp(0, 1), 1, 0)) + + return {'pred': preds, 'target': targets} + + def forward(self, inputs: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + inputs (Tensor): the preprocessed data + + Returns: + Dict[str, Tensor]: results + """ + if self.training: + return self._train_forward(**inputs) + elif 'target' in inputs: + return self._evaluate_postprocess(**inputs) + else: + return self._inference_forward(**inputs) diff --git a/modelscope/models/cv/image_inpainting/__init__.py b/modelscope/models/cv/image_inpainting/__init__.py new file mode 100644 index 00000000..e7c63cd4 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model import FFTInpainting + +else: + _import_structure = { + 'model': ['FFTInpainting'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_inpainting/base.py b/modelscope/models/cv/image_inpainting/base.py new file mode 100644 index 00000000..04e73630 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/base.py @@ -0,0 +1,75 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger +from .modules.adversarial import NonSaturatingWithR1 +from .modules.ffc import FFCResNetGenerator +from .modules.perceptual import ResNetPL +from .modules.pix2pixhd import NLayerDiscriminator + +LOGGER = get_logger() + + +class BaseInpaintingTrainingModule(nn.Module): + + def __init__(self, + model_dir='', + use_ddp=True, + predict_only=False, + visualize_each_iters=100, + average_generator=False, + generator_avg_beta=0.999, + average_generator_start_step=30000, + average_generator_period=10, + store_discr_outputs_for_vis=False, + **kwargs): + super().__init__() + LOGGER.info( + f'BaseInpaintingTrainingModule init called, predict_only is {predict_only}' + ) + + self.generator = FFCResNetGenerator() + self.use_ddp = use_ddp + + if not predict_only: + self.discriminator = NLayerDiscriminator() + self.adversarial_loss = NonSaturatingWithR1( + weight=10, + gp_coef=0.001, + mask_as_fake_target=True, + allow_scale_mask=True) + + self.average_generator = average_generator + self.generator_avg_beta = generator_avg_beta + self.average_generator_start_step = average_generator_start_step + self.average_generator_period = average_generator_period + self.generator_average = None + self.last_generator_averaging_step = -1 + self.store_discr_outputs_for_vis = store_discr_outputs_for_vis + + self.loss_l1 = nn.L1Loss(reduction='none') + + self.loss_resnet_pl = ResNetPL(weight=30, weights_path=model_dir) + + self.visualize_each_iters = visualize_each_iters + LOGGER.info('BaseInpaintingTrainingModule init done') + + def forward(self, batch: Dict[str, + torch.Tensor]) -> Dict[str, torch.Tensor]: + """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys""" + raise NotImplementedError() + + def generator_loss(self, + batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() + + def discriminator_loss( + self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() diff --git a/modelscope/models/cv/image_inpainting/default.py b/modelscope/models/cv/image_inpainting/default.py new file mode 100644 index 00000000..5f57d63f --- /dev/null +++ b/modelscope/models/cv/image_inpainting/default.py @@ -0,0 +1,210 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import bisect + +import torch +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger +from .base import BaseInpaintingTrainingModule +from .modules.feature_matching import feature_matching_loss, masked_l1_loss + +LOGGER = get_logger() + + +def set_requires_grad(module, value): + for param in module.parameters(): + param.requires_grad = value + + +def add_prefix_to_keys(dct, prefix): + return {prefix + k: v for k, v in dct.items()} + + +class LinearRamp: + + def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): + self.start_value = start_value + self.end_value = end_value + self.start_iter = start_iter + self.end_iter = end_iter + + def __call__(self, i): + if i < self.start_iter: + return self.start_value + if i >= self.end_iter: + return self.end_value + part = (i - self.start_iter) / (self.end_iter - self.start_iter) + return self.start_value * (1 - part) + self.end_value * part + + +class LadderRamp: + + def __init__(self, start_iters, values): + self.start_iters = start_iters + self.values = values + assert len(values) == len(start_iters) + 1, (len(values), + len(start_iters)) + + def __call__(self, i): + segment_i = bisect.bisect_right(self.start_iters, i) + return self.values[segment_i] + + +def get_ramp(kind='ladder', **kwargs): + if kind == 'linear': + return LinearRamp(**kwargs) + if kind == 'ladder': + return LadderRamp(**kwargs) + raise ValueError(f'Unexpected ramp kind: {kind}') + + +class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule): + + def __init__(self, + model_dir='', + predict_only=False, + concat_mask=True, + rescale_scheduler_kwargs=None, + image_to_discriminator='predicted_image', + add_noise_kwargs=None, + noise_fill_hole=False, + const_area_crop_kwargs=None, + distance_weighter_kwargs=None, + distance_weighted_mask_for_discr=False, + fake_fakes_proba=0, + fake_fakes_generator_kwargs=None, + **kwargs): + super().__init__(model_dir=model_dir, predict_only=predict_only) + self.concat_mask = concat_mask + self.rescale_size_getter = get_ramp( + **rescale_scheduler_kwargs + ) if rescale_scheduler_kwargs is not None else None + self.image_to_discriminator = image_to_discriminator + self.add_noise_kwargs = add_noise_kwargs + self.noise_fill_hole = noise_fill_hole + self.const_area_crop_kwargs = const_area_crop_kwargs + self.refine_mask_for_losses = None + self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr + + self.feature_matching_weight = 100 + self.losses_l1_weight_known = 10 + self.losses_l1_weight_missing = 0 + self.fake_fakes_proba = fake_fakes_proba + + def forward(self, batch): + img = batch['image'] + mask = batch['mask'] + + masked_img = img * (1 - mask) + + if self.concat_mask: + masked_img = torch.cat([masked_img, mask], dim=1) + + batch['predicted_image'] = self.generator(masked_img) + batch['inpainted'] = mask * batch['predicted_image'] + ( + 1 - mask) * batch['image'] + + batch['mask_for_losses'] = mask + + return batch + + def generator_loss(self, batch): + img = batch['image'] + predicted_img = batch[self.image_to_discriminator] + original_mask = batch['mask'] + supervised_mask = batch['mask_for_losses'] + + # L1 + l1_value = masked_l1_loss(predicted_img, img, supervised_mask, + self.losses_l1_weight_known, + self.losses_l1_weight_missing) + + total_loss = l1_value + metrics = dict(gen_l1=l1_value) + + # discriminator + # adversarial_loss calls backward by itself + mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask + self.adversarial_loss.pre_generator_step( + real_batch=img, + fake_batch=predicted_img, + generator=self.generator, + discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator(img) + discr_fake_pred, discr_fake_features = self.discriminator( + predicted_img) + adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss( + real_batch=img, + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=mask_for_discr) + total_loss = total_loss + adv_gen_loss + metrics['gen_adv'] = adv_gen_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + # feature matching + if self.feature_matching_weight > 0: + need_mask_in_fm = False + mask_for_fm = supervised_mask if need_mask_in_fm else None + fm_value = feature_matching_loss( + discr_fake_features, discr_real_features, + mask=mask_for_fm) * self.feature_matching_weight + total_loss = total_loss + fm_value + metrics['gen_fm'] = fm_value + + if self.loss_resnet_pl is not None: + resnet_pl_value = self.loss_resnet_pl(predicted_img, img) + total_loss = total_loss + resnet_pl_value + metrics['gen_resnet_pl'] = resnet_pl_value + + return total_loss, metrics + + def discriminator_loss(self, batch): + total_loss = 0 + metrics = {} + + predicted_img = batch[self.image_to_discriminator].detach() + self.adversarial_loss.pre_discriminator_step( + real_batch=batch['image'], + fake_batch=predicted_img, + generator=self.generator, + discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator( + batch['image']) + discr_fake_pred, discr_fake_features = self.discriminator( + predicted_img) + adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss( + real_batch=batch['image'], + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=batch['mask']) + + total_loss = (total_loss + adv_discr_loss) * 0.1 + metrics['discr_adv'] = adv_discr_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + return total_loss, metrics + + def _do_step(self, batch, optimizer_idx=None): + if optimizer_idx == 0: # step for generator + set_requires_grad(self.generator, True) + set_requires_grad(self.discriminator, False) + elif optimizer_idx == 1: # step for discriminator + set_requires_grad(self.generator, False) + set_requires_grad(self.discriminator, True) + + batch = self(batch) + total_loss = 0 + if optimizer_idx is None or optimizer_idx == 0: # step for generator + total_loss, metrics = self.generator_loss(batch) + + elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator + total_loss, metrics = self.discriminator_loss(batch) + + result = dict(loss=total_loss) + return result diff --git a/modelscope/models/cv/image_inpainting/model.py b/modelscope/models/cv/image_inpainting/model.py new file mode 100644 index 00000000..b12f6edd --- /dev/null +++ b/modelscope/models/cv/image_inpainting/model.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +LOGGER = get_logger() + + +@MODELS.register_module( + Tasks.image_inpainting, module_name=Models.image_inpainting) +class FFTInpainting(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + super().__init__(model_dir, **kwargs) + + from .default import DefaultInpaintingTrainingModule + pretrained = kwargs.get('pretrained', True) + predict_only = kwargs.get('predict_only', False) + net = DefaultInpaintingTrainingModule( + model_dir=model_dir, predict_only=predict_only) + if pretrained: + path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + LOGGER.info(f'loading pretrained model from {path}') + state = torch.load(path, map_location='cpu') + net.load_state_dict(state, strict=False) + self.model = net + + def forward(self, inputs): + return self.model(inputs) diff --git a/modelscope/models/cv/image_inpainting/modules/__init__.py b/modelscope/models/cv/image_inpainting/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py b/modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py new file mode 100644 index 00000000..89c3e293 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ade20k/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .base import ModelBuilder diff --git a/modelscope/models/cv/image_inpainting/modules/ade20k/base.py b/modelscope/models/cv/image_inpainting/modules/ade20k/base.py new file mode 100644 index 00000000..02bd3cc4 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ade20k/base.py @@ -0,0 +1,380 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules import BatchNorm2d + +from . import resnet + +NUM_CLASS = 150 + + +# Model Builder +class ModelBuilder: + # custom weights initialization + @staticmethod + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + + @staticmethod + def build_encoder(arch='resnet50dilated', + fc_dim=512, + weights='', + model_dir=''): + pretrained = True if len(weights) == 0 else False + arch = arch.lower() + if arch == 'resnet50dilated': + orig_resnet = resnet.__dict__['resnet50']( + pretrained=pretrained, model_dir=model_dir) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet50': + orig_resnet = resnet.__dict__['resnet50']( + pretrained=pretrained, model_dir=model_dir) + net_encoder = Resnet(orig_resnet) + else: + raise Exception('Architecture undefined!') + + # encoders are usually pretrained + # net_encoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + print('Loading weights for net_encoder') + net_encoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), + strict=False) + return net_encoder + + @staticmethod + def build_decoder(arch='ppm_deepsup', + fc_dim=512, + num_class=NUM_CLASS, + weights='', + use_softmax=False, + drop_last_conv=False): + arch = arch.lower() + if arch == 'ppm_deepsup': + net_decoder = PPMDeepsup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + drop_last_conv=drop_last_conv) + elif arch == 'c1_deepsup': + net_decoder = C1DeepSup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + drop_last_conv=drop_last_conv) + else: + raise Exception('Architecture undefined!') + + net_decoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + print('Loading weights for net_decoder') + net_decoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), + strict=False) + return net_decoder + + @staticmethod + def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, + drop_last_conv, *arts, **kwargs): + path = os.path.join( + weights_path, 'ade20k', + f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') + return ModelBuilder.build_decoder( + arch=arch_decoder, + fc_dim=fc_dim, + weights=path, + use_softmax=True, + drop_last_conv=drop_last_conv) + + @staticmethod + def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, + segmentation, *arts, **kwargs): + if segmentation: + path = os.path.join( + weights_path, 'ade20k', + f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') + else: + path = '' + return ModelBuilder.build_encoder( + arch=arch_encoder, + fc_dim=fc_dim, + weights=path, + model_dir=weights_path) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False), + BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +# pyramid pooling, deep supervision +class PPMDeepsup(nn.Module): + + def __init__(self, + num_class=NUM_CLASS, + fc_dim=4096, + use_softmax=False, + pool_scales=(1, 2, 3, 6), + drop_last_conv=False): + super().__init__() + self.use_softmax = use_softmax + self.drop_last_conv = drop_last_conv + + self.ppm = [] + for scale in pool_scales: + self.ppm.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), nn.ReLU(inplace=True))) + self.ppm = nn.ModuleList(self.ppm) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + self.conv_last = nn.Sequential( + nn.Conv2d( + fc_dim + len(pool_scales) * 512, + 512, + kernel_size=3, + padding=1, + bias=False), BatchNorm2d(512), nn.ReLU(inplace=True), + nn.Dropout2d(0.1), nn.Conv2d(512, num_class, kernel_size=1)) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.dropout_deepsup = nn.Dropout2d(0.1) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append( + nn.functional.interpolate( + pool_scale(conv5), (input_size[2], input_size[3]), + mode='bilinear', + align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + if self.drop_last_conv: + return ppm_out + else: + x = self.conv_last(ppm_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.dropout_deepsup(_) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +class Resnet(nn.Module): + + def __init__(self, orig_resnet): + super(Resnet, self).__init__() + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + conv_out.append(x) + x = self.layer2(x) + conv_out.append(x) + x = self.layer3(x) + conv_out.append(x) + x = self.layer4(x) + conv_out.append(x) + + if return_feature_maps: + return conv_out + return [x] + + +# Resnet Dilated +class ResnetDilated(nn.Module): + + def __init__(self, orig_resnet, dilate_scale=8): + super().__init__() + from functools import partial + + if dilate_scale == 8: + orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) + elif dilate_scale == 16: + orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate // 2, dilate // 2) + m.padding = (dilate // 2, dilate // 2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + conv_out.append(x) + x = self.layer2(x) + conv_out.append(x) + x = self.layer3(x) + conv_out.append(x) + x = self.layer4(x) + conv_out.append(x) + + if return_feature_maps: + return conv_out + return [x] + + +# last conv, deep supervision +class C1DeepSup(nn.Module): + + def __init__(self, + num_class=150, + fc_dim=2048, + use_softmax=False, + drop_last_conv=False): + super(C1DeepSup, self).__init__() + self.use_softmax = use_softmax + self.drop_last_conv = drop_last_conv + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + x = self.cbr(conv5) + + if self.drop_last_conv: + return x + else: + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# last conv +class C1(nn.Module): + + def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + super(C1, self).__init__() + self.use_softmax = use_softmax + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + + return x diff --git a/modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py b/modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py new file mode 100644 index 00000000..7da9ff07 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ade20k/resnet.py @@ -0,0 +1,183 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import math +import os + +import torch +import torch.nn as nn +from torch.nn import BatchNorm2d + +__all__ = ['ResNet', 'resnet50'] + + +def conv3x3(in_planes, out_planes, stride=1): + '3x3 convolution with padding' + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 128 + super(ResNet, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet50(pretrained=False, model_dir='', **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + cached_file = os.path.join(model_dir, 'resnet50-imagenet.pth') + model.load_state_dict( + torch.load(cached_file, map_location='cpu'), strict=False) + return model diff --git a/modelscope/models/cv/image_inpainting/modules/adversarial.py b/modelscope/models/cv/image_inpainting/modules/adversarial.py new file mode 100644 index 00000000..b183876b --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/adversarial.py @@ -0,0 +1,167 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BaseAdversarialLoss: + + def pre_generator_step(self, real_batch: torch.Tensor, + fake_batch: torch.Tensor, generator: nn.Module, + discriminator: nn.Module): + """ + Prepare for generator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def pre_discriminator_step(self, real_batch: torch.Tensor, + fake_batch: torch.Tensor, generator: nn.Module, + discriminator: nn.Module): + """ + Prepare for discriminator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate generator loss + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total generator loss along with some values that might be interesting to log + """ + raise NotImplementedError + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate discriminator loss and call .backward() on it + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total discriminator loss along with some values that might be interesting to log + """ + raise NotImplementedError + + def interpolate_mask(self, mask, shape): + assert mask is not None + assert self.allow_scale_mask or shape == mask.shape[-2:] + if shape != mask.shape[-2:] and self.allow_scale_mask: + if self.mask_scale_mode == 'maxpool': + mask = F.adaptive_max_pool2d(mask, shape) + else: + mask = F.interpolate( + mask, size=shape, mode=self.mask_scale_mode) + return mask + + +def make_r1_gp(discr_real_pred, real_batch): + if torch.is_grad_enabled(): + grad_real = torch.autograd.grad( + outputs=discr_real_pred.sum(), + inputs=real_batch, + create_graph=True)[0] + grad_penalty = (grad_real.view(grad_real.shape[0], + -1).norm(2, dim=1)**2).mean() + else: + grad_penalty = 0 + real_batch.requires_grad = False + + return grad_penalty + + +class NonSaturatingWithR1(BaseAdversarialLoss): + + def __init__(self, + gp_coef=5, + weight=1, + mask_as_fake_target=False, + allow_scale_mask=False, + mask_scale_mode='nearest', + extra_mask_weight_for_gen=0, + use_unmasked_for_gen=True, + use_unmasked_for_discr=True): + self.gp_coef = gp_coef + self.weight = weight + # use for discr => use for gen; + # otherwise we teach only the discr to pay attention to very small difference + assert use_unmasked_for_gen or (not use_unmasked_for_discr) + # mask as target => use unmasked for discr: + # if we don't care about unmasked regions at all + # then it doesn't matter if the value of mask_as_fake_target is true or false + assert use_unmasked_for_discr or (not mask_as_fake_target) + self.use_unmasked_for_gen = use_unmasked_for_gen + self.use_unmasked_for_discr = use_unmasked_for_discr + self.mask_as_fake_target = mask_as_fake_target + self.allow_scale_mask = allow_scale_mask + self.mask_scale_mode = mask_scale_mode + self.extra_mask_weight_for_gen = extra_mask_weight_for_gen + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + fake_loss = F.softplus(-discr_fake_pred) + if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ + not self.use_unmasked_for_gen: # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + if not self.use_unmasked_for_gen: + fake_loss = fake_loss * mask + else: + pixel_weights = 1 + mask * self.extra_mask_weight_for_gen + fake_loss = fake_loss * pixel_weights + + return fake_loss.mean() * self.weight, dict() + + def pre_discriminator_step(self, real_batch: torch.Tensor, + fake_batch: torch.Tensor, generator: nn.Module, + discriminator: nn.Module): + real_batch.requires_grad = True + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + + real_loss = F.softplus(-discr_real_pred) + grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef + fake_loss = F.softplus(discr_fake_pred) + + if not self.use_unmasked_for_discr or self.mask_as_fake_target: + # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + # use_unmasked_for_discr=False only makes sense for fakes; + # for reals there is no difference beetween two regions + fake_loss = fake_loss * mask + if self.mask_as_fake_target: + fake_loss = fake_loss + (1 + - mask) * F.softplus(-discr_fake_pred) + + sum_discr_loss = real_loss + grad_penalty + fake_loss + metrics = dict( + discr_real_out=discr_real_pred.mean(), + discr_fake_out=discr_fake_pred.mean(), + discr_real_gp=grad_penalty) + return sum_discr_loss.mean(), metrics diff --git a/modelscope/models/cv/image_inpainting/modules/feature_matching.py b/modelscope/models/cv/image_inpainting/modules/feature_matching.py new file mode 100644 index 00000000..c2effb20 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/feature_matching.py @@ -0,0 +1,45 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +from typing import List + +import torch +import torch.nn.functional as F + + +def masked_l2_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l2 = F.mse_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l2).mean() + + +def masked_l1_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l1 = F.l1_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l1).mean() + + +def feature_matching_loss(fake_features: List[torch.Tensor], + target_features: List[torch.Tensor], + mask=None): + if mask is None: + res = torch.stack([ + F.mse_loss(fake_feat, target_feat) + for fake_feat, target_feat in zip(fake_features, target_features) + ]).mean() + else: + res = 0 + norm = 0 + for fake_feat, target_feat in zip(fake_features, target_features): + cur_mask = F.interpolate( + mask, + size=fake_feat.shape[-2:], + mode='bilinear', + align_corners=False) + error_weights = 1 - cur_mask + cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() + res = res + cur_val + norm += 1 + res = res / norm + return res diff --git a/modelscope/models/cv/image_inpainting/modules/ffc.py b/modelscope/models/cv/image_inpainting/modules/ffc.py new file mode 100644 index 00000000..c74425e3 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/ffc.py @@ -0,0 +1,588 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from kornia.geometry.transform import rotate + + +def get_activation(kind='tanh'): + if kind == 'tanh': + return nn.Tanh() + if kind == 'sigmoid': + return nn.Sigmoid() + if kind is False: + return nn.Identity() + raise ValueError(f'Unknown activation kind {kind}') + + +class SELayer(nn.Module): + + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + res = x * y.expand_as(x) + return res + + +class FourierUnit(nn.Module): + + def __init__(self, + in_channels, + out_channels, + groups=1, + spatial_scale_factor=None, + spatial_scale_mode='bilinear', + spectral_pos_encoding=False, + use_se=False, + se_kwargs=None, + ffc3d=False, + fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d( + in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, + stride=1, + padding=0, + groups=self.groups, + bias=False) + self.bn = torch.nn.BatchNorm2d(out_channels * 2) + self.relu = torch.nn.ReLU(inplace=True) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate( + x, + scale_factor=self.spatial_scale_factor, + mode=self.spatial_scale_mode, + align_corners=False) + + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, + 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view(( + batch, + -1, + ) + ffted.size()[3:]) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = torch.linspace(0, 1, + height)[None, None, :, None].expand( + batch, 1, height, width).to(ffted) + coords_hor = torch.linspace(0, 1, + width)[None, None, None, :].expand( + batch, 1, height, width).to(ffted) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(self.bn(ffted)) + + ffted = ffted.view(( + batch, + -1, + 2, + ) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn( + ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if self.spatial_scale_factor is not None: + output = F.interpolate( + output, + size=orig_size, + mode=self.spatial_scale_mode, + align_corners=False) + + return output + + +class SpectralTransform(nn.Module): + + def __init__(self, + in_channels, + out_channels, + stride=1, + groups=1, + enable_lfu=True, + **fu_kwargs): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels // 2, + kernel_size=1, + groups=groups, + bias=False), nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True)) + self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, + **fu_kwargs) + if self.enable_lfu: + self.lfu = FourierUnit(out_channels // 2, out_channels // 2, + groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, + out_channels, + kernel_size=1, + groups=groups, + bias=False) + + def forward(self, x): + + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + xs = torch.cat( + torch.split(x[:, :c // 4], split_s, dim=-2), + dim=1).contiguous() + xs = torch.cat( + torch.split(xs, split_s, dim=-1), dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + + return output + + +class LearnableSpatialTransformWrapper(nn.Module): + + def __init__(self, + impl, + pad_coef=0.5, + angle_init_range=80, + train_angle=True): + super().__init__() + self.impl = impl + self.angle = torch.rand(1) * angle_init_range + if train_angle: + self.angle = nn.Parameter(self.angle, requires_grad=True) + self.pad_coef = pad_coef + + def forward(self, x): + if torch.is_tensor(x): + return self.inverse_transform(self.impl(self.transform(x)), x) + elif isinstance(x, tuple): + x_trans = tuple(self.transform(elem) for elem in x) + y_trans = self.impl(x_trans) + return tuple( + self.inverse_transform(elem, orig_x) + for elem, orig_x in zip(y_trans, x)) + else: + raise ValueError(f'Unexpected input type {type(x)}') + + def transform(self, x): + height, width = x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') + x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) + return x_padded_rotated + + def inverse_transform(self, y_padded_rotated, orig_x): + height, width = orig_x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + + y_padded = rotate( + y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) + y_height, y_width = y_padded.shape[2:] + y = y_padded[:, :, pad_h:y_height - pad_h, pad_w:y_width - pad_w] + return y + + +class FFC(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + enable_lfu=True, + padding_type='reflect', + gated=False, + **spectral_kwargs): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, 'Stride should be 1 or 2.' + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module( + in_cl, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module( + in_cl, + out_cg, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module( + in_cg, + out_cl, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module(in_cg, out_cg, stride, + 1 if groups == 1 else groups // 2, enable_lfu, + **spectral_kwargs) + + self.gated = gated + module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + self.gate = module(in_channels, 2, 1) + + def forward(self, x): + x_l, x_g = x if type(x) is tuple else (x, 0) + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) + + return out_xl, out_xg + + +class FFC_BN_ACT(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + norm_layer=nn.BatchNorm2d, + activation_layer=nn.Identity, + padding_type='reflect', + enable_lfu=True, + **kwargs): + super(FFC_BN_ACT, self).__init__() + self.ffc = FFC( + in_channels, + out_channels, + kernel_size, + ratio_gin, + ratio_gout, + stride, + padding, + dilation, + groups, + bias, + enable_lfu, + padding_type=padding_type, + **kwargs) + lnorm = nn.Identity if ratio_gout == 1 else norm_layer + gnorm = nn.Identity if ratio_gout == 0 else norm_layer + global_channels = int(out_channels * ratio_gout) + self.bn_l = lnorm(out_channels - global_channels) + self.bn_g = gnorm(global_channels) + + lact = nn.Identity if ratio_gout == 1 else activation_layer + gact = nn.Identity if ratio_gout == 0 else activation_layer + self.act_l = lact(inplace=True) + self.act_g = gact(inplace=True) + + def forward(self, x): + x_l, x_g = self.ffc(x) + x_l = self.act_l(self.bn_l(x_l)) + x_g = self.act_g(self.bn_g(x_g)) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + + def __init__(self, + dim, + padding_type, + norm_layer, + activation_layer=nn.ReLU, + dilation=1, + spatial_transform_kwargs=None, + inline=False, + **conv_kwargs): + super().__init__() + self.conv1 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + self.conv2 = FFC_BN_ACT( + dim, + dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + if spatial_transform_kwargs is not None: + self.conv1 = LearnableSpatialTransformWrapper( + self.conv1, **spatial_transform_kwargs) + self.conv2 = LearnableSpatialTransformWrapper( + self.conv2, **spatial_transform_kwargs) + self.inline = inline + + def forward(self, x): + if self.inline: + x_l, x_g = x[:, :-self.conv1.ffc. + global_in_num], x[:, -self.conv1.ffc.global_in_num:] + else: + x_l, x_g = x if type(x) is tuple else (x, 0) + + id_l, id_g = x_l, x_g + + x_l, x_g = self.conv1((x_l, x_g)) + x_l, x_g = self.conv2((x_l, x_g)) + + x_l, x_g = id_l + x_l, id_g + x_g + out = x_l, x_g + if self.inline: + out = torch.cat(out, dim=1) + return out + + +class ConcatTupleLayer(nn.Module): + + def forward(self, x): + assert isinstance(x, tuple) + x_l, x_g = x + assert torch.is_tensor(x_l) or torch.is_tensor(x_g) + if not torch.is_tensor(x_g): + return x_l + return torch.cat(x, dim=1) + + +class FFCResNetGenerator(nn.Module): + + def __init__(self, + input_nc=4, + output_nc=3, + ngf=64, + n_downsampling=3, + n_blocks=18, + norm_layer=nn.BatchNorm2d, + padding_type='reflect', + activation_layer=nn.ReLU, + up_norm_layer=nn.BatchNorm2d, + up_activation=nn.ReLU(True), + init_conv_kwargs={ + 'ratio_gin': 0, + 'ratio_gout': 0, + 'enable_lfu': False + }, + downsample_conv_kwargs={ + 'ratio_gin': 0, + 'ratio_gout': 0, + 'enable_lfu': False + }, + resnet_conv_kwargs={ + 'ratio_gin': 0.75, + 'ratio_gout': 0.75, + 'enable_lfu': False + }, + spatial_transform_layers=None, + spatial_transform_kwargs={}, + add_out_act='sigmoid', + max_features=1024, + out_ffc=False, + out_ffc_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + model = [ + nn.ReflectionPad2d(3), + FFC_BN_ACT( + input_nc, + ngf, + kernel_size=7, + padding=0, + norm_layer=norm_layer, + activation_layer=activation_layer, + **init_conv_kwargs) + ] + + # downsample + for i in range(n_downsampling): + mult = 2**i + if i == n_downsampling - 1: + cur_conv_kwargs = dict(downsample_conv_kwargs) + cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get( + 'ratio_gin', 0) + else: + cur_conv_kwargs = downsample_conv_kwargs + model += [ + FFC_BN_ACT( + min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, + stride=2, + padding=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + **cur_conv_kwargs) + ] + + mult = 2**n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + # resnet blocks + for i in range(n_blocks): + cur_resblock = FFCResnetBlock( + feats_num_bottleneck, + padding_type=padding_type, + activation_layer=activation_layer, + norm_layer=norm_layer, + **resnet_conv_kwargs) + if spatial_transform_layers is not None and i in spatial_transform_layers: + cur_resblock = LearnableSpatialTransformWrapper( + cur_resblock, **spatial_transform_kwargs) + model += [cur_resblock] + + model += [ConcatTupleLayer()] + + # upsample + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [ + nn.ConvTranspose2d( + min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + up_norm_layer(min(max_features, int(ngf * mult / 2))), + up_activation + ] + + if out_ffc: + model += [ + FFCResnetBlock( + ngf, + padding_type=padding_type, + activation_layer=activation_layer, + norm_layer=norm_layer, + inline=True, + **out_ffc_kwargs) + ] + + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) + ] + if add_out_act: + model.append( + get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) diff --git a/modelscope/models/cv/image_inpainting/modules/inception.py b/modelscope/models/cv/image_inpainting/modules/inception.py new file mode 100644 index 00000000..5070533d --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/inception.py @@ -0,0 +1,324 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +from modelscope.utils.logger import get_logger + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/' \ + 'fid_weights/pt_inception-2015-12-05-6726825d.pth' + +LOGGER = get_logger() + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate( + x, size=(299, 299), mode='bilinear', align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + LOGGER.info('fid_inception_v3 called') + inception = models.inception_v3( + num_classes=1008, aux_logits=False, pretrained=False) + LOGGER.info('models.inception_v3 done') + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + LOGGER.info('fid_inception_v3 patching done') + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + LOGGER.info('fid_inception_v3 weights downloaded') + + inception.load_state_dict(state_dict) + LOGGER.info('fid_inception_v3 weights loaded into model') + + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/modelscope/models/cv/image_inpainting/modules/perceptual.py b/modelscope/models/cv/image_inpainting/modules/perceptual.py new file mode 100644 index 00000000..80fe2b96 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/perceptual.py @@ -0,0 +1,47 @@ +""" +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from .ade20k import ModelBuilder + +IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] +IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] + + +class ResNetPL(nn.Module): + + def __init__(self, + weight=1, + weights_path=None, + arch_encoder='resnet50dilated', + segmentation=True): + super().__init__() + self.impl = ModelBuilder.get_encoder( + weights_path=weights_path, + arch_encoder=arch_encoder, + arch_decoder='ppm_deepsup', + fc_dim=2048, + segmentation=segmentation) + self.impl.eval() + for w in self.impl.parameters(): + w.requires_grad_(False) + + self.weight = weight + + def forward(self, pred, target): + pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) + target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) + + pred_feats = self.impl(pred, return_feature_maps=True) + target_feats = self.impl(target, return_feature_maps=True) + + result = torch.stack([ + F.mse_loss(cur_pred, cur_target) + for cur_pred, cur_target in zip(pred_feats, target_feats) + ]).sum() * self.weight + return result diff --git a/modelscope/models/cv/image_inpainting/modules/pix2pixhd.py b/modelscope/models/cv/image_inpainting/modules/pix2pixhd.py new file mode 100644 index 00000000..32e18f3e --- /dev/null +++ b/modelscope/models/cv/image_inpainting/modules/pix2pixhd.py @@ -0,0 +1,75 @@ +""" +The implementation is adopted from +https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py +""" +import collections +import functools +import logging +from collections import defaultdict +from functools import partial + +import numpy as np +import torch.nn as nn + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + + def __init__( + self, + input_nc=3, + ndf=64, + n_layers=4, + norm_layer=nn.BatchNorm2d, + ): + super().__init__() + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw - 1.0) / 2)) + sequence = [[ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + sequence += [[ + nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw) + ]] + + for n in range(len(sequence)): + setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + return act[-1], act[:-1] diff --git a/modelscope/models/cv/image_inpainting/refinement.py b/modelscope/models/cv/image_inpainting/refinement.py new file mode 100644 index 00000000..662d8a05 --- /dev/null +++ b/modelscope/models/cv/image_inpainting/refinement.py @@ -0,0 +1,393 @@ +''' +Part of the implementation is borrowed and modified from LaMa, publicly available at +https://github.com/saic-mdal/lama +''' +import cv2 +import numpy as np +import torch +import torch.nn as nn +from kornia.filters import gaussian_blur2d +from kornia.geometry.transform import resize +from kornia.morphology import erosion +from torch.nn import functional as F +from torch.optim import SGD, Adam +from tqdm import tqdm + +from .modules.ffc import FFCResnetBlock + + +def move_to_device(obj, device): + if isinstance(obj, nn.Module): + return obj.to(device) + if torch.is_tensor(obj): + return obj.to(device) + if isinstance(obj, (tuple, list)): + return [move_to_device(el, device) for el in obj] + if isinstance(obj, dict): + return {name: move_to_device(val, device) for name, val in obj.items()} + raise ValueError(f'Unexpected type {type(obj)}') + + +def ceil_modulo(x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + +def pad_tensor_to_modulo(img, mod): + batch_size, channels, height, width = img.shape + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + return F.pad( + img, + pad=(0, out_width - width, 0, out_height - height), + mode='reflect') + + +def _pyrdown(im: torch.Tensor, downsize: tuple = None): + """downscale the image""" + if downsize is None: + downsize = (im.shape[2] // 2, im.shape[3] // 2) + assert im.shape[ + 1] == 3, 'Expected shape for the input to be (n,3,height,width)' + im = gaussian_blur2d(im, kernel_size=(5, 5), sigma=(1.0, 1.0)) + im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False) + return im + + +def _pyrdown_mask(mask: torch.Tensor, + downsize: tuple = None, + eps: float = 1e-8, + blur_mask: bool = True, + round_up: bool = True): + """downscale the mask tensor + + Parameters + ---------- + mask : torch.Tensor + mask of size (B, 1, H, W) + downsize : tuple, optional + size to downscale to. If None, image is downscaled to half, by default None + eps : float, optional + threshold value for binarizing the mask, by default 1e-8 + blur_mask : bool, optional + if True, apply gaussian filter before downscaling, by default True + round_up : bool, optional + if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True + + Returns + ------- + torch.Tensor + downscaled mask + """ + + if downsize is None: + downsize = (mask.shape[2] // 2, mask.shape[3] // 2) + assert mask.shape[ + 1] == 1, 'Expected shape for the input to be (n,1,height,width)' + if blur_mask is True: + mask = gaussian_blur2d(mask, kernel_size=(5, 5), sigma=(1.0, 1.0)) + mask = F.interpolate( + mask, size=downsize, mode='bilinear', align_corners=False) + else: + mask = F.interpolate( + mask, size=downsize, mode='bilinear', align_corners=False) + if round_up: + mask[mask >= eps] = 1 + mask[mask < eps] = 0 + else: + mask[mask >= 1.0 - eps] = 1 + mask[mask < 1.0 - eps] = 0 + return mask + + +def _erode_mask(mask: torch.Tensor, + ekernel: torch.Tensor = None, + eps: float = 1e-8): + """erode the mask, and set gray pixels to 0""" + if ekernel is not None: + mask = erosion(mask, ekernel) + mask[mask >= 1.0 - eps] = 1 + mask[mask < 1.0 - eps] = 0 + return mask + + +def _l1_loss(pred: torch.Tensor, + pred_downscaled: torch.Tensor, + ref: torch.Tensor, + mask: torch.Tensor, + mask_downscaled: torch.Tensor, + image: torch.Tensor, + on_pred: bool = True): + """l1 loss on src pixels, and downscaled predictions if on_pred=True""" + loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8])) + if on_pred: + loss += torch.mean( + torch.abs(pred_downscaled[mask_downscaled >= 1e-8] + - ref[mask_downscaled >= 1e-8])) + return loss + + +def _infer(image: torch.Tensor, + mask: torch.Tensor, + forward_front: nn.Module, + forward_rears: nn.Module, + ref_lower_res: torch.Tensor, + orig_shape: tuple, + devices: list, + scale_ind: int, + n_iters: int = 15, + lr: float = 0.002): + """Performs inference with refinement at a given scale. + + Parameters + ---------- + image : torch.Tensor + input image to be inpainted, of size (1,3,H,W) + mask : torch.Tensor + input inpainting mask, of size (1,1,H,W) + forward_front : nn.Module + the front part of the inpainting network + forward_rears : nn.Module + the rear part of the inpainting network + ref_lower_res : torch.Tensor + the inpainting at previous scale, used as reference image + orig_shape : tuple + shape of the original input image before padding + devices : list + list of available devices + scale_ind : int + the scale index + n_iters : int, optional + number of iterations of refinement, by default 15 + lr : float, optional + learning rate, by default 0.002 + + Returns + ------- + torch.Tensor + inpainted image + """ + masked_image = image * (1 - mask) + masked_image = torch.cat([masked_image, mask], dim=1) + + mask = mask.repeat(1, 3, 1, 1) + if ref_lower_res is not None: + ref_lower_res = ref_lower_res.detach() + with torch.no_grad(): + z1, z2 = forward_front(masked_image) + # Inference + mask = mask.to(devices[-1]) + ekernel = torch.from_numpy( + cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (15, 15)).astype(bool)).float() + ekernel = ekernel.to(devices[-1]) + image = image.to(devices[-1]) + z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0]) + z1.requires_grad, z2.requires_grad = True, True + + optimizer = Adam([z1, z2], lr=lr) + + pbar = tqdm(range(n_iters), leave=False) + for idi in pbar: + optimizer.zero_grad() + input_feat = (z1, z2) + for idd, forward_rear in enumerate(forward_rears): + output_feat = forward_rear(input_feat) + if idd < len(devices) - 1: + midz1, midz2 = output_feat + midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to( + devices[idd + 1]) + input_feat = (midz1, midz2) + else: + pred = output_feat + + if ref_lower_res is None: + break + losses = {} + # scaled loss with downsampler + pred_downscaled = _pyrdown(pred[:, :, :orig_shape[0], :orig_shape[1]]) + mask_downscaled = _pyrdown_mask( + mask[:, :1, :orig_shape[0], :orig_shape[1]], + blur_mask=False, + round_up=False) + mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel) + mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1) + losses['ms_l1'] = _l1_loss( + pred, + pred_downscaled, + ref_lower_res, + mask, + mask_downscaled, + image, + on_pred=True) + + loss = sum(losses.values()) + pbar.set_description( + 'Refining scale {} using scale {} ...current loss: {:.4f}'.format( + scale_ind + 1, scale_ind, loss.item())) + if idi < n_iters - 1: + loss.backward() + optimizer.step() + del pred_downscaled + del loss + del pred + # "pred" is the prediction after Plug-n-Play module + inpainted = mask * pred + (1 - mask) * image + inpainted = inpainted.detach().cpu() + return inpainted + + +def _get_image_mask_pyramid(batch: dict, min_side: int, max_scales: int, + px_budget: int): + """Build the image mask pyramid + + Parameters + ---------- + batch : dict + batch containing image, mask, etc + min_side : int + minimum side length to limit the number of scales of the pyramid + max_scales : int + maximum number of scales allowed + px_budget : int + the product H*W cannot exceed this budget, because of resource constraints + + Returns + ------- + tuple + image-mask pyramid in the form of list of images and list of masks + """ + + assert batch['image'].shape[ + 0] == 1, 'refiner works on only batches of size 1!' + + h, w = batch['unpad_to_size'] + h, w = h[0].item(), w[0].item() + + image = batch['image'][..., :h, :w] + mask = batch['mask'][..., :h, :w] + if h * w > px_budget: + # resize + ratio = np.sqrt(px_budget / float(h * w)) + h_orig, w_orig = h, w + h, w = int(h * ratio), int(w * ratio) + print( + f'Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...' + ) + image = resize( + image, (h, w), interpolation='bilinear', align_corners=False) + mask = resize( + mask, (h, w), interpolation='bilinear', align_corners=False) + mask[mask > 1e-8] = 1 + breadth = min(h, w) + n_scales = min(1 + int(round(max(0, np.log2(breadth / min_side)))), + max_scales) + ls_images = [] + ls_masks = [] + + ls_images.append(image) + ls_masks.append(mask) + + for _ in range(n_scales - 1): + image_p = _pyrdown(ls_images[-1]) + mask_p = _pyrdown_mask(ls_masks[-1]) + ls_images.append(image_p) + ls_masks.append(mask_p) + # reverse the lists because we want the lowest resolution image as index 0 + return ls_images[::-1], ls_masks[::-1] + + +def refine_predict(batch: dict, inpainter: nn.Module, gpu_ids: str, + modulo: int, n_iters: int, lr: float, min_side: int, + max_scales: int, px_budget: int): + """Refines the inpainting of the network + + Parameters + ---------- + batch : dict + image-mask batch, currently we assume the batchsize to be 1 + inpainter : nn.Module + the inpainting neural network + gpu_ids : str + the GPU ids of the machine to use. If only single GPU, use: "0," + modulo : int + pad the image to ensure dimension % modulo == 0 + n_iters : int + number of iterations of refinement for each scale + lr : float + learning rate + min_side : int + all sides of image on all scales should be >= min_side / sqrt(2) + max_scales : int + max number of downscaling scales for the image-mask pyramid + px_budget : int + pixels budget. Any image will be resized to satisfy height*width <= px_budget + + Returns + ------- + torch.Tensor + inpainted image of size (1,3,H,W) + """ + inpainter = inpainter.model + assert not inpainter.training + assert not inpainter.add_noise_kwargs + assert inpainter.concat_mask + + gpu_ids = [ + f'cuda:{gpuid}' for gpuid in gpu_ids.replace(' ', '').split(',') + if gpuid.isdigit() + ] + n_resnet_blocks = 0 + first_resblock_ind = 0 + found_first_resblock = False + for idl in range(len(inpainter.generator.model)): + if isinstance(inpainter.generator.model[idl], FFCResnetBlock): + n_resnet_blocks += 1 + found_first_resblock = True + elif not found_first_resblock: + first_resblock_ind += 1 + resblocks_per_gpu = n_resnet_blocks // len(gpu_ids) + + devices = [torch.device(gpu_id) for gpu_id in gpu_ids] + + # split the model into front, and rear parts + forward_front = inpainter.generator.model[0:first_resblock_ind] + forward_front.to(devices[0]) + forward_rears = [] + for idd in range(len(gpu_ids)): + if idd < len(gpu_ids) - 1: + forward_rears.append( + inpainter.generator.model[first_resblock_ind + + resblocks_per_gpu + * (idd):first_resblock_ind + + resblocks_per_gpu * (idd + 1)]) + else: + forward_rears.append( + inpainter.generator.model[first_resblock_ind + + resblocks_per_gpu * (idd):]) + forward_rears[idd].to(devices[idd]) + + ls_images, ls_masks = _get_image_mask_pyramid(batch, min_side, max_scales, + px_budget) + image_inpainted = None + + for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)): + orig_shape = image.shape[2:] + image = pad_tensor_to_modulo(image, modulo) + mask = pad_tensor_to_modulo(mask, modulo) + mask[mask >= 1e-8] = 1.0 + mask[mask < 1e-8] = 0.0 + image, mask = move_to_device(image, devices[0]), move_to_device( + mask, devices[0]) + if image_inpainted is not None: + image_inpainted = move_to_device(image_inpainted, devices[-1]) + image_inpainted = _infer(image, mask, forward_front, forward_rears, + image_inpainted, orig_shape, devices, ids, + n_iters, lr) + image_inpainted = image_inpainted[:, :, :orig_shape[0], :orig_shape[1]] + # detach everything to save resources + image = image.detach().cpu() + mask = mask.detach().cpu() + + return image_inpainted diff --git a/modelscope/models/cv/image_instance_segmentation/__init__.py b/modelscope/models/cv/image_instance_segmentation/__init__.py new file mode 100644 index 00000000..8ccfef4b --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .cascade_mask_rcnn_swin import CascadeMaskRCNNSwin + from .model import CascadeMaskRCNNSwinModel + from .postprocess_utils import get_img_ins_seg_result +else: + _import_structure = { + 'cascade_mask_rcnn_swin': ['CascadeMaskRCNNSwin'], + 'model': ['CascadeMaskRCNNSwinModel'], + 'postprocess_utils': ['get_img_ins_seg_result'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_instance_segmentation/backbones/__init__.py b/modelscope/models/cv/image_instance_segmentation/backbones/__init__.py new file mode 100644 index 00000000..fec1b627 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/backbones/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .swin_transformer import SwinTransformer + +else: + _import_structure = { + 'swin_transformer': ['SwinTransformer'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_instance_segmentation/backbones/swin_transformer.py b/modelscope/models/cv/image_instance_segmentation/backbones/swin_transformer.py new file mode 100644 index 00000000..2007688d --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/backbones/swin_transformer.py @@ -0,0 +1,694 @@ +# The implementation is adopted from Swin Transformer, made publicly available under the MIT License at +# https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, 'input feature has wrong size' + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, + Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Inspiration from + https://github.com/SwinTransformer/Swin-Transformer-Object-Detection + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1] + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], + patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone.""" + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, + 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() diff --git a/modelscope/models/cv/image_instance_segmentation/cascade_mask_rcnn_swin.py b/modelscope/models/cv/image_instance_segmentation/cascade_mask_rcnn_swin.py new file mode 100644 index 00000000..ff83271e --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/cascade_mask_rcnn_swin.py @@ -0,0 +1,268 @@ +# Part of the implementation is borrowed and modified from MMDetection, publicly available at +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py +import os +from collections import OrderedDict + +import torch +import torch.distributed as dist +import torch.nn as nn + +from modelscope.models.cv.image_instance_segmentation.backbones import \ + SwinTransformer +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def build_backbone(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'SwinTransformer': + return SwinTransformer(**cfg) + else: + raise ValueError(f'backbone \'{type}\' is not supported.') + + +def build_neck(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'FPN': + from mmdet.models import FPN + return FPN(**cfg) + else: + raise ValueError(f'neck \'{type}\' is not supported.') + + +def build_rpn_head(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'RPNHead': + from mmdet.models import RPNHead + return RPNHead(**cfg) + else: + raise ValueError(f'rpn head \'{type}\' is not supported.') + + +def build_roi_head(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'CascadeRoIHead': + from mmdet.models import CascadeRoIHead + return CascadeRoIHead(**cfg) + else: + raise ValueError(f'roi head \'{type}\' is not supported.') + + +class CascadeMaskRCNNSwin(nn.Module): + + def __init__(self, + backbone, + neck, + rpn_head, + roi_head, + pretrained=None, + **kwargs): + """ + Args: + backbone (dict): backbone config. + neck (dict): neck config. + rpn_head (dict): rpn_head config. + roi_head (dict): roi_head config. + pretrained (bool): whether to use pretrained model + """ + super(CascadeMaskRCNNSwin, self).__init__() + + self.backbone = build_backbone(backbone) + self.neck = build_neck(neck) + self.rpn_head = build_rpn_head(rpn_head) + self.roi_head = build_roi_head(roi_head) + + self.classes = kwargs.pop('classes', None) + + if pretrained: + assert 'model_dir' in kwargs, 'pretrained model dir is missing.' + model_path = os.path.join(kwargs['model_dir'], + ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + weight = torch.load(model_path)['state_dict'] + tgt_weight = self.state_dict() + for name in list(weight.keys()): + if name in tgt_weight: + load_size = weight[name].size() + tgt_size = tgt_weight[name].size() + mis_match = False + if len(load_size) != len(tgt_size): + mis_match = True + else: + for n1, n2 in zip(load_size, tgt_size): + if n1 != n2: + mis_match = True + break + if mis_match: + logger.info(f'size mismatch for {name}, skip loading.') + del weight[name] + + self.load_state_dict(weight, strict=False) + logger.info('load model done') + + from mmcv.parallel import DataContainer, scatter + + self.data_container = DataContainer + self.scatter = scatter + + def extract_feat(self, img): + x = self.backbone(img) + x = self.neck(x) + return x + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None, + gt_masks=None, + proposals=None, + **kwargs): + """ + Args: + img (Tensor): of shape (N, C, H, W) encoding input images. + Typically these should be mean centered and std scaled. + + img_metas (list[dict]): list of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmdet/datasets/pipelines/formatting.py:Collect`. + + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + + gt_labels (list[Tensor]): class indices corresponding to each box + + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + + gt_masks (None | Tensor) : true segmentation masks for each box + used if the architecture supports a segmentation task. + + proposals : override rpn proposals with custom proposals. Use when + `with_rpn` is False. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + x = self.extract_feat(img) + + losses = dict() + + # RPN forward and loss + proposal_cfg = self.rpn_head.train_cfg.get('rpn_proposal', + self.rpn_head.test_cfg) + rpn_losses, proposal_list = self.rpn_head.forward_train( + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=gt_bboxes_ignore, + proposal_cfg=proposal_cfg, + **kwargs) + losses.update(rpn_losses) + + roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, + gt_bboxes, gt_labels, + gt_bboxes_ignore, gt_masks, + **kwargs) + losses.update(roi_losses) + + return losses + + def forward_test(self, img, img_metas, proposals=None, rescale=True): + + x = self.extract_feat(img) + if proposals is None: + proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) + else: + proposal_list = proposals + + result = self.roi_head.simple_test( + x, proposal_list, img_metas, rescale=rescale) + return dict(eval_result=result, img_metas=img_metas) + + def forward(self, img, img_metas, **kwargs): + + # currently only support cpu or single gpu + if isinstance(img, self.data_container): + img = img.data[0] + if isinstance(img_metas, self.data_container): + img_metas = img_metas.data[0] + for k, w in kwargs.items(): + if isinstance(w, self.data_container): + w = w.data[0] + kwargs[k] = w + + if next(self.parameters()).is_cuda: + device = next(self.parameters()).device + img = self.scatter(img, [device])[0] + img_metas = self.scatter(img_metas, [device])[0] + for k, w in kwargs.items(): + kwargs[k] = self.scatter(w, [device])[0] + + if self.training: + losses = self.forward_train(img, img_metas, **kwargs) + loss, log_vars = self._parse_losses(losses) + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(img_metas)) + return outputs + else: + return self.forward_test(img, img_metas, **kwargs) + + def _parse_losses(self, losses): + + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def train_step(self, data, optimizer): + + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs + + def val_step(self, data, optimizer=None): + + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs diff --git a/modelscope/models/cv/image_instance_segmentation/datasets/__init__.py b/modelscope/models/cv/image_instance_segmentation/datasets/__init__.py new file mode 100644 index 00000000..1b096fb3 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/datasets/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .transforms import build_preprocess_transform diff --git a/modelscope/models/cv/image_instance_segmentation/datasets/transforms.py b/modelscope/models/cv/image_instance_segmentation/datasets/transforms.py new file mode 100644 index 00000000..f0dde759 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/datasets/transforms.py @@ -0,0 +1,114 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +import numpy as np + +from modelscope.fileio import File + + +def build_preprocess_transform(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + if type == 'LoadImageFromFile': + return LoadImageFromFile(**cfg) + elif type == 'LoadAnnotations': + from mmdet.datasets.pipelines import LoadAnnotations + return LoadAnnotations(**cfg) + elif type == 'Resize': + if 'img_scale' in cfg: + if isinstance(cfg.img_scale[0], list): + elems = [] + for elem in cfg.img_scale: + elems.append(tuple(elem)) + cfg.img_scale = elems + else: + cfg.img_scale = tuple(cfg.img_scale) + from mmdet.datasets.pipelines import Resize + return Resize(**cfg) + elif type == 'RandomFlip': + from mmdet.datasets.pipelines import RandomFlip + return RandomFlip(**cfg) + elif type == 'Normalize': + from mmdet.datasets.pipelines import Normalize + return Normalize(**cfg) + elif type == 'Pad': + from mmdet.datasets.pipelines import Pad + return Pad(**cfg) + elif type == 'DefaultFormatBundle': + from mmdet.datasets.pipelines import DefaultFormatBundle + return DefaultFormatBundle(**cfg) + elif type == 'ImageToTensor': + from mmdet.datasets.pipelines import ImageToTensor + return ImageToTensor(**cfg) + elif type == 'Collect': + from mmdet.datasets.pipelines import Collect + return Collect(**cfg) + else: + raise ValueError(f'preprocess transform \'{type}\' is not supported.') + + +class LoadImageFromFile: + """Load an image from file. + + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename", "ann_file", and "classes"). Added or updated keys are + "filename", "ori_filename", "img", "img_shape", "ori_shape" (same as `img_shape`), + "img_fields", "ann_file" (path to annotation file) and "classes". + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, to_float32=False, mode='rgb'): + self.to_float32 = to_float32 + self.mode = mode + + from mmcv import imfrombytes + + self.imfrombytes = imfrombytes + + def __call__(self, results): + """Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`ImageInstanceSegmentationCocoDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + if 'img' in results and isinstance(results['img'], np.ndarray): + img = results['img'] + filename = results['img_info']['filename'] + else: + if results['img_prefix'] is not None: + filename = osp.join(results['img_prefix'], + results['img_info']['filename']) + else: + filename = results['img_info']['filename'] + + img_bytes = File.read(filename) + + img = self.imfrombytes(img_bytes, 'color', 'bgr', backend='pillow') + + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + results['ori_filename'] = results['img_info']['filename'] + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['img_fields'] = ['img'] + results['ann_file'] = results['img_info']['ann_file'] + results['classes'] = results['img_info']['classes'] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32}, ' + f"mode='{self.mode}'") + return repr_str diff --git a/modelscope/models/cv/image_instance_segmentation/model.py b/modelscope/models/cv/image_instance_segmentation/model.py new file mode 100644 index 00000000..a56a1608 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/model.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.image_instance_segmentation import \ + CascadeMaskRCNNSwin +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + + +@MODELS.register_module( + Tasks.image_segmentation, module_name=Models.cascade_mask_rcnn_swin) +class CascadeMaskRCNNSwinModel(TorchModel): + + def __init__(self, model_dir=None, *args, **kwargs): + """ + Args: + model_dir (str): model directory. + + """ + super(CascadeMaskRCNNSwinModel, self).__init__( + model_dir=model_dir, *args, **kwargs) + + if 'backbone' not in kwargs: + config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + model_cfg = cfg.model + kwargs.update(model_cfg) + + self.model = CascadeMaskRCNNSwin(model_dir=model_dir, **kwargs) + + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.model.to(self.device) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + output = self.model(**input) + return output + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + + return input + + def compute_loss(self, outputs: Dict[str, Any], labels): + pass diff --git a/modelscope/models/cv/image_instance_segmentation/postprocess_utils.py b/modelscope/models/cv/image_instance_segmentation/postprocess_utils.py new file mode 100644 index 00000000..6058cd73 --- /dev/null +++ b/modelscope/models/cv/image_instance_segmentation/postprocess_utils.py @@ -0,0 +1,203 @@ +# Part of the implementation is borrowed and modified from MMDetection, publicly available at +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/visualization/image.py +import itertools + +import cv2 +import numpy as np +import pycocotools.mask as maskUtils +import torch + +from modelscope.outputs import OutputKeys + + +def get_seg_bboxes(bboxes, labels, segms=None, class_names=None, score_thr=0.): + assert bboxes.ndim == 2, \ + f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.' + assert labels.ndim == 1, \ + f' labels ndim should be 1, but its ndim is {labels.ndim}.' + assert bboxes.shape[0] == labels.shape[0], \ + 'bboxes.shape[0] and labels.shape[0] should have the same length.' + assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \ + f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.' + + if score_thr > 0: + assert bboxes.shape[1] == 5 + scores = bboxes[:, -1] + inds = scores > score_thr + bboxes = bboxes[inds, :] + labels = labels[inds] + if segms is not None: + segms = segms[inds, ...] + + bboxes_names = [] + for i, (bbox, label) in enumerate(zip(bboxes, labels)): + label_name = class_names[ + label] if class_names is not None else f'class {label}' + bbox = [0 if b < 0 else b for b in list(bbox)] + bbox.append(label_name) + bbox.append(segms[i].astype(bool)) + bboxes_names.append(bbox) + + return bboxes_names + + +def get_img_seg_results(det_rawdata=None, + class_names=None, + score_thr=0.3, + is_decode=True): + ''' + Get all boxes of one image. + score_thr: Classification probability threshold。 + output format: [ [x1,y1,x2,y2, prob, cls_name, mask], [x1,y1,x2,y2, prob, cls_name, mask], ... ] + ''' + assert det_rawdata is not None, 'det_rawdata should be not None.' + assert class_names is not None, 'class_names should be not None.' + + if isinstance(det_rawdata, tuple): + bbox_result, segm_result = det_rawdata + if isinstance(segm_result, tuple): + segm_result = segm_result[0] # ms rcnn + else: + bbox_result, segm_result = det_rawdata, None + bboxes = np.vstack(bbox_result) + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) + + segms = None + if segm_result is not None and len(labels) > 0: # non empty + segms = list(itertools.chain(*segm_result)) + if is_decode: + segms = maskUtils.decode(segms) + segms = segms.transpose(2, 0, 1) + if isinstance(segms[0], torch.Tensor): + segms = torch.stack(segms, dim=0).detach().cpu().numpy() + else: + segms = np.stack(segms, axis=0) + + bboxes_names = get_seg_bboxes( + bboxes, + labels, + segms=segms, + class_names=class_names, + score_thr=score_thr) + + return bboxes_names + + +def get_img_ins_seg_result(img_seg_result=None, + class_names=None, + score_thr=0.3): + assert img_seg_result is not None, 'img_seg_result should be not None.' + assert class_names is not None, 'class_names should be not None.' + + img_seg_result = get_img_seg_results( + det_rawdata=(img_seg_result[0], img_seg_result[1]), + class_names=class_names, + score_thr=score_thr, + is_decode=False) + + results_dict = { + OutputKeys.BOXES: [], + OutputKeys.MASKS: [], + OutputKeys.LABELS: [], + OutputKeys.SCORES: [] + } + for seg_result in img_seg_result: + + box = [ + np.int(seg_result[0]), + np.int(seg_result[1]), + np.int(seg_result[2]), + np.int(seg_result[3]) + ] + score = np.float(seg_result[4]) + category = seg_result[5] + + mask = np.array(seg_result[6], order='F', dtype='uint8') + mask = mask.astype(np.float) + + results_dict[OutputKeys.BOXES].append(box) + results_dict[OutputKeys.MASKS].append(mask) + results_dict[OutputKeys.SCORES].append(score) + results_dict[OutputKeys.LABELS].append(category) + + return results_dict + + +def show_result( + img, + result, + out_file='result.jpg', + show_box=True, + show_label=True, + show_score=True, + alpha=0.5, + fontScale=0.5, + fontFace=cv2.FONT_HERSHEY_COMPLEX_SMALL, + thickness=1, +): + + assert isinstance(img, (str, np.ndarray)), \ + f'img must be str or np.ndarray, but got {type(img)}.' + + if isinstance(img, str): + img = cv2.imread(img) + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = img.astype(np.float32) + + labels = result[OutputKeys.LABELS] + scores = result[OutputKeys.SCORES] + boxes = result[OutputKeys.BOXES] + masks = result[OutputKeys.MASKS] + + for label, score, box, mask in zip(labels, scores, boxes, masks): + + random_color = np.array([ + np.random.random() * 255.0, + np.random.random() * 255.0, + np.random.random() * 255.0 + ]) + + x1 = int(box[0]) + y1 = int(box[1]) + x2 = int(box[2]) + y2 = int(box[3]) + + if show_box: + cv2.rectangle( + img, (x1, y1), (x2, y2), random_color, thickness=thickness) + if show_label or show_score: + if show_label and show_score: + text = '{}|{}'.format(label, round(float(score), 2)) + elif show_label: + text = '{}'.format(label) + else: + text = '{}'.format(round(float(score), 2)) + + retval, baseLine = cv2.getTextSize( + text, + fontFace=fontFace, + fontScale=fontScale, + thickness=thickness) + cv2.rectangle( + img, (x1, y1 - retval[1] - baseLine), (x1 + retval[0], y1), + thickness=-1, + color=(0, 0, 0)) + cv2.putText( + img, + text, (x1, y1 - baseLine), + fontScale=fontScale, + fontFace=fontFace, + thickness=thickness, + color=random_color) + + idx = np.nonzero(mask) + img[idx[0], idx[1], :] *= 1.0 - alpha + img[idx[0], idx[1], :] += alpha * random_color + + cv2.imwrite(out_file, img) diff --git a/modelscope/models/cv/image_panoptic_segmentation/__init__.py b/modelscope/models/cv/image_panoptic_segmentation/__init__.py new file mode 100644 index 00000000..2b2be4b7 --- /dev/null +++ b/modelscope/models/cv/image_panoptic_segmentation/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .panseg_model import SwinLPanopticSegmentation + +else: + _import_structure = { + 'panseg_model': ['SwinLPanopticSegmentation'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_panoptic_segmentation/panseg_model.py b/modelscope/models/cv/image_panoptic_segmentation/panseg_model.py new file mode 100644 index 00000000..f44c01e8 --- /dev/null +++ b/modelscope/models/cv/image_panoptic_segmentation/panseg_model.py @@ -0,0 +1,53 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks + + +@MODELS.register_module( + Tasks.image_segmentation, module_name=Models.panoptic_segmentation) +class SwinLPanopticSegmentation(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, **kwargs) + + from mmcv.runner import load_checkpoint + import mmcv + from mmdet.models import build_detector + + config = osp.join(model_dir, 'config.py') + + cfg = mmcv.Config.fromfile(config) + if 'pretrained' in cfg.model: + cfg.model.pretrained = None + elif 'init_cfg' in cfg.model.backbone: + cfg.model.backbone.init_cfg = None + + # build model + cfg.model.train_cfg = None + self.model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + + # load model + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + checkpoint = load_checkpoint( + self.model, model_path, map_location='cpu') + + self.CLASSES = checkpoint['meta']['CLASSES'] + self.num_classes = len(self.CLASSES) + self.cfg = cfg + + def inference(self, data): + """data is dict,contain img and img_metas,follow with mmdet.""" + + with torch.no_grad(): + results = self.model(return_loss=False, rescale=True, **data) + return results + + def forward(self, Inputs): + return self.model(**Inputs) diff --git a/modelscope/models/cv/image_portrait_enhancement/__init__.py b/modelscope/models/cv/image_portrait_enhancement/__init__.py new file mode 100644 index 00000000..4014bb15 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_portrait_enhancement import ImagePortraitEnhancement + +else: + _import_structure = { + 'image_portrait_enhancement': ['ImagePortraitEnhancement'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_portrait_enhancement/align_faces.py b/modelscope/models/cv/image_portrait_enhancement/align_faces.py new file mode 100755 index 00000000..e6852f8c --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/align_faces.py @@ -0,0 +1,254 @@ +# Part of the implementation is borrowed and modified from Face-Alignment, +# publicly available at https://github.com/foamliu/Face-Alignment/blob/master/align_faces.py +import cv2 +import numpy as np +from skimage import transform as trans + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], + [65.53179932, 51.50139999], + [48.02519989, + 71.73660278], [33.54930115, 92.3655014], + [62.72990036, 92.20410156]] + +DEFAULT_CROP_SIZE = (96, 112) + + +def _umeyama(src, dst, estimate_scale=True, scale=1.0): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = dst_demean.T @ src_demean / num + + # Eq. (39). + d = np.ones((dim, ), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = U @ V + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = U @ np.diag(d) @ V + d[dim - 1] = s + else: + T[:dim, :dim] = U @ np.diag(d) @ V + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) + else: + scale = scale + + T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) + T[:dim, :dim] *= scale + + return T, scale + + +class FaceWarpException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, + inner_padding_factor=0.0, + outer_padding=(0, 0), + default_square=False): + ref_5pts = np.array(REFERENCE_FACIAL_POINTS) + ref_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(ref_crop_size) - ref_crop_size + ref_5pts += size_diff / 2 + ref_crop_size += size_diff + + if (output_size and output_size[0] == ref_crop_size[0] + and output_size[1] == ref_crop_size[1]): + return ref_5pts + + if (inner_padding_factor == 0 and outer_padding == (0, 0)): + if output_size is None: + logger.info('No paddings to do: return default reference points') + return ref_5pts + else: + raise FaceWarpException( + 'No paddings to do, output_size must be None or {}'.format( + ref_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 + or outer_padding[1] > 0) and output_size is None): + output_size = ref_crop_size * (1 + inner_padding_factor * 2).astype( + np.int32) + output_size += np.array(outer_padding) + logger.info('deduced from paddings, output_size = ', output_size) + + if not (outer_padding[0] < output_size[0] + and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0]' + 'and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + if inner_padding_factor > 0: + size_diff = ref_crop_size * inner_padding_factor * 2 + ref_5pts += size_diff / 2 + ref_crop_size += np.round(size_diff).astype(np.int32) + + # 2) resize the padded inner region + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + + if size_bf_outer_pad[0] * ref_crop_size[1] != size_bf_outer_pad[ + 1] * ref_crop_size[0]: + raise FaceWarpException( + 'Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / ref_crop_size[0] + ref_5pts = ref_5pts * scale_factor + ref_crop_size = size_bf_outer_pad + + # 3) add outer_padding to make output_size + reference_5point = ref_5pts + np.array(outer_padding) + ref_crop_size = output_size + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], + [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def get_params(reference_pts, facial_pts, align_type): + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException( + 'reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException( + 'facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException( + 'facial_pts and reference_pts must have the same shape') + + if align_type == 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) + elif align_type == 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) + else: + params, scale = _umeyama(src_pts, ref_pts) + tfm = params[:2, :] + + params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale) + tfm_inv = params[:2, :] + + return tfm, tfm_inv + + +def warp_and_crop_face(src_img, + facial_pts, + reference_pts=None, + crop_size=(96, 112), + align_type='smilarity'): # smilarity cv2_affine affine + + reference_pts_112 = get_reference_facial_points((112, 112), 0.25, (0, 0), + True) + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = True # False + inner_padding_factor = 0.25 # 0 + outer_padding = (0, 0) + output_size = crop_size + reference_pts = get_reference_facial_points( + output_size, inner_padding_factor, outer_padding, + default_square) + + tfm, tfm_inv = get_params(reference_pts, facial_pts, align_type) + tfm_112, tfm_inv_112 = get_params(reference_pts_112, facial_pts, + align_type) + + if src_img is not None: + face_img = cv2.warpAffine( + src_img, tfm, (crop_size[0], crop_size[1]), flags=3) + face_img_112 = cv2.warpAffine(src_img, tfm_112, (112, 112), flags=3) + + return face_img, face_img_112, tfm_inv + else: + return tfm, tfm_inv diff --git a/modelscope/models/cv/image_portrait_enhancement/eqface/__init__.py b/modelscope/models/cv/image_portrait_enhancement/eqface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py b/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py new file mode 100755 index 00000000..51f2206e --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +import cv2 +import numpy as np +import torch + +from .model_resnet import FaceQuality, ResNet + + +class FQA(object): + + def __init__(self, backbone_path, quality_path, device='cuda', size=112): + self.BACKBONE = ResNet(num_layers=100, feature_dim=512) + self.QUALITY = FaceQuality(512 * 7 * 7) + self.size = size + self.device = device + + self.load_model(backbone_path, quality_path) + + def load_model(self, backbone_path, quality_path): + checkpoint = torch.load(backbone_path, map_location='cpu') + self.load_state_dict(self.BACKBONE, checkpoint) + + checkpoint = torch.load(quality_path, map_location='cpu') + self.load_state_dict(self.QUALITY, checkpoint) + + self.BACKBONE.to(self.device) + self.QUALITY.to(self.device) + self.BACKBONE.eval() + self.QUALITY.eval() + + def load_state_dict(self, model, state_dict): + all_keys = {k for k in state_dict.keys()} + for k in all_keys: + if k.startswith('module.'): + state_dict[k[7:]] = state_dict.pop(k) + model_dict = model.state_dict() + pretrained_dict = { + k: v + for k, v in state_dict.items() + if k in model_dict and v.size() == model_dict[k].size() + } + + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + + def get_face_quality(self, img): + img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).flip(1).to( + self.device) + img = (img - 127.5) / 128.0 + + # extract features & predict quality + with torch.no_grad(): + feature, fc = self.BACKBONE(img.to(self.device), True) + s = self.QUALITY(fc)[0] + + return s.cpu().numpy()[0], feature.cpu().numpy()[0] diff --git a/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py b/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py new file mode 100644 index 00000000..e0e8e9d5 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py @@ -0,0 +1,132 @@ +# The implementation is adopted from FaceQuality, made publicly available under the MIT License +# at https://github.com/deepcam-cn/FaceQuality/blob/master/models/model_resnet.py +import torch +from torch import nn + + +class BottleNeck_IR(nn.Module): + + def __init__(self, in_channel, out_channel, stride, dim_match): + super(BottleNeck_IR, self).__init__() + self.res_layer = nn.Sequential( + nn.BatchNorm2d(in_channel), + nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), + nn.BatchNorm2d(out_channel), nn.PReLU(out_channel), + nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), + nn.BatchNorm2d(out_channel)) + if dim_match: + self.shortcut_layer = None + else: + self.shortcut_layer = nn.Sequential( + nn.Conv2d( + in_channel, + out_channel, + kernel_size=(1, 1), + stride=stride, + bias=False), nn.BatchNorm2d(out_channel)) + + def forward(self, x): + shortcut = x + res = self.res_layer(x) + + if self.shortcut_layer is not None: + shortcut = self.shortcut_layer(x) + + return shortcut + res + + +channel_list = [64, 64, 128, 256, 512] + + +def get_layers(num_layers): + if num_layers == 34: + return [3, 4, 6, 3] + if num_layers == 50: + return [3, 4, 14, 3] + elif num_layers == 100: + return [3, 13, 30, 3] + elif num_layers == 152: + return [3, 8, 36, 3] + + +class ResNet(nn.Module): + + def __init__(self, + num_layers=100, + feature_dim=512, + drop_ratio=0.4, + channel_list=channel_list): + super(ResNet, self).__init__() + assert num_layers in [34, 50, 100, 152] + layers = get_layers(num_layers) + block = BottleNeck_IR + + self.input_layer = nn.Sequential( + nn.Conv2d( + 3, channel_list[0], (3, 3), stride=1, padding=1, bias=False), + nn.BatchNorm2d(channel_list[0]), nn.PReLU(channel_list[0])) + self.layer1 = self._make_layer( + block, channel_list[0], channel_list[1], layers[0], stride=2) + self.layer2 = self._make_layer( + block, channel_list[1], channel_list[2], layers[1], stride=2) + self.layer3 = self._make_layer( + block, channel_list[2], channel_list[3], layers[2], stride=2) + self.layer4 = self._make_layer( + block, channel_list[3], channel_list[4], layers[3], stride=2) + + self.output_layer = nn.Sequential( + nn.BatchNorm2d(512), nn.Dropout(drop_ratio), nn.Flatten()) + self.feature_layer = nn.Sequential( + nn.Linear(512 * 7 * 7, feature_dim), nn.BatchNorm1d(feature_dim)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d) or isinstance( + m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, in_channel, out_channel, blocks, stride): + layers = [] + layers.append(block(in_channel, out_channel, stride, False)) + for i in range(1, blocks): + layers.append(block(out_channel, out_channel, 1, True)) + return nn.Sequential(*layers) + + def forward(self, x, fc=False): + x = self.input_layer(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.output_layer(x) + feature = self.feature_layer(x) + if fc: + return feature, x + return feature + + +class FaceQuality(nn.Module): + + def __init__(self, feature_dim): + super(FaceQuality, self).__init__() + self.qualtiy = nn.Sequential( + nn.Linear(feature_dim, 512, bias=False), nn.BatchNorm1d(512), + nn.ReLU(inplace=True), nn.Linear(512, 2, bias=False), + nn.Softmax(dim=1)) + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d) or isinstance( + m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.qualtiy(x) + return x[:, 0:1] diff --git a/modelscope/models/cv/image_portrait_enhancement/gpen.py b/modelscope/models/cv/image_portrait_enhancement/gpen.py new file mode 100755 index 00000000..86009a41 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/gpen.py @@ -0,0 +1,815 @@ +# The GPEN implementation is also open-sourced by the authors, +# and available at https://github.com/yangxy/GPEN/blob/main/face_model/gpen_model.py +import functools +import itertools +import math +import operator +import random + +import torch +from torch import nn +from torch.autograd import Function +from torch.nn import functional as F + +from modelscope.models.cv.face_generation.op import (FusedLeakyReLU, + fused_leaky_relu, + upfirdn2d) + + +class PixelNorm(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt( + torch.mean(input**2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor**2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d( + input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d( + input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor**2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + bias=True): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size**2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + + def __init__(self, + in_dim, + out_dim, + bias=True, + bias_init=0, + lr_mul=1, + activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur( + blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size**2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})') + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view(batch * self.out_channel, in_channel, + self.kernel_size, self.kernel_size) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view(batch, self.out_channel, in_channel, + self.kernel_size, self.kernel_size) + weight = weight.transpose(1, 2).reshape(batch * in_channel, + self.out_channel, + self.kernel_size, + self.kernel_size) + out = F.conv_transpose2d( + input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + + def __init__(self, isconcat=True): + super().__init__() + + self.isconcat = isconcat + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, channel, height, width = image.shape + noise = image.new_empty(batch, channel, height, width).normal_() + + if self.isconcat: + return torch.cat((image, self.weight * noise), dim=1) + else: + return image + self.weight * noise + + +class ConstantInput(nn.Module): + + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + isconcat=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection(isconcat) + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + feat_multiplier = 2 if isconcat else 1 + self.activate = FusedLeakyReLU(out_channel * feat_multiplier) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + + def __init__(self, + in_channel, + style_dim, + upsample=True, + blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d( + in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + isconcat=True, + narrow=1, + ): + super().__init__() + + self.size = size + self.n_mlp = n_mlp + self.style_dim = style_dim + self.feat_multiplier = 2 if isconcat else 1 + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, + style_dim, + lr_mul=lr_mlp, + activation='fused_lrelu')) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], + self.channels[4], + 3, + style_dim, + blur_kernel=blur_kernel, + isconcat=isconcat) + self.to_rgb1 = ToRGB( + self.channels[4] * self.feat_multiplier, style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + + in_channel = self.channels[4] + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2**i] + + self.convs.append( + StyledConv( + in_channel * self.feat_multiplier, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + isconcat=isconcat, + )) + + self.convs.append( + StyledConv( + out_channel * self.feat_multiplier, + out_channel, + 3, + style_dim, + blur_kernel=blur_kernel, + isconcat=isconcat)) + + self.to_rgbs.append( + ToRGB(out_channel * self.feat_multiplier, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + ''' + noise = [None] * (2 * (self.log_size - 2) + 1) + ''' + noise = [] + batch = styles[0].shape[0] + for i in range(self.n_mlp + 1): + size = 2**(i + 2) + noise.append( + torch.randn( + batch, + self.channels[size], + size, + size, + device=styles[0].device)) + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append(truncation_latent + + truncation * (style - truncation_latent)) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat( + 1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + )) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, + out_channel, + 1, + downsample=True, + activate=False, + bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class FullGenerator(nn.Module): + + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + isconcat=True, + narrow=1, + ): + super().__init__() + channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) + } + + self.log_size = int(math.log(size, 2)) + self.generator = Generator( + size, + style_dim, + n_mlp, + channel_multiplier=channel_multiplier, + blur_kernel=blur_kernel, + lr_mlp=lr_mlp, + isconcat=isconcat, + narrow=narrow) + + conv = [ConvLayer(3, channels[size], 1)] + self.ecd0 = nn.Sequential(*conv) + in_channel = channels[size] + + self.names = ['ecd%d' % i for i in range(self.log_size - 1)] + for i in range(self.log_size, 2, -1): + out_channel = channels[2**(i - 1)] + # conv = [ResBlock(in_channel, out_channel, blur_kernel)] + conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] + setattr(self, self.names[self.log_size - i + 1], + nn.Sequential(*conv)) + in_channel = out_channel + self.final_linear = nn.Sequential( + EqualLinear( + channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) + + def forward( + self, + inputs, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + ): + noise = [] + for i in range(self.log_size - 1): + ecd = getattr(self, self.names[i]) + inputs = ecd(inputs) + noise.append(inputs) + inputs = inputs.view(inputs.shape[0], -1) + outs = self.final_linear(inputs) + noise = list( + itertools.chain.from_iterable( + itertools.repeat(x, 2) for x in noise))[::-1] + outs = self.generator([outs], + return_latents, + inject_index, + truncation, + truncation_latent, + input_is_latent, + noise=noise[1:]) + return outs + + +class Discriminator(nn.Module): + + def __init__(self, + size, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + narrow=1): + super().__init__() + + channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2**(i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear( + channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view(group, -1, self.stddev_feat, + channel // self.stddev_feat, height, width) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + return out diff --git a/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py new file mode 100644 index 00000000..26e9e532 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py @@ -0,0 +1,206 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os.path as osp +from copy import deepcopy +from typing import Any, Dict, List, Union + +import torch +import torch.nn.functional as F +from torch import autograd, nn +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .gpen import Discriminator, FullGenerator +from .losses.losses import IDLoss, L1Loss + +logger = get_logger() + +__all__ = ['ImagePortraitEnhancement'] + + +@MODELS.register_module( + Tasks.image_portrait_enhancement, module_name=Models.gpen) +class ImagePortraitEnhancement(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the face enhancement model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + self.size = 256 + self.style_dim = 512 + self.n_mlp = 8 + self.mean_path_length = 0 + self.accum = 0.5**(32 / (10 * 1000)) + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + + self.l1_loss = L1Loss() + self.id_loss = IDLoss(f'{model_dir}/arcface/model_ir_se50.pth', + self._device) + self.generator = FullGenerator( + self.size, self.style_dim, self.n_mlp, + isconcat=True).to(self._device) + self.g_ema = FullGenerator( + self.size, self.style_dim, self.n_mlp, + isconcat=True).to(self._device) + self.discriminator = Discriminator(self.size).to(self._device) + + if self.size == 512: + self.load_pretrained(model_dir) + + def load_pretrained(self, model_dir): + g_path = f'{model_dir}/{ModelFile.TORCH_MODEL_FILE}' + g_dict = torch.load(g_path, map_location=torch.device('cpu')) + self.generator.load_state_dict(g_dict) + self.g_ema.load_state_dict(g_dict) + + d_path = f'{model_dir}/net_d.pt' + d_dict = torch.load(d_path, map_location=torch.device('cpu')) + self.discriminator.load_state_dict(d_dict) + + logger.info('load model done.') + + def accumulate(self): + par1 = dict(self.g_ema.named_parameters()) + par2 = dict(self.generator.named_parameters()) + + for k in par1.keys(): + par1[k].data.mul_(self.accum).add_(1 - self.accum, par2[k].data) + + def requires_grad(self, model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + def d_logistic_loss(self, real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + def d_r1_loss(self, real_pred, real_img): + grad_real, = autograd.grad( + outputs=real_pred.sum(), inputs=real_img, create_graph=True) + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], + -1).sum(1).mean() + + return grad_penalty + + def g_nonsaturating_loss(self, + fake_pred, + fake_img=None, + real_img=None, + input_img=None): + loss = F.softplus(-fake_pred).mean() + loss_l1 = self.l1_loss(fake_img, real_img) + loss_id, __, __ = self.id_loss(fake_img, real_img, input_img) + loss_id = 0 + loss += 1.0 * loss_l1 + 1.0 * loss_id + + return loss + + def g_path_regularize(self, + fake_img, + latents, + mean_path_length, + decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt( + fake_img.shape[2] * fake_img.shape[3]) + grad, = autograd.grad( + outputs=(fake_img * noise).sum(), + inputs=latents, + create_graph=True) + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * ( + path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_mean.detach(), path_lengths + + @torch.no_grad() + def _evaluate_postprocess(self, input: Tensor, + target: Tensor) -> Dict[str, list]: + preds, _ = self.generator(input) + preds = list(torch.split(preds, 1, 0)) + targets = list(torch.split(target, 1, 0)) + + preds = [((pred.data * 0.5 + 0.5) * 255.).squeeze(0).type( + torch.uint8).permute(1, 2, 0).cpu().numpy() for pred in preds] + targets = [((target.data * 0.5 + 0.5) * 255.).squeeze(0).type( + torch.uint8).permute(1, 2, 0).cpu().numpy() for target in targets] + + return {'pred': preds, 'target': targets} + + def _train_forward_d(self, input: Tensor, target: Tensor) -> Tensor: + self.requires_grad(self.generator, False) + self.requires_grad(self.discriminator, True) + + preds, _ = self.generator(input) + fake_pred = self.discriminator(preds) + real_pred = self.discriminator(target) + + d_loss = self.d_logistic_loss(real_pred, fake_pred) + + return d_loss + + def _train_forward_d_r1(self, input: Tensor, target: Tensor) -> Tensor: + input.requires_grad = True + target.requires_grad = True + real_pred = self.discriminator(target) + r1_loss = self.d_r1_loss(real_pred, target) + + return r1_loss + + def _train_forward_g(self, input: Tensor, target: Tensor) -> Tensor: + self.requires_grad(self.generator, True) + self.requires_grad(self.discriminator, False) + + preds, _ = self.generator(input) + fake_pred = self.discriminator(preds) + + g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, input) + + return g_loss + + def _train_forward_g_path(self, input: Tensor, target: Tensor) -> Tensor: + fake_img, latents = self.generator(input, return_latents=True) + + path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( + fake_img, latents, self.mean_path_length) + + return path_loss + + @torch.no_grad() + def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: + return {'outputs': (self.generator(input)[0] * 0.5 + 0.5).clamp(0, 1)} + + def forward(self, input: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Union[list, Tensor]]: results + """ + for key, value in input.items(): + input[key] = input[key].to(self._device) + + if 'target' in input: + return self._evaluate_postprocess(**input) + else: + return self._inference_forward(**input) diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/__init__.py b/modelscope/models/cv/image_portrait_enhancement/losses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py b/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py new file mode 100644 index 00000000..86f6f227 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py @@ -0,0 +1,131 @@ +# The implementation is adopted from InsightFace_Pytorch, +# made publicly available under the MIT License at https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py +from collections import namedtuple + +import torch +from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d, + Module, PReLU, ReLU, Sequential, Sigmoid) + + +class Flatten(Module): + + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride) + ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError( + 'Invalid number of layers: {}. Must be one of [50, 100, 152]'. + format(num_layers)) + return blocks + + +class SEModule(Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d( + channels, + channels // reduction, + kernel_size=1, + padding=0, + bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d( + channels // reduction, + channels, + kernel_size=1, + padding=0, + bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), SEModule(depth, 16)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/losses.py b/modelscope/models/cv/image_portrait_enhancement/losses/losses.py new file mode 100644 index 00000000..0f5198c3 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/losses/losses.py @@ -0,0 +1,92 @@ +# The GPEN implementation is also open-sourced by the authors, +# and available at https://github.com/yangxy/GPEN/tree/main/training/loss/id_loss.py +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .model_irse import Backbone + + +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError( + f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' + ) + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * F.l1_loss( + pred, target, reduction=self.reduction) + + +class IDLoss(nn.Module): + + def __init__(self, model_path, device='cuda', ckpt_dict=None): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone( + input_size=112, num_layers=50, drop_ratio=0.6, + mode='ir_se').to(device) + if ckpt_dict is None: + self.facenet.load_state_dict( + torch.load(model_path, map_location=torch.device('cpu'))) + else: + self.facenet.load_state_dict(ckpt_dict) + self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + + def extract_feats(self, x): + _, _, h, w = x.shape + assert h == w + if h != 256: + x = self.pool(x) + x = x[:, :, 35:-33, 32:-36] # crop roi + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats + + @torch.no_grad() + def forward(self, y_hat, y, x): + n_samples = x.shape[0] + x_feats = self.extract_feats(x) + y_feats = self.extract_feats(y) # Otherwise use the feature from there + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + id_logs = [] + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + diff_input = y_hat_feats[i].dot(x_feats[i]) + diff_views = y_feats[i].dot(x_feats[i]) + id_logs.append({ + 'diff_target': float(diff_target), + 'diff_input': float(diff_input), + 'diff_views': float(diff_views) + }) + loss += 1 - diff_target + id_diff = float(diff_target) - float(diff_views) + sim_improvement += id_diff + count += 1 + + return loss / count, sim_improvement / count, id_logs diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py b/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py new file mode 100644 index 00000000..00dc7c52 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py @@ -0,0 +1,94 @@ +# The implementation is adopted from InsightFace_Pytorch, +# made publicly available under the MIT License at https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py +from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, + Module, PReLU, Sequential) + +from .helpers import (Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks, + l2_norm) + + +class Backbone(Module): + + def __init__(self, + input_size, + num_layers, + mode='ir', + drop_ratio=0.4, + affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], 'input_size should be 112 or 224' + assert num_layers in [50, 100, + 152], 'num_layers should be 50, 100 or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential( + Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential( + BatchNorm2d(512), Dropout(drop_ratio), Flatten(), + Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential( + BatchNorm2d(512), Dropout(drop_ratio), Flatten(), + Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone( + input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone( + input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone( + input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone( + input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone( + input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone( + input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/__init__.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py new file mode 100755 index 00000000..7ad780a8 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py @@ -0,0 +1,219 @@ +# The GPEN implementation is also open-sourced by the authors, +# and available at https://github.com/yangxy/GPEN/blob/main/face_detect/retinaface_detection.py +import os + +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn.functional as F + +from .models.retinaface import RetinaFace +from .utils import PriorBox, decode, decode_landm, py_cpu_nms + +cfg_re50 = { + 'name': 'Resnet50', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'pretrain': False, + 'return_layers': { + 'layer2': 1, + 'layer3': 2, + 'layer4': 3 + }, + 'in_channel': 256, + 'out_channel': 256 +} + + +class RetinaFaceDetection(object): + + def __init__(self, model_path, device='cuda'): + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.device = device + self.cfg = cfg_re50 + self.net = RetinaFace(cfg=self.cfg) + self.load_model() + self.net = self.net.to(device) + + self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device) + + def check_keys(self, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(self.net.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + assert len( + used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + + def remove_prefix(self, state_dict, prefix): + new_state_dict = dict() + # remove unnecessary 'module.' + for k, v in state_dict.items(): + if k.startswith(prefix): + new_state_dict[k[len(prefix):]] = v + else: + new_state_dict[k] = v + return new_state_dict + + def load_model(self, load_to_cpu=False): + pretrained_dict = torch.load( + self.model_path, map_location=torch.device('cpu')) + if 'state_dict' in pretrained_dict.keys(): + pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], + 'module.') + else: + pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') + self.check_keys(pretrained_dict) + self.net.load_state_dict(pretrained_dict, strict=False) + self.net.eval() + + def detect(self, + img_raw, + resize=1, + confidence_threshold=0.9, + nms_threshold=0.4, + top_k=5000, + keep_top_k=750, + save_image=False): + img = np.float32(img_raw) + + im_height, im_width = img.shape[:2] + ss = 1.0 + # tricky + if max(im_height, im_width) > 1500: + ss = 1000.0 / max(im_height, im_width) + img = cv2.resize(img, (0, 0), fx=ss, fy=ss) + im_height, im_width = img.shape[:2] + + scale = torch.Tensor( + [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.to(self.device) + scale = scale.to(self.device) + + loc, conf, landms = self.net(img) # forward pass + del img + + priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) + priors = priorbox.forward() + priors = priors.to(self.device) + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) + boxes = boxes * scale / resize + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = decode_landm( + landms.data.squeeze(0), prior_data, self.cfg['variance']) + scale1 = torch.Tensor([ + im_width, im_height, im_width, im_height, im_width, im_height, + im_width, im_height, im_width, im_height + ]) + scale1 = scale1.to(self.device) + landms = landms * scale1 / resize + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype( + np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + landms = landms[:keep_top_k, :] + + landms = landms.reshape((-1, 5, 2)) + landms = landms.transpose((0, 2, 1)) + landms = landms.reshape( + -1, + 10, + ) + return dets / ss, landms / ss + + def detect_tensor(self, + img, + resize=1, + confidence_threshold=0.9, + nms_threshold=0.4, + top_k=5000, + keep_top_k=750, + save_image=False): + im_height, im_width = img.shape[-2:] + ss = 1000 / max(im_height, im_width) + img = F.interpolate(img, scale_factor=ss) + im_height, im_width = img.shape[-2:] + scale = torch.Tensor([im_width, im_height, im_width, + im_height]).to(self.device) + img -= self.mean + + loc, conf, landms = self.net(img) # forward pass + + priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) + priors = priorbox.forward() + priors = priors.to(self.device) + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) + boxes = boxes * scale / resize + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = decode_landm( + landms.data.squeeze(0), prior_data, self.cfg['variance']) + scale1 = torch.Tensor([ + img.shape[3], img.shape[2], img.shape[3], img.shape[2], + img.shape[3], img.shape[2], img.shape[3], img.shape[2], + img.shape[3], img.shape[2] + ]) + scale1 = scale1.to(self.device) + landms = landms * scale1 / resize + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype( + np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + landms = landms[:keep_top_k, :] + + landms = landms.reshape((-1, 5, 2)) + landms = landms.transpose((0, 2, 1)) + landms = landms.reshape( + -1, + 10, + ) + return dets / ss, landms / ss diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/__init__.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py new file mode 100755 index 00000000..24451e96 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py @@ -0,0 +1,150 @@ +# The implementation is adopted from Pytorch_Retinaface, made pubicly available under the MIT License +# at https://github.com/biubug6/Pytorch_Retinaface/tree/master/models/net.py +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.models._utils as _utils +from torch.autograd import Variable + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn( + in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn( + out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1( + in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1( + in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1( + in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate( + output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate( + output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + x = x.view(-1, 256) + x = self.fc(x) + return x diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py new file mode 100755 index 00000000..64d95971 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py @@ -0,0 +1,146 @@ +# The implementation is adopted from Pytorch_Retinaface, made pubicly available under the MIT License +# at https://github.com/biubug6/Pytorch_Retinaface/tree/master/models/retinaface.py +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.models._utils as _utils +import torchvision.models.detection.backbone_utils as backbone_utils + +from .net import FPN, SSH, MobileNetV1 + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d( + inchannels, + self.num_anchors * 2, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d( + inchannels, + num_anchors * 4, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d( + inchannels, + num_anchors * 10, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +class RetinaFace(nn.Module): + + def __init__(self, cfg=None): + """ + :param cfg: Network related settings. + """ + super(RetinaFace, self).__init__() + backbone = None + if cfg['name'] == 'Resnet50': + backbone = models.resnet50(pretrained=cfg['pretrain']) + else: + raise Exception('Invalid name') + + self.body = _utils.IntermediateLayerGetter(backbone, + cfg['return_layers']) + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head( + fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = self._make_bbox_head( + fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = self._make_landmark_head( + fpn_num=3, inchannels=cfg['out_channel']) + + def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead + + def forward(self, inputs): + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat( + [self.BboxHead[i](feature) for i, feature in enumerate(features)], + dim=1) + classifications = torch.cat( + [self.ClassHead[i](feature) for i, feature in enumerate(features)], + dim=1) + ldm_regressions = torch.cat( + [self.LandmarkHead[i](feat) for i, feat in enumerate(features)], + dim=1) + + output = (bbox_regressions, F.softmax(classifications, + dim=-1), ldm_regressions) + return output diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/utils.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/utils.py new file mode 100755 index 00000000..60c9e2dd --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/utils.py @@ -0,0 +1,123 @@ +# -------------------------------------------------------- +# Modified from https://github.com/biubug6/Pytorch_Retinaface +# -------------------------------------------------------- + +from itertools import product as product +from math import ceil + +import numpy as np +import torch + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ + ceil(self.image_size[0] / step), + ceil(self.image_size[1] / step) + ] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [ + x * self.steps[k] / self.image_size[1] + for x in [j + 0.5] + ] + dense_cy = [ + y * self.steps[k] / self.image_size[0] + for y in [i + 0.5] + ] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat( + (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] + b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] + c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] + d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] + e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] + landms = torch.cat((a, b, c, d, e), dim=1) + return landms diff --git a/modelscope/models/cv/image_reid_person/__init__.py b/modelscope/models/cv/image_reid_person/__init__.py new file mode 100644 index 00000000..0fe0bede --- /dev/null +++ b/modelscope/models/cv/image_reid_person/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .pass_model import PASS + +else: + _import_structure = { + 'pass_model': ['PASS'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_reid_person/pass_model.py b/modelscope/models/cv/image_reid_person/pass_model.py new file mode 100644 index 00000000..3b032949 --- /dev/null +++ b/modelscope/models/cv/image_reid_person/pass_model.py @@ -0,0 +1,136 @@ +# The implementation is adopted from PASS-reID, made pubicly available under the Apache-2.0 License at +# https://github.com/CASIA-IVA-Lab/PASS-reID + +import os +from enum import Enum + +import torch +import torch.nn as nn + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .transreid_model import vit_base_patch16_224_TransReID + + +class Fusions(Enum): + CAT = 'cat' + MEAN = 'mean' + + +@MODELS.register_module( + Tasks.image_reid_person, module_name=Models.image_reid_person) +class PASS(TorchModel): + + def __init__(self, cfg: Config, model_dir: str, **kwargs): + super(PASS, self).__init__(model_dir=model_dir) + size_train = cfg.INPUT.SIZE_TRAIN + sie_coe = cfg.MODEL.SIE_COE + stride_size = cfg.MODEL.STRIDE_SIZE + drop_path = cfg.MODEL.DROP_PATH + drop_out = cfg.MODEL.DROP_OUT + att_drop_rate = cfg.MODEL.ATT_DROP_RATE + gem_pooling = cfg.MODEL.GEM_POOLING + stem_conv = cfg.MODEL.STEM_CONV + weight = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + self.neck_feat = cfg.TEST.NECK_FEAT + self.dropout_rate = cfg.MODEL.DROPOUT_RATE + self.num_classes = cfg.DATASETS.NUM_CLASSES + self.multi_neck = cfg.MODEL.MULTI_NECK + self.feat_fusion = cfg.MODEL.FEAT_FUSION + + self.base = vit_base_patch16_224_TransReID( + img_size=size_train, + sie_xishu=sie_coe, + stride_size=stride_size, + drop_path_rate=drop_path, + drop_rate=drop_out, + attn_drop_rate=att_drop_rate, + gem_pool=gem_pooling, + stem_conv=stem_conv) + self.in_planes = self.base.in_planes + + if self.feat_fusion == Fusions.CAT.value: + self.classifier = nn.Linear( + self.in_planes * 2, self.num_classes, bias=False) + elif self.feat_fusion == Fusions.MEAN.value: + self.classifier = nn.Linear( + self.in_planes, self.num_classes, bias=False) + + if self.multi_neck: + self.bottleneck = nn.BatchNorm1d(self.in_planes) + self.bottleneck.bias.requires_grad_(False) + self.bottleneck_1 = nn.BatchNorm1d(self.in_planes) + self.bottleneck_1.bias.requires_grad_(False) + self.bottleneck_2 = nn.BatchNorm1d(self.in_planes) + self.bottleneck_2.bias.requires_grad_(False) + self.bottleneck_3 = nn.BatchNorm1d(self.in_planes) + self.bottleneck_3.bias.requires_grad_(False) + else: + if self.feat_fusion == Fusions.CAT.value: + self.bottleneck = nn.BatchNorm1d(self.in_planes * 2) + self.bottleneck.bias.requires_grad_(False) + elif self.feat_fusion == Fusions.MEAN.value: + self.bottleneck = nn.BatchNorm1d(self.in_planes) + self.bottleneck.bias.requires_grad_(False) + + self.dropout = nn.Dropout(self.dropout_rate) + + self.load_param(weight) + + def forward(self, input): + + global_feat, local_feat_1, local_feat_2, local_feat_3 = self.base( + input) + + # single-neck, almost the same performance + if not self.multi_neck: + if self.feat_fusion == Fusions.MEAN.value: + local_feat = local_feat_1 / 3. + local_feat_2 / 3. + local_feat_3 / 3. + final_feat_before = (global_feat + local_feat) / 2 + elif self.feat_fusion == Fusions.CAT.value: + final_feat_before = torch.cat( + (global_feat, local_feat_1 / 3. + local_feat_2 / 3. + + local_feat_3 / 3.), + dim=1) + + final_feat_after = self.bottleneck(final_feat_before) + # multi-neck + else: + feat = self.bottleneck(global_feat) + local_feat_1_bn = self.bottleneck_1(local_feat_1) + local_feat_2_bn = self.bottleneck_2(local_feat_2) + local_feat_3_bn = self.bottleneck_3(local_feat_3) + + if self.feat_fusion == Fusions.MEAN.value: + final_feat_before = ((global_feat + local_feat_1 / 3 + + local_feat_2 / 3 + local_feat_3 / 3) + / 2.) + final_feat_after = (feat + local_feat_1_bn / 3 + + local_feat_2_bn / 3 + + local_feat_3_bn / 3) / 2. + elif self.feat_fusion == Fusions.CAT.value: + final_feat_before = torch.cat( + (global_feat, local_feat_1 / 3. + local_feat_2 / 3. + + local_feat_3 / 3.), + dim=1) + final_feat_after = torch.cat( + (feat, local_feat_1_bn / 3 + local_feat_2_bn / 3 + + local_feat_3_bn / 3), + dim=1) + + if self.neck_feat == 'after': + return final_feat_after + else: + return final_feat_before + + def load_param(self, trained_path): + param_dict = torch.load(trained_path, map_location='cpu') + for i in param_dict: + try: + self.state_dict()[i.replace('module.', + '')].copy_(param_dict[i]) + except Exception: + continue diff --git a/modelscope/models/cv/image_reid_person/transreid_model.py b/modelscope/models/cv/image_reid_person/transreid_model.py new file mode 100644 index 00000000..5bceb468 --- /dev/null +++ b/modelscope/models/cv/image_reid_person/transreid_model.py @@ -0,0 +1,418 @@ +# The implementation is adopted from PASS-reID, made pubicly available under the Apache-2.0 License at +# https://github.com/CASIA-IVA-Lab/PASS-reID + +import collections.abc as container_abcs +from functools import partial +from itertools import repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# From PyTorch internals +def _ntuple(n): + + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +def vit_base_patch16_224_TransReID( + img_size=(256, 128), + stride_size=16, + drop_path_rate=0.1, + camera=0, + view=0, + local_feature=False, + sie_xishu=1.5, + **kwargs): + model = TransReID( + img_size=img_size, + patch_size=16, + stride_size=stride_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + camera=camera, + view=view, + drop_path_rate=drop_path_rate, + sie_xishu=sie_xishu, + local_feature=local_feature, + **kwargs) + return model + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class TransReID(nn.Module): + """Transformer-based Object Re-Identification + """ + + def __init__(self, + img_size=224, + patch_size=16, + stride_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + camera=0, + view=0, + drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + local_feature=False, + sie_xishu=1.0, + hw_ratio=1, + gem_pool=False, + stem_conv=False): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.local_feature = local_feature + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + stride_size=stride_size, + in_chans=in_chans, + embed_dim=embed_dim, + stem_conv=stem_conv) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part_token1 = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part_token2 = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part_token3 = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + self.cls_pos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part1_pos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part2_pos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part3_pos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.cam_num = camera + self.view_num = view + self.sie_xishu = sie_xishu + self.in_planes = 768 + self.gem_pool = gem_pool + + # Initialize SIE Embedding + if camera > 1 and view > 1: + self.sie_embed = nn.Parameter( + torch.zeros(camera * view, 1, embed_dim)) + elif camera > 1: + self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim)) + elif view > 1: + self.sie_embed = nn.Parameter(torch.zeros(view, 1, embed_dim)) + + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer) for i in range(depth) + ]) + + self.norm = norm_layer(embed_dim) + + # Classifier head + self.fc = nn.Linear(embed_dim, + num_classes) if num_classes > 0 else nn.Identity() + + self.gem = GeneralizedMeanPooling() + + def forward_features(self, x, camera_id, view_id): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand( + B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + part_tokens1 = self.part_token1.expand(B, -1, -1) + part_tokens2 = self.part_token2.expand(B, -1, -1) + part_tokens3 = self.part_token3.expand(B, -1, -1) + x = torch.cat( + (cls_tokens, part_tokens1, part_tokens2, part_tokens3, x), dim=1) + + if self.cam_num > 0 and self.view_num > 0: + x = x + self.pos_embed + self.sie_xishu * self.sie_embed[ + camera_id * self.view_num + view_id] + elif self.cam_num > 0: + x = x + self.pos_embed + self.sie_xishu * self.sie_embed[camera_id] + elif self.view_num > 0: + x = x + self.pos_embed + self.sie_xishu * self.sie_embed[view_id] + else: + x = x + torch.cat((self.cls_pos, self.part1_pos, self.part2_pos, + self.part3_pos, self.pos_embed), + dim=1) + + x = self.pos_drop(x) + + if self.local_feature: + for blk in self.blocks[:-1]: + x = blk(x) + return x + else: + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + if self.gem_pool: + gf = self.gem(x[:, 1:].permute(0, 2, 1)).squeeze() + return x[:, 0] + gf + return x[:, 0], x[:, 1], x[:, 2], x[:, 3] + + def forward(self, x, cam_label=None, view_label=None): + global_feat, local_feat_1, local_feat_2, local_feat_3 = self.forward_features( + x, cam_label, view_label) + return global_feat, local_feat_1, local_feat_2, local_feat_3 + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding with overlapping patches + """ + + def __init__(self, + img_size=224, + patch_size=16, + stride_size=16, + in_chans=3, + embed_dim=768, + stem_conv=False): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride_size_tuple = to_2tuple(stride_size) + self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 + self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 + self.num_patches = self.num_x * self.num_y + self.img_size = img_size + self.patch_size = patch_size + + self.stem_conv = stem_conv + if self.stem_conv: + hidden_dim = 64 + stem_stride = 2 + stride_size = patch_size = patch_size[0] // stem_stride + self.conv = nn.Sequential( + nn.Conv2d( + in_chans, + hidden_dim, + kernel_size=7, + stride=stem_stride, + padding=3, + bias=False), + IBN(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d( + hidden_dim, + hidden_dim, + kernel_size=3, + stride=1, + padding=1, + bias=False), + IBN(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d( + hidden_dim, + hidden_dim, + kernel_size=3, + stride=1, + padding=1, + bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + ) + in_chans = hidden_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride_size) + + def forward(self, x): + if self.stem_conv: + x = self.conv(x) + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) # [64, 8, 768] + + return x + + +class GeneralizedMeanPooling(nn.Module): + """Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. + The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` + - At p = infinity, one gets Max Pooling + - At p = 1, one gets Average Pooling + The output is of size H x W, for any input size. + The number of output features is equal to the number of input planes. + Args: + output_size: the target output size of the image of the form H x W. + Can be a tuple (H, W) or a single H for a square image H x H + H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + """ + + def __init__(self, norm=3, output_size=1, eps=1e-6): + super(GeneralizedMeanPooling, self).__init__() + assert norm > 0 + self.p = float(norm) + self.output_size = output_size + self.eps = eps + + def forward(self, x): + x = x.clamp(min=self.eps).pow(self.p) + return F.adaptive_avg_pool1d(x, self.output_size).pow(1. / self.p) + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/modelscope/models/cv/image_semantic_segmentation/__init__.py b/modelscope/models/cv/image_semantic_segmentation/__init__.py new file mode 100644 index 00000000..df56c5b8 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .semantic_seg_model import SemanticSegmentation + from .segformer import Segformer + +else: + _import_structure = { + 'semantic_seg_model': ['SemanticSegmentation'], + 'segformer': ['Segformer'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_semantic_segmentation/pan_merge/__init__.py b/modelscope/models/cv/image_semantic_segmentation/pan_merge/__init__.py new file mode 100644 index 00000000..6a31a308 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/pan_merge/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .maskformer_semantic_head import MaskFormerSemanticHead diff --git a/modelscope/models/cv/image_semantic_segmentation/pan_merge/base_panoptic_fusion_head.py b/modelscope/models/cv/image_semantic_segmentation/pan_merge/base_panoptic_fusion_head.py new file mode 100644 index 00000000..05e68d89 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/pan_merge/base_panoptic_fusion_head.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +from mmcv.runner import BaseModule +from mmdet.models.builder import build_loss + + +class BasePanopticFusionHead(BaseModule, metaclass=ABCMeta): + """Base class for panoptic heads.""" + + def __init__(self, + num_things_classes=80, + num_stuff_classes=53, + test_cfg=None, + loss_panoptic=None, + init_cfg=None, + **kwargs): + super(BasePanopticFusionHead, self).__init__(init_cfg) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = num_things_classes + num_stuff_classes + self.test_cfg = test_cfg + + if loss_panoptic: + self.loss_panoptic = build_loss(loss_panoptic) + else: + self.loss_panoptic = None + + @property + def with_loss(self): + """bool: whether the panoptic head contains loss function.""" + return self.loss_panoptic is not None + + @abstractmethod + def forward_train(self, gt_masks=None, gt_semantic_seg=None, **kwargs): + """Forward function during training.""" + + @abstractmethod + def simple_test(self, + img_metas, + det_labels, + mask_preds, + seg_preds, + det_bboxes, + cfg=None, + **kwargs): + """Test without augmentation.""" diff --git a/modelscope/models/cv/image_semantic_segmentation/pan_merge/maskformer_semantic_head.py b/modelscope/models/cv/image_semantic_segmentation/pan_merge/maskformer_semantic_head.py new file mode 100644 index 00000000..2f3364d0 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/pan_merge/maskformer_semantic_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn.functional as F +from mmdet.models.builder import HEADS + +from .base_panoptic_fusion_head import BasePanopticFusionHead + + +@HEADS.register_module() +class MaskFormerSemanticHead(BasePanopticFusionHead): + + def __init__(self, + num_things_classes=80, + num_stuff_classes=53, + test_cfg=None, + loss_panoptic=None, + init_cfg=None, + **kwargs): + super().__init__(num_things_classes, num_stuff_classes, test_cfg, + loss_panoptic, init_cfg, **kwargs) + + def forward_train(self, **kwargs): + """MaskFormerFusionHead has no training loss.""" + return dict() + + def simple_test(self, + mask_cls_results, + mask_pred_results, + img_metas, + rescale=False, + **kwargs): + results = [] + for mask_cls_result, mask_pred_result, meta in zip( + mask_cls_results, mask_pred_results, img_metas): + # remove padding + img_height, img_width = meta['img_shape'][:2] + mask_pred_result = mask_pred_result[:, :img_height, :img_width] + + if rescale: + # return result in original resolution + ori_height, ori_width = meta['ori_shape'][:2] + mask_pred_result = F.interpolate( + mask_pred_result[:, None], + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False)[:, 0] + + # semantic inference + cls_score = F.softmax(mask_cls_result, dim=-1)[..., :-1] + mask_pred = mask_pred_result.sigmoid() + seg_mask = torch.einsum('qc,qhw->chw', cls_score, mask_pred) + # still need softmax and argmax + seg_logit = F.softmax(seg_mask, dim=0) + seg_pred = seg_logit.argmax(dim=0) + seg_pred = seg_pred.cpu().numpy() + results.append(seg_pred) + + return results diff --git a/modelscope/models/cv/image_semantic_segmentation/segformer.py b/modelscope/models/cv/image_semantic_segmentation/segformer.py new file mode 100644 index 00000000..46303526 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/segformer.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.models.segmentation import EncoderDecoder + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.cv.easycv_base import EasyCVBaseModel +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + group_key=Tasks.image_segmentation, module_name=Models.segformer) +class Segformer(EasyCVBaseModel, EncoderDecoder): + + def __init__(self, model_dir=None, *args, **kwargs): + EasyCVBaseModel.__init__(self, model_dir, args, kwargs) + EncoderDecoder.__init__(self, *args, **kwargs) diff --git a/modelscope/models/cv/image_semantic_segmentation/semantic_seg_model.py b/modelscope/models/cv/image_semantic_segmentation/semantic_seg_model.py new file mode 100644 index 00000000..2b38ebad --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/semantic_seg_model.py @@ -0,0 +1,77 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.image_semantic_segmentation import (pan_merge, + vit_adapter) +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks + + +@MODELS.register_module( + Tasks.image_segmentation, module_name=Models.swinL_semantic_segmentation) +@MODELS.register_module( + Tasks.image_segmentation, + module_name=Models.vitadapter_semantic_segmentation) +class SemanticSegmentation(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, **kwargs) + + from mmcv.runner import load_checkpoint + import mmcv + from mmdet.models import build_detector + + config = osp.join(model_dir, 'mmcv_config.py') + cfg = mmcv.Config.fromfile(config) + if 'pretrained' in cfg.model: + cfg.model.pretrained = None + elif 'init_cfg' in cfg.model.backbone: + cfg.model.backbone.init_cfg = None + + # build model + cfg.model.train_cfg = None + self.model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + + # load model + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + _ = load_checkpoint(self.model, model_path, map_location='cpu') + + self.CLASSES = cfg['CLASSES'] # list + self.PALETTE = cfg['PALETTE'] # list + + self.num_classes = len(self.CLASSES) + self.cfg = cfg + + def forward(self, Inputs): + return self.model(**Inputs) + + def postprocess(self, Inputs): + semantic_result = Inputs[0] + + ids = np.unique(semantic_result)[::-1] + legal_indices = ids != self.model.num_classes # for VOID label + ids = ids[legal_indices] + + segms = (semantic_result[None] == ids[:, None, None]) + masks = [it.astype(np.int) for it in segms] + labels_txt = np.array(self.CLASSES)[ids].tolist() + + results = { + OutputKeys.MASKS: masks, + OutputKeys.LABELS: labels_txt, + OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] + } + return results + + def inference(self, data): + with torch.no_grad(): + results = self.model(return_loss=False, rescale=True, **data) + + return results diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/__init__.py new file mode 100644 index 00000000..3b9a301c --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/__init__.py @@ -0,0 +1,5 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +from .models import backbone, decode_heads, segmentors +from .utils import (ResizeToMultiple, add_prefix, build_pixel_sampler, + seg_resize) diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/__init__.py new file mode 100644 index 00000000..791dd26f --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/__init__.py @@ -0,0 +1,5 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +from .backbone import BASEBEiT, BEiTAdapter +from .decode_heads import Mask2FormerHeadFromMMSeg +from .segmentors import EncoderDecoderMask2Former diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/__init__.py new file mode 100644 index 00000000..7abd0ef1 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/__init__.py @@ -0,0 +1,6 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +from .base import BASEBEiT +from .beit_adapter import BEiTAdapter + +__all__ = ['BEiTAdapter', 'BASEBEiT'] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/adapter_modules.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/adapter_modules.py new file mode 100644 index 00000000..cf30cca0 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/adapter_modules.py @@ -0,0 +1,522 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git + +import logging +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmdet.models.utils.transformer import MultiScaleDeformableAttention +from timm.models.layers import DropPath + +_logger = logging.getLogger(__name__) + + +def get_reference_points(spatial_shapes, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace( + 0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] + return reference_points + + +def deform_inputs(x): + bs, c, h, w = x.shape + spatial_shapes = torch.as_tensor([(h // 8, w // 8), (h // 16, w // 16), + (h // 32, w // 32)], + dtype=torch.long, + device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 16, w // 16)], x.device) + deform_inputs1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = torch.as_tensor([(h // 16, w // 16)], + dtype=torch.long, + device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 8, w // 8), + (h // 16, w // 16), + (h // 32, w // 32)], x.device) + deform_inputs2 = [reference_points, spatial_shapes, level_start_index] + + return deform_inputs1, deform_inputs2 + + +class ConvFFN(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0:16 * n, :].transpose(1, 2).view(B, C, H * 2, + W * 2).contiguous() + x2 = x[:, 16 * n:20 * n, :].transpose(1, 2).view(B, C, H, + W).contiguous() + x3 = x[:, 20 * n:, :].transpose(1, 2).view(B, C, H // 2, + W // 2).contiguous() + x1 = self.dwconv(x1).flatten(2).transpose(1, 2) + x2 = self.dwconv(x2).flatten(2).transpose(1, 2) + x3 = self.dwconv(x3).flatten(2).transpose(1, 2) + x = torch.cat([x1, x2, x3], dim=1) + return x + + +class Extractor(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + with_cffn=True, + cffn_ratio=0.25, + drop=0., + drop_path=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + with_cp=False): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MultiScaleDeformableAttention( + embed_dims=dim, + num_heads=num_heads, + num_levels=n_levels, + num_points=n_points, + batch_first=True) + + # modify to fit the deform_ratio + value_proj_in_features = self.attn.value_proj.weight.shape[0] + value_proj_out_features = int(value_proj_in_features * deform_ratio) + self.attn.value_proj = nn.Linear(value_proj_in_features, + value_proj_out_features) + self.attn.output_proj = nn.Linear(value_proj_out_features, + value_proj_in_features) + + self.with_cffn = with_cffn + self.with_cp = with_cp + if with_cffn: + self.ffn = ConvFFN( + in_features=dim, + hidden_features=int(dim * cffn_ratio), + drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, query, reference_points, feat, spatial_shapes, + level_start_index, H, W): + + def _inner_forward(query, feat): + attn = self.attn( + query=self.query_norm(query), + key=None, + value=self.feat_norm(feat), + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index) + + query = query + attn + + if self.with_cffn: + query = query + self.drop_path( + self.ffn(self.ffn_norm(query), H, W)) + return query + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class Injector(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=0., + with_cp=False): + super().__init__() + self.with_cp = with_cp + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MultiScaleDeformableAttention( + embed_dims=dim, + num_heads=num_heads, + num_levels=n_levels, + num_points=n_points, + batch_first=True) + + # modify to fit the deform_ratio + value_proj_in_features = self.attn.value_proj.weight.shape[0] + value_proj_out_features = int(value_proj_in_features * deform_ratio) + self.attn.value_proj = nn.Linear(value_proj_in_features, + value_proj_out_features) + self.attn.output_proj = nn.Linear(value_proj_out_features, + value_proj_in_features) + + self.gamma = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, query, reference_points, feat, spatial_shapes, + level_start_index): + + def _inner_forward(query, feat): + input_query = self.query_norm(query) + input_value = self.feat_norm(feat) + attn = self.attn( + query=input_query, + key=None, + value=input_value, + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index) + return query + self.gamma * attn + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class InteractionBlock(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + drop_path=0., + with_cffn=True, + cffn_ratio=0.25, + init_values=0., + deform_ratio=1.0, + extra_extractor=False, + with_cp=False): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp) + if extra_extractor: + self.extra_extractors = nn.Sequential(*[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp) for _ in range(2) + ]) + else: + self.extra_extractors = None + + def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H, W): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2]) + for idx, blk in enumerate(blocks): + x = blk(x, H, W) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + return x, c + + +class InteractionBlockWithCls(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + drop_path=0., + with_cffn=True, + cffn_ratio=0.25, + init_values=0., + deform_ratio=1.0, + extra_extractor=False, + with_cp=False): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp) + if extra_extractor: + self.extra_extractors = nn.Sequential(*[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp) for _ in range(2) + ]) + else: + self.extra_extractors = None + + def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H, W): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2]) + x = torch.cat((cls, x), dim=1) + for idx, blk in enumerate(blocks): + x = blk(x, H, W) + cls, x = x[:, :1, ], x[:, 1:, ] + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + return x, c, cls + + +class SpatialPriorModule(nn.Module): + + def __init__(self, inplanes=64, embed_dim=384, with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.stem = nn.Sequential(*[ + nn.Conv2d( + 3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d( + inplanes, + inplanes, + kernel_size=3, + stride=1, + padding=1, + bias=False), + nn.BatchNorm2d(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d( + inplanes, + inplanes, + kernel_size=3, + stride=1, + padding=1, + bias=False), + nn.BatchNorm2d(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ]) + self.conv2 = nn.Sequential(*[ + nn.Conv2d( + inplanes, + 2 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False), + nn.BatchNorm2d(2 * inplanes), + nn.ReLU(inplace=True) + ]) + self.conv3 = nn.Sequential(*[ + nn.Conv2d( + 2 * inplanes, + 4 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False), + nn.BatchNorm2d(4 * inplanes), + nn.ReLU(inplace=True) + ]) + self.conv4 = nn.Sequential(*[ + nn.Conv2d( + 4 * inplanes, + 4 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False), + nn.BatchNorm2d(4 * inplanes), + nn.ReLU(inplace=True) + ]) + self.fc1 = nn.Conv2d( + inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc2 = nn.Conv2d( + 2 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + self.fc3 = nn.Conv2d( + 4 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + self.fc4 = nn.Conv2d( + 4 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + + def forward(self, x): + + def _inner_forward(x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + + c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s + c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s + c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s + + return c1, c2, c3, c4 + + if self.with_cp and x.requires_grad: + outs = cp.checkpoint(_inner_forward, x) + else: + outs = _inner_forward(x) + return outs diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/__init__.py new file mode 100644 index 00000000..5b33031f --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/__init__.py @@ -0,0 +1,5 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +from .beit import BASEBEiT + +__all__ = ['BASEBEiT'] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/beit.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/beit.py new file mode 100644 index 00000000..62f873ec --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/base/beit.py @@ -0,0 +1,474 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.runner import _load_checkpoint +from mmdet.models.builder import BACKBONES +from mmdet.utils import get_root_logger +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # commit dropout for the original BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + window_size=None, + attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] + - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, + coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum( + -1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + self.register_buffer('relative_position_index', + relative_position_index) + + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + (self.q_bias, + torch.zeros_like(self.v_bias, + requires_grad=False), self.v_bias)) + + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + window_size=None, + attn_head_dim=None, + with_cp=False): + super().__init__() + self.with_cp = with_cp + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if init_values is not None: + self.gamma_1 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, H, W, rel_pos_bias=None): + + def _inner_forward(x): + if self.gamma_1 is None: + x = x + self.drop_path( + self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn( + self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * ( + img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, Hp, Wp + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, + backbone, + img_size=224, + feature_size=None, + in_chans=3, + embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone( + torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] + - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, + 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +@BACKBONES.register_module() +class BASEBEiT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, + img_size=512, + patch_size=16, + in_chans=3, + num_classes=80, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + hybrid_backbone=None, + norm_layer=None, + init_values=None, + use_checkpoint=False, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, + pretrained=None, + with_cp=False): + super().__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.norm_layer = norm_layer + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.drop_path_rate = drop_path_rate + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, + img_size=img_size, + in_chans=in_chans, + embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias( + window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + with_cp=with_cp, + init_values=init_values, + window_size=self.patch_embed.patch_shape + if use_rel_pos_bias else None) for i in range(depth) + ]) + + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + self.init_weights(pretrained) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + checkpoint = _load_checkpoint( + init_cfg['checkpoint'], logger=logger, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + self.load_state_dict(state_dict, False) + + def fix_init_weight(self): + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/beit_adapter.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/beit_adapter.py new file mode 100644 index 00000000..182fc0c1 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/backbone/beit_adapter.py @@ -0,0 +1,168 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.models.builder import BACKBONES +from mmdet.models.utils.transformer import MultiScaleDeformableAttention +from timm.models.layers import DropPath, trunc_normal_ +from torch.nn.init import normal_ + +from .adapter_modules import InteractionBlockWithCls as InteractionBlock +from .adapter_modules import SpatialPriorModule, deform_inputs +from .base.beit import BASEBEiT + +_logger = logging.getLogger(__name__) + + +@BACKBONES.register_module() +class BEiTAdapter(BASEBEiT): + + def __init__(self, + pretrain_size=224, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + init_values=0., + cffn_ratio=0.25, + deform_ratio=1.0, + with_cffn=True, + interaction_indexes=None, + add_vit_feature=True, + with_cp=False, + *args, + **kwargs): + + super().__init__( + init_values=init_values, with_cp=with_cp, *args, **kwargs) + + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.flags = [ + i for i in range(-1, self.num_block, self.num_block // 4) + ][1:] + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.embed_dim + + self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) + self.spm = SpatialPriorModule( + inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) + self.interactions = nn.Sequential(*[ + InteractionBlock( + dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=self.drop_path_rate, + norm_layer=self.norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + extra_extractor=True if i == len(interaction_indexes) + - 1 else False, + with_cp=with_cp) for i in range(len(interaction_indexes)) + ]) + + self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.BatchNorm2d(embed_dim) + self.norm2 = nn.BatchNorm2d(embed_dim) + self.norm3 = nn.BatchNorm2d(embed_dim) + self.norm4 = nn.BatchNorm2d(embed_dim) + + self.up.apply(self._init_weights) + self.spm.apply(self._init_weights) + self.interactions.apply(self._init_weights) + self.apply(self._init_deform_weights) + normal_(self.level_embed) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape(1, self.pretrain_size[0] // 16, + self.pretrain_size[1] // 16, + -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \ + reshape(1, -1, H * W).permute(0, 2, 1) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + + def _add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_inputs1, deform_inputs2 = deform_inputs(x) + + # SPM forward + c1, c2, c3, c4 = self.spm(x) + c2, c3, c4 = self._add_level_embed(c2, c3, c4) + c = torch.cat([c2, c3, c4], dim=1) + + # Patch Embedding forward + x, H, W = self.patch_embed(x) + bs, n, dim = x.shape + cls = self.cls_token.expand( + bs, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + + if self.pos_embed is not None: + pos_embed = self._get_pos_embed(self.pos_embed, H, W) + x = x + pos_embed + x = self.pos_drop(x) + + # Interaction + outs = list() + for i, layer in enumerate(self.interactions): + indexes = self.interaction_indexes[i] + x, c, cls = layer(x, c, cls, + self.blocks[indexes[0]:indexes[-1] + 1], + deform_inputs1, deform_inputs2, H, W) + outs.append(x.transpose(1, 2).view(bs, dim, H, W).contiguous()) + + # Split & Reshape + c2 = c[:, 0:c2.size(1), :] + c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :] + c4 = c[:, c2.size(1) + c3.size(1):, :] + + c2 = c2.transpose(1, 2).view(bs, dim, H * 2, W * 2).contiguous() + c3 = c3.transpose(1, 2).view(bs, dim, H, W).contiguous() + c4 = c4.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous() + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + x1 = F.interpolate( + x1, scale_factor=4, mode='bilinear', align_corners=False) + x2 = F.interpolate( + x2, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.interpolate( + x4, scale_factor=0.5, mode='bilinear', align_corners=False) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + # Final Norm + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + return [f1, f2, f3, f4] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/__init__.py new file mode 100644 index 00000000..12bf2a21 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/__init__.py @@ -0,0 +1,5 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +from .mask2former_head_from_mmseg import Mask2FormerHeadFromMMSeg + +__all__ = ['Mask2FormerHeadFromMMSeg'] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/base_decode_head.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/base_decode_head.py new file mode 100644 index 00000000..ae7a0416 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/base_decode_head.py @@ -0,0 +1,266 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +from abc import ABCMeta, abstractmethod + +import torch +import torch.nn as nn +from mmcv.runner import BaseModule, auto_fp16, force_fp32 +from mmdet.models.builder import build_loss +from mmdet.models.losses import accuracy + +from ...utils import build_pixel_sampler, seg_resize + + +class BaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + Args: + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict | Sequence[dict]): Config of decode loss. + The `loss_name` is property of corresponding loss function which + could be shown in training log. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + e.g. dict(type='CrossEntropyLoss'), + [dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='DiceLoss', loss_name='loss_dice')] + Default: dict(type='CrossEntropyLoss'). + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255. + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + channels, + *, + num_classes, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + in_index=-1, + input_transform=None, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + ignore_index=255, + sampler=None, + align_corners=False, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='conv_seg'))): + super(BaseDecodeHead, self).__init__(init_cfg) + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.num_classes = num_classes + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + + self.ignore_index = ignore_index + self.align_corners = align_corners + + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + self.fp16_enabled = False + + def extra_repr(self): + """Extra repr.""" + s = f'input_transform={self.input_transform}, ' \ + f'ignore_index={self.ignore_index}, ' \ + f'align_corners={self.align_corners}' + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + seg_resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + @auto_fp16() + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function.""" + pass + + def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs) + losses = self.losses(seg_logits, gt_semantic_seg) + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + return self.forward(inputs) + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + @force_fp32(apply_to=('seg_logit', )) + def losses(self, seg_logit, seg_label): + """Compute segmentation loss.""" + loss = dict() + seg_logit = seg_resize( + input=seg_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logit, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy( + seg_logit, seg_label, ignore_index=self.ignore_index) + return loss diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/mask2former_head_from_mmseg.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/mask2former_head_from_mmseg.py new file mode 100644 index 00000000..c0681d2b --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/decode_heads/mask2former_head_from_mmseg.py @@ -0,0 +1,580 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.ops import point_sample +from mmcv.runner import ModuleList, force_fp32 +from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean +from mmdet.models.builder import HEADS, build_loss +from mmdet.models.utils import get_uncertain_point_coords_with_randomness + +from .base_decode_head import BaseDecodeHead + + +@HEADS.register_module() +class Mask2FormerHeadFromMMSeg(BaseDecodeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of + Mask2Former head. + test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of + Mask2Former head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels, + feat_channels, + out_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=None, + loss_mask=None, + loss_dice=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + **kwargs): + super(Mask2FormerHeadFromMMSeg, self).__init__( + in_channels=in_channels, + channels=feat_channels, + num_classes=(num_things_classes + num_stuff_classes), + init_cfg=init_cfg, + input_transform='multiple_select', + **kwargs) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.transformerlayers. \ + attn_cfgs.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] + self.transformer_decoder = build_transformer_layer_sequence( + transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if (self.decoder_embed_dims != feat_channels + or enforce_decoder_input_project): + self.decoder_input_projs.append( + Conv2d( + feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = build_positional_encoding( + positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, + feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + self.conv_seg = None # fix a bug here (conv_seg is not used) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + self.sampler = build_sampler(self.train_cfg.sampler, context=self) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + def init_weights(self): + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, + gt_masks_list, img_metas): + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape [num_queries, + cls_out_channels]. + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape [num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for all + images. Each with shape (n, ), n is the sum of number of stuff + type and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[list[Tensor]]: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images. + Each with shape [num_queries, ]. + - label_weights_list (list[Tensor]): Label weights of all + images.Each with shape [num_queries, ]. + - mask_targets_list (list[Tensor]): Mask targets of all images. + Each with shape [num_queries, h, w]. + - mask_weights_list (list[Tensor]): Mask weights of all images. + Each with shape [num_queries, ]. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + """ + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + pos_inds_list, + neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, + mask_preds_list, gt_labels_list, + gt_masks_list, img_metas) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, mask_targets_list, + mask_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, + img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_labels (Tensor): Ground truth class indices for one image with + shape (num_gts, ). + gt_masks (Tensor): Ground truth mask for each image, each with + shape (num_gts, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + """ + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), + device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample( + mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, + 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample( + gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, + 1)).squeeze(1) + + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_points_pred, + gt_labels, gt_points_masks, + img_metas) + sampling_result = self.sampler.sample(assign_result, mask_pred, + gt_masks) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries, )) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds) + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (num_gts, ). + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (num_gts, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + num_total_pos, + num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, + gt_labels_list, gt_masks_list, + img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, + self.oversample_ratio, self.importance_sample_ratio) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice( + mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1, 1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * self.num_points) + + return loss_cls, loss_mask, loss_dice + + @force_fp32(apply_to=('all_cls_scores', 'all_mask_preds')) + def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape [num_decoder, batch_size, num_queries, + cls_out_channels]. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape [num_decoder, batch_size, num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (n, ). n is the sum of number of stuff type + and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image with + shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self.loss_single, all_cls_scores, all_mask_preds, + all_gt_labels_list, all_gt_masks_list, img_metas_list) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_mask'] = losses_mask[-1] + loss_dict['loss_dice'] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip( + losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + decoder_out = decoder_out.transpose(0, 1) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + tuple: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_size = len(img_metas) + mask_features, multi_scale_memorys = self.pixel_decoder(feats) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + decoder_input = decoder_input.flatten(2).permute(2, 0, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros( + (batch_size, ) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.flatten( + 2).permute(2, 0, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (num_queries, batch_size, c) + query_feat = self.query_feat.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + attn_masks = [attn_mask, None] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + attn_masks=attn_masks, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[ + (i + 1) % self.num_transformer_feat_level].shape[-2:]) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list + + def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, + gt_masks): + """Forward function for training mode. + + Args: + x (list[Tensor]): Multi-level features from the upstream network, + each is a 4D-tensor. + img_metas (list[Dict]): List of image information. + gt_semantic_seg (list[tensor]):Each element is the ground truth + of semantic segmentation with the shape (N, H, W). + train_cfg (dict): The training config, which not been used in + maskformer. + gt_labels (list[Tensor]): Each element is ground truth labels of + each box, shape (num_gts,). + gt_masks (list[BitmapMasks]): Each element is masks of instances + of a image, shape (num_gts, h, w). + + Returns: + losses (dict[str, Tensor]): a dictionary of loss components + """ + + # forward + all_cls_scores, all_mask_preds = self(x, img_metas) + + # loss + losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, + img_metas) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + inputs (list[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + test_cfg (dict): Testing config. + + Returns: + seg_mask (Tensor): Predicted semantic segmentation logits. + """ + all_cls_scores, all_mask_preds = self(inputs, img_metas) + cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] + ori_h, ori_w, _ = img_metas[0]['ori_shape'] + + # semantic inference + cls_score = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_mask = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) + return seg_mask diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/__init__.py new file mode 100644 index 00000000..18bbce0d --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/__init__.py @@ -0,0 +1,5 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +from .encoder_decoder_mask2former import EncoderDecoderMask2Former + +__all__ = ['EncoderDecoderMask2Former'] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/base_segmentor.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/base_segmentor.py new file mode 100644 index 00000000..311352c2 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/base_segmentor.py @@ -0,0 +1,313 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import mmcv +import numpy as np +import torch +import torch.distributed as dist +from mmcv.runner import BaseModule, auto_fp16 + + +class BaseSegmentor(BaseModule, metaclass=ABCMeta): + """Base class for segmentors.""" + + def __init__(self, init_cfg=None): + super(BaseSegmentor, self).__init__(init_cfg) + self.fp16_enabled = False + + @property + def with_neck(self): + """bool: whether the segmentor has neck""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_auxiliary_head(self): + """bool: whether the segmentor has auxiliary head""" + return hasattr(self, + 'auxiliary_head') and self.auxiliary_head is not None + + @property + def with_decode_head(self): + """bool: whether the segmentor has decode head""" + return hasattr(self, 'decode_head') and self.decode_head is not None + + @abstractmethod + def extract_feat(self, imgs): + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, img, img_metas): + """Placeholder for encode images with backbone and decode into a + semantic segmentation map of the same size as input.""" + pass + + @abstractmethod + def forward_train(self, imgs, img_metas, **kwargs): + """Placeholder for Forward function for training.""" + pass + + @abstractmethod + def simple_test(self, img, img_meta, **kwargs): + """Placeholder for single image test.""" + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Placeholder for augmentation test.""" + pass + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError(f'{name} must be a list, but got ' + f'{type(var)}') + + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f'num of augmentations ({len(imgs)}) != ' + f'num of image meta ({len(img_metas)})') + + # all images in the same aug batch all of the same ori_shape and pad + # shape + def tensor_to_tuple(input_tensor): + return tuple(input_tensor.cpu().numpy()) + + for img_meta in img_metas: + ori_shapes = [_['ori_shape'] for _ in img_meta] + if isinstance(ori_shapes[0], torch.Tensor): + assert all( + tensor_to_tuple(shape) == tensor_to_tuple(ori_shapes[0]) + for shape in ori_shapes) + else: + assert all(shape == ori_shapes[0] for shape in ori_shapes) + + img_shapes = [_['img_shape'] for _ in img_meta] + if isinstance(img_shapes[0], torch.Tensor): + assert all( + tensor_to_tuple(shape) == tensor_to_tuple(img_shapes[0]) + for shape in img_shapes) + else: + assert all(shape == img_shapes[0] for shape in img_shapes) + + pad_shapes = [_['pad_shape'] for _ in img_meta] + if isinstance(pad_shapes[0], torch.Tensor): + assert all( + tensor_to_tuple(shape) == tensor_to_tuple(pad_shapes[0]) + for shape in pad_shapes) + else: + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=('img', )) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(data_batch['img_metas'])) + + return outputs + + def val_step(self, data_batch, optimizer=None, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + losses = self(**data_batch) + loss, log_vars = self._parse_losses(losses) + + log_vars_ = dict() + for loss_name, loss_value in log_vars.items(): + k = loss_name + '_val' + log_vars_[k] = loss_value + + outputs = dict( + loss=loss, + log_vars=log_vars_, + num_samples=len(data_batch['img_metas'])) + + return outputs + + @staticmethod + def _parse_losses(losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + # If the loss_vars has different length, raise assertion error + # to prevent GPUs from infinite waiting. + if dist.is_available() and dist.is_initialized(): + log_var_length = torch.tensor(len(log_vars), device=loss.device) + dist.all_reduce(log_var_length) + message = (f'rank {dist.get_rank()}' + + f' len(log_vars): {len(log_vars)}' + ' keys: ' + + ','.join(log_vars.keys()) + '\n') + assert log_var_length == len(log_vars) * dist.get_world_size(), \ + 'loss log variables are different across GPUs!\n' + message + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def show_result(self, + img, + result, + palette=None, + win_name='', + show=False, + wait_time=0, + out_file=None, + opacity=0.5): + """Draw `result` over `img`. + + Args: + img (str or Tensor): The image to be displayed. + result (Tensor): The semantic segmentation results to draw over + `img`. + palette (list[list[int]]] | np.ndarray | None): The palette of + segmentation map. If None is given, random palette will be + generated. Default: None + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The filename to write the image. + Default: None. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + Returns: + img (Tensor): Only if not `show` or `out_file` + """ + img = mmcv.imread(img) + img = img.copy() + seg = result[0] + if palette is None: + if self.PALETTE is None: + # Get random state before set seed, + # and restore random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + palette = np.random.randint( + 0, 255, size=(len(self.CLASSES), 3)) + np.random.set_state(state) + else: + palette = self.PALETTE + palette = np.array(palette) + assert palette.shape[0] == len(self.CLASSES) + assert palette.shape[1] == 3 + assert len(palette.shape) == 2 + assert 0 < opacity <= 1.0 + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * (1 - opacity) + color_seg * opacity + img = img.astype(np.uint8) + # if out_file specified, do not show image in window + if out_file is not None: + show = False + + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, only ' + 'result image will be returned') + return img diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/encoder_decoder_mask2former.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/encoder_decoder_mask2former.py new file mode 100644 index 00000000..50492374 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/models/segmentors/encoder_decoder_mask2former.py @@ -0,0 +1,302 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.models import builder +from mmdet.models.builder import DETECTORS + +from ...utils import add_prefix, seg_resize +from .base_segmentor import BaseSegmentor + + +@DETECTORS.register_module() +class EncoderDecoderMask2Former(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__(self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(EncoderDecoderMask2Former, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + decode_head.update(train_cfg=train_cfg) + decode_head.update(test_cfg=test_cfg) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = seg_resize( + input=out, + size=img.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, + **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, + gt_semantic_seg, **kwargs) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, + gt_semantic_seg, + self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, + gt_semantic_seg, + **kwargs) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train( + x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + # TODO refactor + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy( + count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + + def tensor_to_tuple(input_tensor): + return tuple(input_tensor.cpu().numpy()) + + if rescale: + preds = seg_resize( + preds, + size=tensor_to_tuple(img_meta[0]['ori_shape'])[:2] + if isinstance(img_meta[0]['ori_shape'], torch.Tensor) else + img_meta[0]['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]['ori_shape'][:2] + seg_logit = seg_resize( + seg_logit, + size=size, + mode='bilinear', + align_corners=self.align_corners, + warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = img_meta[0]['ori_shape'] + + def tensor_to_tuple(input_tensor): + return tuple(input_tensor.cpu().numpy()) + + if isinstance(ori_shape, torch.Tensor): + assert all( + tensor_to_tuple(_['ori_shape']) == tensor_to_tuple(ori_shape) + for _ in img_meta) + else: + assert all(_['ori_shape'] == ori_shape for _ in img_meta) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]['flip'] + if flip: + flip_direction = img_meta[0]['flip_direction'] + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + output = output.flip(dims=(3, )) + elif flip_direction == 'vertical': + output = output.flip(dims=(2, )) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/__init__.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/__init__.py new file mode 100644 index 00000000..9c4d5c4c --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/__init__.py @@ -0,0 +1,9 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +from .builder import build_pixel_sampler +from .data_process_func import ResizeToMultiple +from .seg_func import add_prefix, seg_resize + +__all__ = [ + 'seg_resize', 'add_prefix', 'build_pixel_sampler', 'ResizeToMultiple' +] diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/builder.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/builder.py new file mode 100644 index 00000000..0603ef94 --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/builder.py @@ -0,0 +1,10 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git +from mmcv.utils import Registry, build_from_cfg + +PIXEL_SAMPLERS = Registry('pixel sampler') + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/data_process_func.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/data_process_func.py new file mode 100644 index 00000000..194361af --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/data_process_func.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +from mmdet.datasets.builder import PIPELINES + + +@PIPELINES.register_module() +class ResizeToMultiple(object): + """Resize images & seg to multiple of divisor. + + Args: + size_divisor (int): images and gt seg maps need to resize to multiple + of size_divisor. Default: 32. + interpolation (str, optional): The interpolation mode of image resize. + Default: None + """ + + def __init__(self, size_divisor=32, interpolation=None): + self.size_divisor = size_divisor + self.interpolation = interpolation + + def __call__(self, results): + """Call function to resize images, semantic segmentation map to + multiple of size divisor. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape' keys are updated. + """ + # Align image to multiple of size divisor. + img = results['img'] + img = mmcv.imresize_to_multiple( + img, + self.size_divisor, + scale_factor=1, + interpolation=self.interpolation + if self.interpolation else 'bilinear') + + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape + + # Align segmentation map to multiple of size divisor. + for key in results.get('seg_fields', []): + gt_seg = results[key] + gt_seg = mmcv.imresize_to_multiple( + gt_seg, + self.size_divisor, + scale_factor=1, + interpolation='nearest') + results[key] = gt_seg + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size_divisor={self.size_divisor}, ' + f'interpolation={self.interpolation})') + return repr_str diff --git a/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/seg_func.py b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/seg_func.py new file mode 100644 index 00000000..db564cca --- /dev/null +++ b/modelscope/models/cv/image_semantic_segmentation/vit_adapter/utils/seg_func.py @@ -0,0 +1,47 @@ +# The implementation is adopted from VitAdapter, +# made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git + +import warnings + +import torch.nn.functional as F + + +def seg_resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > input_w: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs diff --git a/modelscope/models/cv/image_to_image_generation/__init__.py b/modelscope/models/cv/image_to_image_generation/__init__.py new file mode 100644 index 00000000..1af3e55f --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from . import data, models, ops diff --git a/modelscope/models/cv/image_to_image_generation/data/__init__.py b/modelscope/models/cv/image_to_image_generation/data/__init__.py new file mode 100644 index 00000000..22b9d22c --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/data/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .transforms import PadToSquare + +else: + _import_structure = { + 'transforms': ['PadToSquare'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) + +# from .transforms import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_generation/data/transforms.py b/modelscope/models/cv/image_to_image_generation/data/transforms.py new file mode 100644 index 00000000..29a25b4b --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/data/transforms.py @@ -0,0 +1,122 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math +import random + +import torchvision.transforms.functional as TF +from PIL import Image, ImageFilter + +__all__ = [ + 'Identity', 'PadToSquare', 'RandomScale', 'RandomRotate', + 'RandomGaussianBlur', 'RandomCrop' +] + + +class Identity(object): + + def __call__(self, *args): + if len(args) == 0: + return None + elif len(args) == 1: + return args[0] + else: + return args + + +class PadToSquare(object): + + def __init__(self, fill=(255, 255, 255)): + self.fill = fill + + def __call__(self, img): + w, h = img.size + if w != h: + if w > h: + t = (w - h) // 2 + b = w - h - t + padding = (0, t, 0, b) + else: + left = (h - w) // 2 + right = h - w - l + padding = (left, 0, right, 0) + img = TF.pad(img, padding, fill=self.fill) + return img + + +class RandomScale(object): + + def __init__(self, + min_scale=0.5, + max_scale=2.0, + min_ratio=0.8, + max_ratio=1.25): + self.min_scale = min_scale + self.max_scale = max_scale + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + def __call__(self, img): + w, h = img.size + scale = 2**random.uniform( + math.log2(self.min_scale), math.log2(self.max_scale)) + ratio = 2**random.uniform( + math.log2(self.min_ratio), math.log2(self.max_ratio)) + ow = int(w * scale * math.sqrt(ratio)) + oh = int(h * scale / math.sqrt(ratio)) + img = img.resize((ow, oh), Image.BILINEAR) + return img + + +class RandomRotate(object): + + def __init__(self, + min_angle=-10.0, + max_angle=10.0, + padding=(255, 255, 255), + p=0.5): + self.min_angle = min_angle + self.max_angle = max_angle + self.padding = padding + self.p = p + + def __call__(self, img): + if random.random() < self.p: + angle = random.uniform(self.min_angle, self.max_angle) + img = img.rotate(angle, Image.BILINEAR, fillcolor=self.padding) + return img + + +class RandomGaussianBlur(object): + + def __init__(self, radius=5, p=0.5): + self.radius = radius + self.p = p + + def __call__(self, img): + if random.random() < self.p: + img = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) + return img + + +class RandomCrop(object): + + def __init__(self, size, padding=(255, 255, 255)): + self.size = size + self.padding = padding + + def __call__(self, img): + # pad + w, h = img.size + pad_w = max(0, self.size - w) + pad_h = max(0, self.size - h) + if pad_w > 0 or pad_h > 0: + half_w = pad_w // 2 + half_h = pad_h // 2 + pad = (half_w, half_h, pad_w - half_w, pad_h - half_h) + img = TF.pad(img, pad, fill=self.padding) + + # crop + w, h = img.size + x1 = random.randint(0, w - self.size) + y1 = random.randint(0, h - self.size) + img = img.crop((x1, y1, x1 + self.size, y1 + self.size)) + return img diff --git a/modelscope/models/cv/image_to_image_generation/model.py b/modelscope/models/cv/image_to_image_generation/model.py new file mode 100644 index 00000000..94e5dd7b --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/model.py @@ -0,0 +1,323 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['UNet'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, scale_factor=1.0): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.scale_factor = scale_factor + + def forward(self, x): + if self.scale_factor == 2.0: + x = F.interpolate(x, scale_factor=2, mode='nearest') + elif self.scale_factor == 0.5: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, embed_dim, out_dim, dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.embedding = nn.Sequential(nn.SiLU(), + nn.Linear(embed_dim, out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, y): + identity = x + x = self.layer1(x) + x = x + self.embedding(y).unsqueeze(-1).unsqueeze(-1) + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class MultiHeadAttention(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=8, dropout=0.0): + assert dim % num_heads == 0 + assert context_dim is None or context_dim % num_heads == 0 + context_dim = context_dim or dim + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = math.pow(self.head_dim, -0.25) + + # layers + self.q = nn.Linear(dim, dim, bias=False) + self.k = nn.Linear(context_dim, dim, bias=False) + self.v = nn.Linear(context_dim, dim, bias=False) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None): + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # compute attention + attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + attn = self.dropout(attn) + + # gather context + x = torch.einsum('bnij,bjnc->binc', attn, v) + x = x.reshape(b, -1, n * c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class GLU(nn.Module): + + def __init__(self, in_dim, out_dim): + super(GLU, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.proj = nn.Linear(in_dim, out_dim * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class TransformerBlock(nn.Module): + + def __init__(self, dim, context_dim, num_heads, dropout=0.0): + super(TransformerBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + # input + self.norm1 = nn.GroupNorm(32, dim, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(dim, dim, 1) + + # self attention + self.norm2 = nn.LayerNorm(dim) + self.self_attn = MultiHeadAttention(dim, None, num_heads, dropout) + + # cross attention + self.norm3 = nn.LayerNorm(dim) + self.cross_attn = MultiHeadAttention(dim, context_dim, num_heads, + dropout) + + # ffn + self.norm4 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + GLU(dim, dim * 4), nn.Dropout(dropout), nn.Linear(dim * 4, dim)) + + # output + self.conv2 = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.conv2.weight) + + def forward(self, x, context): + b, c, h, w = x.size() + identity = x + + # input + x = self.norm1(x) + x = self.conv1(x).view(b, c, -1).transpose(1, 2) + + # attention + x = x + self.self_attn(self.norm2(x)) + x = x + self.cross_attn(self.norm3(x), context) + x = x + self.ffn(self.norm4(x)) + + # output + x = x.transpose(1, 2).view(b, c, h, w) + x = self.conv2(x) + return x + identity + + +class UNet(nn.Module): + + def __init__(self, + resolution=64, + in_dim=3, + dim=192, + label_dim=512, + context_dim=512, + out_dim=3, + dim_mult=[1, 2, 3, 5], + num_heads=1, + head_dim=None, + num_res_blocks=2, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + dropout=0.0): + embed_dim = dim * 4 + super(UNet, self).__init__() + self.resolution = resolution + self.in_dim = in_dim + self.dim = dim + self.context_dim = context_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.clip_embedding = nn.Sequential( + nn.Linear(label_dim, context_dim), nn.SiLU(), + nn.Linear(context_dim, context_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList( + [ResidualBlock(in_dim, embed_dim, out_dim, dropout)]) + if scale in attn_scales: + block.append( + TransformerBlock(out_dim, context_dim, num_heads)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + self.encoder.append( + nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1)) + shortcut_dims.append(out_dim) + scale /= 2.0 + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, dropout), + TransformerBlock(out_dim, context_dim, num_heads), + ResidualBlock(out_dim, embed_dim, out_dim, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, dropout) + ]) + if scale in attn_scales: + block.append( + TransformerBlock(out_dim, context_dim, num_heads, + dropout)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + block.append( + nn.Sequential( + Resample(scale_factor=2.0), + nn.Conv2d(out_dim, out_dim, 3, padding=1))) + scale *= 2.0 + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y): + # embeddings + t = self.time_embedding(sinusoidal_embedding(t, self.dim)) + y = self.clip_embedding(y) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, t, y) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, t, y) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, t, y) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, t, y): + if isinstance(module, ResidualBlock): + x = module(x, t) + elif isinstance(module, TransformerBlock): + x = module(x, y) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, t, y) + else: + x = module(x) + return x diff --git a/modelscope/models/cv/image_to_image_generation/models/__init__.py b/modelscope/models/cv/image_to_image_generation/models/__init__.py new file mode 100644 index 00000000..e98421f2 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/models/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .autoencoder import VQAutoencoder + from .clip import VisionTransformer + +else: + _import_structure = { + 'autoencoder': ['VQAutoencoder'], + 'clip': ['VisionTransformer'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_to_image_generation/models/autoencoder.py b/modelscope/models/cv/image_to_image_generation/models/autoencoder.py new file mode 100644 index 00000000..dce256f6 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/models/autoencoder.py @@ -0,0 +1,413 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['VQAutoencoder', 'KLAutoencoder', 'PatchDiscriminator'] + + +def group_norm(dim): + return nn.GroupNorm(32, dim, eps=1e-6, affine=True) + + +class Resample(nn.Module): + + def __init__(self, dim, scale_factor): + super(Resample, self).__init__() + self.dim = dim + self.scale_factor = scale_factor + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(dim, dim, 3, padding=1)) + elif scale_factor == 0.5: + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=2, padding=0)) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + group_norm(in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1), group_norm(out_dim), + nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Conv2d(in_dim, out_dim, + 1) if in_dim != out_dim else nn.Identity() + + # zero out the last layer params + nn.init.zeros_(self.residual[-1].weight) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class AttentionBlock(nn.Module): + + def __init__(self, dim): + super(AttentionBlock, self).__init__() + self.dim = dim + self.scale = math.pow(dim, -0.25) + + # layers + self.norm = group_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, h, w = x.size() + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, c * 3, -1).chunk(3, dim=1) + + # compute attention + attn = torch.einsum('bci,bcj->bij', q * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.einsum('bij,bcj->bci', attn, v) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class Encoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(Encoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = nn.Conv2d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + downsamples.append(Resample(out_dim, scale_factor=0.5)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + group_norm(out_dim), nn.SiLU(), + nn.Conv2d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x): + x = self.conv1(x) + x = self.downsamples(x) + x = self.middle(x) + x = self.head(x) + return x + + +class Decoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(Decoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = nn.Conv2d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + upsamples.append(Resample(out_dim, scale_factor=2.0)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + group_norm(out_dim), nn.SiLU(), + nn.Conv2d(out_dim, 3, 3, padding=1)) + + def forward(self, x): + x = self.conv1(x) + x = self.middle(x) + x = self.upsamples(x) + x = self.head(x) + return x + + +class VectorQuantizer(nn.Module): + + def __init__(self, codebook_size=8192, z_dim=3, beta=0.25): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size + self.z_dim = z_dim + self.beta = beta + + # init codebook + eps = math.sqrt(1.0 / codebook_size) + self.codebook = nn.Parameter( + torch.empty(codebook_size, z_dim).uniform_(-eps, eps)) + + def forward(self, z): + # preprocess + b, c, h, w = z.size() + flatten = z.permute(0, 2, 3, 1).reshape(-1, c) + + # quantization + with torch.no_grad(): + tokens = torch.cdist(flatten, self.codebook).argmin(dim=1) + quantized = F.embedding(tokens, + self.codebook).view(b, h, w, + c).permute(0, 3, 1, 2) + + # compute loss + codebook_loss = F.mse_loss(quantized, z.detach()) + commitment_loss = F.mse_loss(quantized.detach(), z) + loss = codebook_loss + self.beta * commitment_loss + + # perplexity + counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) + # dist.all_reduce(counts) + p = counts / counts.sum() + perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) + + # postprocess + tokens = tokens.view(b, h, w) + quantized = z + (quantized - z).detach() + return quantized, tokens, loss, perplexity + + +class VQAutoencoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0, + codebook_size=8192, + beta=0.25): + super(VQAutoencoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.codebook_size = codebook_size + self.beta = beta + + # blocks + self.encoder = Encoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + self.conv1 = nn.Conv2d(z_dim, z_dim, 1) + self.quantizer = VectorQuantizer(codebook_size, z_dim, beta) + self.conv2 = nn.Conv2d(z_dim, z_dim, 1) + self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + + def forward(self, x): + z = self.encoder(x) + z = self.conv1(z) + z, tokens, loss, perplexity = self.quantizer(z) + z = self.conv2(z) + x = self.decoder(z) + return x, tokens, loss, perplexity + + def encode(self, imgs): + z = self.encoder(imgs) + z = self.conv1(z) + return z + + def decode(self, z): + r"""Absort the quantizer in the decoder. + """ + z = self.quantizer(z)[0] + z = self.conv2(z) + imgs = self.decoder(z) + return imgs + + @torch.no_grad() + def encode_to_tokens(self, imgs): + # preprocess + z = self.encoder(imgs) + z = self.conv1(z) + + # quantization + b, c, h, w = z.size() + flatten = z.permute(0, 2, 3, 1).reshape(-1, c) + tokens = torch.cdist(flatten, self.quantizer.codebook).argmin(dim=1) + return tokens.view(b, -1) + + @torch.no_grad() + def decode_from_tokens(self, tokens): + # dequantization + z = F.embedding(tokens, self.quantizer.codebook) + + # postprocess + b, l, c = z.size() + h = w = int(math.sqrt(l)) + z = z.view(b, h, w, c).permute(0, 3, 1, 2) + z = self.conv2(z) + imgs = self.decoder(z) + return imgs + + +class KLAutoencoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(KLAutoencoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # blocks + self.encoder = Encoder(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, dropout) + self.conv1 = nn.Conv2d(z_dim * 2, z_dim * 2, 1) + self.conv2 = nn.Conv2d(z_dim, z_dim, 1) + self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x = self.decode(z) + return x, mu, log_var + + def encode(self, x): + x = self.encoder(x) + mu, log_var = self.conv1(x).chunk(2, dim=1) + return mu, log_var + + def decode(self, z): + x = self.conv2(z) + x = self.decoder(x) + return x + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + +class PatchDiscriminator(nn.Module): + + def __init__(self, in_dim=3, dim=64, num_layers=3): + super(PatchDiscriminator, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.num_layers = num_layers + + # params + dims = [dim * min(8, 2**u) for u in range(num_layers + 1)] + + # layers + layers = [ + nn.Conv2d(in_dim, dim, 4, stride=2, padding=1), + nn.LeakyReLU(0.2) + ] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + stride = 1 if i == num_layers - 1 else 2 + layers += [ + nn.Conv2d( + in_dim, out_dim, 4, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_dim), + nn.LeakyReLU(0.2) + ] + layers += [nn.Conv2d(out_dim, 1, 4, stride=1, padding=1)] + self.layers = nn.Sequential(*layers) + + # initialize weights + self.apply(self.init_weights) + + def forward(self, x): + return self.layers(x) + + def init_weights(self, m): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0.0, 0.02) + elif isinstance(m, nn.BatchNorm2d): + nn.init.normal_(m.weight, 1.0, 0.02) + nn.init.zeros_(m.bias) diff --git a/modelscope/models/cv/image_to_image_generation/models/clip.py b/modelscope/models/cv/image_to_image_generation/models/clip.py new file mode 100644 index 00000000..d3dd22b4 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/models/clip.py @@ -0,0 +1,420 @@ +# Part of the implementation is borrowed and modified from CLIP, publicly avaialbe at https://github.com/openai/CLIP. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modelscope.models.cv.image_to_image_translation.ops as ops # for using differentiable all_gather + +__all__ = [ + 'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14', + 'clip_vit_l_14_336px', 'clip_vit_h_16' +] + + +def to_fp16(m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + m.weight.data = m.weight.data.half() + if m.bias is not None: + m.bias.data = m.bias.data.half() + elif hasattr(m, 'head'): + p = getattr(m, 'head') + p.data = p.data.half() + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + r"""Subclass of nn.LayerNorm to handle fp16. + """ + + def forward(self, x): + return super(LayerNorm, self).forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.attn_dropout = nn.Dropout(attn_dropout) + self.proj = nn.Linear(dim, dim) + self.proj_dropout = nn.Dropout(proj_dropout) + + def forward(self, x, mask=None): + r"""x: [B, L, C]. + mask: [*, L, L]. + """ + b, l, _, n = *x.size(), self.num_heads + + # compute query, key, and value + q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) + q = q.reshape(l, b * n, -1).transpose(0, 1) + k = k.reshape(l, b * n, -1).transpose(0, 1) + v = v.reshape(l, b * n, -1).transpose(0, 1) + + # compute attention + attn = self.scale * torch.bmm(q, k.transpose(1, 2)) + if mask is not None: + attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = self.attn_dropout(attn) + + # gather context + x = torch.bmm(attn, v) + x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) + + # output + x = self.proj(x) + x = self.proj_dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), + nn.Dropout(proj_dropout)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.mlp(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + out_dim=512, + num_heads=12, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + assert image_size % patch_size == 0 + super(VisionTransformer, self).__init__() + self.image_size = image_size + self.patch_size = patch_size + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.num_patches = (image_size // patch_size)**2 + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain * torch.randn(1, self.num_patches + 1, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim) + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim) + + # head + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + def forward(self, x): + b, dtype = x.size(0), self.head.dtype + x = x.type(dtype) + + # patch-embedding + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] + x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], + dim=1) + x = self.dropout(x + self.pos_embedding.type(dtype)) + x = self.pre_norm(x) + + # transformer + x = self.transformer(x) + + # head + x = self.post_norm(x) + x = torch.mm(x[:, 0, :], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class TextTransformer(nn.Module): + + def __init__(self, + vocab_size, + text_len, + dim=512, + out_dim=512, + num_heads=8, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(TextTransformer, self).__init__() + self.vocab_size = vocab_size + self.text_len = text_len + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim) + self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.transformer = nn.ModuleList([ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.norm = LayerNorm(dim) + + # head + gain = 1.0 / math.sqrt(dim) + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + # causal attention mask + self.register_buffer('attn_mask', + torch.tril(torch.ones(1, text_len, text_len))) + + def forward(self, x): + eot, dtype = x.argmax(dim=-1), self.head.dtype + + # embeddings + x = self.dropout( + self.token_embedding(x).type(dtype) + + self.pos_embedding.type(dtype)) + + # transformer + for block in self.transformer: + x = block(x, self.attn_mask) + + # head + x = self.norm(x) + x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=49408, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(CLIP, self).__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_tokens): + r"""imgs: [B, C, H, W] of torch.float32. + txt_tokens: [B, T] of torch.long. + """ + xi = self.visual(imgs) + xt = self.textual(txt_tokens) + + # normalize features + xi = F.normalize(xi, p=2, dim=1) + xt = F.normalize(xt, p=2, dim=1) + + # gather features from all ranks + full_xi = ops.diff_all_gather(xi) + full_xt = ops.diff_all_gather(xt) + + # logits + scale = self.log_scale.exp() + logits_i2t = scale * torch.mm(xi, full_xt.t()) + logits_t2i = scale * torch.mm(xt, full_xi.t()) + + # labels + labels = torch.arange( + len(xi) * ops.get_rank(), + len(xi) * (ops.get_rank() + 1), + dtype=torch.long, + device=xi.device) + return logits_i2t, logits_t2i, labels + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else 'textual' + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * transformer.num_layers)) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer.layers: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': + 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + def fp16(self): + return self.apply(to_fp16) + + +def clip_vit_b_32(**kwargs): + return CLIP( + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + text_dim=512, + text_heads=8, + text_layers=12, + **kwargs) + + +def clip_vit_b_16(**kwargs): + return CLIP( + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + text_dim=512, + text_heads=8, + text_layers=12, + **kwargs) + + +def clip_vit_l_14(**kwargs): + return CLIP( + embed_dim=768, + image_size=224, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + text_dim=768, + text_heads=12, + text_layers=12, + **kwargs) + + +def clip_vit_l_14_336px(**kwargs): + return CLIP( + embed_dim=768, + image_size=336, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + text_dim=768, + text_heads=12, + text_layers=12, + **kwargs) + + +def clip_vit_h_16(**kwargs): + return CLIP( + embed_dim=1024, + image_size=256, + patch_size=16, + vision_dim=1280, + vision_heads=16, + vision_layers=32, + text_dim=1024, + text_heads=16, + text_layers=24, + **kwargs) diff --git a/modelscope/models/cv/image_to_image_generation/ops/__init__.py b/modelscope/models/cv/image_to_image_generation/ops/__init__.py new file mode 100644 index 00000000..e3dac584 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/ops/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .diffusion import GaussianDiffusion, beta_schedule + +else: + _import_structure = { + 'diffusion': ['GaussianDiffusion', 'beta_schedule'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_to_image_generation/ops/diffusion.py b/modelscope/models/cv/image_to_image_generation/ops/diffusion.py new file mode 100644 index 00000000..b8ffbbbb --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/ops/diffusion.py @@ -0,0 +1,599 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch + +from .losses import discretized_gaussian_log_likelihood, kl_divergence + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + + # fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + def fn(u): + return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + noise = torch.randn_like(x0) if noise is None else noise + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i( + self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + # no noise when t == 0 + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + assert self.mean_type == 'eps' + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + out = torch.cat( + [ + u_out[:, :3] + guide_scale * # noqa W504 + (y_out[:, :3] - u_out[:, :3]), + y_out[:, 3:] + ], + dim=1) + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / # noqa W504 + (1 - alphas) * # noqa W504 + (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b, c, h, w = x0.size() + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + # mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 + ).abs().flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b, c, h, w = x0.size() + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/cv/image_to_image_generation/ops/losses.py b/modelscope/models/cv/image_to_image_generation/ops/losses.py new file mode 100644 index 00000000..46b9540a --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/ops/losses.py @@ -0,0 +1,36 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch + +__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] + + +def kl_divergence(mu1, logvar1, mu2, logvar2): + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa W504 + ((mu1 - mu2)**2) * torch.exp(-logvar2)) + + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs diff --git a/modelscope/models/cv/image_to_image_translation/__init__.py b/modelscope/models/cv/image_to_image_translation/__init__.py new file mode 100644 index 00000000..35aab6be --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .model_translation import UNet + +else: + _import_structure = { + 'image_to_image_translation_model': ['UNet'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_to_image_translation/data/__init__.py b/modelscope/models/cv/image_to_image_translation/data/__init__.py new file mode 100644 index 00000000..724bca04 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/data/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from .transforms import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_translation/data/transforms.py b/modelscope/models/cv/image_to_image_translation/data/transforms.py new file mode 100644 index 00000000..29a25b4b --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/data/transforms.py @@ -0,0 +1,122 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math +import random + +import torchvision.transforms.functional as TF +from PIL import Image, ImageFilter + +__all__ = [ + 'Identity', 'PadToSquare', 'RandomScale', 'RandomRotate', + 'RandomGaussianBlur', 'RandomCrop' +] + + +class Identity(object): + + def __call__(self, *args): + if len(args) == 0: + return None + elif len(args) == 1: + return args[0] + else: + return args + + +class PadToSquare(object): + + def __init__(self, fill=(255, 255, 255)): + self.fill = fill + + def __call__(self, img): + w, h = img.size + if w != h: + if w > h: + t = (w - h) // 2 + b = w - h - t + padding = (0, t, 0, b) + else: + left = (h - w) // 2 + right = h - w - l + padding = (left, 0, right, 0) + img = TF.pad(img, padding, fill=self.fill) + return img + + +class RandomScale(object): + + def __init__(self, + min_scale=0.5, + max_scale=2.0, + min_ratio=0.8, + max_ratio=1.25): + self.min_scale = min_scale + self.max_scale = max_scale + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + def __call__(self, img): + w, h = img.size + scale = 2**random.uniform( + math.log2(self.min_scale), math.log2(self.max_scale)) + ratio = 2**random.uniform( + math.log2(self.min_ratio), math.log2(self.max_ratio)) + ow = int(w * scale * math.sqrt(ratio)) + oh = int(h * scale / math.sqrt(ratio)) + img = img.resize((ow, oh), Image.BILINEAR) + return img + + +class RandomRotate(object): + + def __init__(self, + min_angle=-10.0, + max_angle=10.0, + padding=(255, 255, 255), + p=0.5): + self.min_angle = min_angle + self.max_angle = max_angle + self.padding = padding + self.p = p + + def __call__(self, img): + if random.random() < self.p: + angle = random.uniform(self.min_angle, self.max_angle) + img = img.rotate(angle, Image.BILINEAR, fillcolor=self.padding) + return img + + +class RandomGaussianBlur(object): + + def __init__(self, radius=5, p=0.5): + self.radius = radius + self.p = p + + def __call__(self, img): + if random.random() < self.p: + img = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) + return img + + +class RandomCrop(object): + + def __init__(self, size, padding=(255, 255, 255)): + self.size = size + self.padding = padding + + def __call__(self, img): + # pad + w, h = img.size + pad_w = max(0, self.size - w) + pad_h = max(0, self.size - h) + if pad_w > 0 or pad_h > 0: + half_w = pad_w // 2 + half_h = pad_h // 2 + pad = (half_w, half_h, pad_w - half_w, pad_h - half_h) + img = TF.pad(img, pad, fill=self.padding) + + # crop + w, h = img.size + x1 = random.randint(0, w - self.size) + y1 = random.randint(0, h - self.size) + img = img.crop((x1, y1, x1 + self.size, y1 + self.size)) + return img diff --git a/modelscope/models/cv/image_to_image_translation/model_translation.py b/modelscope/models/cv/image_to_image_translation/model_translation.py new file mode 100644 index 00000000..f2a9e7db --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/model_translation.py @@ -0,0 +1,324 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['UNet'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, scale_factor=1.0): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.scale_factor = scale_factor + + def forward(self, x): + if self.scale_factor == 2.0: + x = F.interpolate(x, scale_factor=2, mode='nearest') + elif self.scale_factor == 0.5: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, embed_dim, out_dim, dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.embedding = nn.Sequential(nn.SiLU(), + nn.Linear(embed_dim, out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, y): + identity = x + x = self.layer1(x) + x = x + self.embedding(y).unsqueeze(-1).unsqueeze(-1) + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class MultiHeadAttention(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=8, dropout=0.0): + assert dim % num_heads == 0 + assert context_dim is None or context_dim % num_heads == 0 + context_dim = context_dim or dim + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = math.pow(self.head_dim, -0.25) + + # layers + self.q = nn.Linear(dim, dim, bias=False) + self.k = nn.Linear(context_dim, dim, bias=False) + self.v = nn.Linear(context_dim, dim, bias=False) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None): + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # compute attention + attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + attn = self.dropout(attn) + + # gather context + x = torch.einsum('bnij,bjnc->binc', attn, v) + x = x.reshape(b, -1, n * c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class GLU(nn.Module): + + def __init__(self, in_dim, out_dim): + super(GLU, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.proj = nn.Linear(in_dim, out_dim * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class TransformerBlock(nn.Module): + + def __init__(self, dim, context_dim, num_heads, dropout=0.0): + super(TransformerBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + # input + self.norm1 = nn.GroupNorm(32, dim, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(dim, dim, 1) + + # self attention + self.norm2 = nn.LayerNorm(dim) + self.self_attn = MultiHeadAttention(dim, None, num_heads, dropout) + + # cross attention + self.norm3 = nn.LayerNorm(dim) + self.cross_attn = MultiHeadAttention(dim, context_dim, num_heads, + dropout) + + # ffn + self.norm4 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + GLU(dim, dim * 4), nn.Dropout(dropout), nn.Linear(dim * 4, dim)) + + # output + self.conv2 = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.conv2.weight) + + def forward(self, x, context): + b, c, h, w = x.size() + identity = x + + # input + x = self.norm1(x) + x = self.conv1(x).view(b, c, -1).transpose(1, 2) + + # attention + x = x + self.self_attn(self.norm2(x)) + x = x + self.cross_attn(self.norm3(x), context) + x = x + self.ffn(self.norm4(x)) + + # output + x = x.transpose(1, 2).view(b, c, h, w) + x = self.conv2(x) + return x + identity + + +class UNet(nn.Module): + + def __init__(self, + resolution=64, + in_dim=3, + dim=192, + context_dim=512, + out_dim=3, + dim_mult=[1, 2, 3, 5], + num_heads=1, + head_dim=None, + num_res_blocks=2, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + num_classes=1001, + dropout=0.0): + embed_dim = dim * 4 + super(UNet, self).__init__() + self.resolution = resolution + self.in_dim = in_dim + self.dim = dim + self.context_dim = context_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.num_classes = num_classes + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.label_embedding = nn.Embedding(num_classes, context_dim) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList( + [ResidualBlock(in_dim, embed_dim, out_dim, dropout)]) + if scale in attn_scales: + block.append( + TransformerBlock(out_dim, context_dim, num_heads)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + self.encoder.append( + nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1)) + shortcut_dims.append(out_dim) + scale /= 2.0 + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, dropout), + TransformerBlock(out_dim, context_dim, num_heads), + ResidualBlock(out_dim, embed_dim, out_dim, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, dropout) + ]) + if scale in attn_scales: + block.append( + TransformerBlock(out_dim, context_dim, num_heads, + dropout)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + block.append( + nn.Sequential( + Resample(scale_factor=2.0), + nn.Conv2d(out_dim, out_dim, 3, padding=1))) + scale *= 2.0 + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y, concat=None): + # embeddings + if concat is not None: + x = torch.cat([x, concat], dim=1) + t = self.time_embedding(sinusoidal_embedding(t, self.dim)) + y = self.label_embedding(y) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, t, y) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, t, y) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, t, y) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, t, y): + if isinstance(module, ResidualBlock): + x = module(x, t) + elif isinstance(module, TransformerBlock): + x = module(x, y) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, t, y) + else: + x = module(x) + return x diff --git a/modelscope/models/cv/image_to_image_translation/models/__init__.py b/modelscope/models/cv/image_to_image_translation/models/__init__.py new file mode 100644 index 00000000..7fdd8189 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/models/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from .autoencoder import * # noqa F403 +from .clip import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_translation/models/autoencoder.py b/modelscope/models/cv/image_to_image_translation/models/autoencoder.py new file mode 100644 index 00000000..dce256f6 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/models/autoencoder.py @@ -0,0 +1,413 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['VQAutoencoder', 'KLAutoencoder', 'PatchDiscriminator'] + + +def group_norm(dim): + return nn.GroupNorm(32, dim, eps=1e-6, affine=True) + + +class Resample(nn.Module): + + def __init__(self, dim, scale_factor): + super(Resample, self).__init__() + self.dim = dim + self.scale_factor = scale_factor + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(dim, dim, 3, padding=1)) + elif scale_factor == 0.5: + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=2, padding=0)) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + group_norm(in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1), group_norm(out_dim), + nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Conv2d(in_dim, out_dim, + 1) if in_dim != out_dim else nn.Identity() + + # zero out the last layer params + nn.init.zeros_(self.residual[-1].weight) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class AttentionBlock(nn.Module): + + def __init__(self, dim): + super(AttentionBlock, self).__init__() + self.dim = dim + self.scale = math.pow(dim, -0.25) + + # layers + self.norm = group_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, h, w = x.size() + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, c * 3, -1).chunk(3, dim=1) + + # compute attention + attn = torch.einsum('bci,bcj->bij', q * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.einsum('bij,bcj->bci', attn, v) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class Encoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(Encoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = nn.Conv2d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + downsamples.append(Resample(out_dim, scale_factor=0.5)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + group_norm(out_dim), nn.SiLU(), + nn.Conv2d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x): + x = self.conv1(x) + x = self.downsamples(x) + x = self.middle(x) + x = self.head(x) + return x + + +class Decoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(Decoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = nn.Conv2d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + upsamples.append(Resample(out_dim, scale_factor=2.0)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + group_norm(out_dim), nn.SiLU(), + nn.Conv2d(out_dim, 3, 3, padding=1)) + + def forward(self, x): + x = self.conv1(x) + x = self.middle(x) + x = self.upsamples(x) + x = self.head(x) + return x + + +class VectorQuantizer(nn.Module): + + def __init__(self, codebook_size=8192, z_dim=3, beta=0.25): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size + self.z_dim = z_dim + self.beta = beta + + # init codebook + eps = math.sqrt(1.0 / codebook_size) + self.codebook = nn.Parameter( + torch.empty(codebook_size, z_dim).uniform_(-eps, eps)) + + def forward(self, z): + # preprocess + b, c, h, w = z.size() + flatten = z.permute(0, 2, 3, 1).reshape(-1, c) + + # quantization + with torch.no_grad(): + tokens = torch.cdist(flatten, self.codebook).argmin(dim=1) + quantized = F.embedding(tokens, + self.codebook).view(b, h, w, + c).permute(0, 3, 1, 2) + + # compute loss + codebook_loss = F.mse_loss(quantized, z.detach()) + commitment_loss = F.mse_loss(quantized.detach(), z) + loss = codebook_loss + self.beta * commitment_loss + + # perplexity + counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) + # dist.all_reduce(counts) + p = counts / counts.sum() + perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) + + # postprocess + tokens = tokens.view(b, h, w) + quantized = z + (quantized - z).detach() + return quantized, tokens, loss, perplexity + + +class VQAutoencoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0, + codebook_size=8192, + beta=0.25): + super(VQAutoencoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.codebook_size = codebook_size + self.beta = beta + + # blocks + self.encoder = Encoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + self.conv1 = nn.Conv2d(z_dim, z_dim, 1) + self.quantizer = VectorQuantizer(codebook_size, z_dim, beta) + self.conv2 = nn.Conv2d(z_dim, z_dim, 1) + self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + + def forward(self, x): + z = self.encoder(x) + z = self.conv1(z) + z, tokens, loss, perplexity = self.quantizer(z) + z = self.conv2(z) + x = self.decoder(z) + return x, tokens, loss, perplexity + + def encode(self, imgs): + z = self.encoder(imgs) + z = self.conv1(z) + return z + + def decode(self, z): + r"""Absort the quantizer in the decoder. + """ + z = self.quantizer(z)[0] + z = self.conv2(z) + imgs = self.decoder(z) + return imgs + + @torch.no_grad() + def encode_to_tokens(self, imgs): + # preprocess + z = self.encoder(imgs) + z = self.conv1(z) + + # quantization + b, c, h, w = z.size() + flatten = z.permute(0, 2, 3, 1).reshape(-1, c) + tokens = torch.cdist(flatten, self.quantizer.codebook).argmin(dim=1) + return tokens.view(b, -1) + + @torch.no_grad() + def decode_from_tokens(self, tokens): + # dequantization + z = F.embedding(tokens, self.quantizer.codebook) + + # postprocess + b, l, c = z.size() + h = w = int(math.sqrt(l)) + z = z.view(b, h, w, c).permute(0, 3, 1, 2) + z = self.conv2(z) + imgs = self.decoder(z) + return imgs + + +class KLAutoencoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(KLAutoencoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # blocks + self.encoder = Encoder(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, dropout) + self.conv1 = nn.Conv2d(z_dim * 2, z_dim * 2, 1) + self.conv2 = nn.Conv2d(z_dim, z_dim, 1) + self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x = self.decode(z) + return x, mu, log_var + + def encode(self, x): + x = self.encoder(x) + mu, log_var = self.conv1(x).chunk(2, dim=1) + return mu, log_var + + def decode(self, z): + x = self.conv2(z) + x = self.decoder(x) + return x + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + +class PatchDiscriminator(nn.Module): + + def __init__(self, in_dim=3, dim=64, num_layers=3): + super(PatchDiscriminator, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.num_layers = num_layers + + # params + dims = [dim * min(8, 2**u) for u in range(num_layers + 1)] + + # layers + layers = [ + nn.Conv2d(in_dim, dim, 4, stride=2, padding=1), + nn.LeakyReLU(0.2) + ] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + stride = 1 if i == num_layers - 1 else 2 + layers += [ + nn.Conv2d( + in_dim, out_dim, 4, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_dim), + nn.LeakyReLU(0.2) + ] + layers += [nn.Conv2d(out_dim, 1, 4, stride=1, padding=1)] + self.layers = nn.Sequential(*layers) + + # initialize weights + self.apply(self.init_weights) + + def forward(self, x): + return self.layers(x) + + def init_weights(self, m): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0.0, 0.02) + elif isinstance(m, nn.BatchNorm2d): + nn.init.normal_(m.weight, 1.0, 0.02) + nn.init.zeros_(m.bias) diff --git a/modelscope/models/cv/image_to_image_translation/models/clip.py b/modelscope/models/cv/image_to_image_translation/models/clip.py new file mode 100644 index 00000000..d3dd22b4 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/models/clip.py @@ -0,0 +1,420 @@ +# Part of the implementation is borrowed and modified from CLIP, publicly avaialbe at https://github.com/openai/CLIP. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modelscope.models.cv.image_to_image_translation.ops as ops # for using differentiable all_gather + +__all__ = [ + 'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14', + 'clip_vit_l_14_336px', 'clip_vit_h_16' +] + + +def to_fp16(m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + m.weight.data = m.weight.data.half() + if m.bias is not None: + m.bias.data = m.bias.data.half() + elif hasattr(m, 'head'): + p = getattr(m, 'head') + p.data = p.data.half() + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + r"""Subclass of nn.LayerNorm to handle fp16. + """ + + def forward(self, x): + return super(LayerNorm, self).forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.attn_dropout = nn.Dropout(attn_dropout) + self.proj = nn.Linear(dim, dim) + self.proj_dropout = nn.Dropout(proj_dropout) + + def forward(self, x, mask=None): + r"""x: [B, L, C]. + mask: [*, L, L]. + """ + b, l, _, n = *x.size(), self.num_heads + + # compute query, key, and value + q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) + q = q.reshape(l, b * n, -1).transpose(0, 1) + k = k.reshape(l, b * n, -1).transpose(0, 1) + v = v.reshape(l, b * n, -1).transpose(0, 1) + + # compute attention + attn = self.scale * torch.bmm(q, k.transpose(1, 2)) + if mask is not None: + attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = self.attn_dropout(attn) + + # gather context + x = torch.bmm(attn, v) + x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) + + # output + x = self.proj(x) + x = self.proj_dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), + nn.Dropout(proj_dropout)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.mlp(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + out_dim=512, + num_heads=12, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + assert image_size % patch_size == 0 + super(VisionTransformer, self).__init__() + self.image_size = image_size + self.patch_size = patch_size + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.num_patches = (image_size // patch_size)**2 + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain * torch.randn(1, self.num_patches + 1, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim) + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim) + + # head + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + def forward(self, x): + b, dtype = x.size(0), self.head.dtype + x = x.type(dtype) + + # patch-embedding + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] + x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], + dim=1) + x = self.dropout(x + self.pos_embedding.type(dtype)) + x = self.pre_norm(x) + + # transformer + x = self.transformer(x) + + # head + x = self.post_norm(x) + x = torch.mm(x[:, 0, :], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class TextTransformer(nn.Module): + + def __init__(self, + vocab_size, + text_len, + dim=512, + out_dim=512, + num_heads=8, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(TextTransformer, self).__init__() + self.vocab_size = vocab_size + self.text_len = text_len + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim) + self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.transformer = nn.ModuleList([ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.norm = LayerNorm(dim) + + # head + gain = 1.0 / math.sqrt(dim) + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + # causal attention mask + self.register_buffer('attn_mask', + torch.tril(torch.ones(1, text_len, text_len))) + + def forward(self, x): + eot, dtype = x.argmax(dim=-1), self.head.dtype + + # embeddings + x = self.dropout( + self.token_embedding(x).type(dtype) + + self.pos_embedding.type(dtype)) + + # transformer + for block in self.transformer: + x = block(x, self.attn_mask) + + # head + x = self.norm(x) + x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=49408, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(CLIP, self).__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_tokens): + r"""imgs: [B, C, H, W] of torch.float32. + txt_tokens: [B, T] of torch.long. + """ + xi = self.visual(imgs) + xt = self.textual(txt_tokens) + + # normalize features + xi = F.normalize(xi, p=2, dim=1) + xt = F.normalize(xt, p=2, dim=1) + + # gather features from all ranks + full_xi = ops.diff_all_gather(xi) + full_xt = ops.diff_all_gather(xt) + + # logits + scale = self.log_scale.exp() + logits_i2t = scale * torch.mm(xi, full_xt.t()) + logits_t2i = scale * torch.mm(xt, full_xi.t()) + + # labels + labels = torch.arange( + len(xi) * ops.get_rank(), + len(xi) * (ops.get_rank() + 1), + dtype=torch.long, + device=xi.device) + return logits_i2t, logits_t2i, labels + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else 'textual' + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * transformer.num_layers)) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer.layers: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': + 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + def fp16(self): + return self.apply(to_fp16) + + +def clip_vit_b_32(**kwargs): + return CLIP( + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + text_dim=512, + text_heads=8, + text_layers=12, + **kwargs) + + +def clip_vit_b_16(**kwargs): + return CLIP( + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + text_dim=512, + text_heads=8, + text_layers=12, + **kwargs) + + +def clip_vit_l_14(**kwargs): + return CLIP( + embed_dim=768, + image_size=224, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + text_dim=768, + text_heads=12, + text_layers=12, + **kwargs) + + +def clip_vit_l_14_336px(**kwargs): + return CLIP( + embed_dim=768, + image_size=336, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + text_dim=768, + text_heads=12, + text_layers=12, + **kwargs) + + +def clip_vit_h_16(**kwargs): + return CLIP( + embed_dim=1024, + image_size=256, + patch_size=16, + vision_dim=1280, + vision_heads=16, + vision_layers=32, + text_dim=1024, + text_heads=16, + text_layers=24, + **kwargs) diff --git a/modelscope/models/cv/image_to_image_translation/ops/__init__.py b/modelscope/models/cv/image_to_image_translation/ops/__init__.py new file mode 100644 index 00000000..474c811b --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from .degradation import * # noqa F403 +from .diffusion import * # noqa F403 +from .losses import * # noqa F403 +from .metrics import * # noqa F403 +from .random_color import * # noqa F403 +from .random_mask import * # noqa F403 +from .svd import * # noqa F403 +from .utils import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_translation/ops/apps.py b/modelscope/models/cv/image_to_image_translation/ops/apps.py new file mode 100644 index 00000000..39d2e015 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/apps.py @@ -0,0 +1,664 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +# APPs that facilitate the use of pretrained neural networks. + +import os.path as osp + +import artist.data as data +import artist.models as models +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn.functional as F +import torchvision.transforms as T +from artist import DOWNLOAD_TO_CACHE +from PIL import Image +from torch.utils.data import DataLoader, Dataset + +from .utils import parallel, read_image + +__all__ = [ + 'FeatureExtractor', 'Classifier', 'Text2Image', 'Sole2Shoe', 'ImageParser', + 'TextImageMatch', 'taobao_feature_extractor', 'singleton_classifier', + 'orientation_classifier', 'fashion_text2image', 'mindalle_text2image', + 'sole2shoe', 'sole_parser', 'sod_foreground_parser', + 'fashion_text_image_match' +] + + +class ImageFolder(Dataset): + + def __init__(self, paths, transforms=None): + self.paths = paths + self.transforms = transforms + + def __getitem__(self, index): + img = read_image(self.paths[index]) + if img.mode != 'RGB': + img = img.convert('RGB') + if self.transforms is not None: + img = self.transforms(img) + return img + + def __len__(self): + return len(self.paths) + + +class FeatureExtractor(object): + + def __init__( + self, + model='InceptionV1', + checkpoint='models/inception-v1/1218shoes.v9_7.140.0.1520000', + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + batch_size=64, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.resolution = resolution + self.batch_size = batch_size + self.device = device + + # init model + self.net = getattr( + models, + model)(num_classes=None).eval().requires_grad_(False).to(device) + self.net.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) + + # data transforms + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(resolution), + T.ToTensor(), + T.Normalize(mean, std) + ]) + + def __call__(self, imgs, num_workers=0): + r"""imgs: Either a PIL.Image or a list of PIL.Image instances. + """ + # preprocess + if isinstance(imgs, Image.Image): + imgs = [imgs] + assert isinstance(imgs, + (tuple, list)) and isinstance(imgs[0], Image.Image) + imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) + + # forward + feats = [] + for batch in imgs.split(self.batch_size, dim=0): + batch = batch.to(self.device, non_blocking=True) + feats.append(self.net(batch)) + return torch.cat(feats, dim=0) + + def batch_process(self, paths): + # init dataloader + dataloader = DataLoader( + dataset=ImageFolder(paths, self.transforms), + batch_size=self.batch_size, + shuffle=False, + drop_last=False, + pin_memory=True, + num_workers=8, + prefetch_factor=2) + + # forward + feats = [] + for step, batch in enumerate(dataloader, 1): + print(f'Step: {step}/{len(dataloader)}', flush=True) + batch = batch.to(self.device, non_blocking=True) + feats.append(self.net(batch)) + return torch.cat(feats) + + +class Classifier(object): + + def __init__( + self, + model='InceptionV1', + checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth', + num_classes=1, + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + batch_size=64, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.num_classes = num_classes + self.resolution = resolution + self.batch_size = batch_size + self.device = device + + # init model + self.net = getattr(models, model)( + num_classes=num_classes).eval().requires_grad_(False).to(device) + self.net.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) + + # data transforms + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(resolution), + T.ToTensor(), + T.Normalize(mean, std) + ]) + + def __call__(self, imgs, num_workers=0): + r"""imgs: Either a PIL.Image or a list of PIL.Image instances. + """ + # preprocess + if isinstance(imgs, Image.Image): + imgs = [imgs] + assert isinstance(imgs, + (tuple, list)) and isinstance(imgs[0], Image.Image) + imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) + + # forward + scores = [] + for batch in imgs.split(self.batch_size, dim=0): + batch = batch.to(self.device, non_blocking=True) + logits = self.net(batch) + scores.append(logits.sigmoid() if self.num_classes == # noqa W504 + 1 else logits.softmax(dim=1)) + return torch.cat(scores, dim=0) + + +class Text2Image(object): + + def __init__( + self, + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=975, + vqgan_beta=0.25, + gpt_txt_vocab_size=21128, + gpt_txt_seq_len=64, + gpt_img_seq_len=1024, + gpt_dim=1024, + gpt_num_heads=16, + gpt_num_layers=24, + vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', + gpt_checkpoint='models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth', + tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64), + batch_size=16, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.tokenizer = tokenizer + self.batch_size = batch_size + self.device = device + + # init VQGAN model + self.vqgan = models.VQGAN( + dim=vqgan_dim, + z_dim=vqgan_z_dim, + dim_mult=vqgan_dim_mult, + num_res_blocks=vqgan_num_res_blocks, + attn_scales=vqgan_attn_scales, + codebook_size=vqgan_codebook_size, + beta=vqgan_beta).eval().requires_grad_(False).to(device) + self.vqgan.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device)) + + # init GPT model + self.gpt = models.Seq2SeqGPT( + src_vocab_size=gpt_txt_vocab_size, + tar_vocab_size=vqgan_codebook_size, + src_seq_len=gpt_txt_seq_len, + tar_seq_len=gpt_img_seq_len, + dim=gpt_dim, + num_heads=gpt_num_heads, + num_layers=gpt_num_layers).eval().requires_grad_(False).to(device) + self.gpt.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device)) + + def __call__(self, + txts, + top_k=64, + top_p=None, + temperature=0.6, + use_fp16=True): + # preprocess + if isinstance(txts, str): + txts = [txts] + assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str) + txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts]) + + # forward + out_imgs = [] + for batch in txt_tokens.split(self.batch_size, dim=0): + # sample + batch = batch.to(self.device, non_blocking=True) + with amp.autocast(enabled=use_fp16): + img_tokens = self.gpt.sample(batch, top_k, top_p, temperature) + + # decode + imgs = self.vqgan.decode_from_tokens(img_tokens) + imgs = self._whiten_borders(imgs) + imgs = imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute( + 0, 2, 3, 1).cpu().numpy().astype(np.uint8) + imgs = [Image.fromarray(u) for u in imgs] + + # append + out_imgs += imgs + return out_imgs + + def _whiten_borders(self, imgs): + r"""Remove border artifacts. + """ + imgs[:, :, :18, :] = 1 + imgs[:, :, :, :18] = 1 + imgs[:, :, -18:, :] = 1 + imgs[:, :, :, -18:] = 1 + return imgs + + +class Sole2Shoe(object): + + def __init__( + self, + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=975, + vqgan_beta=0.25, + src_resolution=256, + tar_resolution=512, + gpt_dim=1024, + gpt_num_heads=16, + gpt_num_layers=24, + vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', + gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth', + batch_size=12, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.batch_size = batch_size + self.device = device + src_seq_len = (src_resolution // 16)**2 + tar_seq_len = (tar_resolution // 16)**2 + + # init VQGAN model + self.vqgan = models.VQGAN( + dim=vqgan_dim, + z_dim=vqgan_z_dim, + dim_mult=vqgan_dim_mult, + num_res_blocks=vqgan_num_res_blocks, + attn_scales=vqgan_attn_scales, + codebook_size=vqgan_codebook_size, + beta=vqgan_beta).eval().requires_grad_(False).to(device) + self.vqgan.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device)) + + # init GPT model + self.gpt = models.Seq2SeqGPT( + src_vocab_size=vqgan_codebook_size, + tar_vocab_size=vqgan_codebook_size, + src_seq_len=src_seq_len, + tar_seq_len=tar_seq_len, + dim=gpt_dim, + num_heads=gpt_num_heads, + num_layers=gpt_num_layers).eval().requires_grad_(False).to(device) + self.gpt.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device)) + + # data transforms + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(src_resolution), + T.ToTensor(), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + ]) + + def __call__(self, + sole_imgs, + top_k=64, + top_p=None, + temperature=0.6, + use_fp16=True, + num_workers=0): + # preprocess + if isinstance(sole_imgs, Image.Image): + sole_imgs = [sole_imgs] + assert isinstance(sole_imgs, (tuple, list)) and isinstance( + sole_imgs[0], Image.Image) + sole_imgs = torch.stack( + parallel(self.transforms, sole_imgs, num_workers), dim=0) + + # forward + out_imgs = [] + for batch in sole_imgs.split(self.batch_size, dim=0): + # sample + batch = batch.to(self.device) + with amp.autocast(enabled=use_fp16): + sole_tokens = self.vqgan.encode_to_tokens(batch) + shoe_tokens = self.gpt.sample(sole_tokens, top_k, top_p, + temperature) + + # decode + shoe_imgs = self.vqgan.decode_from_tokens(shoe_tokens) + shoe_imgs = self._whiten_borders(shoe_imgs) + shoe_imgs = shoe_imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute( + 0, 2, 3, 1).cpu().numpy().astype(np.uint8) + shoe_imgs = [Image.fromarray(u) for u in shoe_imgs] + + # append + out_imgs += shoe_imgs + return out_imgs + + def _whiten_borders(self, imgs): + r"""Remove border artifacts. + """ + imgs[:, :, :18, :] = 1 + imgs[:, :, :, :18] = 1 + imgs[:, :, -18:, :] = 1 + imgs[:, :, :, -18:] = 1 + return imgs + + +class ImageParser(object): + + def __init__( + self, + model='SPNet', + num_classes=2, + resolution=800, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + model_with_softmax=False, + checkpoint='models/spnet/sole_segmentation_211219.pth', + batch_size=16, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.batch_size = batch_size + self.device = device + + # init model + if checkpoint.endswith('.pt'): + self.net = torch.jit.load( + DOWNLOAD_TO_CACHE(checkpoint)).eval().to(device) + [p.requires_grad_(False) for p in self.net.parameters()] + else: + self.net = getattr(models, model)( + num_classes=num_classes, + pretrained=False).eval().requires_grad_(False).to(device) + self.net.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) + self.softmax = (lambda x, dim: x) if model_with_softmax else F.softmax + + # data transforms + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(resolution), + T.ToTensor(), + T.Normalize(mean, std) + ]) + + def __call__(self, imgs, num_workers=0): + # preprocess + if isinstance(imgs, Image.Image): + imgs = [imgs] + assert isinstance(imgs, + (tuple, list)) and isinstance(imgs[0], Image.Image) + sizes = [u.size for u in imgs] + imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) + + # forward + masks = [] + for batch in imgs.split(self.batch_size, dim=0): + batch = batch.to(self.device, non_blocking=True) + masks.append(self.softmax(self.net(batch), dim=1)) + + # postprocess + masks = torch.cat(masks, dim=0).unsqueeze(1) + masks = [ + F.interpolate(u, v, mode='bilinear', align_corners=False) + for u, v in zip(masks, sizes) + ] + return masks + + +class TextImageMatch(object): + + def __init__( + self, + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=21128, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth', + tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77), + batch_size=64, + device=torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125 + self.tokenizer = tokenizer + self.batch_size = batch_size + self.device = device + + # init model + self.clip = models.CLIP( + embed_dim=embed_dim, + image_size=image_size, + patch_size=patch_size, + vision_dim=vision_dim, + vision_heads=vision_heads, + vision_layers=vision_layers, + vocab_size=vocab_size, + text_len=text_len, + text_dim=text_dim, + text_heads=text_heads, + text_layers=text_layers).eval().requires_grad_(False).to(device) + self.clip.load_state_dict( + torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device)) + + # transforms + scale_size = int(image_size * 8 / 7) + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize(scale_size), + T.CenterCrop(image_size), + T.ToTensor(), + T.Normalize(mean, std) + ]) + + def __call__(self, imgs, txts, num_workers=0): + # preprocess + assert isinstance(imgs, + (tuple, list)) and isinstance(imgs[0], Image.Image) + assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str) + txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts]) + imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0) + + # forward + scores = [] + for img_batch, txt_batch in zip( + imgs.split(self.batch_size, dim=0), + txt_tokens.split(self.batch_size, dim=0)): + img_batch = img_batch.to(self.device) + txt_batch = txt_batch.to(self.device) + xi = F.normalize(self.clip.visual(img_batch), p=2, dim=1) + xt = F.normalize(self.clip.textual(txt_batch), p=2, dim=1) + scores.append((xi * xt).sum(dim=1)) + return torch.cat(scores, dim=0) + + +def taobao_feature_extractor(category='shoes', **kwargs): + r"""Pretrained taobao-search feature extractors. + """ + assert category in ['softall', 'shoes', 'bag'] + checkpoint = osp.join( + 'models/inception-v1', { + 'softall': '1214softall_10.10.0.5000', + 'shoes': '1218shoes.v9_7.140.0.1520000', + 'bag': '0926bag.v9_6.29.0.140000' + }[category]) + app = FeatureExtractor( + model='InceptionV1', + checkpoint=checkpoint, + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + **kwargs) + return app + + +def singleton_classifier(**kwargs): + r"""Pretrained classifier that finds single-object images. + Supports shoes, apparel, and bag images. + """ + app = Classifier( + model='InceptionV1', + checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth', + num_classes=1, + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + **kwargs) + return app + + +def orientation_classifier(**kwargs): + r"""Shoes orientation classifier. + """ + app = Classifier( + model='InceptionV1', + checkpoint='models/classifier/shoes-oriendetect-20211026.pth', + num_classes=1, + resolution=224, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + **kwargs) + return app + + +def fashion_text2image(**kwargs): + r"""Fashion text-to-image generator. + Supports shoe and apparel image generation. + """ + app = Text2Image( + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=975, + vqgan_beta=0.25, + gpt_txt_vocab_size=21128, + gpt_txt_seq_len=64, + gpt_img_seq_len=1024, + gpt_dim=1024, + gpt_num_heads=16, + gpt_num_layers=24, + vqgan_checkpoint= # noqa E251 + 'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', + gpt_checkpoint= # noqa E251 + 'models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth', + tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64), + **kwargs) + return app + + +def mindalle_text2image(**kwargs): + r"""Pretrained text2image generator with weights copied from minDALL-E. + """ + app = Text2Image( + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=16384, + vqgan_beta=0.25, + gpt_txt_vocab_size=16384, + gpt_txt_seq_len=64, + gpt_img_seq_len=256, + gpt_dim=1536, + gpt_num_heads=24, + gpt_num_layers=42, + vqgan_checkpoint='models/minDALLE/1.3B_vqgan.pth', + gpt_checkpoint='models/minDALLE/1.3B_gpt.pth', + tokenizer=data.BPETokenizer(length=64), + **kwargs) + return app + + +def sole2shoe(**kwargs): + app = Sole2Shoe( + vqgan_dim=128, + vqgan_z_dim=256, + vqgan_dim_mult=[1, 1, 2, 2, 4], + vqgan_num_res_blocks=2, + vqgan_attn_scales=[1.0 / 16], + vqgan_codebook_size=975, + vqgan_beta=0.25, + src_resolution=256, + tar_resolution=512, + gpt_dim=1024, + gpt_num_heads=16, + gpt_num_layers=24, + vqgan_checkpoint= # noqa E251 + 'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth', + gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth', + **kwargs) + return app + + +def sole_parser(**kwargs): + app = ImageParser( + model='SPNet', + num_classes=2, + resolution=800, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + model_with_softmax=False, + checkpoint='models/spnet/sole_segmentation_211219.pth', + **kwargs) + return app + + +def sod_foreground_parser(**kwargs): + app = ImageParser( + model=None, + num_classes=None, + resolution=448, + mean=[0.488431, 0.466275, 0.403686], + std=[0.222627, 0.21949, 0.22549], + model_with_softmax=True, + checkpoint='models/semseg/sod_model_20201228.pt', + **kwargs) + return app + + +def fashion_text_image_match(**kwargs): + app = TextImageMatch( + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=21128, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth', + tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77), + **kwargs) + return app diff --git a/modelscope/models/cv/image_to_image_translation/ops/degradation.py b/modelscope/models/cv/image_to_image_translation/ops/degradation.py new file mode 100644 index 00000000..9061e7be --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/degradation.py @@ -0,0 +1,1075 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math +import os +import random + +import cv2 +import numpy as np +import scipy +import scipy.stats as stats +import torch +from scipy import ndimage +from scipy.interpolate import interp2d +from scipy.linalg import orth + +os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' + +__all__ = ['degradation_bsrgan_light', 'degradation_bsrgan'] + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + return np.float32(img / 255.) + + +def single2uint(img): + return np.uint8((img.clip(0, 1) * 255.).round()) + + +def uint162single(img): + return np.float32(img / 65535.) + + +def single2uint16(img): + return np.uint16((img.clip(0, 1) * 65535.).round()) + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, + [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] # noqa E126 + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, + [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + # img1 = img1.squeeze() + # img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h - border, border:w - border] + img2 = img2[border:h - border, border:w - border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + # img1 = img1.squeeze() + # img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h - border, border:w - border] + img2 = img2[border:h - border, border:w - border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * # noqa W504 + (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * # noqa W504 + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + + 2) * (((absx > 1) * # noqa W504 + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, + kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace( + 0, P - 1, P).view(1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W + * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose( + 0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, + idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W + * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, + j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, + j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot( + np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = stats.multivariate_normal.pdf([cx, cy], + mean=mean, + cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d( + x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel( + k_size=np.array([15, 15]), + scale_factor=np.array([4, 4]), + min_var=0.6, + max_var=10., + noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1 + ) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid( + np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve( + x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur_1(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2 / 4 + wd = wd / 4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian( + ksize=random.randint(2, 11) + 3, + theta=random.random() * np.pi, + l1=l1, + l2=l2) + else: + k = fspecial('gaussian', + random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve( + img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize( + img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( + np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, + (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs( + L**2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, + img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, + (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10**(2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype( + np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode( + '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, + rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan_light(image, sf=4, isp_model=None): + """ + This is the variant of the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = uint2single(image) + _, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + # sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + # hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[ + idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur_1(image, sf=sf) + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize( + image, (int(1 / sf1 * image.shape[1]), + int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum( + ) # blur with shifted kernel + image = ndimage.filters.convolve( + image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + elif i == 3: + # downsample3 + image = cv2.resize( + image, (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = single2uint(image) + return image + + +def add_blur_2(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian( + ksize=2 * random.randint(2, 11) + 3, + theta=random.random() * np.pi, + l1=l1, + l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, + wd * random.random()) + img = ndimage.filters.convolve( + img, np.expand_dims(k, axis=2), mode='mirror') + return img + + +def degradation_bsrgan(image, sf=4, isp_model=None): + """ + This is the variant of the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = uint2single(image) + _, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + # sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + # hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[ + idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur_2(image, sf=sf) + elif i == 1: + image = add_blur_2(image, sf=sf) + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize( + image, (int(1 / sf1 * image.shape[1]), + int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum( + ) # blur with shifted kernel + image = ndimage.filters.convolve( + image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + elif i == 3: + # downsample3 + image = cv2.resize( + image, (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = single2uint(image) + return image diff --git a/modelscope/models/cv/image_to_image_translation/ops/diffusion.py b/modelscope/models/cv/image_to_image_translation/ops/diffusion.py new file mode 100644 index 00000000..5ff37dc3 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/diffusion.py @@ -0,0 +1,601 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch + +from .losses import discretized_gaussian_log_likelihood, kl_divergence + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + + # fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + def fn(u): + return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + noise = torch.randn_like(x0) if noise is None else noise + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i( + self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + # no noise when t == 0 + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + assert self.mean_type == 'eps' + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + out = torch.cat( + [ + u_out[:, :3] + guide_scale * # noqa W504 + (y_out[:, :3] - u_out[:, :3]), + y_out[:, 3:] + ], + dim=1) + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / # noqa W504 + (1 - alphas) * # noqa W504 + (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b, c, h, w = x0.size() + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + # mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 + ).abs().flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b, c, h, w = x0.size() + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/cv/image_to_image_translation/ops/losses.py b/modelscope/models/cv/image_to_image_translation/ops/losses.py new file mode 100644 index 00000000..46b9540a --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/losses.py @@ -0,0 +1,36 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch + +__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] + + +def kl_divergence(mu1, logvar1, mu2, logvar2): + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa W504 + ((mu1 - mu2)**2) * torch.exp(-logvar2)) + + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs diff --git a/modelscope/models/cv/image_to_image_translation/ops/metrics.py b/modelscope/models/cv/image_to_image_translation/ops/metrics.py new file mode 100644 index 00000000..c1023fa0 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/metrics.py @@ -0,0 +1,127 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import numpy as np +import scipy.linalg as linalg +import torch + +__all__ = [ + 'get_fid_net', 'get_is_net', 'compute_fid', 'compute_prdc', 'compute_is' +] + + +def get_fid_net(resize_input=True, normalize_input=True): + r"""InceptionV3 network for the evaluation of Fréchet Inception Distance (FID). + + Args: + resize_input: whether or not to resize the input to (299, 299). + normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1). + """ + from artist.models import InceptionV3 + return InceptionV3( + output_blocks=(3, ), + resize_input=resize_input, + normalize_input=normalize_input, + requires_grad=False, + use_fid_inception=True).eval().requires_grad_(False) + + +def get_is_net(resize_input=True, normalize_input=True): + r"""InceptionV3 network for the evaluation of Inception Score (IS). + + Args: + resize_input: whether or not to resize the input to (299, 299). + normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1). + """ + from artist.models import InceptionV3 + return InceptionV3( + output_blocks=(4, ), + resize_input=resize_input, + normalize_input=normalize_input, + requires_grad=False, + use_fid_inception=False).eval().requires_grad_(False) + + +@torch.no_grad() +def compute_fid(real_feats, fake_feats, eps=1e-6): + r"""Compute Fréchet Inception Distance (FID). + + Args: + real_feats: [N, C]. + fake_feats: [N, C]. + """ + # check inputs + if isinstance(real_feats, torch.Tensor): + real_feats = real_feats.cpu().numpy().astype(np.float_) + if isinstance(fake_feats, torch.Tensor): + fake_feats = fake_feats.cpu().numpy().astype(np.float_) + + # real statistics + mu1 = np.mean(real_feats, axis=0) + sigma1 = np.cov(real_feats, rowvar=False) + + # fake statistics + mu2 = np.mean(fake_feats, axis=0) + sigma2 = np.cov(fake_feats, rowvar=False) + + # compute covmean + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + print( + f'FID calculation produces singular product; adding {eps} to diagonal of cov', + flush=True) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + # compute Fréchet distance + diff = mu1 - mu2 + fid = diff.dot(diff) + np.trace(sigma1) + np.trace( + sigma2) - 2 * np.trace(covmean) + return fid.item() + + +@torch.no_grad() +def compute_prdc(real_feats, fake_feats, knn=5): + r"""Compute precision, recall, density, and coverage given two manifolds. + + Args: + real_feats: [N, C]. + fake_feats: [N, C]. + knn: the number of nearest neighbors to consider. + """ + # distances + real_kth = -(-torch.cdist(real_feats, real_feats)).topk( + k=knn, dim=1)[0][:, -1] + fake_kth = -(-torch.cdist(fake_feats, fake_feats)).topk( + k=knn, dim=1)[0][:, -1] + dists = torch.cdist(real_feats, fake_feats) + + # metrics + precision = (dists < real_kth.unsqueeze(1)).any( + dim=0).float().mean().item() + recall = (dists < fake_kth.unsqueeze(0)).any(dim=1).float().mean().item() + density = (dists < real_kth.unsqueeze(1)).float().sum( + dim=0).mean().item() / knn + coverage = (dists.min(dim=1)[0] < real_kth).float().mean().item() + return precision, recall, density, coverage + + +@torch.no_grad() +def compute_is(logits, num_splits=10): + preds = logits.softmax(dim=1).cpu().numpy() + split_scores = [] + for k in range(num_splits): + part = preds[k * (len(logits) // num_splits):(k + 1) + * (len(logits) // num_splits), :] + py = np.mean(part, axis=0) + scores = [] + for i in range(part.shape[0]): + pyx = part[i, :] + scores.append(entropy(pyx, py)) + split_scores.append(np.exp(np.mean(scores))) + return np.mean(split_scores), np.std(split_scores) diff --git a/modelscope/models/cv/image_to_image_translation/ops/random_color.py b/modelscope/models/cv/image_to_image_translation/ops/random_color.py new file mode 100644 index 00000000..75692836 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/random_color.py @@ -0,0 +1,221 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import colorsys +import random + +__all__ = ['RandomColor', 'rand_color'] + +COLORMAP = { + 'blue': { + 'hue_range': [179, 257], + 'lower_bounds': [[20, 100], [30, 86], [40, 80], [50, 74], [60, 60], + [70, 52], [80, 44], [90, 39], [100, 35]] + }, + 'green': { + 'hue_range': [63, 178], + 'lower_bounds': [[30, 100], [40, 90], [50, 85], [60, 81], [70, 74], + [80, 64], [90, 50], [100, 40]] + }, + 'monochrome': { + 'hue_range': [0, 0], + 'lower_bounds': [[0, 0], [100, 0]] + }, + 'orange': { + 'hue_range': [19, 46], + 'lower_bounds': [[20, 100], [30, 93], [40, 88], [50, 86], [60, 85], + [70, 70], [100, 70]] + }, + 'pink': { + 'hue_range': [283, 334], + 'lower_bounds': [[20, 100], [30, 90], [40, 86], [60, 84], [80, 80], + [90, 75], [100, 73]] + }, + 'purple': { + 'hue_range': [258, 282], + 'lower_bounds': [[20, 100], [30, 87], [40, 79], [50, 70], [60, 65], + [70, 59], [80, 52], [90, 45], [100, 42]] + }, + 'red': { + 'hue_range': [-26, 18], + 'lower_bounds': [[20, 100], [30, 92], [40, 89], [50, 85], [60, 78], + [70, 70], [80, 60], [90, 55], [100, 50]] + }, + 'yellow': { + 'hue_range': [47, 62], + 'lower_bounds': [[25, 100], [40, 94], [50, 89], [60, 86], [70, 84], + [80, 82], [90, 80], [100, 75]] + } +} + + +class RandomColor(object): + + def __init__(self, seed=None): + self.colormap = COLORMAP + self.random = random.Random(seed) + + for color_name, color_attrs in self.colormap.items(): + lower_bounds = color_attrs['lower_bounds'] + s_min = lower_bounds[0][0] + s_max = lower_bounds[len(lower_bounds) - 1][0] + + b_min = lower_bounds[len(lower_bounds) - 1][1] + b_max = lower_bounds[0][1] + + self.colormap[color_name]['saturation_range'] = [s_min, s_max] + self.colormap[color_name]['brightness_range'] = [b_min, b_max] + + def generate(self, hue=None, luminosity=None, count=1, format_='hex'): + colors = [] + for _ in range(count): + # First we pick a hue (H) + H = self.pick_hue(hue) + + # Then use H to determine saturation (S) + S = self.pick_saturation(H, hue, luminosity) + + # Then use S and H to determine brightness (B). + B = self.pick_brightness(H, S, luminosity) + + # Then we return the HSB color in the desired format + colors.append(self.set_format([H, S, B], format_)) + + return colors + + def pick_hue(self, hue): + hue_range = self.get_hue_range(hue) + hue = self.random_within(hue_range) + + # Instead of storing red as two seperate ranges, + # we group them, using negative numbers + if (hue < 0): + hue += 360 + + return hue + + def pick_saturation(self, hue, hue_name, luminosity): + + if luminosity == 'random': + return self.random_within([0, 100]) + + if hue_name == 'monochrome': + return 0 + + saturation_range = self.get_saturation_range(hue) + + s_min = saturation_range[0] + s_max = saturation_range[1] + + if luminosity == 'bright': + s_min = 55 + elif luminosity == 'dark': + s_min = s_max - 10 + elif luminosity == 'light': + s_max = 55 + + return self.random_within([s_min, s_max]) + + def pick_brightness(self, H, S, luminosity): + b_min = self.get_minimum_brightness(H, S) + b_max = 100 + + if luminosity == 'dark': + b_max = b_min + 20 + elif luminosity == 'light': + b_min = (b_max + b_min) / 2 + elif luminosity == 'random': + b_min = 0 + b_max = 100 + + return self.random_within([b_min, b_max]) + + def set_format(self, hsv, format_): + if 'hsv' in format_: + color = hsv + elif 'rgb' in format_: + color = self.hsv_to_rgb(hsv) + elif 'hex' in format_: + r, g, b = self.hsv_to_rgb(hsv) + return '#%02x%02x%02x' % (r, g, b) + else: + return 'unrecognized format' + + if 'Array' in format_ or format_ == 'hex': + return color + else: + prefix = format_[:3] + color_values = [str(x) for x in color] + return '%s(%s)' % (prefix, ', '.join(color_values)) + + def get_minimum_brightness(self, H, S): + lower_bounds = self.get_color_info(H)['lower_bounds'] + + for i in range(len(lower_bounds) - 1): + s1 = lower_bounds[i][0] + v1 = lower_bounds[i][1] + + s2 = lower_bounds[i + 1][0] + v2 = lower_bounds[i + 1][1] + + if s1 <= S <= s2: + m = (v2 - v1) / (s2 - s1) + b = v1 - m * s1 + + return m * S + b + + return 0 + + def get_hue_range(self, color_input): + if color_input and color_input.isdigit(): + number = int(color_input) + + if 0 < number < 360: + return [number, number] + + elif color_input and color_input in self.colormap: + color = self.colormap[color_input] + if 'hue_range' in color: + return color['hue_range'] + + else: + return [0, 360] + + def get_saturation_range(self, hue): + return self.get_color_info(hue)['saturation_range'] + + def get_color_info(self, hue): + # Maps red colors to make picking hue easier + if 334 <= hue <= 360: + hue -= 360 + + for color_name, color in self.colormap.items(): + if color['hue_range'] and color['hue_range'][0] <= hue <= color[ + 'hue_range'][1]: + return self.colormap[color_name] + + # this should probably raise an exception + return 'Color not found' + + def random_within(self, r): + return self.random.randint(int(r[0]), int(r[1])) + + @classmethod + def hsv_to_rgb(cls, hsv): + h, s, v = hsv + h = 1 if h == 0 else h + h = 359 if h == 360 else h + + h = float(h) / 360 + s = float(s) / 100 + v = float(v) / 100 + + rgb = colorsys.hsv_to_rgb(h, s, v) + return [int(c * 255) for c in rgb] + + +def rand_color(): + generator = RandomColor() + hue = random.choice(list(COLORMAP.keys())) + color = generator.generate(hue=hue, count=1, format_='rgb')[0] + color = color[color.find('(') + 1:color.find(')')] + color = tuple([int(u) for u in color.split(',')]) + return color diff --git a/modelscope/models/cv/image_to_image_translation/ops/random_mask.py b/modelscope/models/cv/image_to_image_translation/ops/random_mask.py new file mode 100644 index 00000000..bda1ec11 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/random_mask.py @@ -0,0 +1,80 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import cv2 +import numpy as np + +__all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop'] + + +def make_irregular_mask(w, + h, + max_angle=4, + max_length=200, + max_width=100, + min_strokes=1, + max_strokes=5, + mode='line'): + # initialize mask + assert mode in ['line', 'circle', 'square'] + mask = np.zeros((h, w), np.float32) + + # draw strokes + num_strokes = np.random.randint(min_strokes, max_strokes + 1) + for i in range(num_strokes): + x1 = np.random.randint(w) + y1 = np.random.randint(h) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_length) + radius = 5 + np.random.randint(max_width) + x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w) + y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h) + if mode == 'line': + cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius) + elif mode == 'circle': + cv2.circle( + mask, (x1, y1), radius=radius, color=1.0, thickness=-1) + elif mode == 'square': + radius = radius // 2 + mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1 + x1, y1 = x2, y2 + return mask + + +def make_rectangle_mask(w, + h, + margin=10, + min_size=30, + max_size=150, + min_strokes=1, + max_strokes=4): + # initialize mask + mask = np.zeros((h, w), np.float32) + + # draw rectangles + num_strokes = np.random.randint(min_strokes, max_strokes + 1) + for i in range(num_strokes): + box_w = np.random.randint(min_size, max_size) + box_h = np.random.randint(min_size, max_size) + x1 = np.random.randint(margin, w - margin - box_w + 1) + y1 = np.random.randint(margin, h - margin - box_h + 1) + mask[y1:y1 + box_h, x1:x1 + box_w] = 1 + return mask + + +def make_uncrop(w, h): + # initialize mask + mask = np.zeros((h, w), np.float32) + + # randomly halve the image + side = np.random.choice([0, 1, 2, 3]) + if side == 0: + mask[:h // 2, :] = 1 + elif side == 1: + mask[h // 2:, :] = 1 + elif side == 2: + mask[:, :w // 2] = 1 + elif side == 2: + mask[:, w // 2:] = 1 + return mask diff --git a/modelscope/models/cv/image_to_image_translation/ops/svd.py b/modelscope/models/cv/image_to_image_translation/ops/svd.py new file mode 100644 index 00000000..96f7e825 --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/svd.py @@ -0,0 +1,153 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +r"""SVD of linear degradation matrices described in the paper + ``Denoising Diffusion Restoration Models.'' + @article{kawar2022denoising, + title={Denoising Diffusion Restoration Models}, + author={Bahjat Kawar and Michael Elad and Stefano Ermon and Jiaming Song}, + year={2022}, + journal={arXiv preprint arXiv:2201.11793}, + } +""" +import torch + +__all__ = ['SVD', 'IdentitySVD', 'DenoiseSVD', 'ColorizationSVD'] + + +class SVD(object): + r"""SVD decomposition of a matrix, i.e., H = UDV^T. + NOTE: assume that all inputs (i.e., h, x) are of shape [B, CHW]. + """ + + def __init__(self, h): + self.u, self.d, self.v = torch.svd(h, some=False) + self.ut = self.u.t() + self.vt = self.v.t() + self.d[self.d < 1e-3] = 0 + + def U(self, x): + return torch.matmul(self.u, x) + + def Ut(self, x): + return torch.matmul(self.ut, x) + + def V(self, x): + return torch.matmul(self.v, x) + + def Vt(self, x): + return torch.matmul(self.vt, x) + + @property + def D(self): + return self.d + + def H(self, x): + return self.U(self.D * self.Vt(x)[:, :self.D.size(0)]) + + def Ht(self, x): + return self.V(self._pad(self.D * self.Ut(x)[:, :self.D.size(0)])) + + def Hinv(self, x): + r"""Multiplies x by the pseudo inverse of H. + """ + x = self.Ut(x) + x[:, :self.D.size(0)] = x[:, :self.D.size(0)] / self.D + return self.V(self._pad(x)) + + def _pad(self, x): + o = x.new_zeros(x.size(0), self.v.size(0)) + o[:, :self.u.size(0)] = x.view(x.size(0), -1) + return o + + def to(self, *args, **kwargs): + r"""Update the data type and device of UDV matrices. + """ + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + setattr(self, k, v.to(*args, **kwargs)) + return self + + +class IdentitySVD(SVD): + + def __init__(self, c, h, w): + self.d = torch.ones(c * h * w) + + def U(self, x): + return x.clone() + + def Ut(self, x): + return x.clone() + + def V(self, x): + return x.clone() + + def Vt(self, x): + return x.clone() + + def H(self, x): + return x.clone() + + def Ht(self, x): + return x.clone() + + def Hinv(self, x): + return x.clone() + + def _pad(self, x): + return x.clone() + + +class DenoiseSVD(SVD): + + def __init__(self, c, h, w): + self.num_entries = c * h * w + self.d = torch.ones(self.num_entries) + + def U(self, x): + return x.clone() + + def Ut(self, x): + return x.clone() + + def V(self, x): + return x.clone() + + def Vt(self, x): + return x.clone() + + def _pad(self, x): + return x.clone() + + +class ColorizationSVD(SVD): + + def __init__(self, c, h, w): + self.color_dim = c + self.num_pixels = h * w + self.u, self.d, self.v = torch.svd(torch.ones(1, c) / c, some=False) + self.vt = self.v.t() + + def U(self, x): + return self.u[0, 0] * x + + def Ut(self, x): + return self.u[0, 0] * x + + def V(self, x): + return torch.einsum('ij,bjn->bin', self.v, + x.view(x.size(0), self.color_dim, + self.num_pixels)).flatten(1) + + def Vt(self, x): + return torch.einsum('ij,bjn->bin', self.vt, + x.view(x.size(0), self.color_dim, + self.num_pixels)).flatten(1) + + @property + def D(self): + return self.d.repeat(self.num_pixels) + + def _pad(self, x): + o = x.new_zeros(x.size(0), self.color_dim * self.num_pixels) + o[:, :self.num_pixels] = x + return o diff --git a/modelscope/models/cv/image_to_image_translation/ops/utils.py b/modelscope/models/cv/image_to_image_translation/ops/utils.py new file mode 100644 index 00000000..c2aacedc --- /dev/null +++ b/modelscope/models/cv/image_to_image_translation/ops/utils.py @@ -0,0 +1,225 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import base64 +import binascii +import hashlib +import math +import os +import os.path as osp +import zipfile +from io import BytesIO +from multiprocessing.pool import ThreadPool as Pool + +import cv2 +import json +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +from .random_color import rand_color + +__all__ = [ + 'ceil_divide', 'to_device', 'rand_name', 'ema', 'parallel', 'unzip', + 'load_state_dict', 'inverse_indices', 'detect_duplicates', 'md5', 'rope', + 'format_state', 'breakup_grid', 'viz_anno_geometry', 'image_to_base64' +] + +TFS_CLIENT = None + + +def ceil_divide(a, b): + return int(math.ceil(a / b)) + + +def to_device(batch, device, non_blocking=False): + if isinstance(batch, (list, tuple)): + return type(batch)([to_device(u, device, non_blocking) for u in batch]) + elif isinstance(batch, dict): + return type(batch)([(k, to_device(v, device, non_blocking)) + for k, v in batch.items()]) + elif isinstance(batch, torch.Tensor): + return batch.to(device, non_blocking=non_blocking) + return batch + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + +@torch.no_grad() +def ema(net_ema, net, beta, copy_buffer=False): + assert 0.0 <= beta <= 1.0 + for p_ema, p in zip(net_ema.parameters(), net.parameters()): + p_ema.copy_(p.lerp(p_ema, beta)) + if copy_buffer: + for b_ema, b in zip(net_ema.buffers(), net.buffers()): + b_ema.copy_(b) + + +def parallel(func, args_list, num_workers=32, timeout=None): + assert isinstance(args_list, list) + if not isinstance(args_list[0], tuple): + args_list = [(args, ) for args in args_list] + if num_workers == 0: + return [func(*args) for args in args_list] + with Pool(processes=num_workers) as pool: + results = [pool.apply_async(func, args) for args in args_list] + results = [res.get(timeout=timeout) for res in results] + return results + + +def unzip(filename, dst_dir=None): + if dst_dir is None: + dst_dir = osp.dirname(filename) + with zipfile.ZipFile(filename, 'r') as zip_ref: + zip_ref.extractall(dst_dir) + + +def load_state_dict(module, state_dict, drop_prefix=''): + # find incompatible key-vals + src, dst = state_dict, module.state_dict() + if drop_prefix: + src = type(src)([ + (k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v) + for k, v in src.items() + ]) + missing = [k for k in dst if k not in src] + unexpected = [k for k in src if k not in dst] + unmatched = [ + k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape + ] + + # keep only compatible key-vals + incompatible = set(unexpected + unmatched) + src = type(src)([(k, v) for k, v in src.items() if k not in incompatible]) + module.load_state_dict(src, strict=False) + + # report incompatible key-vals + if len(missing) != 0: + print(' Missing: ' + ', '.join(missing), flush=True) + if len(unexpected) != 0: + print(' Unexpected: ' + ', '.join(unexpected), flush=True) + if len(unmatched) != 0: + print(' Shape unmatched: ' + ', '.join(unmatched), flush=True) + + +def inverse_indices(indices): + r"""Inverse map of indices. + E.g., if A[indices] == B, then B[inv_indices] == A. + """ + inv_indices = torch.empty_like(indices) + inv_indices[indices] = torch.arange(len(indices)).to(indices) + return inv_indices + + +def detect_duplicates(feats, thr=0.9): + assert feats.ndim == 2 + + # compute simmat + feats = F.normalize(feats, p=2, dim=1) + simmat = torch.mm(feats, feats.T) + simmat.triu_(1) + torch.cuda.synchronize() + + # detect duplicates + mask = ~simmat.gt(thr).any(dim=0) + return torch.where(mask)[0] + + +def md5(filename): + with open(filename, 'rb') as f: + return hashlib.md5(f.read()).hexdigest() + + +def rope(x): + r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C]. + """ + # reshape + shape = x.shape + x = x.view(x.size(0), -1, x.size(-1)) + l, c = x.shape[-2:] + assert c % 2 == 0 + half = c // 2 + + # apply rotary position embedding on x + sinusoid = torch.outer( + torch.arange(l).to(x), + torch.pow(10000, -torch.arange(half).to(x).div(half))) + sin, cos = torch.sin(sinusoid), torch.cos(sinusoid) + x1, x2 = x.chunk(2, dim=-1) + x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + # reshape back + return x.view(shape) + + +def format_state(state, filename=None): + r"""For comparing/aligning state_dict. + """ + content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()]) + if filename: + with open(filename, 'w') as f: + f.write(content) + + +def breakup_grid(img, grid_size): + r"""The inverse operator of ``torchvision.utils.make_grid``. + """ + # params + nrow = img.height // grid_size + ncol = img.width // grid_size + wrow = wcol = 2 # NOTE: use default values here + + # collect grids + grids = [] + for i in range(nrow): + for j in range(ncol): + x1 = j * grid_size + (j + 1) * wcol + y1 = i * grid_size + (i + 1) * wrow + grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size))) + return grids + + +def viz_anno_geometry(item): + r"""Visualize an annotation item from SmartLabel. + """ + if isinstance(item, str): + item = json.loads(item) + assert isinstance(item, dict) + + # read image + orig_img = read_image(item['image_url'], retry=100) + img = cv2.cvtColor(np.asarray(orig_img), cv2.COLOR_BGR2RGB) + + # loop over geometries + for geometry in item['sd_result']['items']: + # params + poly_img = img.copy() + color = rand_color() + points = np.array(geometry['meta']['geometry']).round().astype(int) + line_color = tuple([int(u * 0.55) for u in color]) + + # draw polygons + poly_img = cv2.fillPoly(poly_img, pts=[points], color=color) + poly_img = cv2.polylines( + poly_img, + pts=[points], + isClosed=True, + color=line_color, + thickness=2) + + # mixing + img = np.clip(0.25 * img + 0.75 * poly_img, 0, 255).astype(np.uint8) + return orig_img, Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + +def image_to_base64(img, format='JPEG'): + buffer = BytesIO() + img.save(buffer, format=format) + code = base64.b64encode(buffer.getvalue()).decode('utf-8') + return code diff --git a/modelscope/models/cv/movie_scene_segmentation/__init__.py b/modelscope/models/cv/movie_scene_segmentation/__init__.py new file mode 100644 index 00000000..25dcda96 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .model import MovieSceneSegmentationModel + from .datasets import MovieSceneSegmentationDataset + +else: + _import_structure = { + 'model': ['MovieSceneSegmentationModel'], + 'datasets': ['MovieSceneSegmentationDataset'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/movie_scene_segmentation/get_model.py b/modelscope/models/cv/movie_scene_segmentation/get_model.py new file mode 100644 index 00000000..5c66fc02 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/get_model.py @@ -0,0 +1,45 @@ +# ------------------------------------------------------------------------------------ +# BaSSL +# Copyright (c) 2021 KakaoBrain. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# Github: https://github.com/kakaobrain/bassl +# ------------------------------------------------------------------------------------ + +from .utils.shot_encoder import resnet50 +from .utils.trn import TransformerCRN + + +def get_shot_encoder(cfg): + name = cfg['model']['shot_encoder']['name'] + shot_encoder_args = cfg['model']['shot_encoder'][name] + if name == 'resnet': + depth = shot_encoder_args['depth'] + if depth == 50: + shot_encoder = resnet50(**shot_encoder_args['params'], ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + return shot_encoder + + +def get_contextual_relation_network(cfg): + crn = None + + if cfg['model']['contextual_relation_network']['enabled']: + name = cfg['model']['contextual_relation_network']['name'] + crn_args = cfg['model']['contextual_relation_network']['params'][name] + if name == 'trn': + sampling_name = cfg['model']['loss']['sampling_method']['name'] + crn_args['neighbor_size'] = ( + 2 * cfg['model']['loss']['sampling_method']['params'] + [sampling_name]['neighbor_size']) + crn = TransformerCRN(crn_args) + else: + raise NotImplementedError + + return crn + + +__all__ = ['get_shot_encoder', 'get_contextual_relation_network'] diff --git a/modelscope/models/cv/movie_scene_segmentation/model.py b/modelscope/models/cv/movie_scene_segmentation/model.py new file mode 100644 index 00000000..8117961a --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/model.py @@ -0,0 +1,198 @@ +# The implementation here is modified based on BaSSL, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/kakaobrain/bassl + +import os +import os.path as osp +from typing import Any, Dict + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as TF +from PIL import Image +from shotdetect_scenedetect_lgss import shot_detect + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .get_model import get_contextual_relation_network, get_shot_encoder +from .utils.save_op import get_pred_boundary, pred2scene, scene2video + +logger = get_logger() + + +@MODELS.register_module( + Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert) +class MovieSceneSegmentationModel(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + params = torch.load(model_path, map_location='cpu') + + config_path = osp.join(model_dir, ModelFile.CONFIGURATION) + self.cfg = Config.from_file(config_path) + + def load_param_with_prefix(prefix, model, src_params): + own_state = model.state_dict() + for name, param in own_state.items(): + src_name = prefix + '.' + name + own_state[name] = src_params[src_name] + + model.load_state_dict(own_state) + + self.shot_encoder = get_shot_encoder(self.cfg) + load_param_with_prefix('shot_encoder', self.shot_encoder, params) + self.crn = get_contextual_relation_network(self.cfg) + load_param_with_prefix('crn', self.crn, params) + + crn_name = self.cfg.model.contextual_relation_network.name + hdim = self.cfg.model.contextual_relation_network.params[crn_name][ + 'hidden_size'] + self.head_sbd = nn.Linear(hdim, 2) + load_param_with_prefix('head_sbd', self.head_sbd, params) + + self.test_transform = TF.Compose([ + TF.Resize(size=256, interpolation=Image.BICUBIC), + TF.CenterCrop(224), + TF.ToTensor(), + TF.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + sampling_method = self.cfg.dataset.sampling_method.name + self.neighbor_size = self.cfg.dataset.sampling_method.params[ + sampling_method].neighbor_size + + self.eps = 1e-5 + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: + data = inputs['video'] + labels = inputs['label'] + outputs = self.shared_step(data) + + loss = F.cross_entropy( + outputs.squeeze(), labels.squeeze(), reduction='none') + lpos = labels == 1 + lneg = labels == 0 + + pp, nn = 1, 1 + wp = (pp / float(pp + nn)) * lpos / (lpos.sum() + self.eps) + wn = (nn / float(pp + nn)) * lneg / (lneg.sum() + self.eps) + w = wp + wn + loss = (w * loss).sum() + + probs = torch.argmax(outputs, dim=1) + + re = dict(pred=probs, loss=loss) + return re + + def inference(self, batch): + logger.info('Begin scene detect ......') + bs = self.cfg.pipeline.batch_size_per_gpu + sids = batch['sid'] + inputs = batch['shot_feat'] + + shot_num = len(sids) + cnt = shot_num // bs + 1 + + infer_sid, infer_pred = [], [] + infer_result = {} + for i in range(cnt): + start = i * bs + end = (i + 1) * bs if (i + 1) * bs < shot_num else shot_num + input_ = inputs[start:end] + sid_ = sids[start:end] + input_ = torch.stack(input_) + outputs = self.shared_step(input_) # shape [b,2] + prob = F.softmax(outputs, dim=1) + infer_sid.extend(sid_.cpu().detach().numpy()) + infer_pred.extend(prob[:, 1].cpu().detach().numpy()) + infer_result.update({'pred': np.stack(infer_pred)}) + infer_result.update({'sid': infer_sid}) + + assert len(infer_result['sid']) == len(sids) + assert len(infer_result['pred']) == len(inputs) + return infer_result + + def shared_step(self, inputs): + with torch.no_grad(): + # infer shot encoder + shot_repr = self.extract_shot_representation(inputs) + assert len(shot_repr.shape) == 3 + + # infer CRN + _, pooled = self.crn(shot_repr, mask=None) + # infer boundary score + pred = self.head_sbd(pooled) + return pred + + def save_shot_feat(self, _repr): + feat = _repr.float().cpu().numpy() + pth = self.cfg.dataset.img_path + '/features' + os.makedirs(pth) + + for idx in range(_repr.shape[0]): + name = f'shot_{str(idx).zfill(4)}.npy' + name = osp.join(pth, name) + np.save(name, feat[idx]) + + def extract_shot_representation(self, + inputs: torch.Tensor) -> torch.Tensor: + """ inputs [b s k c h w] -> output [b d] """ + assert len(inputs.shape) == 6 # (B Shot Keyframe C H W) + b, s, k, c, h, w = inputs.shape + inputs = einops.rearrange(inputs, 'b s k c h w -> (b s) k c h w', s=s) + keyframe_repr = [self.shot_encoder(inputs[:, _k]) for _k in range(k)] + # [k (b s) d] -> [(b s) d] + shot_repr = torch.stack(keyframe_repr).mean(dim=0) + + shot_repr = einops.rearrange(shot_repr, '(b s) d -> b s d', s=s) + return shot_repr + + def postprocess(self, inputs: Dict[str, Any], **kwargs): + logger.info('Generate scene .......') + + pred_dict = inputs['feat'] + thres = self.cfg.pipeline.save_threshold + + anno_dict = get_pred_boundary(pred_dict, thres) + scene_dict_lst, scene_list, shot_num, shot_dict_lst = pred2scene( + self.shot2keyf, anno_dict) + if self.cfg.pipeline.save_split_scene: + re_dir = scene2video(inputs['input_video_pth'], scene_list, thres) + print(f'Split scene video saved to {re_dir}') + return len(scene_list), scene_dict_lst, shot_num, shot_dict_lst + + def preprocess(self, inputs): + logger.info('Begin shot detect......') + shot_keyf_lst, anno, shot2keyf = shot_detect( + inputs, **self.cfg.preprocessor.shot_detect) + logger.info('Shot detect done!') + + single_shot_feat, sid = [], [] + for idx, one_shot in enumerate(shot_keyf_lst): + one_shot = [ + self.test_transform(one_frame) for one_frame in one_shot + ] + one_shot = torch.stack(one_shot, dim=0) + single_shot_feat.append(one_shot) + sid.append(idx) + single_shot_feat = torch.stack(single_shot_feat, dim=0) + shot_feat = [] + for idx, one_shot in enumerate(anno): + shot_idx = int(one_shot['shot_id']) + np.arange( + -self.neighbor_size, self.neighbor_size + 1) + shot_idx = np.clip(shot_idx, 0, one_shot['num_shot']) + _one_shot = single_shot_feat[shot_idx] + shot_feat.append(_one_shot) + self.shot2keyf = shot2keyf + self.anno = anno + return shot_feat, sid diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/__init__.py b/modelscope/models/cv/movie_scene_segmentation/utils/__init__.py new file mode 100644 index 00000000..e5a929aa --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .save_op import get_pred_boundary, pred2scene, scene2video +from .shot_encoder import resnet50 +from .trn import TransformerCRN diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/head.py b/modelscope/models/cv/movie_scene_segmentation/utils/head.py new file mode 100644 index 00000000..d6468c53 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/head.py @@ -0,0 +1,25 @@ +# The implementation here is modified based on BaSSL, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/kakaobrain/bassl + +import torch.nn as nn +import torch.nn.functional as F + + +class MlpHead(nn.Module): + + def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128): + super().__init__() + self.output_dim = output_dim + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.model = nn.Sequential( + nn.Linear(self.input_dim, self.hidden_dim, bias=True), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.output_dim, bias=True), + ) + + def forward(self, x): + # x shape: [b t d] where t means the number of views + x = self.model(x) + return F.normalize(x, dim=-1) diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py b/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py new file mode 100644 index 00000000..3339e1a3 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/save_op.py @@ -0,0 +1,127 @@ +# The implementation here is modified based on SceneSeg, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/AnyiRao/SceneSeg +import os +import os.path as osp +import subprocess + +import cv2 +import numpy as np +from tqdm import tqdm + + +def get_pred_boundary(pred_dict, threshold=0.5): + pred = pred_dict['pred'] + tmp = (pred > threshold).astype(np.int32) + anno_dict = {} + for idx in range(len(tmp)): + anno_dict.update({str(pred_dict['sid'][idx]).zfill(4): int(tmp[idx])}) + return anno_dict + + +def pred2scene(shot2keyf, anno_dict): + scene_list, pair_list = get_demo_scene_list(shot2keyf, anno_dict) + + scene_dict_lst = [] + shot_num = len(shot2keyf) + shot_dict_lst = [] + for item in shot2keyf: + tmp = item.split(' ') + shot_dict_lst.append({ + 'frame': [tmp[0], tmp[1]], + 'timestamps': [tmp[-2], tmp[-1]] + }) + assert len(scene_list) == len(pair_list) + for scene_ind, scene_item in enumerate(scene_list): + scene_dict_lst.append({ + 'shot': pair_list[scene_ind], + 'frame': scene_item[0], + 'timestamps': scene_item[1] + }) + + return scene_dict_lst, scene_list, shot_num, shot_dict_lst + + +def scene2video(source_movie_fn, scene_list, thres): + + vcap = cv2.VideoCapture(source_movie_fn) + fps = vcap.get(cv2.CAP_PROP_FPS) # video.fps + out_video_dir_fn = os.path.join(os.getcwd(), + f'pred_result/scene_video_{thres}') + os.makedirs(out_video_dir_fn, exist_ok=True) + + for scene_ind, scene_item in tqdm(enumerate(scene_list)): + scene = str(scene_ind).zfill(4) + start_frame = int(scene_item[0][0]) + end_frame = int(scene_item[0][1]) + start_time, end_time = start_frame / fps, end_frame / fps + duration_time = end_time - start_time + out_video_fn = os.path.join(out_video_dir_fn, + 'scene_{}.mp4'.format(scene)) + if os.path.exists(out_video_fn): + continue + call_list = ['ffmpeg'] + call_list += ['-v', 'quiet'] + call_list += [ + '-y', '-ss', + str(start_time), '-t', + str(duration_time), '-i', source_movie_fn + ] + call_list += ['-map_chapters', '-1'] + call_list += [out_video_fn] + subprocess.call(call_list) + return osp.join(os.getcwd(), 'pred_result') + + +def get_demo_scene_list(shot2keyf, anno_dict): + pair_list = get_pair_list(anno_dict) + + scene_list = [] + for pair in pair_list: + start_shot, end_shot = int(pair[0]), int(pair[-1]) + start_frame = shot2keyf[start_shot].split(' ')[0] + end_frame = shot2keyf[end_shot].split(' ')[1] + start_timestamp = shot2keyf[start_shot].split(' ')[-2] + end_timestamp = shot2keyf[end_shot].split(' ')[-1] + scene_list.append([[start_frame, end_frame], + [start_timestamp, end_timestamp]]) + return scene_list, pair_list + + +def get_pair_list(anno_dict): + sort_anno_dict_key = sorted(anno_dict.keys()) + tmp = 0 + tmp_list = [] + tmp_label_list = [] + anno_list = [] + anno_label_list = [] + for key in sort_anno_dict_key: + value = anno_dict.get(key) + tmp += value + tmp_list.append(key) + tmp_label_list.append(value) + if tmp == 1: + anno_list.append(tmp_list) + anno_label_list.append(tmp_label_list) + tmp = 0 + tmp_list = [] + tmp_label_list = [] + continue + if key == sort_anno_dict_key[-1]: + if len(tmp_list) > 0: + anno_list.append(tmp_list) + anno_label_list.append(tmp_label_list) + if len(anno_list) == 0: + return None + while [] in anno_list: + anno_list.remove([]) + tmp_anno_list = [anno_list[0]] + pair_list = [] + for ind in range(len(anno_list) - 1): + cont_count = int(anno_list[ind + 1][0]) - int(anno_list[ind][-1]) + if cont_count > 1: + pair_list.extend(tmp_anno_list) + tmp_anno_list = [anno_list[ind + 1]] + continue + tmp_anno_list.append(anno_list[ind + 1]) + pair_list.extend(tmp_anno_list) + return pair_list diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/shot_encoder.py b/modelscope/models/cv/movie_scene_segmentation/utils/shot_encoder.py new file mode 100644 index 00000000..11d20b13 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/shot_encoder.py @@ -0,0 +1,329 @@ +# The implementation is adopted from torchvision + +from typing import Any, Callable, List, Optional, Type, Union + +import torch +import torch.nn as nn +from torch import Tensor + + +def conv3x3(in_planes: int, + out_planes: int, + stride: int = 1, + groups: int = 1, + dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + in_channel_dim: int = 3, + zero_init_residual: bool = False, + use_last_block_grid: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.use_last_block_grid = use_last_block_grid + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + in_channel_dim, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False, + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, + 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, + 0) # type: ignore[arg-type] + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + )) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + )) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor, grid: bool, level: List, both: bool, + grid_only: bool) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + if grid: + x_grid = [] + + if 3 in level: + x_grid.append(x.detach().clone()) + if not both and len(level) == 1: + return x_grid + + x = self.layer4(x) + + if 4 in level: + x_grid.append(x.detach().clone()) + if not both and len(level) == 1: + return x_grid + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + if not grid or len(level) == 0: + return x + + if grid_only: + return x_grid + + if both: + return x, x_grid + + return x + + def forward( + self, + x: Tensor, + grid: bool = False, + level: List = [], + both: bool = False, + grid_only: bool = False, + ) -> Tensor: + return self._forward_impl(x, grid, level, both, grid_only) + + +def resnet50(**kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + """ + return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) diff --git a/modelscope/models/cv/movie_scene_segmentation/utils/trn.py b/modelscope/models/cv/movie_scene_segmentation/utils/trn.py new file mode 100644 index 00000000..769e9ee4 --- /dev/null +++ b/modelscope/models/cv/movie_scene_segmentation/utils/trn.py @@ -0,0 +1,132 @@ +# ------------------------------------------------------------------------------------ +# BaSSL +# Copyright (c) 2021 KakaoBrain. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# Github: https://github.com/kakaobrain/bassl +# ------------------------------------------------------------------------------------ + +import torch +import torch.nn as nn +from transformers.models.bert.modeling_bert import BertEncoder + + +class ShotEmbedding(nn.Module): + + def __init__(self, cfg): + super().__init__() + + nn_size = cfg.neighbor_size + 2 # +1 for center shot, +1 for cls + self.shot_embedding = nn.Linear(cfg.input_dim, cfg.hidden_size) + self.position_embedding = nn.Embedding(nn_size, cfg.hidden_size) + self.mask_embedding = nn.Embedding(2, cfg.input_dim, padding_idx=0) + + # tf naming convention for layer norm + self.LayerNorm = nn.LayerNorm(cfg.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(cfg.hidden_dropout_prob) + + self.register_buffer('pos_ids', + torch.arange(nn_size, dtype=torch.long)) + + def forward( + self, + shot_emb: torch.Tensor, + mask: torch.Tensor = None, + pos_ids: torch.Tensor = None, + ) -> torch.Tensor: + + assert len(shot_emb.size()) == 3 + + if pos_ids is None: + pos_ids = self.pos_ids + + # this for mask embedding (un-masked ones remain unchanged) + if mask is not None: + self.mask_embedding.weight.data[0, :].fill_(0) + mask_emb = self.mask_embedding(mask.long()) + shot_emb = (shot_emb * (1 - mask).float()[:, :, None]) + mask_emb + + # we set [CLS] token to averaged feature + cls_emb = shot_emb.mean(dim=1) + + # embedding shots + shot_emb = torch.cat([cls_emb[:, None, :], shot_emb], dim=1) + shot_emb = self.shot_embedding(shot_emb) + pos_emb = self.position_embedding(pos_ids) + embeddings = shot_emb + pos_emb[None, :] + embeddings = self.dropout(self.LayerNorm(embeddings)) + return embeddings + + +class TransformerCRN(nn.Module): + + def __init__(self, cfg): + super().__init__() + + self.pooling_method = cfg.pooling_method + self.shot_embedding = ShotEmbedding(cfg) + self.encoder = BertEncoder(cfg) + + nn_size = cfg.neighbor_size + 2 # +1 for center shot, +1 for cls + self.register_buffer( + 'attention_mask', + self._get_extended_attention_mask( + torch.ones((1, nn_size)).float()), + ) + + def forward( + self, + shot: torch.Tensor, + mask: torch.Tensor = None, + pos_ids: torch.Tensor = None, + pooling_method: str = None, + ): + if self.attention_mask.shape[1] != (shot.shape[1] + 1): + n_shot = shot.shape[1] + 1 # +1 for CLS token + attention_mask = self._get_extended_attention_mask( + torch.ones((1, n_shot), dtype=torch.float, device=shot.device)) + else: + attention_mask = self.attention_mask + + shot_emb = self.shot_embedding(shot, mask=mask, pos_ids=pos_ids) + encoded_emb = self.encoder( + shot_emb, attention_mask=attention_mask).last_hidden_state + + return encoded_emb, self.pooler( + encoded_emb, pooling_method=pooling_method) + + def pooler(self, sequence_output, pooling_method=None): + if pooling_method is None: + pooling_method = self.pooling_method + + if pooling_method == 'cls': + return sequence_output[:, 0, :] + elif pooling_method == 'avg': + return sequence_output[:, 1:].mean(dim=1) + elif pooling_method == 'max': + return sequence_output[:, 1:].max(dim=1)[0] + elif pooling_method == 'center': + cidx = sequence_output.shape[1] // 2 + return sequence_output[:, cidx, :] + else: + raise ValueError + + def _get_extended_attention_mask(self, attention_mask): + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f'Wrong shape for attention_mask (shape {attention_mask.shape})' + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask diff --git a/modelscope/models/cv/object_detection/__init__.py b/modelscope/models/cv/object_detection/__init__.py new file mode 100644 index 00000000..0c782d7b --- /dev/null +++ b/modelscope/models/cv/object_detection/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .mmdet_model import DetectionModel + from .yolox_pai import YOLOX + +else: + _import_structure = { + 'mmdet_model': ['DetectionModel'], + 'yolox_pai': ['YOLOX'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/object_detection/mmdet_model.py b/modelscope/models/cv/object_detection/mmdet_model.py new file mode 100644 index 00000000..485d440a --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_model.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .mmdet_ms.backbones import ViT +from .mmdet_ms.dense_heads import RPNNHead +from .mmdet_ms.necks import FPNF +from .mmdet_ms.roi_heads import FCNMaskNHead, Shared4Conv1FCBBoxNHead + + +@MODELS.register_module(Tasks.human_detection, module_name=Models.detection) +@MODELS.register_module( + Tasks.image_object_detection, module_name=Models.detection) +class DetectionModel(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + + from mmcv.runner import load_checkpoint + from mmdet.datasets import replace_ImageToTensor + from mmdet.datasets.pipelines import Compose + from mmdet.models import build_detector + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + config_path = osp.join(model_dir, 'mmcv_config.py') + config = Config.from_file(config_path) + config.model.pretrained = None + self.model = build_detector(config.model) + + checkpoint = load_checkpoint( + self.model, model_path, map_location='cpu') + self.class_names = checkpoint['meta']['CLASSES'] + config.test_pipeline[0].type = 'LoadImageFromWebcam' + self.transform_input = Compose( + replace_ImageToTensor(config.test_pipeline)) + self.model.cfg = config + self.model.eval() + self.score_thr = config.score_thr + + def inference(self, data): + """data is dict,contain img and img_metas,follow with mmdet.""" + + with torch.no_grad(): + results = self.model(return_loss=False, rescale=True, **data) + return results + + def preprocess(self, image): + """image is numpy return is dict contain img and img_metas,follow with mmdet.""" + + from mmcv.parallel import collate, scatter + data = dict(img=image) + data = self.transform_input(data) + data = collate([data], samples_per_gpu=1) + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + data['img'] = [img.data[0] for img in data['img']] + + if next(self.model.parameters()).is_cuda: + data = scatter(data, [next(self.model.parameters()).device])[0] + + return data + + def postprocess(self, inputs): + + if isinstance(inputs[0], tuple): + bbox_result, _ = inputs[0] + else: + bbox_result, _ = inputs[0], None + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) + + bbox_result = np.vstack(bbox_result) + scores = bbox_result[:, -1] + inds = scores > self.score_thr + if np.sum(np.array(inds).astype('int')) == 0: + return None, None, None + bboxes = bbox_result[inds, :] + labels = labels[inds] + scores = np.around(bboxes[:, 4], 6) + bboxes = (bboxes[:, 0:4]).astype(int) + labels = [self.class_names[i_label] for i_label in labels] + return bboxes, scores, labels diff --git a/modelscope/models/cv/object_detection/mmdet_ms/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/__init__.py new file mode 100644 index 00000000..3a1fdd0b --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/__init__.py @@ -0,0 +1,6 @@ +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from .backbones import ViT +from .dense_heads import AnchorNHead, RPNNHead +from .necks import FPNF +from .utils import ConvModule_Norm, load_checkpoint diff --git a/modelscope/models/cv/object_detection/mmdet_ms/backbones/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/backbones/__init__.py new file mode 100644 index 00000000..c0697d48 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/backbones/__init__.py @@ -0,0 +1,5 @@ +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from .vit import ViT + +__all__ = ['ViT'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/backbones/vit.py b/modelscope/models/cv/object_detection/mmdet_ms/backbones/vit.py new file mode 100644 index 00000000..53bda358 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/backbones/vit.py @@ -0,0 +1,626 @@ +# -------------------------------------------------------- +# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +# Github source: https://github.com/microsoft/unilm/tree/master/beit +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# By Hangbo Bao +# Based on timm, mmseg, setr, xcit and swin code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/fudan-zvg/SETR +# https://github.com/facebookresearch/xcit/ +# https://github.com/microsoft/Swin-Transformer +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from mmdet.models.builder import BACKBONES +from mmdet.utils import get_root_logger +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + +from ..utils import load_checkpoint + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + window_size=None, + attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + self.window_size = window_size + q_size = window_size[0] + rel_sp_dim = 2 * q_size - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W, rel_pos_bias=None): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = calc_rel_pos_spatial(attn, q, self.window_size, + self.window_size, self.rel_pos_h, + self.rel_pos_w) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def calc_rel_pos_spatial( + attn, + q, + q_shape, + k_shape, + rel_pos_h, + rel_pos_w, +): + """ + Spatial Relative Positional Embeddings. + """ + sp_idx = 0 + q_h, q_w = q_shape + k_h, k_w = k_shape + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = ( + torch.arange(q_h)[:, None] * q_h_ratio + - torch.arange(k_h)[None, :] * k_h_ratio) + dist_h += (k_h - 1) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = ( + torch.arange(q_w)[:, None] * q_w_ratio + - torch.arange(k_w)[None, :] * k_w_ratio) + dist_w += (k_w - 1) * k_w_ratio + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + B, n_head, q_N, dim = q.shape + r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) + rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh) + rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw) + attn[:, :, sp_idx:, sp_idx:] = ( + attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :]).view( + B, -1, q_h * q_w, k_h * k_w) + + return attn + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + attn_head_dim=None): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + q_size = window_size[0] + rel_sp_dim = 2 * q_size - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + x = x.reshape(B_, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.window_size[1] + - W % self.window_size[1]) % self.window_size[1] + pad_b = (self.window_size[0] + - H % self.window_size[0]) % self.window_size[0] + + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x = window_partition( + x, self.window_size[0]) # nW*B, window_size, window_size, C + x = x.view(-1, self.window_size[1] * self.window_size[0], + C) # nW*B, window_size*window_size, C + B_w = x.shape[0] + N_w = x.shape[1] + qkv = self.qkv(x).reshape(B_w, N_w, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = calc_rel_pos_spatial(attn, q, self.window_size, + self.window_size, self.rel_pos_h, + self.rel_pos_w) + + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_w, N_w, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x.view(-1, self.window_size[1], self.window_size[0], C) + x = window_reverse(x, self.window_size[0], Hp, Wp) # B H' W' C + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B_, H * W, C) + + return x + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + window_size=None, + attn_head_dim=None, + window=False): + super().__init__() + self.norm1 = norm_layer(dim) + if not window: + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + attn_head_dim=attn_head_dim) + else: + self.attn = WindowAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if init_values is not None: + self.gamma_1 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, H, W): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path( + self.gamma_1 * self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * ( + img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, + backbone, + img_size=224, + feature_size=None, + in_chans=3, + embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone( + torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class Norm2d(nn.Module): + + def __init__(self, embed_dim): + super().__init__() + self.ln = nn.LayerNorm(embed_dim, eps=1e-6) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = self.ln(x) + x = x.permute(0, 3, 1, 2).contiguous() + return x + + +@BACKBONES.register_module() +class ViT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=80, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + hybrid_backbone=None, + norm_layer=None, + init_values=None, + use_checkpoint=False, + use_abs_pos_emb=False, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + out_indices=[11], + interval=3, + pretrained=None): + super().__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, + img_size=img_size, + in_chans=in_chans, + embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.out_indices = out_indices + + if use_abs_pos_emb: + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim)) + else: + self.pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + window_size=(14, 14) if + ((i + 1) % interval != 0) else self.patch_embed.patch_shape, + window=((i + 1) % interval != 0)) for i in range(depth) + ]) + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self.norm = norm_layer(embed_dim) + + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), + Norm2d(embed_dim), + nn.GELU(), + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, + stride=2), ) + + self.fpn3 = nn.Identity() + + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.apply(self._init_weights) + self.fix_init_weight() + self.pretrained = pretrained + + def fix_init_weight(self): + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + pretrained = pretrained or self.pretrained + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + print(f'load from {pretrained}') + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + for i, blk in enumerate(self.blocks): + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, Hp, Wp) + + x = self.norm(x) + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp) + + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(ops)): + features.append(ops[i](xp)) + + return tuple(features) + + def forward(self, x): + + x = self.forward_features(x) + + return x diff --git a/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/__init__.py new file mode 100644 index 00000000..0d34e996 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/__init__.py @@ -0,0 +1,6 @@ +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from .anchor_head import AnchorNHead +from .rpn_head import RPNNHead + +__all__ = ['AnchorNHead', 'RPNNHead'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/anchor_head.py b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/anchor_head.py new file mode 100644 index 00000000..d4ea5282 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/anchor_head.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from mmdet.models.builder import HEADS +from mmdet.models.dense_heads import AnchorHead + + +@HEADS.register_module() +class AnchorNHead(AnchorHead): + """Anchor-based head (RPN, RetinaNet, SSD, etc.). + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels. Used in child classes. + anchor_generator (dict): Config dict for anchor generator + bbox_coder (dict): Config of bounding box coder. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Default False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + loss_cls (dict): Config of classification loss. + loss_bbox (dict): Config of localization loss. + train_cfg (dict): Training config of anchor head. + test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ # noqa: W605 + + def __init__(self, + num_classes, + in_channels, + feat_channels, + anchor_generator=None, + bbox_coder=None, + reg_decoded_bbox=False, + loss_cls=None, + loss_bbox=None, + train_cfg=None, + test_cfg=None, + norm_cfg=None, + init_cfg=None): + self.norm_cfg = norm_cfg + super(AnchorNHead, + self).__init__(num_classes, in_channels, feat_channels, + anchor_generator, bbox_coder, reg_decoded_bbox, + loss_cls, loss_bbox, train_cfg, test_cfg, + init_cfg) diff --git a/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/rpn_head.py b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/rpn_head.py new file mode 100644 index 00000000..8e934a5c --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/dense_heads/rpn_head.py @@ -0,0 +1,269 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import batched_nms +from mmdet.models.builder import HEADS + +from ..utils import ConvModule_Norm +from .anchor_head import AnchorNHead + + +@HEADS.register_module() +class RPNNHead(AnchorNHead): + """RPN head. + + Args: + in_channels (int): Number of channels in the input feature map. + init_cfg (dict or list[dict], optional): Initialization config dict. + num_convs (int): Number of convolution layers in the head. Default 1. + """ # noqa: W605 + + def __init__(self, + in_channels, + init_cfg=dict(type='Normal', layer='Conv2d', std=0.01), + num_convs=1, + **kwargs): + self.num_convs = num_convs + super(RPNNHead, self).__init__( + 1, in_channels, init_cfg=init_cfg, **kwargs) + + def _init_layers(self): + """Initialize layers of the head.""" + if self.num_convs > 1: + rpn_convs = [] + for i in range(self.num_convs): + if i == 0: + in_channels = self.in_channels + else: + in_channels = self.feat_channels + # use ``inplace=False`` to avoid error: one of the variables + # needed for gradient computation has been modified by an + # inplace operation. + rpn_convs.append( + ConvModule_Norm( + in_channels, + self.feat_channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + inplace=False)) + self.rpn_conv = nn.Sequential(*rpn_convs) + else: + self.rpn_conv = nn.Conv2d( + self.in_channels, self.feat_channels, 3, padding=1) + self.rpn_cls = nn.Conv2d(self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 1) + self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_base_priors * 4, + 1) + + def forward_single(self, x): + """Forward feature map of a single scale level.""" + x = self.rpn_conv(x) + x = F.relu(x, inplace=True) + rpn_cls_score = self.rpn_cls(x) + rpn_bbox_pred = self.rpn_reg(x) + return rpn_cls_score, rpn_bbox_pred + + def loss(self, + cls_scores, + bbox_preds, + gt_bboxes, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + losses = super(RPNNHead, self).loss( + cls_scores, + bbox_preds, + gt_bboxes, + None, + img_metas, + gt_bboxes_ignore=gt_bboxes_ignore) + return dict( + loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox']) + + def _get_bboxes_single(self, + cls_score_list, + bbox_pred_list, + score_factor_list, + mlvl_anchors, + img_meta, + cfg, + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_anchors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has + shape (num_anchors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. RPN head does not need this value. + mlvl_anchors (list[Tensor]): Anchors of all scale level + each item has shape (num_anchors, 4). + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. + """ + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + + # bboxes from different level should be independent during NMS, + # level_ids are used as labels for batched NMS to separate them + level_ids = [] + mlvl_scores = [] + mlvl_bbox_preds = [] + mlvl_valid_anchors = [] + nms_pre = cfg.get('nms_pre', -1) + for level_idx in range(len(cls_score_list)): + rpn_cls_score = cls_score_list[level_idx] + rpn_bbox_pred = bbox_pred_list[level_idx] + assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] + rpn_cls_score = rpn_cls_score.permute(1, 2, 0) + if self.use_sigmoid_cls: + rpn_cls_score = rpn_cls_score.reshape(-1) + scores = rpn_cls_score.sigmoid() + else: + rpn_cls_score = rpn_cls_score.reshape(-1, 2) + # We set FG labels to [0, num_class-1] and BG label to + # num_class in RPN head since mmdet v2.5, which is unified to + # be consistent with other head since mmdet v2.0. In mmdet v2.0 + # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. + scores = rpn_cls_score.softmax(dim=1)[:, 0] + rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) + + anchors = mlvl_anchors[level_idx] + if 0 < nms_pre < scores.shape[0]: + # sort is faster than topk + # _, topk_inds = scores.topk(cfg.nms_pre) + ranked_scores, rank_inds = scores.sort(descending=True) + topk_inds = rank_inds[:nms_pre] + scores = ranked_scores[:nms_pre] + rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] + anchors = anchors[topk_inds, :] + + mlvl_scores.append(scores) + mlvl_bbox_preds.append(rpn_bbox_pred) + mlvl_valid_anchors.append(anchors) + level_ids.append( + scores.new_full((scores.size(0), ), + level_idx, + dtype=torch.long)) + + return self._bbox_post_process(mlvl_scores, mlvl_bbox_preds, + mlvl_valid_anchors, level_ids, cfg, + img_shape) + + def _bbox_post_process(self, mlvl_scores, mlvl_bboxes, mlvl_valid_anchors, + level_ids, cfg, img_shape, **kwargs): + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually with_nms is False is used for aug test. + + Args: + mlvl_scores (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_bboxes, num_class). + mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale + levels of a single image, each item has shape (num_bboxes, 4). + mlvl_valid_anchors (list[Tensor]): Anchors of all scale level + each item has shape (num_bboxes, 4). + level_ids (list[Tensor]): Indexes from all scale levels of a + single image, each item has shape (num_bboxes, ). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + img_shape (tuple(int)): Shape of current image. + + Returns: + Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. + """ + scores = torch.cat(mlvl_scores) + anchors = torch.cat(mlvl_valid_anchors) + rpn_bbox_pred = torch.cat(mlvl_bboxes) + proposals = self.bbox_coder.decode( + anchors, rpn_bbox_pred, max_shape=img_shape) + ids = torch.cat(level_ids) + + if cfg.min_bbox_size >= 0: + w = proposals[:, 2] - proposals[:, 0] + h = proposals[:, 3] - proposals[:, 1] + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + proposals = proposals[valid_mask] + scores = scores[valid_mask] + ids = ids[valid_mask] + + if proposals.numel() > 0: + dets, _ = batched_nms(proposals, scores, ids, cfg.nms) + else: + return proposals.new_zeros(0, 5) + + return dets[:cfg.max_per_img] + + def onnx_export(self, x, img_metas): + """Test without augmentation. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + img_metas (list[dict]): Meta info of each image. + Returns: + Tensor: dets of shape [N, num_det, 5]. + """ + cls_scores, bbox_preds = self(x) + + assert len(cls_scores) == len(bbox_preds) + + batch_bboxes, batch_scores = super(RPNNHead, self).onnx_export( + cls_scores, bbox_preds, img_metas=img_metas, with_nms=False) + # Use ONNX::NonMaxSuppression in deployment + from mmdet.core.export import add_dummy_nms_for_onnx + cfg = copy.deepcopy(self.test_cfg) + score_threshold = cfg.nms.get('score_thr', 0.0) + nms_pre = cfg.get('deploy_nms_pre', -1) + # Different from the normal forward doing NMS level by level, + # we do NMS across all levels when exporting ONNX. + dets, _ = add_dummy_nms_for_onnx(batch_bboxes, batch_scores, + cfg.max_per_img, + cfg.nms.iou_threshold, + score_threshold, nms_pre, + cfg.max_per_img) + return dets diff --git a/modelscope/models/cv/object_detection/mmdet_ms/necks/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/necks/__init__.py new file mode 100644 index 00000000..d164987e --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/necks/__init__.py @@ -0,0 +1,5 @@ +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from .fpn import FPNF + +__all__ = ['FPNF'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/necks/fpn.py b/modelscope/models/cv/object_detection/mmdet_ms/necks/fpn.py new file mode 100644 index 00000000..5f8648ce --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/necks/fpn.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +import torch.nn as nn +import torch.nn.functional as F +from mmcv.runner import BaseModule, auto_fp16 +from mmdet.models.builder import NECKS + +from ..utils import ConvModule_Norm + + +@NECKS.register_module() +class FPNF(BaseModule): + r"""Feature Pyramid Network. + + This is an implementation of paper `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (str): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(mode='nearest')` + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + use_residual=True, + upsample_cfg=dict(mode='nearest'), + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super(FPNF, self).__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + self.use_residual = use_residual + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + self.add_extra_convs = 'on_input' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule_Norm( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule_Norm( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule_Norm( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + @auto_fp16() + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + if self.use_residual: + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] += F.interpolate(laterals[i], + **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + + return tuple(outs) diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/__init__.py new file mode 100644 index 00000000..658280df --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/__init__.py @@ -0,0 +1,10 @@ +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from .bbox_heads import (ConvFCBBoxNHead, Shared2FCBBoxNHead, + Shared4Conv1FCBBoxNHead) +from .mask_heads import FCNMaskNHead + +__all__ = [ + 'ConvFCBBoxNHead', 'Shared2FCBBoxNHead', 'Shared4Conv1FCBBoxNHead', + 'FCNMaskNHead' +] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/__init__.py new file mode 100644 index 00000000..61d93503 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/__init__.py @@ -0,0 +1,6 @@ +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from .convfc_bbox_head import (ConvFCBBoxNHead, Shared2FCBBoxNHead, + Shared4Conv1FCBBoxNHead) + +__all__ = ['ConvFCBBoxNHead', 'Shared2FCBBoxNHead', 'Shared4Conv1FCBBoxNHead'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/convfc_bbox_head.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/convfc_bbox_head.py new file mode 100644 index 00000000..726329a1 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/bbox_heads/convfc_bbox_head.py @@ -0,0 +1,230 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +import torch.nn as nn +from mmdet.models.builder import HEADS +from mmdet.models.roi_heads.bbox_heads.bbox_head import BBoxHead +from mmdet.models.utils import build_linear_layer + +from ...utils import ConvModule_Norm + + +@HEADS.register_module() +class ConvFCBBoxNHead(BBoxHead): + r"""More general bbox head, with shared conv and fc layers and two optional + separated branches. + + .. code-block:: none + + /-> cls convs -> cls fcs -> cls + shared convs -> shared fcs + \-> reg convs -> reg fcs -> reg + """ # noqa: W605 + + def __init__(self, + num_shared_convs=0, + num_shared_fcs=0, + num_cls_convs=0, + num_cls_fcs=0, + num_reg_convs=0, + num_reg_fcs=0, + conv_out_channels=256, + fc_out_channels=1024, + conv_cfg=None, + norm_cfg=None, + init_cfg=None, + *args, + **kwargs): + super(ConvFCBBoxNHead, self).__init__( + *args, init_cfg=init_cfg, **kwargs) + assert (num_shared_convs + num_shared_fcs + num_cls_convs + num_cls_fcs + + num_reg_convs + num_reg_fcs > 0) + if num_cls_convs > 0 or num_reg_convs > 0: + assert num_shared_fcs == 0 + if not self.with_cls: + assert num_cls_convs == 0 and num_cls_fcs == 0 + if not self.with_reg: + assert num_reg_convs == 0 and num_reg_fcs == 0 + self.num_shared_convs = num_shared_convs + self.num_shared_fcs = num_shared_fcs + self.num_cls_convs = num_cls_convs + self.num_cls_fcs = num_cls_fcs + self.num_reg_convs = num_reg_convs + self.num_reg_fcs = num_reg_fcs + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + # add shared convs and fcs + self.shared_convs, self.shared_fcs, last_layer_dim = \ + self._add_conv_fc_branch( + self.num_shared_convs, self.num_shared_fcs, self.in_channels, + True) + self.shared_out_channels = last_layer_dim + + # add cls specific branch + self.cls_convs, self.cls_fcs, self.cls_last_dim = \ + self._add_conv_fc_branch( + self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels) + + # add reg specific branch + self.reg_convs, self.reg_fcs, self.reg_last_dim = \ + self._add_conv_fc_branch( + self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels) + + if self.num_shared_fcs == 0 and not self.with_avg_pool: + if self.num_cls_fcs == 0: + self.cls_last_dim *= self.roi_feat_area + if self.num_reg_fcs == 0: + self.reg_last_dim *= self.roi_feat_area + + self.relu = nn.ReLU(inplace=True) + # reconstruct fc_cls and fc_reg since input channels are changed + if self.with_cls: + if self.custom_cls_channels: + cls_channels = self.loss_cls.get_cls_channels(self.num_classes) + else: + cls_channels = self.num_classes + 1 + self.fc_cls = build_linear_layer( + self.cls_predictor_cfg, + in_features=self.cls_last_dim, + out_features=cls_channels) + if self.with_reg: + out_dim_reg = (4 if self.reg_class_agnostic else 4 + * self.num_classes) + self.fc_reg = build_linear_layer( + self.reg_predictor_cfg, + in_features=self.reg_last_dim, + out_features=out_dim_reg) + + if init_cfg is None: + # when init_cfg is None, + # It has been set to + # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))], + # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))] + # after `super(ConvFCBBoxHead, self).__init__()` + # we only need to append additional configuration + # for `shared_fcs`, `cls_fcs` and `reg_fcs` + self.init_cfg += [ + dict( + type='Xavier', + override=[ + dict(name='shared_fcs'), + dict(name='cls_fcs'), + dict(name='reg_fcs') + ]) + ] + + def _add_conv_fc_branch(self, + num_branch_convs, + num_branch_fcs, + in_channels, + is_shared=False): + """Add shared or separable branch. + + convs -> avg pool (optional) -> fcs + """ + last_layer_dim = in_channels + # add branch specific conv layers + branch_convs = nn.ModuleList() + if num_branch_convs > 0: + for i in range(num_branch_convs): + conv_in_channels = ( + last_layer_dim if i == 0 else self.conv_out_channels) + branch_convs.append( + ConvModule_Norm( + conv_in_channels, + self.conv_out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + last_layer_dim = self.conv_out_channels + # add branch specific fc layers + branch_fcs = nn.ModuleList() + if num_branch_fcs > 0: + # for shared branch, only consider self.with_avg_pool + # for separated branches, also consider self.num_shared_fcs + if (is_shared + or self.num_shared_fcs == 0) and not self.with_avg_pool: + last_layer_dim *= self.roi_feat_area + for i in range(num_branch_fcs): + fc_in_channels = ( + last_layer_dim if i == 0 else self.fc_out_channels) + branch_fcs.append( + nn.Linear(fc_in_channels, self.fc_out_channels)) + last_layer_dim = self.fc_out_channels + return branch_convs, branch_fcs, last_layer_dim + + def forward(self, x): + # shared part + if self.num_shared_convs > 0: + for conv in self.shared_convs: + x = conv(x) + + if self.num_shared_fcs > 0: + if self.with_avg_pool: + x = self.avg_pool(x) + + x = x.flatten(1) + + for fc in self.shared_fcs: + x = self.relu(fc(x)) + # separate branches + x_cls = x + x_reg = x + + for conv in self.cls_convs: + x_cls = conv(x_cls) + if x_cls.dim() > 2: + if self.with_avg_pool: + x_cls = self.avg_pool(x_cls) + x_cls = x_cls.flatten(1) + for fc in self.cls_fcs: + x_cls = self.relu(fc(x_cls)) + + for conv in self.reg_convs: + x_reg = conv(x_reg) + if x_reg.dim() > 2: + if self.with_avg_pool: + x_reg = self.avg_pool(x_reg) + x_reg = x_reg.flatten(1) + for fc in self.reg_fcs: + x_reg = self.relu(fc(x_reg)) + + cls_score = self.fc_cls(x_cls) if self.with_cls else None + bbox_pred = self.fc_reg(x_reg) if self.with_reg else None + return cls_score, bbox_pred + + +@HEADS.register_module() +class Shared2FCBBoxNHead(ConvFCBBoxNHead): + + def __init__(self, fc_out_channels=1024, *args, **kwargs): + super(Shared2FCBBoxNHead, self).__init__( + num_shared_convs=0, + num_shared_fcs=2, + num_cls_convs=0, + num_cls_fcs=0, + num_reg_convs=0, + num_reg_fcs=0, + fc_out_channels=fc_out_channels, + *args, + **kwargs) + + +@HEADS.register_module() +class Shared4Conv1FCBBoxNHead(ConvFCBBoxNHead): + + def __init__(self, fc_out_channels=1024, *args, **kwargs): + super(Shared4Conv1FCBBoxNHead, self).__init__( + num_shared_convs=4, + num_shared_fcs=1, + num_cls_convs=0, + num_cls_fcs=0, + num_reg_convs=0, + num_reg_fcs=0, + fc_out_channels=fc_out_channels, + *args, + **kwargs) diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/__init__.py new file mode 100644 index 00000000..043e62a0 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/__init__.py @@ -0,0 +1,5 @@ +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from .fcn_mask_head import FCNMaskNHead + +__all__ = ['FCNMaskNHead'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/fcn_mask_head.py b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/fcn_mask_head.py new file mode 100644 index 00000000..335f6b8f --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/roi_heads/mask_heads/fcn_mask_head.py @@ -0,0 +1,415 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from warnings import warn + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_conv_layer, build_upsample_layer +from mmcv.ops.carafe import CARAFEPack +from mmcv.runner import BaseModule, ModuleList, auto_fp16, force_fp32 +from mmdet.core import mask_target +from mmdet.models.builder import HEADS, build_loss +from torch.nn.modules.utils import _pair + +from ...utils import ConvModule_Norm + +BYTES_PER_FLOAT = 4 +# TODO: This memory limit may be too much or too little. It would be better to +# determine it based on available resources. +GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit + + +@HEADS.register_module() +class FCNMaskNHead(BaseModule): + + def __init__(self, + num_convs=4, + roi_feat_size=14, + in_channels=256, + conv_kernel_size=3, + conv_out_channels=256, + num_classes=80, + class_agnostic=False, + upsample_cfg=dict(type='deconv', scale_factor=2), + conv_cfg=None, + norm_cfg=None, + predictor_cfg=dict(type='Conv'), + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), + init_cfg=None): + assert init_cfg is None, 'To prevent abnormal initialization ' \ + 'behavior, init_cfg is not allowed to be set' + super(FCNMaskNHead, self).__init__(init_cfg) + self.upsample_cfg = upsample_cfg.copy() + if self.upsample_cfg['type'] not in [ + None, 'deconv', 'nearest', 'bilinear', 'carafe' + ]: + raise ValueError( + f'Invalid upsample method {self.upsample_cfg["type"]}, ' + 'accepted methods are "deconv", "nearest", "bilinear", ' + '"carafe"') + self.num_convs = num_convs + # WARN: roi_feat_size is reserved and not used + self.roi_feat_size = _pair(roi_feat_size) + self.in_channels = in_channels + self.conv_kernel_size = conv_kernel_size + self.conv_out_channels = conv_out_channels + self.upsample_method = self.upsample_cfg.get('type') + self.scale_factor = self.upsample_cfg.pop('scale_factor', None) + self.num_classes = num_classes + self.class_agnostic = class_agnostic + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.predictor_cfg = predictor_cfg + self.fp16_enabled = False + self.loss_mask = build_loss(loss_mask) + + self.convs = ModuleList() + for i in range(self.num_convs): + in_channels = ( + self.in_channels if i == 0 else self.conv_out_channels) + padding = (self.conv_kernel_size - 1) // 2 + self.convs.append( + ConvModule_Norm( + in_channels, + self.conv_out_channels, + self.conv_kernel_size, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + upsample_in_channels = ( + self.conv_out_channels if self.num_convs > 0 else in_channels) + upsample_cfg_ = self.upsample_cfg.copy() + if self.upsample_method is None: + self.upsample = None + elif self.upsample_method == 'deconv': + upsample_cfg_.update( + in_channels=upsample_in_channels, + out_channels=self.conv_out_channels, + kernel_size=self.scale_factor, + stride=self.scale_factor) + self.upsample = build_upsample_layer(upsample_cfg_) + elif self.upsample_method == 'carafe': + upsample_cfg_.update( + channels=upsample_in_channels, scale_factor=self.scale_factor) + self.upsample = build_upsample_layer(upsample_cfg_) + else: + # suppress warnings + align_corners = (None + if self.upsample_method == 'nearest' else False) + upsample_cfg_.update( + scale_factor=self.scale_factor, + mode=self.upsample_method, + align_corners=align_corners) + self.upsample = build_upsample_layer(upsample_cfg_) + + out_channels = 1 if self.class_agnostic else self.num_classes + logits_in_channel = ( + self.conv_out_channels + if self.upsample_method == 'deconv' else upsample_in_channels) + self.conv_logits = build_conv_layer(self.predictor_cfg, + logits_in_channel, out_channels, 1) + self.relu = nn.ReLU(inplace=True) + self.debug_imgs = None + + def init_weights(self): + super(FCNMaskNHead, self).init_weights() + for m in [self.upsample, self.conv_logits]: + if m is None: + continue + elif isinstance(m, CARAFEPack): + m.init_weights() + elif hasattr(m, 'weight') and hasattr(m, 'bias'): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + nn.init.constant_(m.bias, 0) + + @auto_fp16() + def forward(self, x): + for conv in self.convs: + x = conv(x) + if self.upsample is not None: + x = self.upsample(x) + if self.upsample_method == 'deconv': + x = self.relu(x) + mask_pred = self.conv_logits(x) + return mask_pred + + def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg): + pos_proposals = [res.pos_bboxes for res in sampling_results] + pos_assigned_gt_inds = [ + res.pos_assigned_gt_inds for res in sampling_results + ] + mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, + gt_masks, rcnn_train_cfg) + return mask_targets + + @force_fp32(apply_to=('mask_pred', )) + def loss(self, mask_pred, mask_targets, labels): + """ + Example: + >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA + >>> N = 7 # N = number of extracted ROIs + >>> C, H, W = 11, 32, 32 + >>> # Create example instance of FCN Mask Head. + >>> # There are lots of variations depending on the configuration + >>> self = FCNMaskHead(num_classes=C, num_convs=1) + >>> inputs = torch.rand(N, self.in_channels, H, W) + >>> mask_pred = self.forward(inputs) + >>> sf = self.scale_factor + >>> labels = torch.randint(0, C, size=(N,)) + >>> # With the default properties the mask targets should indicate + >>> # a (potentially soft) single-class label + >>> mask_targets = torch.rand(N, H * sf, W * sf) + >>> loss = self.loss(mask_pred, mask_targets, labels) + >>> print('loss = {!r}'.format(loss)) + """ + loss = dict() + if mask_pred.size(0) == 0: + loss_mask = mask_pred.sum() + else: + if self.class_agnostic: + loss_mask = self.loss_mask(mask_pred, mask_targets, + torch.zeros_like(labels)) + else: + loss_mask = self.loss_mask(mask_pred, mask_targets, labels) + loss['loss_mask'] = loss_mask + return loss + + def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, + ori_shape, scale_factor, rescale): + """Get segmentation masks from mask_pred and bboxes. + + Args: + mask_pred (Tensor or ndarray): shape (n, #class, h, w). + For single-scale testing, mask_pred is the direct output of + model, whose type is Tensor, while for multi-scale testing, + it will be converted to numpy array outside of this method. + det_bboxes (Tensor): shape (n, 4/5) + det_labels (Tensor): shape (n, ) + rcnn_test_cfg (dict): rcnn testing config + ori_shape (Tuple): original image height and width, shape (2,) + scale_factor(ndarray | Tensor): If ``rescale is True``, box + coordinates are divided by this scale factor to fit + ``ori_shape``. + rescale (bool): If True, the resulting masks will be rescaled to + ``ori_shape``. + + Returns: + list[list]: encoded masks. The c-th item in the outer list + corresponds to the c-th class. Given the c-th outer list, the + i-th item in that inner list is the mask for the i-th box with + class label c. + + Example: + >>> import mmcv + >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA + >>> N = 7 # N = number of extracted ROIs + >>> C, H, W = 11, 32, 32 + >>> # Create example instance of FCN Mask Head. + >>> self = FCNMaskHead(num_classes=C, num_convs=0) + >>> inputs = torch.rand(N, self.in_channels, H, W) + >>> mask_pred = self.forward(inputs) + >>> # Each input is associated with some bounding box + >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N) + >>> det_labels = torch.randint(0, C, size=(N,)) + >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, }) + >>> ori_shape = (H * 4, W * 4) + >>> scale_factor = torch.FloatTensor((1, 1)) + >>> rescale = False + >>> # Encoded masks are a list for each category. + >>> encoded_masks = self.get_seg_masks( + >>> mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape, + >>> scale_factor, rescale + >>> ) + >>> assert len(encoded_masks) == C + >>> assert sum(list(map(len, encoded_masks))) == N + """ + if isinstance(mask_pred, torch.Tensor): + mask_pred = mask_pred.sigmoid() + else: + # In AugTest, has been activated before + mask_pred = det_bboxes.new_tensor(mask_pred) + + device = mask_pred.device + cls_segms = [[] for _ in range(self.num_classes) + ] # BG is not included in num_classes + bboxes = det_bboxes[:, :4] + labels = det_labels + + # In most cases, scale_factor should have been + # converted to Tensor when rescale the bbox + if not isinstance(scale_factor, torch.Tensor): + if isinstance(scale_factor, float): + scale_factor = np.array([scale_factor] * 4) + warn('Scale_factor should be a Tensor or ndarray ' + 'with shape (4,), float would be deprecated. ') + assert isinstance(scale_factor, np.ndarray) + scale_factor = torch.Tensor(scale_factor) + + if rescale: + img_h, img_w = ori_shape[:2] + bboxes = bboxes / scale_factor.to(bboxes) + else: + w_scale, h_scale = scale_factor[0], scale_factor[1] + img_h = np.round(ori_shape[0] * h_scale.item()).astype(np.int32) + img_w = np.round(ori_shape[1] * w_scale.item()).astype(np.int32) + + N = len(mask_pred) + # The actual implementation split the input into chunks, + # and paste them chunk by chunk. + if device.type == 'cpu': + # CPU is most efficient when they are pasted one by one with + # skip_empty=True, so that it performs minimal number of + # operations. + num_chunks = N + else: + # GPU benefits from parallelism for larger chunks, + # but may have memory issue + # the types of img_w and img_h are np.int32, + # when the image resolution is large, + # the calculation of num_chunks will overflow. + # so we need to change the types of img_w and img_h to int. + # See https://github.com/open-mmlab/mmdetection/pull/5191 + num_chunks = int( + np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT + / GPU_MEM_LIMIT)) + # assert (num_chunks <= N), 'Default GPU_MEM_LIMIT is too small; try increasing it' + assert num_chunks <= N, 'Default GPU_MEM_LIMIT is too small; try increasing it' + chunks = torch.chunk(torch.arange(N, device=device), num_chunks) + + threshold = rcnn_test_cfg.mask_thr_binary + im_mask = torch.zeros( + N, + img_h, + img_w, + device=device, + dtype=torch.bool if threshold >= 0 else torch.uint8) + + if not self.class_agnostic: + mask_pred = mask_pred[range(N), labels][:, None] + + for inds in chunks: + masks_chunk, spatial_inds = _do_paste_mask( + mask_pred[inds], + bboxes[inds], + img_h, + img_w, + skip_empty=device.type == 'cpu') + + if threshold >= 0: + masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) + else: + # for visualization and debugging + masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) + + im_mask[(inds, ) + spatial_inds] = masks_chunk + + for i in range(N): + cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy()) + return cls_segms + + def onnx_export(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, + ori_shape, **kwargs): + """Get segmentation masks from mask_pred and bboxes. + + Args: + mask_pred (Tensor): shape (n, #class, h, w). + det_bboxes (Tensor): shape (n, 4/5) + det_labels (Tensor): shape (n, ) + rcnn_test_cfg (dict): rcnn testing config + ori_shape (Tuple): original image height and width, shape (2,) + + Returns: + Tensor: a mask of shape (N, img_h, img_w). + """ + + mask_pred = mask_pred.sigmoid() + bboxes = det_bboxes[:, :4] + labels = det_labels + # No need to consider rescale and scale_factor while exporting to ONNX + img_h, img_w = ori_shape[:2] + threshold = rcnn_test_cfg.mask_thr_binary + if not self.class_agnostic: + box_inds = torch.arange(mask_pred.shape[0]) + mask_pred = mask_pred[box_inds, labels][:, None] + masks, _ = _do_paste_mask( + mask_pred, bboxes, img_h, img_w, skip_empty=False) + if threshold >= 0: + # should convert to float to avoid problems in TRT + masks = (masks >= threshold).to(dtype=torch.float) + return masks + + +def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): + """Paste instance masks according to boxes. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): N, 1, H, W + boxes (Tensor): N, 4 + img_h (int): Height of the image to be pasted. + img_w (int): Width of the image to be pasted. + skip_empty (bool): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + tuple: (Tensor, tuple). The first item is mask tensor, the second one + is the slice object. + If skip_empty == False, the whole image will be pasted. It will + return a mask of shape (N, img_h, img_w) and an empty tuple. + If skip_empty == True, only area around the mask will be pasted. + A mask of shape (N, h', w') and its start and end coordinates + in the original image will be returned. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + if skip_empty: + x0_int, y0_int = torch.clamp( + boxes.min(dim=0).values.floor()[:2] - 1, + min=0).to(dtype=torch.int32) + x1_int = torch.clamp( + boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp( + boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + N = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + # img_x, img_y have shapes (N, w), (N, h) + # IsInf op is not supported with ONNX<=1.7.0 + if not torch.onnx.is_in_onnx_export(): + if torch.isinf(img_x).any(): + inds = torch.where(torch.isinf(img_x)) + img_x[inds] = 0 + if torch.isinf(img_y).any(): + inds = torch.where(torch.isinf(img_y)) + img_y[inds] = 0 + + gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + img_masks = F.grid_sample( + masks.to(dtype=torch.float32), grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + else: + return img_masks[:, 0], () diff --git a/modelscope/models/cv/object_detection/mmdet_ms/utils/__init__.py b/modelscope/models/cv/object_detection/mmdet_ms/utils/__init__.py new file mode 100644 index 00000000..34f240c6 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/utils/__init__.py @@ -0,0 +1,6 @@ +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from .checkpoint import load_checkpoint +from .convModule_norm import ConvModule_Norm + +__all__ = ['load_checkpoint', 'ConvModule_Norm'] diff --git a/modelscope/models/cv/object_detection/mmdet_ms/utils/checkpoint.py b/modelscope/models/cv/object_detection/mmdet_ms/utils/checkpoint.py new file mode 100644 index 00000000..7833f592 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/utils/checkpoint.py @@ -0,0 +1,559 @@ +# Copyright (c) Open-MMLab. All rights reserved. +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +import io +import os +import os.path as osp +import pkgutil +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import mmcv +import torch +import torchvision +from mmcv.fileio import FileClient +from mmcv.fileio import load as load_file +from mmcv.parallel import is_module_wrapper +from mmcv.runner import get_dist_info +from torch.nn import functional as F +from torch.optim import Optimizer +from torch.utils import model_zoo + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + print('finish load') + + +def load_url_dist(url, model_dir=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load( + downloaded_file, map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + allowed_backends = ['ceph'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('torchvision://'): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('open-mmlab://'): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' + f'of open-mmlab://{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith('mmcls://'): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(('http://', 'https://')): + checkpoint = load_url_dist(filename) + elif filename.startswith('pavi://'): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith('s3://'): + checkpoint = load_fileclient_dist( + filename, backend='ceph', map_location=map_location) + else: + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None, + load_ema=True): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if load_ema and 'state_dict_ema' in checkpoint: + state_dict = checkpoint['state_dict_ema'] + # logger.info(f'loading from state_dict_ema') + logger.info('loading from state_dict_ema') + elif 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + # logger.info(f'loading from state_dict') + logger.info('loading from state_dict') + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + # logger.info(f'loading from model') + logger.info('loading from model') + print('loading from model') + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = { + k.replace('encoder.', ''): v + for k, v in state_dict.items() if k.startswith('encoder.') + } + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2) + + all_keys = list(state_dict.keys()) + for key in all_keys: + if 'relative_position_index' in key: + state_dict.pop(key) + + if 'relative_position_bias_table' in key: + state_dict.pop(key) + + if '.q_bias' in key: + q_bias = state_dict[key] + v_bias = state_dict[key.replace('q_bias', 'v_bias')] + qkv_bias = torch.cat([q_bias, torch.zeros_like(q_bias), v_bias], 0) + state_dict[key.replace('q_bias', 'qkv.bias')] = qkv_bias + + if '.v.bias' in key: + continue + + all_keys = list(state_dict.keys()) + new_state_dict = {} + for key in all_keys: + if 'qkv.bias' in key: + value = state_dict[key] + dim = value.shape[0] + selected_dim = (dim * 2) // 3 + new_state_dict[key.replace( + 'qkv.bias', 'pos_bias')] = state_dict[key][:selected_dim] + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + if table_key not in model.state_dict().keys(): + logger.warning( + 'relative_position_bias_table exits in pretrained model but not in current one, pass' + ) + continue + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f'Error in loading {table_key}, pass') + else: + if L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0) + rank, _ = get_dist_info() + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + H, W = model.patch_embed.patch_shape + num_patches = model.patch_embed.num_patches + num_extra_tokens = 1 + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + if rank == 0: + print('Position interpolate from %dx%d to %dx%d' % + (orig_size, orig_size, H, W)) + # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, + embedding_size).permute( + 0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(H, W), mode='bicubic', align_corners=False) + new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + # new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, 'wb') as f: + torch.save(checkpoint, f) + f.flush() diff --git a/modelscope/models/cv/object_detection/mmdet_ms/utils/convModule_norm.py b/modelscope/models/cv/object_detection/mmdet_ms/utils/convModule_norm.py new file mode 100644 index 00000000..a15780f7 --- /dev/null +++ b/modelscope/models/cv/object_detection/mmdet_ms/utils/convModule_norm.py @@ -0,0 +1,30 @@ +# Implementation in this file is modified based on ViTAE-Transformer +# Originally Apache 2.0 License and publicly avaialbe at https://github.com/ViTAE-Transformer/ViTDet +from mmcv.cnn import ConvModule + + +class ConvModule_Norm(ConvModule): + + def __init__(self, in_channels, out_channels, kernel, **kwargs): + super().__init__(in_channels, out_channels, kernel, **kwargs) + + self.normType = kwargs.get('norm_cfg', {'type': ''}) + if self.normType is not None: + self.normType = self.normType['type'] + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == 'conv': + if self.with_explicit_padding: + x = self.padding_layer(x) + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: + if 'LN' in self.normType: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = x.permute(0, 3, 1, 2).contiguous() + else: + x = self.norm(x) + elif layer == 'act' and activate and self.with_activation: + x = self.activate(x) + return x diff --git a/modelscope/models/cv/object_detection/yolox_pai.py b/modelscope/models/cv/object_detection/yolox_pai.py new file mode 100644 index 00000000..46bd4e3c --- /dev/null +++ b/modelscope/models/cv/object_detection/yolox_pai.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.models.detection.detectors import YOLOX as _YOLOX + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.cv.easycv_base import EasyCVBaseModel +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + group_key=Tasks.image_object_detection, module_name=Models.yolox) +@MODELS.register_module( + group_key=Tasks.image_object_detection, + module_name=Models.image_object_detection_auto) +class YOLOX(EasyCVBaseModel, _YOLOX): + + def __init__(self, model_dir=None, *args, **kwargs): + EasyCVBaseModel.__init__(self, model_dir, args, kwargs) + _YOLOX.__init__(self, *args, **kwargs) diff --git a/modelscope/models/cv/product_retrieval_embedding/__init__.py b/modelscope/models/cv/product_retrieval_embedding/__init__.py new file mode 100644 index 00000000..2cbc9099 --- /dev/null +++ b/modelscope/models/cv/product_retrieval_embedding/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .item_model import ProductRetrievalEmbedding + +else: + _import_structure = { + 'item_model': ['ProductRetrievalEmbedding'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/product_retrieval_embedding/item_detection.py b/modelscope/models/cv/product_retrieval_embedding/item_detection.py new file mode 100644 index 00000000..2002c6cb --- /dev/null +++ b/modelscope/models/cv/product_retrieval_embedding/item_detection.py @@ -0,0 +1,522 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import cv2 +import numpy as np + + +class YOLOXONNX(object): + """ + Product detection model with onnx inference + """ + + def __init__(self, onnx_path, multi_detect=False): + """Create product detection model + Args: + onnx_path: onnx model path for product detection + multi_detect: detection parameter, should be set as False + + """ + self.input_reso = 416 + self.iou_thr = 0.45 + self.score_thr = 0.3 + self.img_shape = tuple([self.input_reso, self.input_reso, 3]) + self.num_classes = 13 + self.onnx_path = onnx_path + import onnxruntime as ort + options = ort.SessionOptions() + options.intra_op_num_threads = 1 + options.inter_op_num_threads = 1 + self.ort_session = ort.InferenceSession( + self.onnx_path, sess_options=options) + self.with_p6 = False + self.multi_detect = multi_detect + + def format_judge(self, img): + m_min_width = 100 + m_min_height = 100 + + height, width, c = img.shape + + if width * height > 1024 * 1024: + if height > width: + long_side = height + short_side = width + long_ratio = float(long_side) / 1024.0 + short_ratio = float(short_side) / float(m_min_width) + else: + long_side = width + short_side = height + long_ratio = float(long_side) / 1024.0 + short_ratio = float(short_side) / float(m_min_height) + + if long_side == height: + if long_ratio < short_ratio: + height_new = 1024 + width_new = (int)((1024 * width) / height) + + img_res = cv2.resize(img, (width_new, height_new), + cv2.INTER_LINEAR) + else: + height_new = (int)((m_min_width * height) / width) + width_new = m_min_width + + img_res = cv2.resize(img, (width_new, height_new), + cv2.INTER_LINEAR) + + elif long_side == width: + if long_ratio < short_ratio: + height_new = (int)((1024 * height) / width) + width_new = 1024 + + img_res = cv2.resize(img, (width_new, height_new), + cv2.INTER_LINEAR) + else: + width_new = (int)((m_min_height * width) / height) + height_new = m_min_height + + img_res = cv2.resize(img, (width_new, height_new), + cv2.INTER_LINEAR) + else: + img_res = img + + return img_res + + def preprocess(self, image, input_size, swap=(2, 0, 1)): + """ + Args: + image, cv2 image with BGR format + input_size, model input size + """ + if len(image.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 + else: + padded_img = np.ones(input_size) * 114.0 + img = np.array(image) + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.float32) + padded_img[:int(img.shape[0] * r), :int(img.shape[1] + * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + + def cal_iou(self, val1, val2): + x11, y11, x12, y12 = val1 + x21, y21, x22, y22 = val2 + + leftX = max(x11, x21) + topY = max(y11, y21) + rightX = min(x12, x22) + bottomY = min(y12, y22) + if rightX < leftX or bottomY < topY: + return 0 + area = float((rightX - leftX) * (bottomY - topY)) + barea = (x12 - x11) * (y12 - y11) + (x22 - x21) * (y22 - y21) - area + if barea <= 0: + return 0 + return area / barea + + def nms(self, boxes, scores, nms_thr): + """ + Single class NMS implemented in Numpy. + """ + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + + def multiclass_nms(self, boxes, scores, nms_thr, score_thr): + """ + Multiclass NMS implemented in Numpy + """ + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = self.nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate([ + valid_boxes[keep], valid_scores[keep, None], cls_inds + ], 1) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + + def postprocess(self, outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + + if not p6: + strides = [8, 16, 32] + else: + strides = [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + + def get_new_box_order(self, bboxes, labels, img_h, img_w): + """ + refine bbox score + """ + bboxes = np.hstack((bboxes, np.zeros((bboxes.shape[0], 1)))) + scores = bboxes[:, 4] + order = scores.argsort()[::-1] + bboxes_temp = bboxes[order] + labels_temp = labels[order] + bboxes = np.empty((0, 6)) + # import pdb;pdb.set_trace() + bboxes = np.vstack((bboxes, bboxes_temp[0].tolist())) + labels = np.empty((0, )) + + labels = np.hstack((labels, [labels_temp[0]])) + for i in range(1, bboxes_temp.shape[0]): + iou_max = 0 + for j in range(bboxes.shape[0]): + iou_temp = self.cal_iou(bboxes_temp[i][:4], bboxes[j][:4]) + if (iou_temp > iou_max): + iou_max = iou_temp + if (iou_max < 0.45): + bboxes = np.vstack((bboxes, bboxes_temp[i].tolist())) + labels = np.hstack((labels, [labels_temp[i]])) + + num_03 = scores > 0.3 + num_03 = num_03.sum() + num_out = max(num_03, 1) + bboxes = bboxes[:num_out, :] + labels = labels[:num_out] + + return bboxes, labels + + def forward(self, img_input, cid='0', sub_class=False): + """ + forward for product detection + """ + input_shape = self.img_shape + + img, ratio = self.preprocess(img_input, input_shape) + img_h, img_w = img_input.shape[:2] + + ort_inputs = { + self.ort_session.get_inputs()[0].name: img[None, :, :, :] + } + + output = self.ort_session.run(None, ort_inputs) + + predictions = self.postprocess(output[0], input_shape, self.with_p6)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2. + boxes_xyxy /= ratio + dets = self.multiclass_nms( + boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + + if dets is None: + top1_bbox_str = str(0) + ',' + str(img_w) + ',' + str( + 0) + ',' + str(img_h) + crop_img = img_input.copy() + coord = top1_bbox_str + else: + bboxes = dets[:, :5] + labels = dets[:, 5] + + if not self.multi_detect: + cid = int(cid) + if (not sub_class): + if cid > -1: + if cid == 0: # cloth + cid_ind1 = np.where(labels < 3) + cid_ind2 = np.where(labels == 9) + cid_ind = np.hstack((cid_ind1[0], cid_ind2[0])) + scores = bboxes[cid_ind, -1] # 0, 1, 2, 9 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 3: # bag + cid_ind = np.where(labels == 3) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 4: # shoe + cid_ind = np.where(labels == 4) + scores = bboxes[cid_ind, -1] # 4 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + else: # other + cid_ind5 = np.where(labels == 5) + cid_ind6 = np.where(labels == 6) + cid_ind7 = np.where(labels == 7) + cid_ind8 = np.where(labels == 8) + cid_ind10 = np.where(labels == 10) + cid_ind11 = np.where(labels == 11) + cid_ind12 = np.where(labels == 12) + cid_ind = np.hstack( + (cid_ind5[0], cid_ind6[0], cid_ind7[0], + cid_ind8[0], cid_ind10[0], cid_ind11[0], + cid_ind12[0])) + scores = bboxes[cid_ind, -1] # 5,6,7,8,10,11,12 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + else: + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + else: + if cid > -1: + if cid == 0: # upper + cid_ind = np.where(labels == 0) + + scores = bboxes[cid_ind, -1] # 0, 1, 2, 9 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 1: # skirt + cid_ind = np.where(labels == 1) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 2: # lower + cid_ind = np.where(labels == 2) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 3: # bag + cid_ind = np.where(labels == 3) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 4: # shoe + cid_ind = np.where(labels == 4) + scores = bboxes[cid_ind, -1] # 4 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 5: # access + cid_ind = np.where(labels == 5) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 7: # beauty + cid_ind = np.where(labels == 6) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 9: # furniture + cid_ind = np.where(labels == 8) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + elif cid == 21: # underwear + cid_ind = np.where(labels == 9) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + elif cid == 22: # digital + cid_ind = np.where(labels == 11) + scores = bboxes[cid_ind, -1] # 3 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + else: # other + cid_ind5 = np.where(labels == 7) # bottle + cid_ind6 = np.where(labels == 10) # toy + cid_ind7 = np.where(labels == 12) # toy + cid_ind = np.hstack( + (cid_ind5[0], cid_ind6[0], cid_ind7[0])) + scores = bboxes[cid_ind, -1] # 5,6,7 + + if scores.size > 0: + + bboxes = bboxes[cid_ind] + labels = labels[cid_ind] + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + + else: + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + else: + bboxes, labels = self.get_new_box_order( + bboxes, labels, img_h, img_w) + top1_bbox = bboxes[0].astype(np.int32) + top1_bbox[0] = min(max(0, top1_bbox[0]), img_input.shape[1] - 1) + top1_bbox[1] = min(max(0, top1_bbox[1]), img_input.shape[0] - 1) + top1_bbox[2] = max(min(img_input.shape[1] - 1, top1_bbox[2]), 0) + top1_bbox[3] = max(min(img_input.shape[0] - 1, top1_bbox[3]), 0) + if not self.multi_detect: + + top1_bbox_str = str(top1_bbox[0]) + ',' + str( + top1_bbox[2]) + ',' + str(top1_bbox[1]) + ',' + str( + top1_bbox[3]) # x1, x2, y1, y2 + crop_img = img_input[top1_bbox[1]:top1_bbox[3], + top1_bbox[0]:top1_bbox[2], :] + coord = top1_bbox_str + coord = '' + for i in range(0, len(bboxes)): + top_bbox = bboxes[i].astype(np.int32) + top_bbox[0] = min( + max(0, top_bbox[0]), img_input.shape[1] - 1) + top_bbox[1] = min( + max(0, top_bbox[1]), img_input.shape[0] - 1) + top_bbox[2] = max( + min(img_input.shape[1] - 1, top_bbox[2]), 0) + top_bbox[3] = max( + min(img_input.shape[0] - 1, top_bbox[3]), 0) + coord = coord + str(top_bbox[0]) + ',' + str( + top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str( + top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str( + bboxes[i][5]) + ';' + + else: + coord = '' + for i in range(0, len(bboxes)): + top_bbox = bboxes[i].astype(np.int32) + top_bbox[0] = min( + max(0, top_bbox[0]), img_input.shape[1] - 1) + top_bbox[1] = min( + max(0, top_bbox[1]), img_input.shape[0] - 1) + top_bbox[2] = max( + min(img_input.shape[1] - 1, top_bbox[2]), 0) + top_bbox[3] = max( + min(img_input.shape[0] - 1, top_bbox[3]), 0) + coord = coord + str(top_bbox[0]) + ',' + str( + top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str( + top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str( + bboxes[i][5]) + ';' # x1, x2, y1, y2, conf + crop_img = img_input[top1_bbox[1]:top1_bbox[3], + top1_bbox[0]:top1_bbox[2], :] + + crop_img = cv2.resize(crop_img, (224, 224)) + + return coord, crop_img # return top1 image and coord diff --git a/modelscope/models/cv/product_retrieval_embedding/item_embedding.py b/modelscope/models/cv/product_retrieval_embedding/item_embedding.py new file mode 100644 index 00000000..ea9ec846 --- /dev/null +++ b/modelscope/models/cv/product_retrieval_embedding/item_embedding.py @@ -0,0 +1,154 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import cv2 +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + + +def gn_init(m, zero_init=False): + assert isinstance(m, nn.GroupNorm) + m.weight.data.fill_(0. if zero_init else 1.) + m.bias.data.zero_() + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + """Bottleneck for resnet-style networks + Args: + inplanes: input channel number + planes: output channel number + """ + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.GroupNorm(32, planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.GroupNorm(32, planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.GroupNorm(32, planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + gn_init(self.bn1) + gn_init(self.bn2) + gn_init(self.bn3, zero_init=True) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + """ + resnet-style network with group normalization + """ + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.GroupNorm(32, 64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1) + + self.gap = nn.AvgPool2d((14, 14)) + self.reduce_conv = nn.Conv2d(2048, 512, kernel_size=1) + + gn_init(self.bn1) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d(stride, stride), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=1, + bias=False), + nn.GroupNorm(32, planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.gap(x) + x = self.reduce_conv(x) # 512 + + x = x.view(x.size(0), -1) # 512 + return F.normalize(x, p=2, dim=1) + + +def preprocess(img): + """ + preprocess the image with cv2-bgr style to tensor + """ + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + + img_size = 224 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_new = cv2.resize( + img, (img_size, img_size), interpolation=cv2.INTER_LINEAR) + content = np.array(img_new).astype(np.float32) + content = (content / 255.0 - mean) / std + # transpose + img_new = content.transpose(2, 0, 1) + img_new = img_new[np.newaxis, :, :, :] + return img_new + + +def resnet50_embed(): + """ + create resnet50 network with group normalization + """ + net = ResNet(Bottleneck, [3, 4, 6, 3]) + return net diff --git a/modelscope/models/cv/product_retrieval_embedding/item_model.py b/modelscope/models/cv/product_retrieval_embedding/item_model.py new file mode 100644 index 00000000..3964efbe --- /dev/null +++ b/modelscope/models/cv/product_retrieval_embedding/item_model.py @@ -0,0 +1,116 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os.path as osp +from typing import Any, Dict + +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.product_retrieval_embedding.item_detection import \ + YOLOXONNX +from modelscope.models.cv.product_retrieval_embedding.item_embedding import ( + preprocess, resnet50_embed) +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import create_device +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ProductRetrievalEmbedding'] + + +@MODELS.register_module( + Tasks.product_retrieval_embedding, + module_name=Models.product_retrieval_embedding) +class ProductRetrievalEmbedding(TorchModel): + + def __init__(self, model_dir, device='cpu', **kwargs): + super().__init__(model_dir=model_dir, device=device, **kwargs) + + def filter_param(src_params, own_state): + copied_keys = [] + for name, param in src_params.items(): + if 'module.' == name[0:7]: + name = name[7:] + if '.module.' not in list(own_state.keys())[0]: + name = name.replace('.module.', '.') + if (name in own_state) and (own_state[name].shape + == param.shape): + own_state[name].copy_(param) + copied_keys.append(name) + + def load_pretrained(model, src_params): + if 'state_dict' in src_params: + src_params = src_params['state_dict'] + own_state = model.state_dict() + filter_param(src_params, own_state) + model.load_state_dict(own_state) + + self.device = create_device( + device) # device.type == "cpu" or device.type == "cuda" + self.use_gpu = self.device.type == 'cuda' + + # config the model path + self.local_model_dir = model_dir + + # init feat model + self.preprocess_for_embed = preprocess # input is cv2 bgr format + model_feat = resnet50_embed() + src_params = torch.load( + osp.join(self.local_model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + 'cpu') + load_pretrained(model_feat, src_params) + if self.use_gpu: + model_feat.to(self.device) + logger.info('Use GPU: {}'.format(self.device)) + else: + logger.info('Use CPU for inference') + + self.model_feat = model_feat + + # init det model + self.model_det = YOLOXONNX( + onnx_path=osp.join(self.local_model_dir, 'onnx_detection.onnx'), + multi_detect=False) + logger.info('load model done') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """ + detection and feature extraction for input product image + """ + # input should be cv2 bgr format + assert 'img' in input.keys() + + def set_phase(model, is_train): + if is_train: + model.train() + else: + model.eval() + + is_train = False + set_phase(self.model_feat, is_train) + img = input['img'] # for detection + cid = '3' # preprocess detection category bag + # transform img(tensor) to numpy array with bgr + if isinstance(img, torch.Tensor): + img = img.data.cpu().numpy() + res, crop_img = self.model_det.forward(img, + cid) # detect with bag category + crop_img = self.preprocess_for_embed(crop_img) # feat preprocess + input_tensor = torch.from_numpy(crop_img.astype(np.float32)) + device = next(self.model_feat.parameters()).device + use_gpu = device.type == 'cuda' + with torch.no_grad(): + if use_gpu: + input_tensor = input_tensor.to(device) + out_embedding = self.model_feat(input_tensor) + out_embedding = out_embedding.cpu().numpy()[ + 0, :] # feature array with 512 elements + + output = {OutputKeys.IMG_EMBEDDING: None} + output[OutputKeys.IMG_EMBEDDING] = out_embedding + return output diff --git a/modelscope/models/cv/product_segmentation/__init__.py b/modelscope/models/cv/product_segmentation/__init__.py new file mode 100644 index 00000000..e87c8db1 --- /dev/null +++ b/modelscope/models/cv/product_segmentation/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .seg_infer import F3NetProductSegmentation + +else: + _import_structure = {'seg_infer': ['F3NetProductSegmentation']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/product_segmentation/net.py b/modelscope/models/cv/product_segmentation/net.py new file mode 100644 index 00000000..454c99d8 --- /dev/null +++ b/modelscope/models/cv/product_segmentation/net.py @@ -0,0 +1,197 @@ +# The implementation here is modified based on F3Net, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/weijun88/F3Net + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Bottleneck(nn.Module): + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=(3 * dilation - 1) // 2, + bias=False, + dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.downsample = downsample + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x)), inplace=True) + out = F.relu(self.bn2(self.conv2(out)), inplace=True) + out = self.bn3(self.conv3(out)) + if self.downsample is not None: + x = self.downsample(x) + return F.relu(out + x, inplace=True) + + +class ResNet(nn.Module): + + def __init__(self): + super(ResNet, self).__init__() + self.inplanes = 64 + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self.make_layer(64, 3, stride=1, dilation=1) + self.layer2 = self.make_layer(128, 4, stride=2, dilation=1) + self.layer3 = self.make_layer(256, 6, stride=2, dilation=1) + self.layer4 = self.make_layer(512, 3, stride=2, dilation=1) + + def make_layer(self, planes, blocks, stride, dilation): + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * 4, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(planes * 4)) + layers = [ + Bottleneck( + self.inplanes, planes, stride, downsample, dilation=dilation) + ] + self.inplanes = planes * 4 + for _ in range(1, blocks): + layers.append(Bottleneck(self.inplanes, planes, dilation=dilation)) + return nn.Sequential(*layers) + + def forward(self, x): + x = x.reshape(1, 3, 448, 448) + out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) + out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) + out2 = self.layer1(out1) + out3 = self.layer2(out2) + out4 = self.layer3(out3) + out5 = self.layer4(out4) + return out2, out3, out4, out5 + + +class CFM(nn.Module): + + def __init__(self): + super(CFM, self).__init__() + self.conv1h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn1h = nn.BatchNorm2d(64) + self.conv2h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn2h = nn.BatchNorm2d(64) + self.conv3h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn3h = nn.BatchNorm2d(64) + self.conv4h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn4h = nn.BatchNorm2d(64) + + self.conv1v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn1v = nn.BatchNorm2d(64) + self.conv2v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn2v = nn.BatchNorm2d(64) + self.conv3v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn3v = nn.BatchNorm2d(64) + self.conv4v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn4v = nn.BatchNorm2d(64) + + def forward(self, left, down): + if down.size()[2:] != left.size()[2:]: + down = F.interpolate(down, size=left.size()[2:], mode='bilinear') + out1h = F.relu(self.bn1h(self.conv1h(left)), inplace=True) + out2h = F.relu(self.bn2h(self.conv2h(out1h)), inplace=True) + out1v = F.relu(self.bn1v(self.conv1v(down)), inplace=True) + out2v = F.relu(self.bn2v(self.conv2v(out1v)), inplace=True) + fuse = out2h * out2v + out3h = F.relu(self.bn3h(self.conv3h(fuse)), inplace=True) + out1h + out4h = F.relu(self.bn4h(self.conv4h(out3h)), inplace=True) + out3v = F.relu(self.bn3v(self.conv3v(fuse)), inplace=True) + out1v + out4v = F.relu(self.bn4v(self.conv4v(out3v)), inplace=True) + return out4h, out4v + + +class Decoder(nn.Module): + + def __init__(self): + super(Decoder, self).__init__() + self.cfm45 = CFM() + self.cfm34 = CFM() + self.cfm23 = CFM() + + def forward(self, out2h, out3h, out4h, out5v, fback=None): + if fback is not None: + refine5 = F.interpolate( + fback, size=out5v.size()[2:], mode='bilinear') + refine4 = F.interpolate( + fback, size=out4h.size()[2:], mode='bilinear') + refine3 = F.interpolate( + fback, size=out3h.size()[2:], mode='bilinear') + refine2 = F.interpolate( + fback, size=out2h.size()[2:], mode='bilinear') + out5v = out5v + refine5 + out4h, out4v = self.cfm45(out4h + refine4, out5v) + out3h, out3v = self.cfm34(out3h + refine3, out4v) + out2h, pred = self.cfm23(out2h + refine2, out3v) + else: + out4h, out4v = self.cfm45(out4h, out5v) + out3h, out3v = self.cfm34(out3h, out4v) + out2h, pred = self.cfm23(out2h, out3v) + return out2h, out3h, out4h, out5v, pred + + +class F3Net(nn.Module): + + def __init__(self): + super(F3Net, self).__init__() + self.bkbone = ResNet() + self.squeeze5 = nn.Sequential( + nn.Conv2d(2048, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.squeeze4 = nn.Sequential( + nn.Conv2d(1024, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.squeeze3 = nn.Sequential( + nn.Conv2d(512, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.squeeze2 = nn.Sequential( + nn.Conv2d(256, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + + self.decoder1 = Decoder() + self.decoder2 = Decoder() + self.linearp1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearp2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + + self.linearr2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearr3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearr4 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearr5 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + + def forward(self, x, shape=None): + x = x.reshape(1, 3, 448, 448) + out2h, out3h, out4h, out5v = self.bkbone(x) + out2h, out3h, out4h, out5v = self.squeeze2(out2h), self.squeeze3( + out3h), self.squeeze4(out4h), self.squeeze5(out5v) + out2h, out3h, out4h, out5v, pred1 = self.decoder1( + out2h, out3h, out4h, out5v) + out2h, out3h, out4h, out5v, pred2 = self.decoder2( + out2h, out3h, out4h, out5v, pred1) + + shape = x.size()[2:] if shape is None else shape + pred1 = F.interpolate( + self.linearp1(pred1), size=shape, mode='bilinear') + pred2 = F.interpolate( + self.linearp2(pred2), size=shape, mode='bilinear') + + out2h = F.interpolate( + self.linearr2(out2h), size=shape, mode='bilinear') + out3h = F.interpolate( + self.linearr3(out3h), size=shape, mode='bilinear') + out4h = F.interpolate( + self.linearr4(out4h), size=shape, mode='bilinear') + out5h = F.interpolate( + self.linearr5(out5v), size=shape, mode='bilinear') + return pred1, pred2, out2h, out3h, out4h, out5h diff --git a/modelscope/models/cv/product_segmentation/seg_infer.py b/modelscope/models/cv/product_segmentation/seg_infer.py new file mode 100644 index 00000000..8814d619 --- /dev/null +++ b/modelscope/models/cv/product_segmentation/seg_infer.py @@ -0,0 +1,76 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import cv2 +import numpy as np +import torch +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .net import F3Net + +logger = get_logger() + + +def load_state_dict(model_dir, device): + _dict = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device) + state_dict = {} + for k, v in _dict.items(): + if k.startswith('module'): + k = k[7:] + state_dict[k] = v + return state_dict + + +@MODELS.register_module( + Tasks.product_segmentation, module_name=Models.product_segmentation) +class F3NetForProductSegmentation(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = F3Net() + if torch.cuda.is_available(): + self.device = 'cuda' + logger.info('Use GPU') + else: + self.device = 'cpu' + logger.info('Use CPU') + + self.params = load_state_dict(model_dir, self.device) + self.model.load_state_dict(self.params) + self.model.to(self.device) + self.model.eval() + self.model.to(self.device) + + def forward(self, x): + pred_result = self.model(x) + return pred_result + + +mean, std = np.array([[[124.55, 118.90, + 102.94]]]), np.array([[[56.77, 55.97, 57.50]]]) + + +def inference(model, device, img): + img = img.cpu().numpy() + img = (img - mean) / std + img = cv2.resize(img, dsize=(448, 448), interpolation=cv2.INTER_LINEAR) + img = torch.from_numpy(img) + img = img.permute(2, 0, 1) + img = img.to(device).float() + outputs = model(img) + out = outputs[0] + pred = (torch.sigmoid(out[0, 0]) * 255).cpu().numpy() + pred[pred < 20] = 0 + pred = pred[:, :, np.newaxis] + pred = np.round(pred) + logger.info('Inference Done') + return pred diff --git a/modelscope/models/cv/realtime_object_detection/__init__.py b/modelscope/models/cv/realtime_object_detection/__init__.py new file mode 100644 index 00000000..66156977 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .realtime_detector import RealtimeDetector + from .realtime_video_detector import RealtimeVideoDetector +else: + _import_structure = { + 'realtime_detector': ['RealtimeDetector'], + 'realtime_video_detector': ['RealtimeVideoDetector'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/realtime_object_detection/realtime_detector.py b/modelscope/models/cv/realtime_object_detection/realtime_detector.py new file mode 100644 index 00000000..2b4b3f8c --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/realtime_detector.py @@ -0,0 +1,90 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import logging as logger +import os +import os.path as osp +import time + +import cv2 +import json +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .yolox.data.data_augment import ValTransform +from .yolox.exp import get_exp_by_name +from .yolox.utils import postprocess + + +@MODELS.register_module( + group_key=Tasks.image_object_detection, + module_name=Models.realtime_object_detection) +class RealtimeDetector(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + # model type + self.exp = get_exp_by_name(self.config.model_type) + + # build model + self.model = self.exp.get_model() + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + ckpt = torch.load(model_path, map_location='cpu') + + # load the model state dict + self.model.load_state_dict(ckpt['model']) + self.model.eval() + + # params setting + self.exp.num_classes = self.config.num_classes + self.confthre = self.config.conf_thr + self.num_classes = self.exp.num_classes + self.nmsthre = self.exp.nmsthre + self.test_size = self.exp.test_size + self.preproc = ValTransform(legacy=False) + self.label_mapping = self.config['labels'] + + def inference(self, img): + with torch.no_grad(): + outputs = self.model(img) + return outputs + + def forward(self, inputs): + return self.inference(inputs) + + def preprocess(self, img): + img = LoadImage.convert_to_ndarray(img) + height, width = img.shape[:2] + self.ratio = min(self.test_size[0] / img.shape[0], + self.test_size[1] / img.shape[1]) + + img, _ = self.preproc(img, None, self.test_size) + img = torch.from_numpy(img).unsqueeze(0) + img = img.float() + + return img + + def postprocess(self, input): + outputs = postprocess( + input, + self.num_classes, + self.confthre, + self.nmsthre, + class_agnostic=True) + + if len(outputs) == 1: + bboxes = outputs[0][:, 0:4].cpu().numpy() / self.ratio + scores = outputs[0][:, 5].cpu().numpy() + labels = outputs[0][:, 6].cpu().int().numpy() + pred_label_names = [] + for lab in labels: + pred_label_names.append(self.label_mapping[lab]) + + return bboxes, scores, pred_label_names diff --git a/modelscope/models/cv/realtime_object_detection/realtime_video_detector.py b/modelscope/models/cv/realtime_object_detection/realtime_video_detector.py new file mode 100644 index 00000000..3830fb42 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/realtime_video_detector.py @@ -0,0 +1,121 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import logging as logger +import os +import os.path as osp +import time + +import cv2 +import json +import torch +from tqdm import tqdm + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .utils import timestamp_format +from .yolox.data.data_augment import ValTransform +from .yolox.exp import get_exp_by_name +from .yolox.utils import postprocess + + +@MODELS.register_module( + group_key=Tasks.video_object_detection, + module_name=Models.realtime_video_object_detection) +class RealtimeVideoDetector(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + # model type + self.exp = get_exp_by_name(self.config.model_type) + + # build model + self.model = self.exp.get_model() + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + ckpt = torch.load(model_path, map_location='cpu') + + # load the model state dict + self.model.load_state_dict(ckpt['model']) + self.model.eval() + + # params setting + self.exp.num_classes = self.config.num_classes + self.confthre = self.config.conf_thr + self.num_classes = self.exp.num_classes + self.nmsthre = self.exp.nmsthre + self.test_size = self.exp.test_size + self.preproc = ValTransform(legacy=False) + self.current_buffer = None + self.label_mapping = self.config['labels'] + + def inference(self, img): + with torch.no_grad(): + outputs, self.current_buffer = self.model( + img, buffer=self.current_buffer, mode='on_pipe') + return outputs + + def forward(self, inputs): + return self.inference_video(inputs) + + def preprocess(self, img): + img = LoadImage.convert_to_ndarray(img) + height, width = img.shape[:2] + self.ratio = min(self.test_size[0] / img.shape[0], + self.test_size[1] / img.shape[1]) + + img, _ = self.preproc(img, None, self.test_size) + img = torch.from_numpy(img).unsqueeze(0) + img = img.float() + + # Video decoding and preprocessing automatically are not supported by Pipeline/Model + # Sending preprocessed video frame tensor to GPU buffer self-adaptively + if next(self.model.parameters()).is_cuda: + img = img.to(next(self.model.parameters()).device) + return img + + def postprocess(self, input): + outputs = postprocess( + input, + self.num_classes, + self.confthre, + self.nmsthre, + class_agnostic=True) + + if len(outputs) == 1: + bboxes = outputs[0][:, 0:4].cpu().numpy() / self.ratio + scores = outputs[0][:, 5].cpu().numpy() + labels = outputs[0][:, 6].cpu().int().numpy() + pred_label_names = [] + for lab in labels: + pred_label_names.append(self.label_mapping[lab]) + + return bboxes, scores, pred_label_names + + def inference_video(self, v_path): + outputs = [] + desc = 'Detecting video: {}'.format(v_path) + for frame_idx, (frame, result) in enumerate( + tqdm(self.inference_video_iter(v_path), desc=desc)): + result = result + (timestamp_format(seconds=frame_idx + / self.fps), ) + outputs.append(result) + + return outputs + + def inference_video_iter(self, v_path): + capture = cv2.VideoCapture(v_path) + self.fps = capture.get(cv2.CAP_PROP_FPS) + while capture.isOpened(): + ret, frame = capture.read() + if not ret: + break + output = self.preprocess(frame) + output = self.inference(output) + output = self.postprocess(output) + yield frame, output diff --git a/modelscope/models/cv/realtime_object_detection/utils.py b/modelscope/models/cv/realtime_object_detection/utils.py new file mode 100644 index 00000000..c3d7a4c6 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/utils.py @@ -0,0 +1,9 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math + + +def timestamp_format(seconds): + m, s = divmod(seconds, 60) + h, m = divmod(m, 60) + time = '%02d:%02d:%06.3f' % (h, m, s) + return time diff --git a/modelscope/models/cv/realtime_object_detection/yolox/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/realtime_object_detection/yolox/data/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/realtime_object_detection/yolox/data/data_augment.py b/modelscope/models/cv/realtime_object_detection/yolox/data/data_augment.py new file mode 100644 index 00000000..b52a65fe --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/data/data_augment.py @@ -0,0 +1,69 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX +""" +Data augmentation functionality. Passed as callable transformations to +Dataset classes. + +The data augmentation procedures were interpreted from @weiliu89's SSD paper +http://arxiv.org/abs/1512.02325 +""" + +import math +import random + +import cv2 +import numpy as np + +from ..utils import xyxy2cxcywh + + +def preproc(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones( + (input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + + +class ValTransform: + """ + Defines the transformations that should be applied to test PIL image + for input into the network + + dimension -> tensorize -> color adj + + Arguments: + resize (int): input dimension to SSD + rgb_means ((int,int,int)): average RGB of the dataset + (104,117,123) + swap ((int,int,int)): final order of channels + + Returns: + transform (transform) : callable transform to be applied to test/val + data + """ + + def __init__(self, swap=(2, 0, 1), legacy=False): + self.swap = swap + self.legacy = legacy + + # assume input is cv2 img for now + def __call__(self, img, res, input_size): + img, _ = preproc(img, input_size, self.swap) + if self.legacy: + img = img[::-1, :, :].copy() + img /= 255.0 + img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1) + img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) + return img, np.zeros((1, 5)) diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/__init__.py new file mode 100644 index 00000000..e8e3be15 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/__init__.py @@ -0,0 +1,5 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from .base_exp import BaseExp +from .build import get_exp_by_name +from .yolox_base import Exp diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/base_exp.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/base_exp.py new file mode 100644 index 00000000..a4278cbf --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/base_exp.py @@ -0,0 +1,12 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from abc import ABCMeta, abstractmethod + +from torch.nn import Module + + +class BaseExp(metaclass=ABCMeta): + + @abstractmethod + def get_model(self) -> Module: + pass diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py new file mode 100644 index 00000000..5865c53b --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/build.py @@ -0,0 +1,20 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import os +import sys + + +def get_exp_by_name(exp_name): + exp = exp_name.replace('-', + '_') # convert string like "yolox-s" to "yolox_s" + if exp == 'yolox_s': + from .default import YoloXSExp as YoloXExp + elif exp == 'yolox_nano': + from .default import YoloXNanoExp as YoloXExp + elif exp == 'yolox_tiny': + from .default import YoloXTinyExp as YoloXExp + elif exp == 'streamyolo': + from .default import StreamYoloExp as YoloXExp + else: + pass + return YoloXExp() diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py new file mode 100644 index 00000000..cfec836c --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/__init__.py @@ -0,0 +1,5 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX +from .streamyolo import StreamYoloExp +from .yolox_nano import YoloXNanoExp +from .yolox_s import YoloXSExp +from .yolox_tiny import YoloXTinyExp diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/streamyolo.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/streamyolo.py new file mode 100644 index 00000000..5a62c8fc --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/streamyolo.py @@ -0,0 +1,43 @@ +# The implementation is based on StreamYOLO, available at https://github.com/yancie-yjr/StreamYOLO +import os +import sys + +import torch + +from ..yolox_base import Exp as YoloXExp + + +class StreamYoloExp(YoloXExp): + + def __init__(self): + super(YoloXExp, self).__init__() + self.depth = 1.0 + self.width = 1.0 + self.num_classes = 8 + self.test_size = (600, 960) + self.test_conf = 0.3 + self.nmsthre = 0.65 + + def get_model(self): + from ...models import StreamYOLO, DFPPAFPN, TALHead + + def init_yolo(M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + if getattr(self, 'model', None) is None: + in_channels = [256, 512, 1024] + backbone = DFPPAFPN( + self.depth, self.width, in_channels=in_channels) + head = TALHead( + self.num_classes, + self.width, + in_channels=in_channels, + gamma=1.0, + ignore_thr=0.5, + ignore_value=1.6) + self.model = StreamYOLO(backbone, head) + + return self.model diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_nano.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_nano.py new file mode 100644 index 00000000..7bada485 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_nano.py @@ -0,0 +1,47 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import os + +import torch.nn as nn + +from ..yolox_base import Exp as YoloXExp + + +class YoloXNanoExp(YoloXExp): + + def __init__(self): + super(YoloXNanoExp, self).__init__() + self.depth = 0.33 + self.width = 0.25 + self.input_size = (416, 416) + self.test_size = (416, 416) + + def get_model(self, sublinear=False): + + def init_yolo(M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + if 'model' not in self.__dict__: + from ...models import YOLOX, YOLOPAFPN, YOLOXHead + in_channels = [256, 512, 1024] + # NANO model use depthwise = True, which is main difference. + backbone = YOLOPAFPN( + self.depth, + self.width, + in_channels=in_channels, + act=self.act, + depthwise=True, + ) + head = YOLOXHead( + self.num_classes, + self.width, + in_channels=in_channels, + act=self.act, + depthwise=True) + self.model = YOLOX(backbone, head) + self.model.apply(init_yolo) + self.model.head.initialize_biases(1e-2) + return self.model diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_s.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_s.py new file mode 100644 index 00000000..5a123b37 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_s.py @@ -0,0 +1,13 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import os + +from ..yolox_base import Exp as YoloXExp + + +class YoloXSExp(YoloXExp): + + def __init__(self): + super(YoloXSExp, self).__init__() + self.depth = 0.33 + self.width = 0.50 diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_tiny.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_tiny.py new file mode 100644 index 00000000..a80d0f2d --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/default/yolox_tiny.py @@ -0,0 +1,20 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import os + +from ..yolox_base import Exp as YoloXExp + + +class YoloXTinyExp(YoloXExp): + + def __init__(self): + super(YoloXTinyExp, self).__init__() + self.depth = 0.33 + self.width = 0.375 + self.input_size = (416, 416) + self.mosaic_scale = (0.5, 1.5) + self.random_size = (10, 20) + self.test_size = (416, 416) + self.exp_name = os.path.split( + os.path.realpath(__file__))[1].split('.')[0] + self.enable_mixup = False diff --git a/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py b/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py new file mode 100644 index 00000000..c5159a9f --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/exp/yolox_base.py @@ -0,0 +1,58 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX +import os +import random + +import torch +import torch.distributed as dist +import torch.nn as nn + +from .base_exp import BaseExp + + +class Exp(BaseExp): + + def __init__(self): + super().__init__() + + # ---------------- model config ---------------- # + # detect classes number of model + self.num_classes = 80 + # factor of model depth + self.depth = 1.00 + # factor of model width + self.width = 1.00 + # activation name. For example, if using "relu", then "silu" will be replaced to "relu". + self.act = 'silu' + # ----------------- testing config ------------------ # + # output image size during evaluation/test + self.test_size = (640, 640) + # confidence threshold during evaluation/test, + # boxes whose scores are less than test_conf will be filtered + self.test_conf = 0.01 + # nms threshold + self.nmsthre = 0.65 + + def get_model(self): + from ..models import YOLOX, YOLOPAFPN, YOLOXHead + + def init_yolo(M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + if getattr(self, 'model', None) is None: + in_channels = [256, 512, 1024] + backbone = YOLOPAFPN( + self.depth, self.width, in_channels=in_channels, act=self.act) + head = YOLOXHead( + self.num_classes, + self.width, + in_channels=in_channels, + act=self.act) + self.model = YOLOX(backbone, head) + + self.model.apply(init_yolo) + self.model.head.initialize_biases(1e-2) + self.model.train() + return self.model diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py new file mode 100644 index 00000000..d2e889f1 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/__init__.py @@ -0,0 +1,10 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from .darknet import CSPDarknet, Darknet +from .dfp_pafpn import DFPPAFPN +from .streamyolo import StreamYOLO +from .tal_head import TALHead +from .yolo_fpn import YOLOFPN +from .yolo_head import YOLOXHead +from .yolo_pafpn import YOLOPAFPN +from .yolox import YOLOX diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/darknet.py b/modelscope/models/cv/realtime_object_detection/yolox/models/darknet.py new file mode 100644 index 00000000..8ece2a1e --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/darknet.py @@ -0,0 +1,189 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from torch import nn + +from .network_blocks import (BaseConv, CSPLayer, DWConv, Focus, ResLayer, + SPPBottleneck) + + +class Darknet(nn.Module): + # number of blocks from dark2 to dark5. + depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]} + + def __init__( + self, + depth, + in_channels=3, + stem_out_channels=32, + out_features=('dark3', 'dark4', 'dark5'), + ): + """ + Args: + depth (int): depth of darknet used in model, usually use [21, 53] for this param. + in_channels (int): number of input channels, for example, use 3 for RGB image. + stem_out_channels (int): number of output channels of darknet stem. + It decides channels of darknet layer2 to layer5. + out_features (Tuple[str]): desired output layer name. + """ + super().__init__() + assert out_features, 'please provide output features of Darknet' + self.out_features = out_features + self.stem = nn.Sequential( + BaseConv( + in_channels, stem_out_channels, ksize=3, stride=1, + act='lrelu'), + *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2), + ) + in_channels = stem_out_channels * 2 # 64 + + num_blocks = Darknet.depth2blocks[depth] + # create darknet with `stem_out_channels` and `num_blocks` layers. + # to make model structure more clear, we don't use `for` statement in python. + self.dark2 = nn.Sequential( + *self.make_group_layer(in_channels, num_blocks[0], stride=2)) + in_channels *= 2 # 128 + self.dark3 = nn.Sequential( + *self.make_group_layer(in_channels, num_blocks[1], stride=2)) + in_channels *= 2 # 256 + self.dark4 = nn.Sequential( + *self.make_group_layer(in_channels, num_blocks[2], stride=2)) + in_channels *= 2 # 512 + + self.dark5 = nn.Sequential( + *self.make_group_layer(in_channels, num_blocks[3], stride=2), + *self.make_spp_block([in_channels, in_channels * 2], + in_channels * 2), + ) + + def make_group_layer(self, + in_channels: int, + num_blocks: int, + stride: int = 1): + 'starts with conv layer then has `num_blocks` `ResLayer`' + return [ + BaseConv( + in_channels, + in_channels * 2, + ksize=3, + stride=stride, + act='lrelu'), + *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)], + ] + + def make_spp_block(self, filters_list, in_filters): + m = nn.Sequential(*[ + BaseConv(in_filters, filters_list[0], 1, stride=1, act='lrelu'), + BaseConv( + filters_list[0], filters_list[1], 3, stride=1, act='lrelu'), + SPPBottleneck( + in_channels=filters_list[1], + out_channels=filters_list[0], + activation='lrelu', + ), + BaseConv( + filters_list[0], filters_list[1], 3, stride=1, act='lrelu'), + BaseConv( + filters_list[1], filters_list[0], 1, stride=1, act='lrelu'), + ]) + return m + + def forward(self, x): + outputs = {} + x = self.stem(x) + outputs['stem'] = x + x = self.dark2(x) + outputs['dark2'] = x + x = self.dark3(x) + outputs['dark3'] = x + x = self.dark4(x) + outputs['dark4'] = x + x = self.dark5(x) + outputs['dark5'] = x + return {k: v for k, v in outputs.items() if k in self.out_features} + + +class CSPDarknet(nn.Module): + + def __init__( + self, + dep_mul, + wid_mul, + out_features=('dark3', 'dark4', 'dark5'), + depthwise=False, + act='silu', + ): + super().__init__() + assert out_features, 'please provide output features of Darknet' + self.out_features = out_features + Conv = DWConv if depthwise else BaseConv + + base_channels = int(wid_mul * 64) # 64 + base_depth = max(round(dep_mul * 3), 1) # 3 + + # stem + self.stem = Focus(3, base_channels, ksize=3, act=act) + + # dark2 + self.dark2 = nn.Sequential( + Conv(base_channels, base_channels * 2, 3, 2, act=act), + CSPLayer( + base_channels * 2, + base_channels * 2, + n=base_depth, + depthwise=depthwise, + act=act, + ), + ) + + # dark3 + self.dark3 = nn.Sequential( + Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), + CSPLayer( + base_channels * 4, + base_channels * 4, + n=base_depth * 3, + depthwise=depthwise, + act=act, + ), + ) + + # dark4 + self.dark4 = nn.Sequential( + Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), + CSPLayer( + base_channels * 8, + base_channels * 8, + n=base_depth * 3, + depthwise=depthwise, + act=act, + ), + ) + + # dark5 + self.dark5 = nn.Sequential( + Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), + SPPBottleneck( + base_channels * 16, base_channels * 16, activation=act), + CSPLayer( + base_channels * 16, + base_channels * 16, + n=base_depth, + shortcut=False, + depthwise=depthwise, + act=act, + ), + ) + + def forward(self, x): + outputs = {} + x = self.stem(x) + outputs['stem'] = x + x = self.dark2(x) + outputs['dark2'] = x + x = self.dark3(x) + outputs['dark3'] = x + x = self.dark4(x) + outputs['dark4'] = x + x = self.dark5(x) + outputs['dark5'] = x + return {k: v for k, v in outputs.items() if k in self.out_features} diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/dfp_pafpn.py b/modelscope/models/cv/realtime_object_detection/yolox/models/dfp_pafpn.py new file mode 100644 index 00000000..01284791 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/dfp_pafpn.py @@ -0,0 +1,307 @@ +# The implementation is based on StreamYOLO, available at https://github.com/yancie-yjr/StreamYOLO +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .darknet import CSPDarknet +from .network_blocks import BaseConv, CSPLayer, DWConv + + +class DFPPAFPN(nn.Module): + """ + YOLOv3 model. Darknet 53 is the default backbone of this model. + """ + + def __init__( + self, + depth=1.0, + width=1.0, + in_features=('dark3', 'dark4', 'dark5'), + in_channels=[256, 512, 1024], + depthwise=False, + act='silu', + ): + super().__init__() + self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) + self.in_features = in_features + self.in_channels = in_channels + Conv = DWConv if depthwise else BaseConv + + self.lateral_conv0 = BaseConv( + int(in_channels[2] * width), + int(in_channels[1] * width), + 1, + 1, + act=act) + self.C3_p4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) # cat + + self.reduce_conv1 = BaseConv( + int(in_channels[1] * width), + int(in_channels[0] * width), + 1, + 1, + act=act) + self.C3_p3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[0] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv2 = Conv( + int(in_channels[0] * width), + int(in_channels[0] * width), + 3, + 2, + act=act) + self.C3_n3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv1 = Conv( + int(in_channels[1] * width), + int(in_channels[1] * width), + 3, + 2, + act=act) + self.C3_n4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + self.jian2 = Conv( + in_channels=int(in_channels[0] * width), + out_channels=int(in_channels[0] * width) // 2, + ksize=1, + stride=1, + act=act, + ) + + self.jian1 = Conv( + in_channels=int(in_channels[1] * width), + out_channels=int(in_channels[1] * width) // 2, + ksize=1, + stride=1, + act=act, + ) + + self.jian0 = Conv( + in_channels=int(in_channels[2] * width), + out_channels=int(in_channels[2] * width) // 2, + ksize=1, + stride=1, + act=act, + ) + + def off_forward(self, input): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + rurrent_out_features = self.backbone(torch.split(input, 3, dim=1)[0]) + rurrent_features = [rurrent_out_features[f] for f in self.in_features] + [rurrent_x2, rurrent_x1, rurrent_x0] = rurrent_features + + rurrent_fpn_out0 = self.lateral_conv0(rurrent_x0) # 1024->512/32 + rurrent_f_out0 = F.interpolate( + rurrent_fpn_out0, size=rurrent_x1.shape[2:4], + mode='nearest') # 512/16 + rurrent_f_out0 = torch.cat([rurrent_f_out0, rurrent_x1], + 1) # 512->1024/16 + rurrent_f_out0 = self.C3_p4(rurrent_f_out0) # 1024->512/16 + + rurrent_fpn_out1 = self.reduce_conv1(rurrent_f_out0) # 512->256/16 + rurrent_f_out1 = F.interpolate( + rurrent_fpn_out1, size=rurrent_x2.shape[2:4], + mode='nearest') # 256/8 + rurrent_f_out1 = torch.cat([rurrent_f_out1, rurrent_x2], + 1) # 256->512/8 + rurrent_pan_out2 = self.C3_p3(rurrent_f_out1) # 512->256/8 + + rurrent_p_out1 = self.bu_conv2(rurrent_pan_out2) # 256->256/16 + rurrent_p_out1 = torch.cat([rurrent_p_out1, rurrent_fpn_out1], + 1) # 256->512/16 + rurrent_pan_out1 = self.C3_n3(rurrent_p_out1) # 512->512/16 + + rurrent_p_out0 = self.bu_conv1(rurrent_pan_out1) # 512->512/32 + rurrent_p_out0 = torch.cat([rurrent_p_out0, rurrent_fpn_out0], + 1) # 512->1024/32 + rurrent_pan_out0 = self.C3_n4(rurrent_p_out0) # 1024->1024/32 + + ##### + + support_out_features = self.backbone(torch.split(input, 3, dim=1)[1]) + support_features = [support_out_features[f] for f in self.in_features] + [support_x2, support_x1, support_x0] = support_features + + support_fpn_out0 = self.lateral_conv0(support_x0) # 1024->512/32 + support_f_out0 = F.interpolate( + support_fpn_out0, size=support_x1.shape[2:4], + mode='nearest') # 512/16 + support_f_out0 = torch.cat([support_f_out0, support_x1], + 1) # 512->1024/16 + support_f_out0 = self.C3_p4(support_f_out0) # 1024->512/16 + + support_fpn_out1 = self.reduce_conv1(support_f_out0) # 512->256/16 + support_f_out1 = F.interpolate( + support_fpn_out1, size=support_x2.shape[2:4], + mode='nearest') # 256/8 + support_f_out1 = torch.cat([support_f_out1, support_x2], + 1) # 256->512/8 + support_pan_out2 = self.C3_p3(support_f_out1) # 512->256/8 + + support_p_out1 = self.bu_conv2(support_pan_out2) # 256->256/16 + support_p_out1 = torch.cat([support_p_out1, support_fpn_out1], + 1) # 256->512/16 + support_pan_out1 = self.C3_n3(support_p_out1) # 512->512/16 + + support_p_out0 = self.bu_conv1(support_pan_out1) # 512->512/32 + support_p_out0 = torch.cat([support_p_out0, support_fpn_out0], + 1) # 512->1024/32 + support_pan_out0 = self.C3_n4(support_p_out0) # 1024->1024/32 + + # 0.5 channel + pan_out2 = torch.cat( + [self.jian2(rurrent_pan_out2), + self.jian2(support_pan_out2)], + dim=1) + rurrent_pan_out2 + pan_out1 = torch.cat( + [self.jian1(rurrent_pan_out1), + self.jian1(support_pan_out1)], + dim=1) + rurrent_pan_out1 + pan_out0 = torch.cat( + [self.jian0(rurrent_pan_out0), + self.jian0(support_pan_out0)], + dim=1) + rurrent_pan_out0 + + outputs = (pan_out2, pan_out1, pan_out0) + + return outputs + + def online_forward(self, input, buffer=None, node='star'): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + rurrent_out_features = self.backbone(input) + rurrent_features = [rurrent_out_features[f] for f in self.in_features] + [rurrent_x2, rurrent_x1, rurrent_x0] = rurrent_features + + rurrent_fpn_out0 = self.lateral_conv0(rurrent_x0) # 1024->512/32 + rurrent_f_out0 = F.interpolate( + rurrent_fpn_out0, size=rurrent_x1.shape[2:4], + mode='nearest') # 512/16 + rurrent_f_out0 = torch.cat([rurrent_f_out0, rurrent_x1], + 1) # 512->1024/16 + rurrent_f_out0 = self.C3_p4(rurrent_f_out0) # 1024->512/16 + + rurrent_fpn_out1 = self.reduce_conv1(rurrent_f_out0) # 512->256/16 + rurrent_f_out1 = F.interpolate( + rurrent_fpn_out1, size=rurrent_x2.shape[2:4], + mode='nearest') # 256/8 + rurrent_f_out1 = torch.cat([rurrent_f_out1, rurrent_x2], + 1) # 256->512/8 + rurrent_pan_out2 = self.C3_p3(rurrent_f_out1) # 512->256/8 + + rurrent_p_out1 = self.bu_conv2(rurrent_pan_out2) # 256->256/16 + rurrent_p_out1 = torch.cat([rurrent_p_out1, rurrent_fpn_out1], + 1) # 256->512/16 + rurrent_pan_out1 = self.C3_n3(rurrent_p_out1) # 512->512/16 + + rurrent_p_out0 = self.bu_conv1(rurrent_pan_out1) # 512->512/32 + rurrent_p_out0 = torch.cat([rurrent_p_out0, rurrent_fpn_out0], + 1) # 512->1024/32 + rurrent_pan_out0 = self.C3_n4(rurrent_p_out0) # 1024->1024/32 + + ##### + if node == 'star': + pan_out2 = torch.cat( + [self.jian2(rurrent_pan_out2), + self.jian2(rurrent_pan_out2)], + dim=1) + rurrent_pan_out2 + pan_out1 = torch.cat( + [self.jian1(rurrent_pan_out1), + self.jian1(rurrent_pan_out1)], + dim=1) + rurrent_pan_out1 + pan_out0 = torch.cat( + [self.jian0(rurrent_pan_out0), + self.jian0(rurrent_pan_out0)], + dim=1) + rurrent_pan_out0 + elif node == 'buffer': + + [support_pan_out2, support_pan_out1, support_pan_out0] = buffer + + pan_out2 = torch.cat( + [self.jian2(rurrent_pan_out2), + self.jian2(support_pan_out2)], + dim=1) + rurrent_pan_out2 + pan_out1 = torch.cat( + [self.jian1(rurrent_pan_out1), + self.jian1(support_pan_out1)], + dim=1) + rurrent_pan_out1 + pan_out0 = torch.cat( + [self.jian0(rurrent_pan_out0), + self.jian0(support_pan_out0)], + dim=1) + rurrent_pan_out0 + + outputs = (pan_out2, pan_out1, pan_out0) + + buffer_ = (rurrent_pan_out2, rurrent_pan_out1, rurrent_pan_out0) + + return outputs, buffer_ + + def forward(self, input, buffer=None, mode='off_pipe'): + + if mode == 'off_pipe': + # Glops caculate mode + if input.size()[1] == 3: + input = torch.cat([input, input], dim=1) + output = self.off_forward(input) + # offline train mode + elif input.size()[1] == 6: + output = self.off_forward(input) + + return output + + elif mode == 'on_pipe': + # online star state + if buffer is None: + output, buffer_ = self.online_forward(input, node='star') + # online inference + else: + assert len(buffer) == 3 + assert input.size()[1] == 3 + output, buffer_ = self.online_forward( + input, buffer=buffer, node='buffer') + + return output, buffer_ diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py b/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py new file mode 100644 index 00000000..88bd55c7 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/network_blocks.py @@ -0,0 +1,212 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX +import torch +import torch.nn as nn + + +def get_activation(name='silu', inplace=True): + if name == 'silu': + module = nn.SiLU(inplace=inplace) + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + return module + + +class BaseConv(nn.Module): + """A Conv2d -> Batchnorm -> silu/leaky relu block""" + + def __init__(self, + in_channels, + out_channels, + ksize, + stride, + groups=1, + bias=False, + act='silu'): + super(BaseConv, self).__init__() + # same padding + pad = (ksize - 1) // 2 + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=ksize, + stride=stride, + padding=pad, + groups=groups, + bias=bias, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.act = get_activation(act, inplace=True) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class DWConv(nn.Module): + """Depthwise Conv + Conv""" + + def __init__(self, in_channels, out_channels, ksize, stride=1, act='silu'): + super(DWConv, self).__init__() + self.dconv = BaseConv( + in_channels, + in_channels, + ksize=ksize, + stride=stride, + groups=in_channels, + act=act, + ) + self.pconv = BaseConv( + in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) + + def forward(self, x): + x = self.dconv(x) + return self.pconv(x) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__( + self, + in_channels, + out_channels, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + ): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + + +class ResLayer(nn.Module): + 'Residual layer with `in_channels` inputs.' + + def __init__(self, in_channels: int): + super().__init__() + mid_channels = in_channels // 2 + self.layer1 = BaseConv( + in_channels, mid_channels, ksize=1, stride=1, act='lrelu') + self.layer2 = BaseConv( + mid_channels, in_channels, ksize=3, stride=1, act='lrelu') + + def forward(self, x): + out = self.layer2(self.layer1(x)) + return x + out + + +class SPPBottleneck(nn.Module): + """Spatial pyramid pooling layer used in YOLOv3-SPP""" + + def __init__(self, + in_channels, + out_channels, + kernel_sizes=(5, 9, 13), + activation='silu'): + super().__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=activation) + self.m = nn.ModuleList([ + nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ]) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = BaseConv( + conv2_channels, out_channels, 1, stride=1, act=activation) + + def forward(self, x): + x = self.conv1(x) + x = torch.cat([x] + [m(x) for m in self.m], dim=1) + x = self.conv2(x) + return x + + +class CSPLayer(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=1, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv3 = BaseConv( + 2 * hidden_channels, out_channels, 1, stride=1, act=act) + module_list = [ + Bottleneck( + hidden_channels, + hidden_channels, + shortcut, + 1.0, + depthwise, + act=act) for _ in range(n) + ] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + x_2 = self.conv2(x) + x_1 = self.m(x_1) + x = torch.cat((x_1, x_2), dim=1) + return self.conv3(x) + + +class Focus(nn.Module): + """Focus width and height information into channel space.""" + + def __init__(self, + in_channels, + out_channels, + ksize=1, + stride=1, + act='silu'): + super().__init__() + self.conv = BaseConv( + in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) + patch_top_left = x[..., ::2, ::2] + patch_top_right = x[..., ::2, 1::2] + patch_bot_left = x[..., 1::2, ::2] + patch_bot_right = x[..., 1::2, 1::2] + x = torch.cat( + ( + patch_top_left, + patch_bot_left, + patch_top_right, + patch_bot_right, + ), + dim=1, + ) + return self.conv(x) diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/streamyolo.py b/modelscope/models/cv/realtime_object_detection/yolox/models/streamyolo.py new file mode 100644 index 00000000..b3ec3504 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/streamyolo.py @@ -0,0 +1,41 @@ +# The implementation is based on StreamYOLO, available at https://github.com/yancie-yjr/StreamYOLO +import torch.nn as nn + +from .dfp_pafpn import DFPPAFPN +from .tal_head import TALHead + + +class StreamYOLO(nn.Module): + """ + YOLOX model module. The module list is defined by create_yolov3_modules function. + The network returns loss values from three YOLO layers during training + and detection results during test. + """ + + def __init__(self, backbone=None, head=None): + super().__init__() + if backbone is None: + backbone = DFPPAFPN() + if head is None: + head = TALHead(20) + + self.backbone = backbone + self.head = head + + def forward(self, x, targets=None, buffer=None, mode='off_pipe'): + # fpn output content features of [dark3, dark4, dark5] + assert mode in ['off_pipe', 'on_pipe'] + + if mode == 'off_pipe': + fpn_outs = self.backbone(x, buffer=buffer, mode='off_pipe') + if self.training: + pass + else: + outputs = self.head(fpn_outs, imgs=x) + + return outputs + elif mode == 'on_pipe': + fpn_outs, buffer_ = self.backbone(x, buffer=buffer, mode='on_pipe') + outputs = self.head(fpn_outs) + + return outputs, buffer_ diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/tal_head.py b/modelscope/models/cv/realtime_object_detection/yolox/models/tal_head.py new file mode 100644 index 00000000..7a82f8c6 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/tal_head.py @@ -0,0 +1,170 @@ +# The implementation is based on StreamYOLO, available at https://github.com/yancie-yjr/StreamYOLO +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .network_blocks import BaseConv, DWConv + + +class TALHead(nn.Module): + + def __init__( + self, + num_classes, + width=1.0, + strides=[8, 16, 32], + in_channels=[256, 512, 1024], + act='silu', + depthwise=False, + gamma=1.5, + ignore_thr=0.2, + ignore_value=0.2, + ): + """ + Args: + act (str): activation type of conv. Defalut value: "silu". + depthwise (bool): wheather apply depthwise conv in conv branch. Defalut value: False. + """ + super().__init__() + + self.gamma = gamma + self.ignore_thr = ignore_thr + self.ignore_value = ignore_value + + self.n_anchors = 1 + self.num_classes = num_classes + self.decode_in_inference = True # for deploy, set to False + + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.cls_preds = nn.ModuleList() + self.reg_preds = nn.ModuleList() + self.obj_preds = nn.ModuleList() + self.stems = nn.ModuleList() + Conv = DWConv if depthwise else BaseConv + + for i in range(len(in_channels)): + self.stems.append( + BaseConv( + in_channels=int(in_channels[i] * width), + out_channels=int(256 * width), + ksize=1, + stride=1, + act=act, + )) + self.cls_convs.append( + nn.Sequential(*[ + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ])) + self.reg_convs.append( + nn.Sequential(*[ + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ])) + self.cls_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * self.num_classes, + kernel_size=1, + stride=1, + padding=0, + )) + self.reg_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=4, + kernel_size=1, + stride=1, + padding=0, + )) + self.obj_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * 1, + kernel_size=1, + stride=1, + padding=0, + )) + + self.strides = strides + self.grids = [torch.zeros(1)] * len(in_channels) + self.expanded_strides = [None] * len(in_channels) + + def forward(self, xin, labels=None, imgs=None): + outputs = [] + for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( + zip(self.cls_convs, self.reg_convs, self.strides, xin)): + x = self.stems[k](x) + cls_x = x + reg_x = x + + cls_feat = cls_conv(cls_x) + cls_output = self.cls_preds[k](cls_feat) + + reg_feat = reg_conv(reg_x) + reg_output = self.reg_preds[k](reg_feat) + obj_output = self.obj_preds[k](reg_feat) + + if self.training: + pass + + else: + output = torch.cat( + [reg_output, + obj_output.sigmoid(), + cls_output.sigmoid()], 1) + + outputs.append(output) + + if self.training: + pass + else: + self.hw = [x.shape[-2:] for x in outputs] + outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], + dim=2).permute(0, 2, 1) + if self.decode_in_inference: + return self.decode_outputs(outputs, dtype=xin[0].type()) + else: + return outputs + + def decode_outputs(self, outputs, dtype): + grids = [] + strides = [] + for (hsize, wsize), stride in zip(self.hw, self.strides): + yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) + grid = torch.stack((xv, yv), 2).view(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + strides.append(torch.full((*shape, 1), stride)) + + grids = torch.cat(grids, dim=1).type(dtype) + strides = torch.cat(strides, dim=1).type(dtype) + + outputs[..., :2] = (outputs[..., :2] + grids) * strides + outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides + return outputs diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_fpn.py b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_fpn.py new file mode 100644 index 00000000..0cbebb09 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_fpn.py @@ -0,0 +1,80 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import torch +import torch.nn as nn + +from .darknet import Darknet +from .network_blocks import BaseConv + + +class YOLOFPN(nn.Module): + """ + YOLOFPN module. Darknet 53 is the default backbone of this model. + """ + + def __init__( + self, + depth=53, + in_features=['dark3', 'dark4', 'dark5'], + ): + super(YOLOFPN, self).__init__() + + self.backbone = Darknet(depth) + self.in_features = in_features + + # out 1 + self.out1_cbl = self._make_cbl(512, 256, 1) + self.out1 = self._make_embedding([256, 512], 512 + 256) + + # out 2 + self.out2_cbl = self._make_cbl(256, 128, 1) + self.out2 = self._make_embedding([128, 256], 256 + 128) + + # upsample + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + def _make_cbl(self, _in, _out, ks): + return BaseConv(_in, _out, ks, stride=1, act='lrelu') + + def _make_embedding(self, filters_list, in_filters): + m = nn.Sequential(*[ + self._make_cbl(in_filters, filters_list[0], 1), + self._make_cbl(filters_list[0], filters_list[1], 3), + self._make_cbl(filters_list[1], filters_list[0], 1), + self._make_cbl(filters_list[0], filters_list[1], 3), + self._make_cbl(filters_list[1], filters_list[0], 1), + ]) + return m + + def load_pretrained_model(self, filename='./weights/darknet53.mix.pth'): + with open(filename, 'rb') as f: + state_dict = torch.load(f, map_location='cpu') + print('loading pretrained weights...') + self.backbone.load_state_dict(state_dict) + + def forward(self, inputs): + """ + Args: + inputs (Tensor): input image. + + Returns: + Tuple[Tensor]: FPN output features.. + """ + # backbone + out_features = self.backbone(inputs) + x2, x1, x0 = [out_features[f] for f in self.in_features] + + # yolo branch 1 + x1_in = self.out1_cbl(x0) + x1_in = self.upsample(x1_in) + x1_in = torch.cat([x1_in, x1], 1) + out_dark4 = self.out1(x1_in) + + # yolo branch 2 + x2_in = self.out2_cbl(out_dark4) + x2_in = self.upsample(x2_in) + x2_in = torch.cat([x2_in, x2], 1) + out_dark3 = self.out2(x2_in) + + outputs = (out_dark3, out_dark4, x0) + return outputs diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_head.py b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_head.py new file mode 100644 index 00000000..1eef93a4 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_head.py @@ -0,0 +1,182 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import bboxes_iou, meshgrid +from .network_blocks import BaseConv, DWConv + + +class YOLOXHead(nn.Module): + + def __init__( + self, + num_classes, + width=1.0, + strides=[8, 16, 32], + in_channels=[256, 512, 1024], + act='silu', + depthwise=False, + ): + """ + Args: + act (str): activation type of conv. Defalut value: "silu". + depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False. + """ + super(YOLOXHead, self).__init__() + + self.n_anchors = 1 + self.num_classes = num_classes + self.decode_in_inference = True # for deploy, set to False + + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.cls_preds = nn.ModuleList() + self.reg_preds = nn.ModuleList() + self.obj_preds = nn.ModuleList() + self.stems = nn.ModuleList() + Conv = DWConv if depthwise else BaseConv + + for i in range(len(in_channels)): + self.stems.append( + BaseConv( + in_channels=int(in_channels[i] * width), + out_channels=int(256 * width), + ksize=1, + stride=1, + act=act, + )) + self.cls_convs.append( + nn.Sequential(*[ + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ])) + self.reg_convs.append( + nn.Sequential(*[ + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + Conv( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ])) + self.cls_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * self.num_classes, + kernel_size=1, + stride=1, + padding=0, + )) + self.reg_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=4, + kernel_size=1, + stride=1, + padding=0, + )) + self.obj_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * 1, + kernel_size=1, + stride=1, + padding=0, + )) + + self.use_l1 = False + self.l1_loss = nn.L1Loss(reduction='none') + self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction='none') + # self.iou_loss = IOUloss(reduction="none") + self.strides = strides + self.grids = [torch.zeros(1)] * len(in_channels) + + def initialize_biases(self, prior_prob): + for conv in self.cls_preds: + b = conv.bias.view(self.n_anchors, -1) + b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) + conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + for conv in self.obj_preds: + b = conv.bias.view(self.n_anchors, -1) + b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) + conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def forward(self, xin, labels=None, imgs=None): + outputs = [] + + for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( + zip(self.cls_convs, self.reg_convs, self.strides, xin)): + x = self.stems[k](x) + cls_x = x + reg_x = x + + cls_feat = cls_conv(cls_x) + cls_output = self.cls_preds[k](cls_feat) + + reg_feat = reg_conv(reg_x) + reg_output = self.reg_preds[k](reg_feat) + obj_output = self.obj_preds[k](reg_feat) + + if self.training: + pass + else: + output = torch.cat( + [reg_output, + obj_output.sigmoid(), + cls_output.sigmoid()], 1) + + outputs.append(output) + + if self.training: + pass + else: + self.hw = [x.shape[-2:] for x in outputs] + # [batch, n_anchors_all, 85] + outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], + dim=2).permute(0, 2, 1) + if self.decode_in_inference: + return self.decode_outputs(outputs, dtype=xin[0].type()) + else: + return outputs + + def decode_outputs(self, outputs, dtype): + grids = [] + strides = [] + for (hsize, wsize), stride in zip(self.hw, self.strides): + yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) + grid = torch.stack((xv, yv), 2).view(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + strides.append(torch.full((*shape, 1), stride)) + + grids = torch.cat(grids, dim=1).type(dtype) + strides = torch.cat(strides, dim=1).type(dtype) + + outputs[..., :2] = (outputs[..., :2] + grids) * strides + outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides + return outputs diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_pafpn.py b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_pafpn.py new file mode 100644 index 00000000..cd4258bf --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/yolo_pafpn.py @@ -0,0 +1,126 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import torch +import torch.nn as nn + +from .darknet import CSPDarknet +from .network_blocks import BaseConv, CSPLayer, DWConv + + +class YOLOPAFPN(nn.Module): + """ + YOLOv3 model. Darknet 53 is the default backbone of this model. + """ + + def __init__( + self, + depth=1.0, + width=1.0, + in_features=('dark3', 'dark4', 'dark5'), + in_channels=[256, 512, 1024], + depthwise=False, + act='silu', + ): + super(YOLOPAFPN, self).__init__() + self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) + self.in_features = in_features + self.in_channels = in_channels + Conv = DWConv if depthwise else BaseConv + + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.lateral_conv0 = BaseConv( + int(in_channels[2] * width), + int(in_channels[1] * width), + 1, + 1, + act=act) + self.C3_p4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) # cat + + self.reduce_conv1 = BaseConv( + int(in_channels[1] * width), + int(in_channels[0] * width), + 1, + 1, + act=act) + self.C3_p3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[0] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv2 = Conv( + int(in_channels[0] * width), + int(in_channels[0] * width), + 3, + 2, + act=act) + self.C3_n3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv1 = Conv( + int(in_channels[1] * width), + int(in_channels[1] * width), + 3, + 2, + act=act) + self.C3_n4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + def forward(self, input): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + out_features = self.backbone(input) + features = [out_features[f] for f in self.in_features] + [x2, x1, x0] = features + + fpn_out0 = self.lateral_conv0(x0) # 1024->512/32 + f_out0 = self.upsample(fpn_out0) # 512/16 + f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16 + f_out0 = self.C3_p4(f_out0) # 1024->512/16 + + fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16 + f_out1 = self.upsample(fpn_out1) # 256/8 + f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8 + pan_out2 = self.C3_p3(f_out1) # 512->256/8 + + p_out1 = self.bu_conv2(pan_out2) # 256->256/16 + p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16 + pan_out1 = self.C3_n3(p_out1) # 512->512/16 + + p_out0 = self.bu_conv1(pan_out1) # 512->512/32 + p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32 + pan_out0 = self.C3_n4(p_out0) # 1024->1024/32 + + outputs = (pan_out2, pan_out1, pan_out0) + return outputs diff --git a/modelscope/models/cv/realtime_object_detection/yolox/models/yolox.py b/modelscope/models/cv/realtime_object_detection/yolox/models/yolox.py new file mode 100644 index 00000000..181c368b --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/models/yolox.py @@ -0,0 +1,33 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import torch.nn as nn + +from .yolo_head import YOLOXHead +from .yolo_pafpn import YOLOPAFPN + + +class YOLOX(nn.Module): + """ + YOLOX model module. The module list is defined by create_yolov3_modules function. + The network returns loss values from three YOLO layers during training + and detection results during test. + """ + + def __init__(self, backbone=None, head=None): + super(YOLOX, self).__init__() + if backbone is None: + backbone = YOLOPAFPN() + if head is None: + head = YOLOXHead(80) + + self.backbone = backbone + self.head = head + + def forward(self, x, targets=None): + fpn_outs = self.backbone(x) + if self.training: + raise NotImplementedError('Training is not supported yet!') + else: + outputs = self.head(fpn_outs) + + return outputs diff --git a/modelscope/models/cv/realtime_object_detection/yolox/utils/__init__.py b/modelscope/models/cv/realtime_object_detection/yolox/utils/__init__.py new file mode 100644 index 00000000..2c1ea489 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/utils/__init__.py @@ -0,0 +1,5 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +from .boxes import * # noqa + +__all__ = ['bboxes_iou', 'meshgrid', 'postprocess', 'xyxy2cxcywh', 'xyxy2xywh'] diff --git a/modelscope/models/cv/realtime_object_detection/yolox/utils/boxes.py b/modelscope/models/cv/realtime_object_detection/yolox/utils/boxes.py new file mode 100644 index 00000000..b29a3a04 --- /dev/null +++ b/modelscope/models/cv/realtime_object_detection/yolox/utils/boxes.py @@ -0,0 +1,107 @@ +# The implementation is based on YOLOX, available at https://github.com/Megvii-BaseDetection/YOLOX + +import torch +import torchvision + +_TORCH_VER = [int(x) for x in torch.__version__.split('.')[:2]] + + +def meshgrid(*tensors): + if _TORCH_VER >= [1, 10]: + return torch.meshgrid(*tensors, indexing='ij') + else: + return torch.meshgrid(*tensors) + + +def xyxy2xywh(bboxes): + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + return bboxes + + +def xyxy2cxcywh(bboxes): + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 + bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 + return bboxes + + +def postprocess(prediction, + num_classes, + conf_thre=0.7, + nms_thre=0.45, + class_agnostic=False): + box_corner = prediction.new(prediction.shape) + box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = box_corner[:, :, :4] + + output = [None for _ in range(len(prediction))] + for i, image_pred in enumerate(prediction): + + # If none are remaining => process next image + if not image_pred.size(0): + continue + # Get score and class with highest confidence + class_conf, class_pred = torch.max( + image_pred[:, 5:5 + num_classes], 1, keepdim=True) + + conf_mask = image_pred[:, 4] * class_conf.squeeze() + conf_mask = (conf_mask >= conf_thre).squeeze() + # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) + detections = torch.cat( + (image_pred[:, :5], class_conf, class_pred.float()), 1) + detections = detections[conf_mask] + if not detections.size(0): + continue + + if class_agnostic: + nms_out_index = torchvision.ops.nms( + detections[:, :4], + detections[:, 4] * detections[:, 5], + nms_thre, + ) + else: + nms_out_index = torchvision.ops.batched_nms( + detections[:, :4], + detections[:, 4] * detections[:, 5], + detections[:, 6], + nms_thre, + ) + + detections = detections[nms_out_index] + if output[i] is None: + output[i] = detections + else: + output[i] = torch.cat((output[i], detections)) + + return output + + +def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): + if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: + raise IndexError + + if xyxy: + tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) + br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) + area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) + area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) + else: + tl = torch.max( + (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), + ) + br = torch.min( + (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), + ) + + area_a = torch.prod(bboxes_a[:, 2:], 1) + area_b = torch.prod(bboxes_b[:, 2:], 1) + en = (tl < br).type(tl.type()).prod(dim=2) + area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) + return area_i / (area_a[:, None] + area_b - area_i) diff --git a/modelscope/models/cv/referring_video_object_segmentation/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/__init__.py new file mode 100644 index 00000000..4c97bd7b --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .model import ReferringVideoObjectSegmentation + +else: + _import_structure = { + 'model': ['ReferringVideoObjectSegmentation'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/referring_video_object_segmentation/model.py b/modelscope/models/cv/referring_video_object_segmentation/model.py new file mode 100644 index 00000000..91f7ea91 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/model.py @@ -0,0 +1,142 @@ +# Part of the implementation is borrowed and modified from MTTR, +# publicly available at https://github.com/mttr2021/MTTR + +import os.path as osp +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .utils import (MTTR, A2DSentencesPostProcess, HungarianMatcher, + ReferYoutubeVOSPostProcess, SetCriterion, + flatten_temporal_batch_dims, + nested_tensor_from_videos_list) + +logger = get_logger() + + +@MODELS.register_module( + Tasks.referring_video_object_segmentation, + module_name=Models.referring_video_object_segmentation) +class ReferringVideoObjectSegmentation(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + + config_path = osp.join(model_dir, ModelFile.CONFIGURATION) + self.cfg = Config.from_file(config_path) + self.model = MTTR(**self.cfg.model) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + params_dict = torch.load(model_path, map_location='cpu') + if 'model_state_dict' in params_dict.keys(): + params_dict = params_dict['model_state_dict'] + self.model.load_state_dict(params_dict, strict=True) + + self.set_postprocessor(self.cfg.pipeline.dataset_name) + self.set_criterion() + + def set_device(self, device, name): + self.device = device + self._device_name = name + + def set_postprocessor(self, dataset_name): + if 'a2d_sentences' in dataset_name or 'jhmdb_sentences' in dataset_name: + self.postprocessor = A2DSentencesPostProcess() # fine-tune + elif 'ref_youtube_vos' in dataset_name: + self.postprocessor = ReferYoutubeVOSPostProcess() # inference + else: + assert False, f'postprocessing for dataset: {dataset_name} is not supported' + + def forward(self, inputs: Dict[str, Any]): + samples = inputs['samples'] + targets = inputs['targets'] + text_queries = inputs['text_queries'] + + valid_indices = torch.tensor( + [i for i, t in enumerate(targets) if None not in t]) + targets = [targets[i] for i in valid_indices.tolist()] + if self._device_name == 'gpu': + samples = samples.to(self.device) + valid_indices = valid_indices.to(self.device) + if isinstance(text_queries, tuple): + text_queries = list(text_queries) + + outputs = self.model(samples, valid_indices, text_queries) + losses = -1 + if self.training: + loss_dict = self.criterion(outputs, targets) + weight_dict = self.criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] + for k in loss_dict.keys() if k in weight_dict) + + predictions = [] + if not self.training: + outputs.pop('aux_outputs', None) + outputs, targets = flatten_temporal_batch_dims(outputs, targets) + processed_outputs = self.postprocessor( + outputs, + resized_padded_sample_size=samples.tensors.shape[-2:], + resized_sample_sizes=[t['size'] for t in targets], + orig_sample_sizes=[t['orig_size'] for t in targets]) + image_ids = [t['image_id'] for t in targets] + predictions = [] + for p, image_id in zip(processed_outputs, image_ids): + for s, m in zip(p['scores'], p['rle_masks']): + predictions.append({ + 'image_id': image_id, + 'category_id': + 1, # dummy label, as categories are not predicted in ref-vos + 'segmentation': m, + 'score': s.item() + }) + + re = dict(pred=predictions, loss=losses) + return re + + def inference(self, **kwargs): + window = kwargs['window'] + text_query = kwargs['text_query'] + video_metadata = kwargs['metadata'] + + window = nested_tensor_from_videos_list([window]) + valid_indices = torch.arange(len(window.tensors)) + if self._device_name == 'gpu': + valid_indices = valid_indices.cuda() + outputs = self.model(window, valid_indices, [text_query]) + window_masks = self.postprocessor( + outputs, [video_metadata], + window.tensors.shape[-2:])[0]['pred_masks'] + return window_masks + + def postprocess(self, inputs: Dict[str, Any], **kwargs): + return inputs + + def set_criterion(self): + matcher = HungarianMatcher( + cost_is_referred=self.cfg.matcher.set_cost_is_referred, + cost_dice=self.cfg.matcher.set_cost_dice) + weight_dict = { + 'loss_is_referred': self.cfg.loss.is_referred_loss_coef, + 'loss_dice': self.cfg.loss.dice_loss_coef, + 'loss_sigmoid_focal': self.cfg.loss.sigmoid_focal_loss_coef + } + + if self.cfg.loss.aux_loss: + aux_weight_dict = {} + for i in range(self.cfg.model.num_decoder_layers - 1): + aux_weight_dict.update( + {k + f'_{i}': v + for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + self.criterion = SetCriterion( + matcher=matcher, + weight_dict=weight_dict, + eos_coef=self.cfg.loss.eos_coef) diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py new file mode 100644 index 00000000..fbb75b00 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .criterion import SetCriterion, flatten_temporal_batch_dims +from .matcher import HungarianMatcher +from .misc import interpolate, nested_tensor_from_videos_list +from .mttr import MTTR +from .postprocessing import A2DSentencesPostProcess, ReferYoutubeVOSPostProcess diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py b/modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py new file mode 100644 index 00000000..afa384c1 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/backbone.py @@ -0,0 +1,198 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR + +import torch +import torch.nn.functional as F +import torchvision +from einops import rearrange +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter + +from .misc import NestedTensor, is_main_process +from .swin_transformer import SwinTransformer3D + + +class VideoSwinTransformerBackbone(nn.Module): + """ + A wrapper which allows using Video-Swin Transformer as a temporal encoder for MTTR. + Check out video-swin's original paper at: https://arxiv.org/abs/2106.13230 for more info about this architecture. + Only the 'tiny' version of video swin was tested and is currently supported in our project. + Additionally, we slightly modify video-swin to make it output per-frame embeddings as required by MTTR (check our + paper's supplementary for more details), and completely discard of its 4th block. + """ + + def __init__(self, backbone_pretrained, backbone_pretrained_path, + train_backbone, running_mode, **kwargs): + super(VideoSwinTransformerBackbone, self).__init__() + # patch_size is (1, 4, 4) instead of the original (2, 4, 4). + # this prevents swinT's original temporal downsampling so we can get per-frame features. + swin_backbone = SwinTransformer3D( + patch_size=(1, 4, 4), + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_size=(8, 7, 7), + drop_path_rate=0.1, + patch_norm=True) + if backbone_pretrained and running_mode == 'train': + state_dict = torch.load(backbone_pretrained_path)['state_dict'] + # extract swinT's kinetics-400 pretrained weights and ignore the rest (prediction head etc.) + state_dict = { + k[9:]: v + for k, v in state_dict.items() if 'backbone.' in k + } + + # sum over the patch embedding weight temporal dim [96, 3, 2, 4, 4] --> [96, 3, 1, 4, 4] + patch_embed_weight = state_dict['patch_embed.proj.weight'] + patch_embed_weight = patch_embed_weight.sum(dim=2, keepdims=True) + state_dict['patch_embed.proj.weight'] = patch_embed_weight + swin_backbone.load_state_dict(state_dict) + + self.patch_embed = swin_backbone.patch_embed + self.pos_drop = swin_backbone.pos_drop + self.layers = swin_backbone.layers[:-1] + self.downsamples = nn.ModuleList() + for layer in self.layers: + self.downsamples.append(layer.downsample) + layer.downsample = None + self.downsamples[ + -1] = None # downsampling after the last layer is not necessary + + self.layer_output_channels = [ + swin_backbone.embed_dim * 2**i for i in range(len(self.layers)) + ] + self.train_backbone = train_backbone + if not train_backbone: + for parameter in self.parameters(): + parameter.requires_grad_(False) + + def forward(self, samples: NestedTensor): + vid_frames = rearrange(samples.tensors, 't b c h w -> b c t h w') + + vid_embeds = self.patch_embed(vid_frames) + vid_embeds = self.pos_drop(vid_embeds) + layer_outputs = [] # layer outputs before downsampling + for layer, downsample in zip(self.layers, self.downsamples): + vid_embeds = layer(vid_embeds.contiguous()) + layer_outputs.append(vid_embeds) + if downsample: + vid_embeds = rearrange(vid_embeds, 'b c t h w -> b t h w c') + vid_embeds = downsample(vid_embeds) + vid_embeds = rearrange(vid_embeds, 'b t h w c -> b c t h w') + layer_outputs = [ + rearrange(o, 'b c t h w -> t b c h w') for o in layer_outputs + ] + + outputs = [] + orig_pad_mask = samples.mask + for l_out in layer_outputs: + pad_mask = F.interpolate( + orig_pad_mask.float(), size=l_out.shape[-2:]).to(torch.bool) + outputs.append(NestedTensor(l_out, pad_mask)) + return outputs + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + Modified from DETR https://github.com/facebookresearch/detr + BatchNorm2d where the batch statistics and the affine parameters are fixed. + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer('weight', torch.ones(n)) + self.register_buffer('bias', torch.zeros(n)) + self.register_buffer('running_mean', torch.zeros(n)) + self.register_buffer('running_var', torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class ResNetBackbone(nn.Module): + """ + Modified from DETR https://github.com/facebookresearch/detr + ResNet backbone with frozen BatchNorm. + """ + + def __init__(self, + backbone_name: str = 'resnet50', + train_backbone: bool = True, + dilation: bool = True, + **kwargs): + super(ResNetBackbone, self).__init__() + backbone = getattr(torchvision.models, backbone_name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), + norm_layer=FrozenBatchNorm2d) + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + return_layers = { + 'layer1': '0', + 'layer2': '1', + 'layer3': '2', + 'layer4': '3' + } + self.body = IntermediateLayerGetter( + backbone, return_layers=return_layers) + output_channels = 512 if backbone_name in ('resnet18', + 'resnet34') else 2048 + self.layer_output_channels = [ + output_channels // 8, output_channels // 4, output_channels // 2, + output_channels + ] + + def forward(self, tensor_list: NestedTensor): + t, b, _, _, _ = tensor_list.tensors.shape + video_frames = rearrange(tensor_list.tensors, + 't b c h w -> (t b) c h w') + padding_masks = rearrange(tensor_list.mask, 't b h w -> (t b) h w') + features_list = self.body(video_frames) + out = [] + for _, f in features_list.items(): + resized_padding_masks = F.interpolate( + padding_masks[None].float(), + size=f.shape[-2:]).to(torch.bool)[0] + f = rearrange(f, '(t b) c h w -> t b c h w', t=t, b=b) + resized_padding_masks = rearrange( + resized_padding_masks, '(t b) h w -> t b h w', t=t, b=b) + out.append(NestedTensor(f, resized_padding_masks)) + return out + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +def init_backbone(backbone_name, **kwargs): + if backbone_name == 'swin-t': + return VideoSwinTransformerBackbone(**kwargs) + elif 'resnet' in backbone_name: + return ResNetBackbone(backbone_name, **kwargs) + assert False, f'error: backbone "{backbone_name}" is not supported' diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py b/modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py new file mode 100644 index 00000000..a4d2f0ff --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/criterion.py @@ -0,0 +1,198 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr +import torch +from torch import nn + +from .misc import (get_world_size, interpolate, is_dist_avail_and_initialized, + nested_tensor_from_tensor_list) +from .segmentation import dice_loss, sigmoid_focal_loss + + +class SetCriterion(nn.Module): + """ This class computes the loss for MTTR. + The process happens in two steps: + 1) we compute the hungarian assignment between the ground-truth and predicted sequences. + 2) we supervise each pair of matched ground-truth / prediction sequences (mask + reference prediction) + """ + + def __init__(self, matcher, weight_dict, eos_coef): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the un-referred category + """ + super().__init__() + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + # make sure that only loss functions with non-zero weights are computed: + losses_to_compute = [] + if weight_dict['loss_dice'] > 0 or weight_dict[ + 'loss_sigmoid_focal'] > 0: + losses_to_compute.append('masks') + if weight_dict['loss_is_referred'] > 0: + losses_to_compute.append('is_referred') + self.losses = losses_to_compute + + def forward(self, outputs, targets): + aux_outputs_list = outputs.pop('aux_outputs', None) + # compute the losses for the output of the last decoder layer: + losses = self.compute_criterion( + outputs, targets, losses_to_compute=self.losses) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate decoder layer. + if aux_outputs_list is not None: + aux_losses_to_compute = self.losses.copy() + for i, aux_outputs in enumerate(aux_outputs_list): + losses_dict = self.compute_criterion(aux_outputs, targets, + aux_losses_to_compute) + losses_dict = {k + f'_{i}': v for k, v in losses_dict.items()} + losses.update(losses_dict) + + return losses + + def compute_criterion(self, outputs, targets, losses_to_compute): + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs, targets) + + # T & B dims are flattened so loss functions can be computed per frame (but with same indices per video). + # also, indices are repeated so so the same indices can be used for frames of the same video. + T = len(targets) + outputs, targets = flatten_temporal_batch_dims(outputs, targets) + # repeat the indices list T times so the same indices can be used for each video frame + indices = T * indices + + # Compute the average number of target masks across all nodes, for normalization purposes + num_masks = sum(len(t['masks']) for t in targets) + num_masks = torch.as_tensor([num_masks], + dtype=torch.float, + device=indices[0][0].device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_masks) + num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in losses_to_compute: + losses.update( + self.get_loss( + loss, outputs, targets, indices, num_masks=num_masks)) + return losses + + def loss_is_referred(self, outputs, targets, indices, **kwargs): + device = outputs['pred_is_referred'].device + bs = outputs['pred_is_referred'].shape[0] + pred_is_referred = outputs['pred_is_referred'].log_softmax( + dim=-1) # note that log-softmax is used here + target_is_referred = torch.zeros_like(pred_is_referred) + # extract indices of object queries that where matched with text-referred target objects + query_referred_indices = self._get_query_referred_indices( + indices, targets) + # by default penalize compared to the no-object class (last token) + target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) + if 'is_ref_inst_visible' in targets[ + 0]: # visibility labels are available per-frame for the referred object: + is_ref_inst_visible_per_frame = torch.stack( + [t['is_ref_inst_visible'] for t in targets]) + ref_inst_visible_frame_indices = is_ref_inst_visible_per_frame.nonzero( + ).squeeze() + # keep only the matched query indices of the frames in which the referred object is visible: + visible_query_referred_indices = query_referred_indices[ + ref_inst_visible_frame_indices] + target_is_referred[ref_inst_visible_frame_indices, + visible_query_referred_indices] = torch.tensor( + [1.0, 0.0], device=device) + else: # assume that the referred object is visible in every frame: + target_is_referred[torch.arange(bs), + query_referred_indices] = torch.tensor( + [1.0, 0.0], device=device) + loss = -(pred_is_referred * target_is_referred).sum(-1) + # apply no-object class weights: + eos_coef = torch.full(loss.shape, self.eos_coef, device=loss.device) + eos_coef[torch.arange(bs), query_referred_indices] = 1.0 + loss = loss * eos_coef + bs = len(indices) + loss = loss.sum() / bs # sum and normalize the loss by the batch size + losses = {'loss_is_referred': loss} + return losses + + def loss_masks(self, outputs, targets, indices, num_masks, **kwargs): + assert 'pred_masks' in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs['pred_masks'] + src_masks = src_masks[src_idx] + masks = [t['masks'] for t in targets] + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = interpolate( + src_masks[:, None], + size=target_masks.shape[-2:], + mode='bilinear', + align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + 'loss_sigmoid_focal': + sigmoid_focal_loss(src_masks, target_masks, num_masks), + 'loss_dice': + dice_loss(src_masks, target_masks, num_masks), + } + return losses + + @staticmethod + def _get_src_permutation_idx(indices): + # permute predictions following indices + batch_idx = torch.cat( + [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + @staticmethod + def _get_tgt_permutation_idx(indices): + # permute targets following indices + batch_idx = torch.cat( + [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + @staticmethod + def _get_query_referred_indices(indices, targets): + """ + extract indices of object queries that where matched with text-referred target objects + """ + query_referred_indices = [] + for (query_idxs, target_idxs), target in zip(indices, targets): + ref_query_idx = query_idxs[torch.where( + target_idxs == target['referred_instance_idx'])[0]] + query_referred_indices.append(ref_query_idx) + query_referred_indices = torch.cat(query_referred_indices) + return query_referred_indices + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + 'masks': self.loss_masks, + 'is_referred': self.loss_is_referred, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + +def flatten_temporal_batch_dims(outputs, targets): + for k in outputs.keys(): + if isinstance(outputs[k], torch.Tensor): + outputs[k] = outputs[k].flatten(0, 1) + else: # list + outputs[k] = [i for step_t in outputs[k] for i in step_t] + targets = [ + frame_t_target for step_t in targets for frame_t_target in step_t + ] + return outputs, targets diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py b/modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py new file mode 100644 index 00000000..4f9b88e5 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/matcher.py @@ -0,0 +1,163 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr +# Module to compute the matching cost and solve the corresponding LSAP. + +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from .misc import interpolate, nested_tensor_from_tensor_list + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_is_referred: float = 1, cost_dice: float = 1): + """Creates the matcher + + Params: + cost_is_referred: This is the relative weight of the reference cost in the total matching cost + cost_dice: This is the relative weight of the dice cost in the total matching cost + """ + super().__init__() + self.cost_is_referred = cost_is_referred + self.cost_dice = cost_dice + assert cost_is_referred != 0 or cost_dice != 0, 'all costs cant be 0' + + @torch.inference_mode() + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: A dict that contains at least these entries: + "pred_is_referred": Tensor of dim [time, batch_size, num_queries, 2] with the reference logits + "pred_masks": Tensor of dim [time, batch_size, num_queries, H, W] with the predicted masks logits + + targets: A list of lists of targets (outer - time steps, inner - batch samples). each target is a dict + which contain mask and reference ground truth information for a single frame. + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_masks) + """ + t, bs, num_queries = outputs['pred_masks'].shape[:3] + + # We flatten to compute the cost matrices in a batch + out_masks = outputs['pred_masks'].flatten( + 1, 2) # [t, batch_size * num_queries, mask_h, mask_w] + + # preprocess and concat the target masks + tgt_masks = [[ + m for v in t_step_batch for m in v['masks'].unsqueeze(1) + ] for t_step_batch in targets] + # pad the target masks to a uniform shape + tgt_masks, valid = list( + zip(*[ + nested_tensor_from_tensor_list(t).decompose() + for t in tgt_masks + ])) + tgt_masks = torch.stack(tgt_masks).squeeze(2) + + # upsample predicted masks to target mask size + out_masks = interpolate( + out_masks, + size=tgt_masks.shape[-2:], + mode='bilinear', + align_corners=False) + + # Compute the soft-tokens cost: + if self.cost_is_referred > 0: + cost_is_referred = compute_is_referred_cost(outputs, targets) + else: + cost_is_referred = 0 + + # Compute the DICE coefficient between the masks: + if self.cost_dice > 0: + cost_dice = -dice_coef(out_masks, tgt_masks) + else: + cost_dice = 0 + + # Final cost matrix + C = self.cost_is_referred * cost_is_referred + self.cost_dice * cost_dice + C = C.view(bs, num_queries, -1).cpu() + + num_traj_per_batch = [ + len(v['masks']) for v in targets[0] + ] # number of instance trajectories in each batch + indices = [ + linear_sum_assignment(c[i]) + for i, c in enumerate(C.split(num_traj_per_batch, -1)) + ] + device = out_masks.device + return [(torch.as_tensor(i, dtype=torch.int64, device=device), + torch.as_tensor(j, dtype=torch.int64, device=device)) + for i, j in indices] + + +def dice_coef(inputs, targets, smooth=1.0): + """ + Compute the DICE coefficient, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid().flatten(2).unsqueeze(2) + targets = targets.flatten(2).unsqueeze(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + coef = (numerator + smooth) / (denominator + smooth) + coef = coef.mean( + 0) # average on the temporal dim to get instance trajectory scores + return coef + + +def compute_is_referred_cost(outputs, targets): + pred_is_referred = outputs['pred_is_referred'].flatten(1, 2).softmax( + dim=-1) # [t, b*nq, 2] + device = pred_is_referred.device + t = pred_is_referred.shape[0] + # number of instance trajectories in each batch + num_traj_per_batch = torch.tensor([len(v['masks']) for v in targets[0]], + device=device) + total_trajectories = num_traj_per_batch.sum() + # note that ref_indices are shared across time steps: + ref_indices = torch.tensor( + [v['referred_instance_idx'] for v in targets[0]], device=device) + # convert ref_indices to fit flattened batch targets: + ref_indices += torch.cat( + (torch.zeros(1, dtype=torch.long, + device=device), num_traj_per_batch.cumsum(0)[:-1])) + # number of instance trajectories in each batch + target_is_referred = torch.zeros((t, total_trajectories, 2), device=device) + # 'no object' class by default (for un-referred objects) + target_is_referred[:, :, :] = torch.tensor([0.0, 1.0], device=device) + if 'is_ref_inst_visible' in targets[0][ + 0]: # visibility labels are available per-frame for the referred object: + is_ref_inst_visible = torch.stack([ + torch.stack([t['is_ref_inst_visible'] for t in t_step]) + for t_step in targets + ]).permute(1, 0) + for ref_idx, is_visible in zip(ref_indices, is_ref_inst_visible): + is_visible = is_visible.nonzero().squeeze() + target_is_referred[is_visible, + ref_idx, :] = torch.tensor([1.0, 0.0], + device=device) + else: # assume that the referred object is visible in every frame: + target_is_referred[:, ref_indices, :] = torch.tensor([1.0, 0.0], + device=device) + cost_is_referred = -(pred_is_referred.unsqueeze(2) + * target_is_referred.unsqueeze(1)).sum(dim=-1).mean( + dim=0) + return cost_is_referred diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/misc.py b/modelscope/models/cv/referring_video_object_segmentation/utils/misc.py new file mode 100644 index 00000000..ecf34b8c --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/misc.py @@ -0,0 +1,234 @@ +# Modified from DETR https://github.com/facebookresearch/detr +# Misc functions. +# Mostly copy-paste from torchvision references. + +import pickle +from typing import List, Optional + +import torch +import torch.distributed as dist +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +from torch import Tensor + +if float(torchvision.__version__.split('.')[1]) < 7.0: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to('cuda') + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device='cuda') + size_list = [torch.tensor([0], device='cuda') for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append( + torch.empty((max_size, ), dtype=torch.uint8, device='cuda')) + if local_size != max_size: + padding = torch.empty( + size=(max_size - local_size, ), dtype=torch.uint8, device='cuda') + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + """ + This function receives a list of image tensors and returns a NestedTensor of the padded images, along with their + padding masks (true for padding areas, false otherwise). + """ + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img) + m[:img.shape[1], :img.shape[2]] = False + return NestedTensor(tensor, mask) + + +def nested_tensor_from_videos_list(videos_list: List[Tensor]): + """ + This function receives a list of videos (each of shape [T, C, H, W]) and returns a NestedTensor of the padded + videos (shape [T, B, C, PH, PW], along with their padding masks (true for padding areas, false otherwise, of shape + [T, B, PH, PW]. + """ + max_size = _max_by_axis([list(img.shape) for img in videos_list]) + padded_batch_shape = [len(videos_list)] + max_size + b, t, c, h, w = padded_batch_shape + dtype = videos_list[0].dtype + device = videos_list[0].device + padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) + videos_pad_masks = torch.ones((b, t, h, w), + dtype=torch.bool, + device=device) + for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, + padded_videos, + videos_pad_masks): + pad_vid_frames[:vid_frames.shape[0], :, :vid_frames. + shape[2], :vid_frames.shape[3]].copy_(vid_frames) + vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames. + shape[3]] = False + # transpose the temporal and batch dims and create a NestedTensor: + return NestedTensor( + padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1)) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def interpolate(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__.split('.')[1]) < 7.0: + if input.numel() > 0: + return torch.nn.functional.interpolate(input, size, scale_factor, + mode, align_corners) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, + mode, align_corners) diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py b/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py new file mode 100644 index 00000000..e603df6c --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/mttr.py @@ -0,0 +1,128 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from .backbone import init_backbone +from .misc import NestedTensor +from .multimodal_transformer import MultimodalTransformer +from .segmentation import FPNSpatialDecoder + + +class MTTR(nn.Module): + """ The main module of the Multimodal Tracking Transformer """ + + def __init__(self, + num_queries, + mask_kernels_dim=8, + aux_loss=False, + **kwargs): + """ + Parameters: + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + MTTR can detect in a single image. In our paper we use 50 in all settings. + mask_kernels_dim: dim of the segmentation kernels and of the feature maps outputted by the spatial decoder. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.backbone = init_backbone(**kwargs) + self.transformer = MultimodalTransformer(**kwargs) + d_model = self.transformer.d_model + self.is_referred_head = nn.Linear( + d_model, + 2) # binary 'is referred?' prediction head for object queries + self.instance_kernels_head = MLP( + d_model, d_model, output_dim=mask_kernels_dim, num_layers=2) + self.obj_queries = nn.Embedding( + num_queries, d_model) # pos embeddings for the object queries + self.vid_embed_proj = nn.Conv2d( + self.backbone.layer_output_channels[-1], d_model, kernel_size=1) + self.spatial_decoder = FPNSpatialDecoder( + d_model, self.backbone.layer_output_channels[:-1][::-1], + mask_kernels_dim) + self.aux_loss = aux_loss + + def forward(self, samples: NestedTensor, valid_indices, text_queries): + """The forward expects a NestedTensor, which consists of: + - samples.tensor: Batched frames of shape [time x batch_size x 3 x H x W] + - samples.mask: A binary mask of shape [time x batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_is_referred": The reference prediction logits for all queries. + Shape: [time x batch_size x num_queries x 2] + - "pred_masks": The mask logits for all queries. + Shape: [time x batch_size x num_queries x H_mask x W_mask] + - "aux_outputs": Optional, only returned when auxiliary losses are activated. It is a list of + dictionaries containing the two above keys for each decoder layer. + """ + backbone_out = self.backbone(samples) + # keep only the valid frames (frames which are annotated): + # (for example, in a2d-sentences only the center frame in each window is annotated). + for layer_out in backbone_out: + layer_out.tensors = layer_out.tensors.index_select( + 0, valid_indices) + layer_out.mask = layer_out.mask.index_select(0, valid_indices) + bbone_final_layer_output = backbone_out[-1] + vid_embeds, vid_pad_mask = bbone_final_layer_output.decompose() + + T, B, _, _, _ = vid_embeds.shape + vid_embeds = rearrange(vid_embeds, 't b c h w -> (t b) c h w') + vid_embeds = self.vid_embed_proj(vid_embeds) + vid_embeds = rearrange( + vid_embeds, '(t b) c h w -> t b c h w', t=T, b=B) + + transformer_out = self.transformer(vid_embeds, vid_pad_mask, + text_queries, + self.obj_queries.weight) + # hs is: [L, T, B, N, D] where L is number of decoder layers + # vid_memory is: [T, B, D, H, W] + # txt_memory is a list of length T*B of [S, C] where S might be different for each sentence + # encoder_middle_layer_outputs is a list of [T, B, H, W, D] + hs, vid_memory, txt_memory = transformer_out + + vid_memory = rearrange(vid_memory, 't b d h w -> (t b) d h w') + bbone_middle_layer_outputs = [ + rearrange(o.tensors, 't b d h w -> (t b) d h w') + for o in backbone_out[:-1][::-1] + ] + decoded_frame_features = self.spatial_decoder( + vid_memory, bbone_middle_layer_outputs) + decoded_frame_features = rearrange( + decoded_frame_features, '(t b) d h w -> t b d h w', t=T, b=B) + instance_kernels = self.instance_kernels_head(hs) # [L, T, B, N, C] + # output masks is: [L, T, B, N, H_mask, W_mask] + output_masks = torch.einsum('ltbnc,tbchw->ltbnhw', instance_kernels, + decoded_frame_features) + outputs_is_referred = self.is_referred_head(hs) # [L, T, B, N, 2] + + layer_outputs = [] + for pm, pir in zip(output_masks, outputs_is_referred): + layer_out = {'pred_masks': pm, 'pred_is_referred': pir} + layer_outputs.append(layer_out) + out = layer_outputs[ + -1] # the output for the last decoder layer is used by default + if self.aux_loss: + out['aux_outputs'] = layer_outputs[:-1] + return out + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py new file mode 100644 index 00000000..39962715 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/multimodal_transformer.py @@ -0,0 +1,440 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# MTTR Multimodal Transformer class. +# Modified from DETR https://github.com/facebookresearch/detr + +import copy +import os +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor, nn +from transformers import RobertaModel, RobertaTokenizerFast + +from .position_encoding_2d import PositionEmbeddingSine2D + +os.environ[ + 'TOKENIZERS_PARALLELISM'] = 'false' # this disables a huggingface tokenizer warning (printed every epoch) + + +class MultimodalTransformer(nn.Module): + + def __init__(self, + num_encoder_layers=3, + num_decoder_layers=3, + text_encoder_type='roberta-base', + freeze_text_encoder=True, + **kwargs): + super().__init__() + self.d_model = kwargs['d_model'] + encoder_layer = TransformerEncoderLayer(**kwargs) + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers) + decoder_layer = TransformerDecoderLayer(**kwargs) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + norm=nn.LayerNorm(self.d_model), + return_intermediate=True) + self.pos_encoder_2d = PositionEmbeddingSine2D() + self._reset_parameters() + + self.text_encoder = RobertaModel.from_pretrained(text_encoder_type) + self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems... + self.tokenizer = RobertaTokenizerFast.from_pretrained( + text_encoder_type) + self.freeze_text_encoder = freeze_text_encoder + if freeze_text_encoder: + for p in self.text_encoder.parameters(): + p.requires_grad_(False) + + self.txt_proj = FeatureResizer( + input_feat_size=self.text_encoder.config.hidden_size, + output_feat_size=self.d_model, + dropout=kwargs['dropout'], + ) + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, vid_embeds, vid_pad_mask, text_queries, obj_queries): + device = vid_embeds.device + t, b, _, h, w = vid_embeds.shape + + txt_memory, txt_pad_mask = self.forward_text(text_queries, device) + # add temporal dim to txt memory & padding mask: + txt_memory = repeat(txt_memory, 's b c -> s (t b) c', t=t) + txt_pad_mask = repeat(txt_pad_mask, 'b s -> (t b) s', t=t) + + vid_embeds = rearrange(vid_embeds, 't b c h w -> (h w) (t b) c') + # Concat the image & text embeddings on the sequence dimension + encoder_src_seq = torch.cat((vid_embeds, txt_memory), dim=0) + seq_mask = torch.cat( + (rearrange(vid_pad_mask, 't b h w -> (t b) (h w)'), txt_pad_mask), + dim=1) + # vid_pos_embed is: [T*B, H, W, d_model] + vid_pos_embed = self.pos_encoder_2d( + rearrange(vid_pad_mask, 't b h w -> (t b) h w'), self.d_model) + # use zeros in place of pos embeds for the text sequence: + pos_embed = torch.cat( + (rearrange(vid_pos_embed, 't_b h w c -> (h w) t_b c'), + torch.zeros_like(txt_memory)), + dim=0) + + memory = self.encoder( + encoder_src_seq, src_key_padding_mask=seq_mask, + pos=pos_embed) # [S, T*B, C] + vid_memory = rearrange( + memory[:h * w, :, :], + '(h w) (t b) c -> t b c h w', + h=h, + w=w, + t=t, + b=b) + txt_memory = memory[h * w:, :, :] + txt_memory = rearrange(txt_memory, 's t_b c -> t_b s c') + txt_memory = [ + t_mem[~pad_mask] + for t_mem, pad_mask in zip(txt_memory, txt_pad_mask) + ] # remove padding + + # add T*B dims to query embeds (was: [N, C], where N is the number of object queries): + obj_queries = repeat(obj_queries, 'n c -> n (t b) c', t=t, b=b) + tgt = torch.zeros_like(obj_queries) # [N, T*B, C] + + # hs is [L, N, T*B, C] where L is number of layers in the decoder + hs = self.decoder( + tgt, + memory, + memory_key_padding_mask=seq_mask, + pos=pos_embed, + query_pos=obj_queries) + hs = rearrange(hs, 'l n (t b) c -> l t b n c', t=t, b=b) + return hs, vid_memory, txt_memory + + def forward_text(self, text_queries, device): + tokenized_queries = self.tokenizer.batch_encode_plus( + text_queries, padding='longest', return_tensors='pt') + tokenized_queries = tokenized_queries.to(device) + with torch.inference_mode(mode=self.freeze_text_encoder): + encoded_text = self.text_encoder(**tokenized_queries) + # Transpose memory because pytorch's attention expects sequence first + tmp_last_hidden_state = encoded_text.last_hidden_state.clone() + txt_memory = rearrange(tmp_last_hidden_state, 'b s c -> s b c') + txt_memory = self.txt_proj( + txt_memory) # change text embeddings dim to model dim + # Invert attention mask that we get from huggingface because its the opposite in pytorch transformer + txt_pad_mask = tokenized_queries.attention_mask.ne(1).bool() # [B, S] + return txt_memory, txt_pad_mask + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, + decoder_layer, + num_layers, + norm=None, + return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, + d_model, + nheads, + dim_feedforward=2048, + dropout=0.1, + activation='relu', + normalize_before=False, + **kwargs): + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nheads, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, + k, + value=src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, + k, + value=src2, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, + d_model, + nheads, + dim_feedforward=2048, + dropout=0.1, + activation='relu', + normalize_before=False, + **kwargs): + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nheads, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention( + d_model, nheads, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, + k, + value=tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, + k, + value=tgt2, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, + pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class FeatureResizer(nn.Module): + """ + This class takes as input a set of embeddings of dimension C1 and outputs a set of + embedding of dimension C2, after a linear transformation, dropout and normalization (LN). + """ + + def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): + super().__init__() + self.do_ln = do_ln + # Object feature encoding + self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) + self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) + self.dropout = nn.Dropout(dropout) + + def forward(self, encoder_features): + x = self.fc(encoder_features) + if self.do_ln: + x = self.layer_norm(x) + output = self.dropout(x) + return output + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == 'relu': + return F.relu + if activation == 'gelu': + return F.gelu + if activation == 'glu': + return F.glu + raise RuntimeError(F'activation should be relu/gelu, not {activation}.') diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py b/modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py new file mode 100644 index 00000000..f9ef05a1 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/position_encoding_2d.py @@ -0,0 +1,57 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr +# 2D sine positional encodings for the visual features in the multimodal transformer. + +import math + +import torch +from torch import Tensor, nn + + +class PositionEmbeddingSine2D(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, temperature=10000, normalize=True, scale=None): + super().__init__() + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError('normalize should be True if scale is passed') + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, mask: Tensor, hidden_dim: int): + """ + @param mask: a tensor of shape [B, H, W] + @param hidden_dim: int + @return: + """ + num_pos_feats = hidden_dim // 2 + + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature**(2 * (dim_t // 2) / num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3) + return pos diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py b/modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py new file mode 100644 index 00000000..64582140 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/postprocessing.py @@ -0,0 +1,119 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR + +import numpy as np +import pycocotools.mask as mask_util +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class A2DSentencesPostProcess(nn.Module): + """ + This module converts the model's output into the format expected by the coco api for the given task + """ + + def __init__(self): + super(A2DSentencesPostProcess, self).__init__() + + @torch.inference_mode() + def forward(self, outputs, resized_padded_sample_size, + resized_sample_sizes, orig_sample_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + resized_padded_sample_size: size of samples (input to model) after size augmentation + padding. + resized_sample_sizes: size of samples after size augmentation but without padding. + orig_sample_sizes: original size of the samples (no augmentations or padding) + """ + pred_is_referred = outputs['pred_is_referred'] + prob = F.softmax(pred_is_referred, dim=-1) + scores = prob[..., 0] + pred_masks = outputs['pred_masks'] + pred_masks = F.interpolate( + pred_masks, + size=resized_padded_sample_size, + mode='bilinear', + align_corners=False) + pred_masks = (pred_masks.sigmoid() > 0.5) + processed_pred_masks, rle_masks = [], [] + for f_pred_masks, resized_size, orig_size in zip( + pred_masks, resized_sample_sizes, orig_sample_sizes): + f_mask_h, f_mask_w = resized_size # resized shape without padding + # remove the samples' padding + f_pred_masks_no_pad = f_pred_masks[:, :f_mask_h, : + f_mask_w].unsqueeze(1) + # resize the samples back to their original dataset (target) size for evaluation + f_pred_masks_processed = F.interpolate( + f_pred_masks_no_pad.float(), size=orig_size, mode='nearest') + f_pred_rle_masks = [ + mask_util.encode( + np.array( + mask[0, :, :, np.newaxis], dtype=np.uint8, + order='F'))[0] + for mask in f_pred_masks_processed.cpu() + ] + processed_pred_masks.append(f_pred_masks_processed) + rle_masks.append(f_pred_rle_masks) + predictions = [{ + 'scores': s, + 'masks': m, + 'rle_masks': rle + } for s, m, rle in zip(scores, processed_pred_masks, rle_masks)] + return predictions + + +class ReferYoutubeVOSPostProcess(nn.Module): + """ + This module converts the model's output into the format expected by the coco api for the given task + """ + + def __init__(self): + super(ReferYoutubeVOSPostProcess, self).__init__() + + @torch.inference_mode() + def forward(self, outputs, videos_metadata, samples_shape_with_padding): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + videos_metadata: a dictionary with each video's metadata. + samples_shape_with_padding: size of the batch frames with padding. + """ + pred_is_referred = outputs['pred_is_referred'] + prob_is_referred = F.softmax(pred_is_referred, dim=-1) + # note we average on the temporal dim to compute score per trajectory: + trajectory_scores = prob_is_referred[..., 0].mean(dim=0) + pred_trajectory_indices = torch.argmax(trajectory_scores, dim=-1) + pred_masks = rearrange(outputs['pred_masks'], + 't b nq h w -> b t nq h w') + # keep only the masks of the chosen trajectories: + b = pred_masks.shape[0] + pred_masks = pred_masks[torch.arange(b), :, pred_trajectory_indices] + # resize the predicted masks to the size of the model input (which might include padding) + pred_masks = F.interpolate( + pred_masks, + size=samples_shape_with_padding, + mode='bilinear', + align_corners=False) + # apply a threshold to create binary masks: + pred_masks = (pred_masks.sigmoid() > 0.5) + # remove the padding per video (as videos might have different resolutions and thus different padding): + preds_by_video = [] + for video_pred_masks, video_metadata in zip(pred_masks, + videos_metadata): + # size of the model input batch frames without padding: + resized_h, resized_w = video_metadata['resized_frame_size'] + video_pred_masks = video_pred_masks[:, :resized_h, : + resized_w].unsqueeze( + 1) # remove the padding + # resize the masks back to their original frames dataset size for evaluation: + original_frames_size = video_metadata['original_frame_size'] + tuple_size = tuple(original_frames_size.cpu().numpy()) + video_pred_masks = F.interpolate( + video_pred_masks.float(), size=tuple_size, mode='nearest') + video_pred_masks = video_pred_masks.to(torch.uint8).cpu() + # combine the predicted masks and the video metadata to create a final predictions dict: + video_pred = {**video_metadata, **{'pred_masks': video_pred_masks}} + preds_by_video.append(video_pred) + return preds_by_video diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py b/modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py new file mode 100644 index 00000000..b3228820 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/segmentation.py @@ -0,0 +1,137 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class FPNSpatialDecoder(nn.Module): + """ + An FPN-like spatial decoder. Generates high-res, semantically rich features which serve as the base for creating + instance segmentation masks. + """ + + def __init__(self, context_dim, fpn_dims, mask_kernels_dim=8): + super().__init__() + + inter_dims = [ + context_dim, context_dim // 2, context_dim // 4, context_dim // 8, + context_dim // 16 + ] + self.lay1 = torch.nn.Conv2d(context_dim, inter_dims[0], 3, padding=1) + self.gn1 = torch.nn.GroupNorm(8, inter_dims[0]) + self.lay2 = torch.nn.Conv2d(inter_dims[0], inter_dims[1], 3, padding=1) + self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.context_dim = context_dim + + self.add_extra_layer = len(fpn_dims) == 3 + if self.add_extra_layer: + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + self.lay5 = torch.nn.Conv2d( + inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) + self.out_lay = torch.nn.Conv2d( + inter_dims[4], mask_kernels_dim, 3, padding=1) + else: + self.out_lay = torch.nn.Conv2d( + inter_dims[3], mask_kernels_dim, 3, padding=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: Tensor, layer_features: List[Tensor]): + x = self.lay1(x) + x = self.gn1(x) + x = F.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = F.relu(x) + + cur_fpn = self.adapter1(layer_features[0]) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode='nearest') + x = self.lay3(x) + x = self.gn3(x) + x = F.relu(x) + + cur_fpn = self.adapter2(layer_features[1]) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode='nearest') + x = self.lay4(x) + x = self.gn4(x) + x = F.relu(x) + + if self.add_extra_layer: + cur_fpn = self.adapter3(layer_features[2]) + x = cur_fpn + F.interpolate( + x, size=cur_fpn.shape[-2:], mode='nearest') + x = self.lay5(x) + x = self.gn5(x) + x = F.relu(x) + + x = self.out_lay(x) + return x + + def num_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +def dice_loss(inputs, targets, num_masks): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + +def sigmoid_focal_loss(inputs, + targets, + num_masks, + alpha: float = 0.25, + gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits( + inputs, targets, reduction='none') + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t)**gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_masks diff --git a/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py new file mode 100644 index 00000000..faaf6e10 --- /dev/null +++ b/modelscope/models/cv/referring_video_object_segmentation/utils/swin_transformer.py @@ -0,0 +1,732 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from Video-Swin-Transformer https://github.com/SwinTransformer/Video-Swin-Transformer + +from functools import lru_cache, reduce +from operator import mul + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], + window_size[1], W // window_size[2], window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, + 7).contiguous().view(-1, reduce(mul, window_size), C) + return windows + + +def window_reverse(windows, window_size, B, D, H, W): + """ + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], + W // window_size[2], window_size[0], window_size[1], + window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class WindowAttention3D(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wd, Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + wd, wh, ww = window_size + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * wd - 1) * (2 * wh - 1) * (2 * ww - 1), + num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, + coords_w)) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] + - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:N, :N].reshape(-1)].reshape( + N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock3D(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=(2, 7, 7), + shift_size=(0, 0, 0), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_checkpoint=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + + assert 0 <= self.shift_size[0] < self.window_size[ + 0], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[1] < self.window_size[ + 1], 'shift_size must in 0-window_size' + assert 0 <= self.shift_size[2] < self.window_size[ + 2], 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention3D( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + + x = self.norm1(x) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll( + x, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + # partition windows + x_windows = window_partition(shifted_x, + window_size) # B*nW, Wd*Wh*Ww, C + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C, ))) + shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, + Wp) # B D' H' W' C + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll( + shifted_x, + shifts=(shift_size[0], shift_size[1], shift_size[2]), + dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :].contiguous() + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + """ + B, D, H, W, C = x.shape + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +# cache each stage results +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], + -shift_size[0]), slice( + -shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], + -shift_size[1]), slice( + -shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], + -shift_size[2]), slice( + -shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, + window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + return attn_mask + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (1,7,7). + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(1, 7, 7), + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock3D( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) for i in range(depth) + ]) + + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for SW-MSA + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, + self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, 'b d h w c -> b c d h w') + return x + + +class PatchEmbed3D(nn.Module): + """ Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad( + x, + (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # B C D Wh Ww + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + + return x + + +class SwinTransformer3D(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + patch_size (int | tuple(int)): Patch size. Default: (4,4,4). + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer: Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + """ + + def __init__(self, + pretrained=None, + pretrained2d=True, + patch_size=(4, 4, 4), + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(2, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=False, + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrained = pretrained + self.pretrained2d = pretrained2d + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + self.window_size = window_size + self.patch_size = patch_size + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging + if i_layer < self.num_layers - 1 else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.num_features = int(embed_dim * 2**(self.num_layers - 1)) + + # add a norm layer for each output + self.norm = norm_layer(self.num_features) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1: + self.pos_drop.eval() + for i in range(0, self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def inflate_weights(self, logger): + """Inflate the swin2d parameters to swin3d. + + The differences between swin3d and swin2d mainly lie in an extra + axis. To utilize the pretrained parameters in 2d model, + the weight of swin2d models should be inflated to fit in the shapes of + the 3d counterpart. + + Args: + logger (logging.Logger): The logger used to print + debugging infomation. + """ + checkpoint = torch.load(self.pretrained, map_location='cpu') + state_dict = checkpoint['model'] + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_position_index' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k] + for k in attn_mask_keys: + del state_dict[k] + + state_dict['patch_embed.proj.weight'] = state_dict[ + 'patch_embed.proj.weight'].unsqueeze(2).repeat( + 1, 1, self.patch_size[0], 1, 1) / self.patch_size[0] + + # bicubic interpolate relative_position_bias_table if not match + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if 'relative_position_bias_table' in k + ] + for k in relative_position_bias_table_keys: + relative_position_bias_table_pretrained = state_dict[k] + relative_position_bias_table_current = self.state_dict()[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + wd = self.window_size[0] + if nH1 != nH2: + logger.warning(f'Error in loading {k}, passing') + else: + if L1 != L2: + S1 = int(L1**0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute( + 1, 0).view(1, nH1, S1, S1), + size=(2 * self.window_size[1] - 1, + 2 * self.window_size[2] - 1), + mode='bicubic') + relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view( + nH2, L2).permute(1, 0) + state_dict[k] = relative_position_bias_table_pretrained.repeat( + 2 * wd - 1, 1) + + msg = self.load_state_dict(state_dict, strict=False) + logger.info(msg) + logger.info(f"=> loaded successfully '{self.pretrained}'") + del checkpoint + torch.cuda.empty_cache() + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x.contiguous()) + + x = rearrange(x, 'n c d h w -> n d h w c') + x = self.norm(x) + x = rearrange(x, 'n d h w c -> n c d h w') + + return x + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer3D, self).train(mode) + self._freeze_stages() diff --git a/modelscope/models/cv/salient_detection/__init__.py b/modelscope/models/cv/salient_detection/__init__.py new file mode 100644 index 00000000..b3b5b5fa --- /dev/null +++ b/modelscope/models/cv/salient_detection/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .salient_model import SalientDetection + +else: + _import_structure = { + 'salient_model': ['SalientDetection'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/salient_detection/models/__init__.py b/modelscope/models/cv/salient_detection/models/__init__.py new file mode 100644 index 00000000..8ea7a5d3 --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/__init__.py @@ -0,0 +1,3 @@ +# The implementation is adopted from U-2-Net, made publicly available under the Apache 2.0 License +# source code avaiable via https://github.com/xuebinqin/U-2-Net +from .u2net import U2NET diff --git a/modelscope/models/cv/salient_detection/models/u2net.py b/modelscope/models/cv/salient_detection/models/u2net.py new file mode 100644 index 00000000..05dbf7ad --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/u2net.py @@ -0,0 +1,301 @@ +# The implementation is adopted from U-2-Net, made publicly available under the Apache 2.0 License +# source code avaiable via https://github.com/xuebinqin/U-2-Net +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class REBNCONV(nn.Module): + + def __init__(self, in_ch=3, out_ch=3, dirate=1): + super(REBNCONV, self).__init__() + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self, x): + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + return xout + + +def _upsample_like(src, tar): + """upsample tensor 'src' to have the same spatial size with tensor 'tar'.""" + src = F.upsample(src, size=tar.shape[2:], mode='bilinear') + return src + + +class RSU7(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU7, self).__init__() + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + hx6 = self.rebnconv6(hx) + hx7 = self.rebnconv7(hx6) + hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) + hx6dup = _upsample_like(hx6d, hx5) + hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + return hx1d + hxin + + +class RSU6(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + hx5 = self.rebnconv5(hx) + hx6 = self.rebnconv6(hx5) + hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + return hx1d + hxin + + +class RSU5(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + hx4 = self.rebnconv4(hx) + hx5 = self.rebnconv5(hx4) + hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + return hx1d + hxin + + +class RSU4(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + hx3 = self.rebnconv3(hx) + hx4 = self.rebnconv4(hx3) + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + return hx1d + hxin + + +class RSU4F(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + + hx = x + hxin = self.rebnconvin(hx) + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + hx4 = self.rebnconv4(hx3) + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) + hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) + return hx1d + hxin + + +class U2NET(nn.Module): + + def __init__(self, in_ch=3, out_ch=1): + super(U2NET, self).__init__() + + # encoder + self.stage1 = RSU7(in_ch, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.stage6 = RSU4F(512, 256, 512) + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) + self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) + self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) + self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) + self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1) + + def forward(self, x): + + hx = x + hx1 = self.stage1(hx) + hx = self.pool12(hx1) + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + + hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) + d1 = self.side1(hx1d) + d2 = self.side2(hx2d) + d2 = _upsample_like(d2, d1) + d3 = self.side3(hx3d) + d3 = _upsample_like(d3, d1) + d4 = self.side4(hx4d) + d4 = _upsample_like(d4, d1) + d5 = self.side5(hx5d) + d5 = _upsample_like(d5, d1) + d6 = self.side6(hx6) + d6 = _upsample_like(d6, d1) + d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) + return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid( + d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid( + d5), torch.sigmoid(d6) diff --git a/modelscope/models/cv/salient_detection/salient_model.py b/modelscope/models/cv/salient_detection/salient_model.py new file mode 100644 index 00000000..73c3c3fb --- /dev/null +++ b/modelscope/models/cv/salient_detection/salient_model.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .models import U2NET + + +@MODELS.register_module( + Tasks.semantic_segmentation, module_name=Models.detection) +class SalientDetection(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + self.model = U2NET(3, 1) + checkpoint = torch.load(model_path, map_location='cpu') + self.transform_input = transforms.Compose([ + transforms.Resize((320, 320)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.model.load_state_dict(checkpoint) + self.model.eval() + + def inference(self, data): + """data is tensor 3 * H * W ---> return tensor H * W .""" + data = data.unsqueeze(0) + if next(self.model.parameters()).is_cuda: + data = data.to( + torch.device([next(self.model.parameters()).device][0])) + + with torch.no_grad(): + results = self.model(data) + + if next(self.model.parameters()).is_cuda: + return results[0][0, 0, :, :].cpu() + return results[0][0, 0, :, :] + + def preprocess(self, image): + """image is numpy.""" + data = self.transform_input(Image.fromarray(image)) + return data.float() + + def postprocess(self, inputs): + """resize .""" + data = inputs['data'] + w = inputs['img_w'] + h = inputs['img_h'] + data_norm = (data - torch.min(data)) / ( + torch.max(data) - torch.min(data)) + data_norm_np = (data_norm.numpy() * 255).astype('uint8') + data_norm_rst = cv2.resize(data_norm_np, (w, h)) + + return data_norm_rst diff --git a/modelscope/models/cv/shop_segmentation/__init__.py b/modelscope/models/cv/shop_segmentation/__init__.py new file mode 100644 index 00000000..072628bd --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .shop_seg_base import SHOPSEG + +else: + _import_structure = {'shop_seg_base': ['SHOPSEG']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/shop_segmentation/common.py b/modelscope/models/cv/shop_segmentation/common.py new file mode 100644 index 00000000..8cb940a5 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/common.py @@ -0,0 +1,57 @@ +# Base modules are adapted from https://github.com/open-mmlab/mmcv/, +# originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +# https://github.com/open-mmlab/mmsegmentation/, +# originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +# and adapted from https://github.com/raoyongming/DenseCLIP/, +# originally MIT License, Copyright (c) 2022 Rao, Yongming. + +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > input_w: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super(Upsample, self).__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/modelscope/models/cv/shop_segmentation/head_fpn.py b/modelscope/models/cv/shop_segmentation/head_fpn.py new file mode 100644 index 00000000..cad389c7 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/head_fpn.py @@ -0,0 +1,120 @@ +# Base modules are adapted from https://github.com/open-mmlab/mmcv/, +# originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +# https://github.com/open-mmlab/mmsegmentation/, +# originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +# and adapted from https://github.com/raoyongming/DenseCLIP/, +# originally MIT License, Copyright (c) 2022 Rao, Yongming. + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from timm.models.layers import drop, drop_path, trunc_normal_ + +from .common import Upsample, resize + + +class FPNHead(nn.Module): + """Panoptic Feature Pyramid Networks. + This head is the implementation of `Semantic FPN + `_. + Args: + feature_strides (tuple[int]): The strides for input feature maps. + stack_lateral. All strides suppose to be power of 2. The first + one is of largest resolution. + """ + + def __init__(self, + channels, + num_classes, + dropout_ratio=0.1, + feature_strides=[4, 8, 16, 32], + align_corners=False, + **kwargs): + super(FPNHead, self).__init__() + self.act_cfg = dict(type='ReLU') + self.channels = channels + self.conv_cfg = None + self.norm_cfg = None + self.norm_cfg = dict(type='BN2d', requires_grad=True) + self.align_corners = align_corners + self.dropout_ratio = dropout_ratio + self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + self.in_index = [0, 1, 2, 3] + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + self.scale_heads = nn.ModuleList() + for i in range(len(feature_strides)): + head_length = max( + 1, + int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) + scale_head = [] + for k in range(head_length): + scale_head.append( + ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if feature_strides[i] != feature_strides[0]: + scale_head.append( + Upsample( + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners)) + self.scale_heads.append(nn.Sequential(*scale_head)) + + self.apply(self._init_weights) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + inputs = [inputs[i] for i in self.in_index] + return inputs + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + def forward(self, inputs): + x = self._transform_inputs(inputs) + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.feature_strides)): + # non inplace + output = output + resize( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + output = self.cls_seg(output) + return output + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) diff --git a/modelscope/models/cv/shop_segmentation/models.py b/modelscope/models/cv/shop_segmentation/models.py new file mode 100644 index 00000000..3880d074 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/models.py @@ -0,0 +1,899 @@ +# Base modules are adapted from https://github.com/open-mmlab/mmcv/, +# originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +# https://github.com/open-mmlab/mmsegmentation/, +# originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +# and adapted from https://github.com/raoyongming/DenseCLIP/, +# originally MIT License, Copyright (c) 2022 Rao, Yongming. + +import math +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop, drop_path, trunc_normal_ +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + self.embed_dim = embed_dim + self.spacial_dim = spacial_dim + + def forward(self, x): + B, C, H, W = x.shape + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + + cls_pos = self.positional_embedding[0:1, :] + spatial_pos = F.interpolate( + self.positional_embedding[1:, ].reshape(1, self.spacial_dim, + self.spacial_dim, + self.embed_dim).permute( + 0, 3, 1, 2), + size=(H, W), + mode='bilinear') + spatial_pos = spatial_pos.reshape(self.embed_dim, H * W).permute(1, 0) + positional_embedding = torch.cat([cls_pos, spatial_pos], dim=0) + + x = x + positional_embedding[:, None, :] + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + x = x.permute(1, 2, 0) + global_feat = x[:, :, 0] + feature_map = x[:, :, 1:].reshape(B, -1, H, W) + return global_feat, feature_map + + +class CLIPResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim=512, + input_resolution=224, + width=64, + pretrained=None, + **kwargs): + super().__init__() + self.pretrained = pretrained + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('visual.'): + new_k = k.replace('visual.', '') + state_dict[new_k] = checkpoint[k] + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in CLIPResNet') + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + + outs = [] + x = self.layer1(x) + outs.append(x) + x = self.layer2(x) + outs.append(x) + x = self.layer3(x) + outs.append(x) + x = self.layer4(x) + outs.append(x) + + return tuple(outs) + + +class CLIPResNetWithAttention(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim=1024, + input_resolution=224, + width=64, + pretrained=None, + **kwargs): + super().__init__() + self.pretrained = pretrained + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, 32, + output_dim) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('visual.'): + new_k = k.replace('visual.', '') + state_dict[new_k] = checkpoint[k] + + if 'positional_embedding' in new_k: + if self.attnpool.positional_embedding.shape != state_dict[ + new_k].shape: + print( + f'Resize the pos_embed shape from {state_dict[new_k].shape}' + f' to {self.attnpool.positional_embedding.shape}' + ) + cls_pos = state_dict[new_k][0:1, :] + H = W = self.input_resolution // 32 + old_h = int( + math.sqrt(state_dict[new_k][1:, ].shape[0])) + spatial_pos = F.interpolate( + state_dict[new_k][1:, ].reshape( + 1, old_h, old_h, + cls_pos.shape[1]).permute(0, 3, 1, 2), + size=(H, W), + mode='bilinear') + spatial_pos = spatial_pos.reshape( + cls_pos.shape[1], H * W).permute(1, 0) + positional_embedding = torch.cat( + [cls_pos, spatial_pos], dim=0) + state_dict[new_k] = positional_embedding + assert self.attnpool.positional_embedding.shape == state_dict[ + new_k].shape + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in CLIPResNet') + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + + outs = [] + x = self.layer1(x) + outs.append(x) + x = self.layer2(x) + outs.append(x) + x = self.layer3(x) + outs.append(x) + x = self.layer4(x) + outs.append(x) + + x_global, x_local = self.attnpool(x) + outs.append([x_global, x_local]) + + return tuple(outs) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None, + drop_path=0.): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.drop_path(self.attention(self.ln_1(x))) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + drop_path_rate=0.): + super().__init__() + self.width = width + self.layers = layers + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, layers) + ] # stochastic depth decay rule + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) + for i in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q, k, v): + B, N, C = q.shape + assert k.shape == v.shape + B, M, C = k.shape + q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads) + k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads) + v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads) + + attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale + + attn = attn.softmax(dim=-1) + + x = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class TransformerDecoderLayer(nn.Module): + + def __init__( + self, + d_model, + nhead, + dropout=0.1, + ): + super().__init__() + self.self_attn = Attention(d_model, nhead, proj_drop=dropout) + self.cross_attn = Attention(d_model, nhead, proj_drop=dropout) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.mlp = nn.Sequential( + nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), + nn.Linear(d_model * 4, d_model)) + + def forward(self, x, mem): + q = k = v = self.norm1(x) + x = x + self.self_attn(q, k, v) + q = self.norm2(x) + x = x + self.cross_attn(q, mem, mem) + x = x + self.dropout(self.mlp(self.norm3(x))) + return x + + +class CLIPVisionTransformer(nn.Module): + + def __init__(self, + input_resolution=224, + patch_size=32, + width=768, + layers=12, + heads=12, + output_dim=512, + drop_path_rate=0.0, + out_indices=[3, 5, 7, 11], + pretrained=None, + get_embeddings=False, + **kwargs): + super().__init__() + self.pretrained = pretrained + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.spatial_size = input_resolution // patch_size + self.ln_pre = LayerNorm(width) + self.get_embeddings = get_embeddings + + self.transformer = Transformer( + width, layers, heads, drop_path_rate=drop_path_rate) + + self.out_indices = out_indices + + if get_embeddings: + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + embed_dim = width + + if patch_size == 16: + self.fpn1 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + nn.BatchNorm2d(embed_dim), + nn.GELU(), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn2 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn3 = nn.GroupNorm(1, embed_dim) + + self.fpn4 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.MaxPool2d(kernel_size=2, stride=2)) + + elif patch_size == 8: + self.fpn1 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn2 = nn.GroupNorm(1, embed_dim) + + self.fpn3 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + + self.fpn4 = nn.Sequential( + nn.GroupNorm(1, embed_dim), + nn.MaxPool2d(kernel_size=4, stride=4), + ) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('visual.'): + new_k = k.replace('visual.', '') + state_dict[new_k] = checkpoint[k] + + if 'positional_embedding' in state_dict.keys(): + if self.positional_embedding.shape != state_dict[ + 'positional_embedding'].shape: + print( + f'Resize the pos_embed shape from {state_dict["positional_embedding"].shape} to' + f' {self.positional_embedding.shape}') + cls_pos = state_dict['positional_embedding'][0:1, :] + spatial_pos = F.interpolate( + state_dict['positional_embedding'][1:, ].reshape( + 1, 14, 14, 768).permute(0, 3, 1, 2), + size=(self.spatial_size, self.spatial_size), + mode='bilinear') + spatial_pos = spatial_pos.reshape( + 768, + self.spatial_size * self.spatial_size).permute(1, 0) + positional_embedding = torch.cat([cls_pos, spatial_pos], + dim=0) + state_dict['positional_embedding'] = positional_embedding + assert self.positional_embedding.shape == state_dict[ + 'positional_embedding'].shape + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in vision transformer') + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + B, C, H, W = x.shape + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x1 = self.class_embedding.to(x.dtype) + x2 = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([x1 + x2, x], dim=1) + pos = self.positional_embedding.to(x.dtype) + cls_pos = pos[0, :] + self.class_embedding.to(x.dtype) + spatial_pos = F.interpolate( + pos[1:, ].reshape(1, self.spatial_size, self.spatial_size, + C).permute(0, 3, 1, 2), + size=(H, W), + mode='bilinear') + spatial_pos = spatial_pos.reshape(1, C, H * W).permute(0, 2, 1) + pos = torch.cat([cls_pos.reshape(1, 1, C), spatial_pos], dim=1) + x = x + pos + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + + gradientcheckpoint = False + + features = [] + for i, blk in enumerate(self.transformer.resblocks): + if gradientcheckpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + if i in self.out_indices: + xp = x.permute(1, 0, 2)[:, + 1:, :].permute(0, 2, + 1).reshape(B, -1, H, W) + features.append(xp.contiguous()) + + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(features)): + features[i] = ops[i](features[i]) + + if self.get_embeddings: + x = x.permute(1, 0, 2) + x = self.ln_post(x) + x = x @ self.proj + + global_embedding = x[:, 0] + visual_embedding = x[:, 1:].reshape(B, H, W, + -1).permute(0, 3, 1, + 2) # B C H W + + features.append([global_embedding, visual_embedding]) + + return tuple(features) + + +class CLIPTextEncoder(nn.Module): + + def __init__(self, + context_length=77, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + embed_dim=1024, + out_dim=256, + pretrained=None, + **kwargs): + super().__init__() + + self.pretrained = pretrained + + self.context_length = context_length + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('transformer.'): + state_dict[k] = checkpoint[k] + + if k == 'positional_embedding' or k == 'text_projection' or k.startswith( + 'token_embedding') or k.startswith('ln_final'): + if k == 'positional_embedding' and checkpoint[k].size( + 0) > self.context_length: + checkpoint[k] = checkpoint[k][:self.context_length] + print('positional_embedding is tuncated from 77 to', + self.context_length) + state_dict[k] = checkpoint[k] + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in text encoder') + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text): + x = self.token_embedding(text) + x = x + self.positional_embedding + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_final(x) + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1), ...] @ self.text_projection + return x + + +class CLIPTextContextEncoder(nn.Module): + + def __init__(self, + context_length=22, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + embed_dim=1024, + out_dim=256, + pretrained=None, + **kwargs): + super().__init__() + + self.pretrained = pretrained + + self.context_length = context_length + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.embed_dim = embed_dim + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + + def init_weights(self, pretrained=None): + pretrained = pretrained or self.pretrained + if isinstance(pretrained, str): + checkpoint = torch.jit.load( + pretrained, map_location='cpu').float().state_dict() + + state_dict = {} + + for k in checkpoint.keys(): + if k.startswith('transformer.'): + state_dict[k] = checkpoint[k] + + if k == 'positional_embedding' or k == 'text_projection' or k.startswith( + 'token_embedding') or k.startswith('ln_final'): + if k == 'positional_embedding' and checkpoint[k].size( + 0) > self.context_length: + checkpoint[k] = checkpoint[k][:self.context_length] + print('positional_embedding is tuncated from 77 to', + self.context_length) + state_dict[k] = checkpoint[k] + + u, w = self.load_state_dict(state_dict, False) + print(u, w, 'are misaligned params in text encoder') + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text, context=None): + x_text = self.token_embedding(text) # n_clas, n_text, C + K, N1, C = x_text.shape # 150类 * 5??? * 512 + B, N2, C = context.shape # 1 * 8 * 512 + + eos_indx = text.argmax(dim=-1) + N2 + eos_indx = eos_indx.reshape(1, K).expand(B, K).reshape(-1) + + x_text = x_text.reshape(1, K, N1, C).expand(B, K, N1, C) + context = context.reshape(B, 1, N2, C).expand(B, K, N2, C) + + x = torch.cat([x_text[:, :, 0:1], context, x_text[:, :, 1:]], + dim=2).reshape(B * K, N1 + N2, C) + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + x = x[torch.arange(x.shape[0]), eos_indx] @ self.text_projection + x = x.reshape(B, K, self.embed_dim) + return x + + +class ContextDecoder(nn.Module): + + def __init__(self, + transformer_width=256, + transformer_heads=4, + transformer_layers=6, + visual_dim=1024, + dropout=0.1, + **kwargs): + super().__init__() + + self.memory_proj = nn.Sequential( + nn.LayerNorm(visual_dim), + nn.Linear(visual_dim, transformer_width), + nn.LayerNorm(transformer_width), + ) + + self.text_proj = nn.Sequential( + nn.LayerNorm(visual_dim), + nn.Linear(visual_dim, transformer_width), + ) + + self.decoder = nn.ModuleList([ + TransformerDecoderLayer(transformer_width, transformer_heads, + dropout) for _ in range(transformer_layers) + ]) + + self.out_proj = nn.Sequential( + nn.LayerNorm(transformer_width), + nn.Linear(transformer_width, visual_dim)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, text, visual): + B, N, C = visual.shape + visual = self.memory_proj(visual) + x = self.text_proj(text) + + for layer in self.decoder: + x = layer(x, visual) + + return self.out_proj(x) diff --git a/modelscope/models/cv/shop_segmentation/neck_fpn.py b/modelscope/models/cv/shop_segmentation/neck_fpn.py new file mode 100644 index 00000000..aa4d7159 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/neck_fpn.py @@ -0,0 +1,215 @@ +# Base modules are adapted from https://github.com/open-mmlab/mmcv/, +# originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +# https://github.com/open-mmlab/mmsegmentation/, +# originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +# and adapted from https://github.com/raoyongming/DenseCLIP/, +# originally MIT License, Copyright (c) 2022 Rao, Yongming. + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from timm.models.layers import drop, drop_path, trunc_normal_ + +from .common import resize + + +class FPN(nn.Module): + """Feature Pyramid Network. + + This neck is the implementation of `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs + on the original feature from the backbone. If True, + it is equivalent to `add_extra_convs='on_input'`. If False, it is + equivalent to set `add_extra_convs='on_output'`. Default to True. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: dict(mode='nearest'). + init_cfg (dict or list[dict], optional): Initialization config dict. + + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest')): + super(FPN, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # For compatibility with previous release + # TODO: deprecate `extra_convs_on_inputs` + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + self.apply(self._init_weights) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) diff --git a/modelscope/models/cv/shop_segmentation/shop_seg_base.py b/modelscope/models/cv/shop_segmentation/shop_seg_base.py new file mode 100644 index 00000000..34686370 --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/shop_seg_base.py @@ -0,0 +1,155 @@ +# Base modules are adapted from https://github.com/open-mmlab/mmcv/, +# originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +# https://github.com/open-mmlab/mmsegmentation/, +# originally Apache 2.0 License, Copyright (c) 2020-2021 OpenMMLab, +# and adapted from https://github.com/raoyongming/DenseCLIP/, +# originally MIT License, Copyright (c) 2022 Rao, Yongming. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .head_fpn import FPNHead +from .models import (CLIPTextContextEncoder, CLIPVisionTransformer, + ContextDecoder) +from .neck_fpn import FPN +from .utils import SimpleTokenizer, tokenize + + +class SHOPSEG(nn.Module): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__(self, + model_dir, + context_length=22, + context_feature='attention', + score_concat_index=2, + tau=0.07, + token_embed_dim=512, + text_dim=512, + **args): + super(SHOPSEG, self).__init__() + + self.model_dir = model_dir + self.tokenizer = SimpleTokenizer(model_dir + + '/bpe_simple_vocab_16e6.txt.gz') + + backbone = CLIPVisionTransformer( + input_resolution=1024, + patch_size=16, + width=768, + layers=12, + output_dim=512, + drop_path_rate=0.1, + pretrained=False, + get_embeddings=True) + + text_encoder = CLIPTextContextEncoder( + context_length=30, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + embed_dim=512, + pretrained=False) + + context_decoder = ContextDecoder( + transformer_width=256, + transformer_heads=4, + transformer_layers=3, + visual_dim=512, + dropout=0.1) + neck = FPN( + in_channels=[768, 768, 768 + 2, 768], out_channels=256, num_outs=4) + head_fpd = FPNHead(channels=256, num_classes=2) + + self.backbone = backbone + self.text_encoder = text_encoder + self.context_decoder = context_decoder + self.context_length = context_length + self.score_concat_index = score_concat_index + + self.context_feature = context_feature + self.tau = tau + context_length = self.text_encoder.context_length - self.context_length + self.contexts = nn.Parameter( + torch.randn(1, context_length, token_embed_dim)) + nn.init.trunc_normal_(self.contexts) + self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4) + + self.neck = neck + self.head_fpn = head_fpd + + self.tau = 0.07 + + def encode_text(self, text, context_length): + output = tokenize(self.tokenizer, text, context_length, True) + return output + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + return x + + def after_extract_feat(self, x, name_list): + x_orig = list(x[0:4]) + global_feat, visual_embeddings = x[4] + B, C, H, W = visual_embeddings.shape + if self.context_feature == 'attention': + x1 = global_feat.reshape(B, C, 1) + x2 = visual_embeddings.reshape(B, C, H * W) + visual_context = torch.cat([x1, x2], dim=2).permute(0, 2, 1) + texts = torch.cat([ + self.encode_text(c, context_length=self.context_length) + for c in name_list + ]) + x1 = texts.to(global_feat.device) + x1 = self.text_encoder(x1, self.contexts) + text_embeddings = x1.expand(B, -1, -1) + # update text_embeddings by visual_context! + # (B, 1, C) + text_diff = self.context_decoder(text_embeddings, visual_context) + # (B, K, C) + text_embeddings = text_embeddings + self.gamma * text_diff + + # compute score map and concat + B, K, C = text_embeddings.shape + visual_embeddings = F.normalize(visual_embeddings, dim=1, p=2) + text = F.normalize(text_embeddings, dim=2, p=2) + score_map_list = [] + bsz = B + for i in range(bsz): + ind = 2 * i + sub_text = torch.cat( + [text[i:i + 1, ind:ind + 1], text[i:i + 1, ind + 1:ind + 2]], + dim=1) # 1 * 2 * h * w + + sub_score_map = torch.einsum('bchw,bkc->bkhw', + visual_embeddings[i:i + 1], + sub_text) # 1 * 2 * h * w + score_map_list.append(sub_score_map) + score_map = torch.cat(score_map_list, dim=0) # b * 2 * h * w + x_orig[self.score_concat_index] = torch.cat( + [x_orig[self.score_concat_index], score_map], dim=1) + return x_orig, score_map + + def forward(self, img, text_list=None): + if text_list is None: + bsz = img.size()[0] + text_list = ['foregeound'] * bsz + x = self.extract_feat(img) + _x_orig = [x[i] for i in range(4)] + name_list = [] + for name in text_list: + name_list.append('others') + name_list.append(name[0:20]) + x_orig, score_map = self.after_extract_feat(x, name_list) + x_orig = list(self.neck(x_orig)) + _x_orig = x_orig + pred = self.head_fpn(_x_orig) + return pred diff --git a/modelscope/models/cv/shop_segmentation/shop_seg_model.py b/modelscope/models/cv/shop_segmentation/shop_seg_model.py new file mode 100644 index 00000000..ac0d67fa --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/shop_seg_model.py @@ -0,0 +1,117 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.shop_segmentation import SHOPSEG +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ShopSegmentation'] + + +@MODELS.register_module( + Tasks.shop_segmentation, module_name=Models.shop_segmentation) +class ShopSegmentation(TorchModel): + """ shop segmentation model. + """ + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = SHOPSEG(model_dir=model_dir) + pretrained_params = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location='cpu') + self.model.load_state_dict(pretrained_params) + self.model.eval() + if device_id >= 0 and torch.cuda.is_available(): + self.model.to('cuda:{}'.format(device_id)) + logger.info('Use GPU: {}'.format(device_id)) + else: + device_id = -1 + logger.info('Use CPU for inference') + self.device_id = device_id + + def preprocess(self, img, size=1024): + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + h, w, c = img.shape + max_hw = max(h, w) + ratio = 1.0 * size / max_hw + crop_h, crop_w = int(ratio * h), int(ratio * w) + pil_img = Image.fromarray(img) + pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR) + np_img = np.array(pil_img, dtype=np.float32) / 255. + + for j in range(3): + np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j] + + img_pad = np.zeros((size, size, 3), dtype=np.float32) + img_pad[:crop_h, :crop_w] = np_img + + img_pad = torch.from_numpy(img_pad).permute(2, 0, + 1).unsqueeze(0).float() + return img_pad, h, w, crop_h, crop_w + + def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w): + output = np.clip(tensors * 255., a_min=0, a_max=255.) + crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8) + + pil_output = Image.fromarray(crop_output) + pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR) + np_output = np.array(pil_output, dtype=np.uint8) + + np_output[np_output < 128] = 0 + np_output[np_output >= 128] = 255 + np_output = np.uint8(np_output) + return np_output + + def forward(self, image): + """ + image should be numpy array, dtype=np.uint8, shape: height*width*3 + """ + image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess( + image, size=1024) + pred = self.inference(image_tensor) + msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=1024) + + outputs = {OutputKeys.MASKS: msk} + return outputs + + def inference(self, image): + """ + image should be tensor, 1 * 3 * 1024 * 1024 + """ + with torch.no_grad(): + if self.device_id == -1: + output = self.model(image) + else: + device = torch.device('cuda', self.device_id) + output = self.model(image.to(device)) + output = F.interpolate(output, size=(1024, 1024), mode='bilinear') + output = F.softmax(output, dim=1) + output = torch.argmax(output, dim=1) + output = output[0] + if self.device_id == -1: + pred = output.data.numpy() + else: + pred = output.data.cpu().numpy() + + del output + return pred diff --git a/modelscope/models/cv/shop_segmentation/utils.py b/modelscope/models/cv/shop_segmentation/utils.py new file mode 100644 index 00000000..4035b0ef --- /dev/null +++ b/modelscope/models/cv/shop_segmentation/utils.py @@ -0,0 +1,198 @@ +# CLIP Tokenizer +# Adapted from https://github.com/openai/CLIP. +# Originally MIT License, Copyright (c) 2021 OpenAI. + +import gzip +import html +import os +from functools import lru_cache +from typing import Any, List, Union + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + error_list = [] + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception as err: + error_list.append(err) + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +def tokenize(tokenizer, + texts, + context_length: int = 77, + truncate: bool = False) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = tokenizer.encoder['<|startoftext|>'] + eot_token = tokenizer.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] + for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f'Input {texts[i]} is too long for context length {context_length}' + ) + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/modelscope/models/cv/skin_retouching/__init__.py b/modelscope/models/cv/skin_retouching/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/skin_retouching/detection_model/__init__.py b/modelscope/models/cv/skin_retouching/detection_model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/skin_retouching/detection_model/detection_module.py b/modelscope/models/cv/skin_retouching/detection_model/detection_module.py new file mode 100644 index 00000000..5db9c44c --- /dev/null +++ b/modelscope/models/cv/skin_retouching/detection_model/detection_module.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn + + +class ConvBNActiv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + bn=True, + sample='none-3', + activ='relu', + bias=False): + super(ConvBNActiv, self).__init__() + + if sample == 'down-7': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + elif sample == 'down-5': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + elif sample == 'down-3': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + + if bn: + self.bn = nn.BatchNorm2d(out_channels) + + if activ == 'relu': + self.activation = nn.ReLU() + elif activ == 'leaky': + self.activation = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, images): + + outputs = self.conv(images) + if hasattr(self, 'bn'): + outputs = self.bn(outputs) + if hasattr(self, 'activation'): + outputs = self.activation(outputs) + + return outputs diff --git a/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py b/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py new file mode 100644 index 00000000..c0be1a52 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/detection_model/detection_unet_in.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..weights_init import weights_init +from .detection_module import ConvBNActiv + + +class DetectionUNet(nn.Module): + + def __init__(self, + n_channels, + n_classes, + up_sampling_node='nearest', + init_weights=True): + super(DetectionUNet, self).__init__() + + self.n_classes = n_classes + self.up_sampling_node = up_sampling_node + + self.ec_images_1 = ConvBNActiv( + n_channels, 64, bn=False, sample='down-3') + self.ec_images_2 = ConvBNActiv(64, 128, sample='down-3') + self.ec_images_3 = ConvBNActiv(128, 256, sample='down-3') + self.ec_images_4 = ConvBNActiv(256, 512, sample='down-3') + self.ec_images_5 = ConvBNActiv(512, 512, sample='down-3') + self.ec_images_6 = ConvBNActiv(512, 512, sample='down-3') + + self.dc_images_6 = ConvBNActiv(512 + 512, 512, activ='leaky') + self.dc_images_5 = ConvBNActiv(512 + 512, 512, activ='leaky') + self.dc_images_4 = ConvBNActiv(512 + 256, 256, activ='leaky') + self.dc_images_3 = ConvBNActiv(256 + 128, 128, activ='leaky') + self.dc_images_2 = ConvBNActiv(128 + 64, 64, activ='leaky') + self.dc_images_1 = nn.Conv2d(64 + n_channels, n_classes, kernel_size=1) + + if init_weights: + self.apply(weights_init()) + + def forward(self, input_images): + + ec_images = {} + + ec_images['ec_images_0'] = input_images + ec_images['ec_images_1'] = self.ec_images_1(input_images) + ec_images['ec_images_2'] = self.ec_images_2(ec_images['ec_images_1']) + ec_images['ec_images_3'] = self.ec_images_3(ec_images['ec_images_2']) + ec_images['ec_images_4'] = self.ec_images_4(ec_images['ec_images_3']) + ec_images['ec_images_5'] = self.ec_images_5(ec_images['ec_images_4']) + ec_images['ec_images_6'] = self.ec_images_6(ec_images['ec_images_5']) + # -------------- + # images decoder + # -------------- + logits = ec_images['ec_images_6'] + + for _ in range(6, 0, -1): + + ec_images_skip = 'ec_images_{:d}'.format(_ - 1) + dc_conv = 'dc_images_{:d}'.format(_) + + logits = F.interpolate( + logits, scale_factor=2, mode=self.up_sampling_node) + logits = torch.cat((logits, ec_images[ec_images_skip]), dim=1) + + logits = getattr(self, dc_conv)(logits) + + return logits diff --git a/modelscope/models/cv/skin_retouching/inpainting_model/__init__.py b/modelscope/models/cv/skin_retouching/inpainting_model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py b/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py new file mode 100644 index 00000000..8b3eb2fc --- /dev/null +++ b/modelscope/models/cv/skin_retouching/inpainting_model/gconv.py @@ -0,0 +1,208 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn + + +class GatedConvBNActiv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + bn=True, + sample='none-3', + activ='relu', + bias=False): + super(GatedConvBNActiv, self).__init__() + + if sample == 'down-7': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + elif sample == 'down-5': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + elif sample == 'down-3': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + + if bn: + self.bn = nn.BatchNorm2d(out_channels) + + if activ == 'relu': + self.activation = nn.ReLU() + elif activ == 'leaky': + self.activation = nn.LeakyReLU(negative_slope=0.2) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + + images = self.conv(x) + gates = self.sigmoid(self.gate(x)) + + if hasattr(self, 'bn'): + images = self.bn(images) + if hasattr(self, 'activation'): + images = self.activation(images) + + images = images * gates + + return images + + +class GatedConvBNActiv2(nn.Module): + + def __init__(self, + in_channels, + out_channels, + bn=True, + sample='none-3', + activ='relu', + bias=False): + super(GatedConvBNActiv2, self).__init__() + + if sample == 'down-7': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=bias) + elif sample == 'down-5': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=2, + padding=2, + bias=bias) + elif sample == 'down-3': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=bias) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + self.gate = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + + self.conv_skip = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias) + + if bn: + self.bn = nn.BatchNorm2d(out_channels) + + if activ == 'relu': + self.activation = nn.ReLU() + elif activ == 'leaky': + self.activation = nn.LeakyReLU(negative_slope=0.2) + + self.sigmoid = nn.Sigmoid() + + def forward(self, f_up, f_skip, mask): + x = torch.cat((f_up, f_skip, mask), dim=1) + images = self.conv(x) + images_skip = self.conv_skip(f_skip) + gates = self.sigmoid(self.gate(x)) + + if hasattr(self, 'bn'): + images = self.bn(images) + images_skip = self.bn(images_skip) + if hasattr(self, 'activation'): + images = self.activation(images) + images_skip = self.activation(images_skip) + + images = images * gates + images_skip * (1 - gates) + + return images diff --git a/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py b/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py new file mode 100644 index 00000000..dd220dd6 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/inpainting_model/inpainting_unet.py @@ -0,0 +1,89 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.cv.skin_retouching.inpainting_model.gconv import \ + GatedConvBNActiv +from ..weights_init import weights_init + + +class RetouchingNet(nn.Module): + + def __init__(self, + in_channels=3, + out_channels=3, + up_sampling_node='nearest', + init_weights=True): + super(RetouchingNet, self).__init__() + + self.freeze_ec_bn = False + self.up_sampling_node = up_sampling_node + + self.ec_images_1 = GatedConvBNActiv( + in_channels, 64, bn=False, sample='down-3') + self.ec_images_2 = GatedConvBNActiv(64, 128, sample='down-3') + self.ec_images_3 = GatedConvBNActiv(128, 256, sample='down-3') + self.ec_images_4 = GatedConvBNActiv(256, 512, sample='down-3') + self.ec_images_5 = GatedConvBNActiv(512, 512, sample='down-3') + self.ec_images_6 = GatedConvBNActiv(512, 512, sample='down-3') + + self.dc_images_6 = GatedConvBNActiv(512 + 512, 512, activ='leaky') + self.dc_images_5 = GatedConvBNActiv(512 + 512, 512, activ='leaky') + self.dc_images_4 = GatedConvBNActiv(512 + 256, 256, activ='leaky') + self.dc_images_3 = GatedConvBNActiv(256 + 128, 128, activ='leaky') + self.dc_images_2 = GatedConvBNActiv(128 + 64, 64, activ='leaky') + self.dc_images_1 = GatedConvBNActiv( + 64 + in_channels, + out_channels, + bn=False, + sample='none-3', + activ=None, + bias=True) + + self.tanh = nn.Tanh() + + if init_weights: + self.apply(weights_init()) + + def forward(self, input_images, input_masks): + + ec_images = {} + + ec_images['ec_images_0'] = torch.cat((input_images, input_masks), + dim=1) + ec_images['ec_images_1'] = self.ec_images_1(ec_images['ec_images_0']) + ec_images['ec_images_2'] = self.ec_images_2(ec_images['ec_images_1']) + ec_images['ec_images_3'] = self.ec_images_3(ec_images['ec_images_2']) + + ec_images['ec_images_4'] = self.ec_images_4(ec_images['ec_images_3']) + ec_images['ec_images_5'] = self.ec_images_5(ec_images['ec_images_4']) + ec_images['ec_images_6'] = self.ec_images_6(ec_images['ec_images_5']) + + # -------------- + # images decoder + # -------------- + dc_images = ec_images['ec_images_6'] + for _ in range(6, 0, -1): + ec_images_skip = 'ec_images_{:d}'.format(_ - 1) + dc_conv = 'dc_images_{:d}'.format(_) + + dc_images = F.interpolate( + dc_images, scale_factor=2, mode=self.up_sampling_node) + dc_images = torch.cat((dc_images, ec_images[ec_images_skip]), + dim=1) + + dc_images = getattr(self, dc_conv)(dc_images) + + outputs = self.tanh(dc_images) + + return outputs + + def train(self, mode=True): + + super().train(mode) + + if self.freeze_ec_bn: + for name, module in self.named_modules(): + if isinstance(module, nn.BatchNorm2d): + module.eval() diff --git a/modelscope/models/cv/skin_retouching/retinaface/__init__.py b/modelscope/models/cv/skin_retouching/retinaface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/skin_retouching/retinaface/box_utils.py b/modelscope/models/cv/skin_retouching/retinaface/box_utils.py new file mode 100644 index 00000000..a4aeffd1 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/box_utils.py @@ -0,0 +1,272 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +from typing import List, Tuple, Union + +import numpy as np +import torch + + +def point_form(boxes: torch.Tensor) -> torch.Tensor: + """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form \ + ground truth data. + + Args: + boxes: center-size default boxes from priorbox layers. + Return: + boxes: Converted x_min, y_min, x_max, y_max form of boxes. + """ + return torch.cat( + (boxes[:, :2] - boxes[:, 2:] / 2, boxes[:, :2] + boxes[:, 2:] / 2), + dim=1) + + +def center_size(boxes: torch.Tensor) -> torch.Tensor: + """Convert prior_boxes to (cx, cy, w, h) representation for comparison to center-size form ground truth data. + Args: + boxes: point_form boxes + Return: + boxes: Converted x_min, y_min, x_max, y_max form of boxes. + """ + return torch.cat( + ((boxes[:, 2:] + boxes[:, :2]) / 2, boxes[:, 2:] - boxes[:, :2]), + dim=1) + + +def intersect(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: + """ We resize both tensors to [A,B,2] without new malloc: + [A, 2] -> [A, 1, 2] -> [A, B, 2] + [B, 2] -> [1, B, 2] -> [A, B, 2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: bounding boxes, Shape: [A, 4]. + box_b: bounding boxes, Shape: [B, 4]. + Return: + intersection area, Shape: [A, B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), + box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), + box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap is simply the intersection over + union of two boxes. Here we operate on ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: Ground truth bounding boxes, Shape: [num_objects,4] + box_b: Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]) + area_a = area_a.unsqueeze(1).expand_as(inter) # [A,B] + area_b = (box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1]) + area_b = area_b.unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union + + +def matrix_iof(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """ + return iof of a and b, numpy version for data augmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match( + threshold: float, + box_gt: torch.Tensor, + priors: torch.Tensor, + variances: List[float], + labels_gt: torch.Tensor, + landmarks_gt: torch.Tensor, + box_t: torch.Tensor, + label_t: torch.Tensor, + landmarks_t: torch.Tensor, + batch_id: int, +) -> None: + """Match each prior box with the ground truth box of the highest jaccard overlap, encode the bounding + boxes, then return the matched indices corresponding to both confidence and location preds. + + Args: + threshold: The overlap threshold used when matching boxes. + box_gt: Ground truth boxes, Shape: [num_obj, 4]. + priors: Prior boxes from priorbox layers, Shape: [n_priors, 4]. + variances: Variances corresponding to each prior coord, Shape: [num_priors, 4]. + labels_gt: All the class labels for the image, Shape: [num_obj, 2]. + landmarks_gt: Ground truth landms, Shape [num_obj, 10]. + box_t: Tensor to be filled w/ endcoded location targets. + label_t: Tensor to be filled w/ matched indices for labels predictions. + landmarks_t: Tensor to be filled w/ endcoded landmarks targets. + batch_id: current batch index + Return: + The matched indices corresponding to 1)location 2)confidence 3)landmarks preds. + """ + # Compute iou between gt and priors + overlaps = jaccard(box_gt, point_form(priors)) + # (Bipartite Matching) + # [1, num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + box_t[batch_id] = 0 + label_t[batch_id] = 0 + return + + # [1, num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, + 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): + best_truth_idx[best_prior_idx[j]] = j + + matches = box_gt[best_truth_idx] # Shape: [num_priors, 4] + labels = labels_gt[best_truth_idx] # Shape: [num_priors] + # label as background + labels[best_truth_overlap < threshold] = 0 + loc = encode(matches, priors, variances) + + matches_landm = landmarks_gt[best_truth_idx] + landmarks_gt = encode_landm(matches_landm, priors, variances) + box_t[batch_id] = loc # [num_priors, 4] encoded offsets to learn + label_t[batch_id] = labels # [num_priors] top class label for each prior + landmarks_t[batch_id] = landmarks_gt + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= variances[0] * priors[:, 2:] + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def encode_landm( + matched: torch.Tensor, priors: torch.Tensor, + variances: Union[List[float], Tuple[float, float]]) -> torch.Tensor: + """Encode the variances from the priorbox layers into the ground truth boxes we have matched + (based on jaccard overlap) with the prior boxes. + Args: + matched: Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: Variances of priorboxes + Return: + encoded landmarks, Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), + 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), + 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), + 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), + 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy = g_cxcy // variances[0] * priors[:, :, 2:] + # return target for smooth_l1_loss + return g_cxcy.reshape(g_cxcy.size(0), -1) + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc: torch.Tensor, priors: torch.Tensor, + variances: Union[List[float], Tuple[float, float]]) -> torch.Tensor: + """Decode locations from predictions using priors to undo the encoding we did for offset regression at train time. + Args: + loc: location predictions for loc layers, + Shape: [num_priors, 4] + priors: Prior boxes in center-offset form. + Shape: [num_priors, 4]. + variances: Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat( + ( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]), + ), + 1, + ) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm( + pre: torch.Tensor, priors: torch.Tensor, + variances: Union[List[float], Tuple[float, float]]) -> torch.Tensor: + """Decode landmarks from predictions using priors to undo the encoding we did for offset regression at train time. + Args: + pre: landmark predictions for loc layers, + Shape: [num_priors, 10] + priors: Prior boxes in center-offset form. + Shape: [num_priors, 4]. + variances: Variances of priorboxes + Return: + decoded landmark predictions + """ + return torch.cat( + ( + priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ), + dim=1, + ) + + +def log_sum_exp(x: torch.Tensor) -> torch.Tensor: + """Utility function for computing log_sum_exp while determining This will be used to determine unaveraged + confidence loss across all examples in a batch. + Args: + x: conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max diff --git a/modelscope/models/cv/skin_retouching/retinaface/net.py b/modelscope/models/cv/skin_retouching/retinaface/net.py new file mode 100644 index 00000000..e9b0297b --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/net.py @@ -0,0 +1,124 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +from typing import Dict, List + +import torch +import torch.nn.functional as F +from torch import nn + + +def conv_bn(inp: int, + oup: int, + stride: int = 1, + leaky: float = 0) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +def conv_bn_no_relu(inp: int, oup: int, stride: int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp: int, + oup: int, + stride: int, + leaky: float = 0) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +def conv_dw(inp: int, + oup: int, + stride: int, + leaky: float = 0.1) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel: int, out_channel: int) -> None: + super().__init__() + if out_channel % 4 != 0: + raise ValueError( + f'Expect out channel % 4 == 0, but we got {out_channel % 4}') + + leaky: float = 0 + if out_channel <= 64: + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn( + in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn( + out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + conv3X3 = self.conv3X3(x) + + conv5X5_1 = self.conv5X5_1(x) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + + return F.relu(out) + + +class FPN(nn.Module): + + def __init__(self, in_channels_list: List[int], out_channels: int) -> None: + super().__init__() + leaky = 0.0 + if out_channels <= 64: + leaky = 0.1 + + self.output1 = conv_bn1X1( + in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1( + in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1( + in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, x: Dict[str, torch.Tensor]) -> List[torch.Tensor]: + y = list(x.values()) + + output1 = self.output1(y[0]) + output2 = self.output2(y[1]) + output3 = self.output3(y[2]) + + up3 = F.interpolate( + output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate( + output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + return [output1, output2, output3] diff --git a/modelscope/models/cv/skin_retouching/retinaface/network.py b/modelscope/models/cv/skin_retouching/retinaface/network.py new file mode 100644 index 00000000..3b197ca9 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/network.py @@ -0,0 +1,146 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +from typing import Dict, Tuple + +import torch +from torch import nn +from torchvision import models +from torchvision.models import _utils + +from .net import FPN, SSH + + +class ClassHead(nn.Module): + + def __init__(self, in_channels: int = 512, num_anchors: int = 3) -> None: + super().__init__() + self.conv1x1 = nn.Conv2d( + in_channels, + num_anchors * 2, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, in_channels: int = 512, num_anchors: int = 3): + super().__init__() + self.conv1x1 = nn.Conv2d( + in_channels, + num_anchors * 4, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, in_channels: int = 512, num_anchors: int = 3): + super().__init__() + self.conv1x1 = nn.Conv2d( + in_channels, + num_anchors * 10, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + return out.view(out.shape[0], -1, 10) + + +class RetinaFace(nn.Module): + + def __init__(self, name: str, pretrained: bool, in_channels: int, + return_layers: Dict[str, int], out_channels: int) -> None: + super().__init__() + + if name == 'Resnet50': + backbone = models.resnet50(pretrained=pretrained) + else: + raise NotImplementedError( + f'Only Resnet50 backbone is supported but got {name}') + + self.body = _utils.IntermediateLayerGetter(backbone, return_layers) + in_channels_stage2 = in_channels + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head( + fpn_num=3, in_channels=out_channels) + self.BboxHead = self._make_bbox_head( + fpn_num=3, in_channels=out_channels) + self.LandmarkHead = self._make_landmark_head( + fpn_num=3, in_channels=out_channels) + + @staticmethod + def _make_class_head(fpn_num: int = 3, + in_channels: int = 64, + anchor_num: int = 2) -> nn.ModuleList: + classhead = nn.ModuleList() + for _ in range(fpn_num): + classhead.append(ClassHead(in_channels, anchor_num)) + return classhead + + @staticmethod + def _make_bbox_head(fpn_num: int = 3, + in_channels: int = 64, + anchor_num: int = 2) -> nn.ModuleList: + bboxhead = nn.ModuleList() + for _ in range(fpn_num): + bboxhead.append(BboxHead(in_channels, anchor_num)) + return bboxhead + + @staticmethod + def _make_landmark_head(fpn_num: int = 3, + in_channels: int = 64, + anchor_num: int = 2) -> nn.ModuleList: + landmarkhead = nn.ModuleList() + for _ in range(fpn_num): + landmarkhead.append(LandmarkHead(in_channels, anchor_num)) + return landmarkhead + + def forward( + self, inputs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat( + [self.BboxHead[i](feature) for i, feature in enumerate(features)], + dim=1) + classifications = torch.cat( + [self.ClassHead[i](feature) for i, feature in enumerate(features)], + dim=1) + ldm_regressions = [ + self.LandmarkHead[i](feature) for i, feature in enumerate(features) + ] + ldm_regressions = torch.cat(ldm_regressions, dim=1) + + return bbox_regressions, classifications, ldm_regressions diff --git a/modelscope/models/cv/skin_retouching/retinaface/predict_single.py b/modelscope/models/cv/skin_retouching/retinaface/predict_single.py new file mode 100644 index 00000000..659a1134 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/predict_single.py @@ -0,0 +1,152 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +""" +There is a lot of post processing of the predictions. +""" +from typing import Dict, List, Union + +import albumentations as A +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.ops import nms + +from ..utils import pad_to_size, unpad_from_size +from .box_utils import decode, decode_landm +from .network import RetinaFace +from .prior_box import priorbox +from .utils import tensor_from_rgb_image + + +class Model: + + def __init__(self, max_size: int = 960, device: str = 'cpu') -> None: + self.model = RetinaFace( + name='Resnet50', + pretrained=False, + return_layers={ + 'layer2': 1, + 'layer3': 2, + 'layer4': 3 + }, + in_channels=256, + out_channels=256, + ).to(device) + self.device = device + self.transform = A.Compose( + [A.LongestMaxSize(max_size=max_size, p=1), + A.Normalize(p=1)]) + self.max_size = max_size + self.prior_box = priorbox( + min_sizes=[[16, 32], [64, 128], [256, 512]], + steps=[8, 16, 32], + clip=False, + image_size=(self.max_size, self.max_size), + ).to(device) + self.variance = [0.1, 0.2] + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + self.model.load_state_dict(state_dict) + + def eval(self): + self.model.eval() + + def predict_jsons( + self, + image: np.array, + confidence_threshold: float = 0.7, + nms_threshold: float = 0.4) -> List[Dict[str, Union[List, float]]]: + with torch.no_grad(): + original_height, original_width = image.shape[:2] + + scale_landmarks = torch.from_numpy( + np.tile([self.max_size, self.max_size], + 5)).to(self.device).float() + scale_bboxes = torch.from_numpy( + np.tile([self.max_size, self.max_size], + 2)).to(self.device).float() + + transformed_image = self.transform(image=image)['image'] + + paded = pad_to_size( + target_size=(self.max_size, self.max_size), + image=transformed_image) + + pads = paded['pads'] + + torched_image = tensor_from_rgb_image(paded['image']).to( + self.device) + + loc, conf, land = self.model(torched_image.unsqueeze(0)) + + conf = F.softmax(conf, dim=-1) + + annotations: List[Dict[str, Union[List, float]]] = [] + + boxes = decode(loc.data[0], self.prior_box, self.variance) + + boxes *= scale_bboxes + scores = conf[0][:, 1] + + landmarks = decode_landm(land.data[0], self.prior_box, + self.variance) + landmarks *= scale_landmarks + + # ignore low scores + valid_index = scores > confidence_threshold + boxes = boxes[valid_index] + landmarks = landmarks[valid_index] + scores = scores[valid_index] + + # Sort from high to low + order = scores.argsort(descending=True) + boxes = boxes[order] + landmarks = landmarks[order] + scores = scores[order] + + # do NMS + keep = nms(boxes, scores, nms_threshold) + boxes = boxes[keep, :].int() + + if boxes.shape[0] == 0: + return [{'bbox': [], 'score': -1, 'landmarks': []}] + + landmarks = landmarks[keep] + + scores = scores[keep].cpu().numpy().astype(np.float64) + boxes = boxes.cpu().numpy() + landmarks = landmarks.cpu().numpy() + landmarks = landmarks.reshape([-1, 2]) + + unpadded = unpad_from_size(pads, bboxes=boxes, keypoints=landmarks) + + resize_coeff = max(original_height, original_width) / self.max_size + + boxes = (unpadded['bboxes'] * resize_coeff).astype(int) + landmarks = (unpadded['keypoints'].reshape(-1, 10) + * resize_coeff).astype(int) + + for box_id, bbox in enumerate(boxes): + x_min, y_min, x_max, y_max = bbox + + x_min = np.clip(x_min, 0, original_width - 1) + x_max = np.clip(x_max, x_min + 1, original_width - 1) + + if x_min >= x_max: + continue + + y_min = np.clip(y_min, 0, original_height - 1) + y_max = np.clip(y_max, y_min + 1, original_height - 1) + + if y_min >= y_max: + continue + + annotations += [{ + 'bbox': + bbox.tolist(), + 'score': + scores[box_id], + 'landmarks': + landmarks[box_id].reshape(-1, 2).tolist(), + }] + + return annotations diff --git a/modelscope/models/cv/skin_retouching/retinaface/prior_box.py b/modelscope/models/cv/skin_retouching/retinaface/prior_box.py new file mode 100644 index 00000000..863a676c --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/prior_box.py @@ -0,0 +1,28 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +from itertools import product +from math import ceil + +import torch + + +def priorbox(min_sizes, steps, clip, image_size): + feature_maps = [[ceil(image_size[0] / step), + ceil(image_size[1] / step)] for step in steps] + + anchors = [] + for k, f in enumerate(feature_maps): + t_min_sizes = min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in t_min_sizes: + s_kx = min_size / image_size[1] + s_ky = min_size / image_size[0] + dense_cx = [x * steps[k] / image_size[1] for x in [j + 0.5]] + dense_cy = [y * steps[k] / image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if clip: + output.clamp_(max=1, min=0) + return output diff --git a/modelscope/models/cv/skin_retouching/retinaface/utils.py b/modelscope/models/cv/skin_retouching/retinaface/utils.py new file mode 100644 index 00000000..c6b97484 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/retinaface/utils.py @@ -0,0 +1,70 @@ +# Implementation in this file is modifed from source code avaiable via https://github.com/ternaus/retinaface +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import cv2 +import numpy as np +import torch + + +def load_checkpoint(file_path: Union[Path, str], + rename_in_layers: Optional[dict] = None) -> Dict[str, Any]: + """Loads PyTorch checkpoint, optionally renaming layer names. + Args: + file_path: path to the torch checkpoint. + rename_in_layers: {from_name: to_name} + ex: {"model.0.": "", + "model.": ""} + Returns: + """ + checkpoint = torch.load( + file_path, map_location=lambda storage, loc: storage) + + if rename_in_layers is not None: + model_state_dict = checkpoint['state_dict'] + + result = {} + for key, value in model_state_dict.items(): + for key_r, value_r in rename_in_layers.items(): + key = re.sub(key_r, value_r, key) + + result[key] = value + + checkpoint['state_dict'] = result + + return checkpoint + + +def tensor_from_rgb_image(image: np.ndarray) -> torch.Tensor: + image = np.transpose(image, (2, 0, 1)) + return torch.from_numpy(image) + + +def vis_annotations(image: np.ndarray, + annotations: List[Dict[str, Any]]) -> np.ndarray: + vis_image = image.copy() + + for annotation in annotations: + landmarks = annotation['landmarks'] + + colors = [(255, 0, 0), (128, 255, 0), (255, 178, 102), (102, 128, 255), + (0, 255, 255)] + + for landmark_id, (x, y) in enumerate(landmarks): + vis_image = cv2.circle( + vis_image, (x, y), + radius=3, + color=colors[landmark_id], + thickness=3) + + x_min, y_min, x_max, y_max = annotation['bbox'] + + x_min = np.clip(x_min, 0, x_max - 1) + y_min = np.clip(y_min, 0, y_max - 1) + + vis_image = cv2.rectangle( + vis_image, (x_min, y_min), (x_max, y_max), + color=(0, 255, 0), + thickness=2) + return vis_image diff --git a/modelscope/models/cv/skin_retouching/unet_deploy.py b/modelscope/models/cv/skin_retouching/unet_deploy.py new file mode 100755 index 00000000..0ff75b85 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/unet_deploy.py @@ -0,0 +1,144 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .weights_init import weights_init + +warnings.filterwarnings(action='ignore') + + +class double_conv(nn.Module): + '''(conv => BN => ReLU) * 2''' + + def __init__(self, in_ch, out_ch): + super(double_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) + + def forward(self, x): + x = self.conv(x) + return x + + +class inconv(nn.Module): + + def __init__(self, in_ch, out_ch): + super(inconv, self).__init__() + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x): + x = self.conv(x) + return x + + +class down(nn.Module): + + def __init__(self, in_ch, out_ch): + super(down, self).__init__() + self.mpconv = nn.Sequential( + nn.MaxPool2d(2), double_conv(in_ch, out_ch)) + + def forward(self, x): + x = self.mpconv(x) + return x + + +class up(nn.Module): + + def __init__(self, in_ch, out_ch, bilinear=True): + super(up, self).__init__() + + if bilinear: + self.up = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) + + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x1, x2): + x1 = self.up(x1) + + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad( + x1, + (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) + + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + return x + + +class outconv(nn.Module): + + def __init__(self, in_ch, out_ch): + super(outconv, self).__init__() + self.conv = nn.Conv2d(in_ch, out_ch, 1) + + def forward(self, x): + x = self.conv(x) + return x + + +class UNet(nn.Module): + + def __init__(self, + n_channels, + n_classes, + deep_supervision=False, + init_weights=True): + super(UNet, self).__init__() + self.deep_supervision = deep_supervision + self.inc = inconv(n_channels, 64) + self.down1 = down(64, 128) + self.down2 = down(128, 256) + self.down3 = down(256, 512) + self.down4 = down(512, 512) + self.up1 = up(1024, 256) + self.up2 = up(512, 128) + self.up3 = up(256, 64) + self.up4 = up(128, 64) + self.outc = outconv(64, n_classes) + + self.dsoutc4 = outconv(256, n_classes) + self.dsoutc3 = outconv(128, n_classes) + self.dsoutc2 = outconv(64, n_classes) + self.dsoutc1 = outconv(64, n_classes) + + self.sigmoid = nn.Sigmoid() + + if init_weights: + self.apply(weights_init()) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x44 = self.up1(x5, x4) + x33 = self.up2(x44, x3) + x22 = self.up3(x33, x2) + x11 = self.up4(x22, x1) + x0 = self.outc(x11) + x0 = self.sigmoid(x0) + if self.deep_supervision: + x11 = F.interpolate( + self.dsoutc1(x11), x0.shape[2:], mode='bilinear') + x22 = F.interpolate( + self.dsoutc2(x22), x0.shape[2:], mode='bilinear') + x33 = F.interpolate( + self.dsoutc3(x33), x0.shape[2:], mode='bilinear') + x44 = F.interpolate( + self.dsoutc4(x44), x0.shape[2:], mode='bilinear') + + return x0, x11, x22, x33, x44 + else: + return x0 diff --git a/modelscope/models/cv/skin_retouching/utils.py b/modelscope/models/cv/skin_retouching/utils.py new file mode 100644 index 00000000..eb0da6b9 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/utils.py @@ -0,0 +1,328 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import time +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + 'gen_diffuse_mask', 'get_crop_bbox', 'get_roi_without_padding', + 'patch_aggregation_overlap', 'patch_partition_overlap', 'preprocess_roi', + 'resize_on_long_side', 'roi_to_tensor', 'smooth_border_mg', 'whiten_img' +] + + +def resize_on_long_side(img, long_side=800): + src_height = img.shape[0] + src_width = img.shape[1] + + if src_height > src_width: + scale = long_side * 1.0 / src_height + _img = cv2.resize( + img, (int(src_width * scale), long_side), + interpolation=cv2.INTER_LINEAR) + else: + scale = long_side * 1.0 / src_width + _img = cv2.resize( + img, (long_side, int(src_height * scale)), + interpolation=cv2.INTER_LINEAR) + + return _img, scale + + +def get_crop_bbox(detecting_results): + boxes = [] + for anno in detecting_results: + if anno['score'] == -1: + break + boxes.append({ + 'x1': anno['bbox'][0], + 'y1': anno['bbox'][1], + 'x2': anno['bbox'][2], + 'y2': anno['bbox'][3] + }) + face_count = len(boxes) + + suitable_bboxes = [] + for i in range(face_count): + face_bbox = boxes[i] + + face_bbox_width = abs(face_bbox['x2'] - face_bbox['x1']) + face_bbox_height = abs(face_bbox['y2'] - face_bbox['y1']) + + face_bbox_center = ((face_bbox['x1'] + face_bbox['x2']) / 2, + (face_bbox['y1'] + face_bbox['y2']) / 2) + + square_bbox_length = face_bbox_height if face_bbox_height > face_bbox_width else face_bbox_width + enlarge_ratio = 1.5 + square_bbox_length = int(enlarge_ratio * square_bbox_length) + + sideScale = 1 + + square_bbox = { + 'x1': + int(face_bbox_center[0] - sideScale * square_bbox_length / 2), + 'x2': + int(face_bbox_center[0] + sideScale * square_bbox_length / 2), + 'y1': + int(face_bbox_center[1] - sideScale * square_bbox_length / 2), + 'y2': int(face_bbox_center[1] + sideScale * square_bbox_length / 2) + } + + suitable_bboxes.append(square_bbox) + + return suitable_bboxes + + +def get_roi_without_padding(img, bbox): + crop_t = max(bbox['y1'], 0) + crop_b = min(bbox['y2'], img.shape[0]) + crop_l = max(bbox['x1'], 0) + crop_r = min(bbox['x2'], img.shape[1]) + roi = img[crop_t:crop_b, crop_l:crop_r] + return roi, 0, [crop_t, crop_b, crop_l, crop_r] + + +def roi_to_tensor(img): + img = torch.from_numpy(img.transpose((2, 0, 1)))[None, ...] + + return img + + +def preprocess_roi(img): + img = img.float() / 255.0 + img = (img - 0.5) * 2 + + return img + + +def patch_partition_overlap(image, p1, p2, padding=32): + + B, C, H, W = image.size() + h, w = H // p1, W // p2 + image = F.pad( + image, + pad=(padding, padding, padding, padding, 0, 0), + mode='constant', + value=0) + + patch_list = [] + for i in range(h): + for j in range(w): + patch = image[:, :, p1 * i:p1 * (i + 1) + padding * 2, + p2 * j:p2 * (j + 1) + padding * 2] + patch_list.append(patch) + + output = torch.cat( + patch_list, dim=0) # (b h w) c (p1 + 2 * padding) (p2 + 2 * padding) + return output + + +def patch_aggregation_overlap(image, h, w, padding=32): + + image = image[:, :, padding:-padding, padding:-padding] + + output = rearrange(image, '(b h w) c p1 p2 -> b c (h p1) (w p2)', h=h, w=w) + + return output + + +def smooth_border_mg(diffuse_mask, mg): + mg = mg - 0.5 + diffuse_mask = F.interpolate( + diffuse_mask, mg.shape[:2], mode='bilinear')[0].permute(1, 2, 0) + mg = mg * diffuse_mask + mg = mg + 0.5 + return mg + + +def whiten_img(image, skin_mask, whitening_degree, flag_bigKernal=False): + """ + image: rgb + """ + dilate_kernalsize = 30 + if flag_bigKernal: + dilate_kernalsize = 80 + new_kernel1 = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (dilate_kernalsize, dilate_kernalsize)) + new_kernel2 = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (dilate_kernalsize, dilate_kernalsize)) + if len(skin_mask.shape) == 3: + skin_mask = skin_mask[:, :, -1] + skin_mask = cv2.dilate(skin_mask, new_kernel1, 1) + skin_mask = cv2.erode(skin_mask, new_kernel2, 1) + skin_mask = cv2.blur(skin_mask, (20, 20)) / 255.0 + skin_mask = skin_mask.squeeze() + skin_mask = torch.from_numpy(skin_mask).to(image.device) + skin_mask = torch.stack([skin_mask, skin_mask, skin_mask], dim=0)[None, + ...] + skin_mask[:, 1:, :, :] *= 0.75 + + whiten_mg = skin_mask * 0.2 * whitening_degree + 0.5 + assert len(whiten_mg.shape) == 4 + whiten_mg = F.interpolate( + whiten_mg, image.shape[:2], mode='bilinear')[0].permute(1, 2, + 0).half() + output_pred = image.half() + output_pred = output_pred / 255.0 + output_pred = ( + -2 * whiten_mg + 1 + ) * output_pred * output_pred + 2 * whiten_mg * output_pred # value: 0~1 + output_pred = output_pred * 255.0 + output_pred = output_pred.byte() + + output_pred = output_pred.cpu().numpy() + return output_pred + + +def gen_diffuse_mask(out_channels=3): + mask_size = 500 + diffuse_with = 20 + a = np.ones(shape=(mask_size, mask_size), dtype=np.float32) + + for i in range(mask_size): + for j in range(mask_size): + if i >= diffuse_with and i <= ( + mask_size - diffuse_with) and j >= diffuse_with and j <= ( + mask_size - diffuse_with): + a[i, j] = 1.0 + elif i <= diffuse_with: + a[i, j] = i * 1.0 / diffuse_with + elif i > (mask_size - diffuse_with): + a[i, j] = (mask_size - i) * 1.0 / diffuse_with + + for i in range(mask_size): + for j in range(mask_size): + if j <= diffuse_with: + a[i, j] = min(a[i, j], j * 1.0 / diffuse_with) + elif j > (mask_size - diffuse_with): + a[i, j] = min(a[i, j], (mask_size - j) * 1.0 / diffuse_with) + a = np.dstack([a] * out_channels) + return a + + +def pad_to_size( + target_size: Tuple[int, int], + image: np.array, + bboxes: Optional[np.ndarray] = None, + keypoints: Optional[np.ndarray] = None, +) -> Dict[str, Union[np.ndarray, Tuple[int, int, int, int]]]: + """Pads the image on the sides to the target_size + + Args: + target_size: (target_height, target_width) + image: + bboxes: np.array with shape (num_boxes, 4). Each row: [x_min, y_min, x_max, y_max] + keypoints: np.array with shape (num_keypoints, 2), each row: [x, y] + + Returns: + { + "image": padded_image, + "pads": (x_min_pad, y_min_pad, x_max_pad, y_max_pad), + "bboxes": shifted_boxes, + "keypoints": shifted_keypoints + } + + """ + target_height, target_width = target_size + + image_height, image_width = image.shape[:2] + + if target_width < image_width: + raise ValueError(f'Target width should bigger than image_width' + f'We got {target_width} {image_width}') + + if target_height < image_height: + raise ValueError(f'Target height should bigger than image_height' + f'We got {target_height} {image_height}') + + if image_height == target_height: + y_min_pad = 0 + y_max_pad = 0 + else: + y_pad = target_height - image_height + y_min_pad = y_pad // 2 + y_max_pad = y_pad - y_min_pad + + if image_width == target_width: + x_min_pad = 0 + x_max_pad = 0 + else: + x_pad = target_width - image_width + x_min_pad = x_pad // 2 + x_max_pad = x_pad - x_min_pad + + result = { + 'pads': (x_min_pad, y_min_pad, x_max_pad, y_max_pad), + 'image': + cv2.copyMakeBorder(image, y_min_pad, y_max_pad, x_min_pad, x_max_pad, + cv2.BORDER_CONSTANT), + } + + if bboxes is not None: + bboxes[:, 0] += x_min_pad + bboxes[:, 1] += y_min_pad + bboxes[:, 2] += x_min_pad + bboxes[:, 3] += y_min_pad + + result['bboxes'] = bboxes + + if keypoints is not None: + keypoints[:, 0] += x_min_pad + keypoints[:, 1] += y_min_pad + + result['keypoints'] = keypoints + + return result + + +def unpad_from_size( + pads: Tuple[int, int, int, int], + image: Optional[np.array] = None, + bboxes: Optional[np.ndarray] = None, + keypoints: Optional[np.ndarray] = None, +) -> Dict[str, np.ndarray]: + """Crops patch from the center so that sides are equal to pads. + + Args: + image: + pads: (x_min_pad, y_min_pad, x_max_pad, y_max_pad) + bboxes: np.array with shape (num_boxes, 4). Each row: [x_min, y_min, x_max, y_max] + keypoints: np.array with shape (num_keypoints, 2), each row: [x, y] + + Returns: cropped image + + { + "image": cropped_image, + "bboxes": shifted_boxes, + "keypoints": shifted_keypoints + } + + """ + x_min_pad, y_min_pad, x_max_pad, y_max_pad = pads + + result = {} + + if image is not None: + height, width = image.shape[:2] + result['image'] = image[y_min_pad:height - y_max_pad, + x_min_pad:width - x_max_pad] + + if bboxes is not None: + bboxes[:, 0] -= x_min_pad + bboxes[:, 1] -= y_min_pad + bboxes[:, 2] -= x_min_pad + bboxes[:, 3] -= y_min_pad + + result['bboxes'] = bboxes + + if keypoints is not None: + keypoints[:, 0] -= x_min_pad + keypoints[:, 1] -= y_min_pad + + result['keypoints'] = keypoints + + return result diff --git a/modelscope/models/cv/skin_retouching/weights_init.py b/modelscope/models/cv/skin_retouching/weights_init.py new file mode 100644 index 00000000..ae62d4a4 --- /dev/null +++ b/modelscope/models/cv/skin_retouching/weights_init.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn + + +def weights_init(init_type='kaiming', gain=0.02): + + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + elif classname.find('BatchNorm2d') != -1: + nn.init.normal_(m.weight.data, 1.0, gain) + nn.init.constant_(m.bias.data, 0.0) + + return init_func + + +def spectral_norm(module, mode=True): + + if mode: + return nn.utils.spectral_norm(module) + + return module diff --git a/modelscope/models/cv/super_resolution/__init__.py b/modelscope/models/cv/super_resolution/__init__.py new file mode 100644 index 00000000..5065e280 --- /dev/null +++ b/modelscope/models/cv/super_resolution/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .rrdbnet_arch import RRDBNet + +else: + _import_structure = {'rrdbnet_arch': ['RRDBNet']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/super_resolution/arch_util.py b/modelscope/models/cv/super_resolution/arch_util.py new file mode 100644 index 00000000..99711a11 --- /dev/null +++ b/modelscope/models/cv/super_resolution/arch_util.py @@ -0,0 +1,228 @@ +# The implementation is adopted from BasicSR, made public available under the Apache 2.0 License +# at https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/arch_util.py +import collections.abc +import math +import warnings +from itertools import repeat + +import torch +import torchvision +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError( + f'scale {scale} is not supported. Supported scales: 2^n and 3.' + ) + super(Upsample, self).__init__(*m) + + +def flow_warp(x, + flow, + interp_mode='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid( + torch.arange(0, h).type_as(x), + torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample( + x, + vgrid_scaled, + mode=interp_mode, + padding_mode=padding_mode, + align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, + size_type, + sizes, + interp_mode='bilinear', + align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError( + f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, + size=(output_h, output_w), + mode=interp_mode, + align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + + assert hh % scale == 0 and hw % scale == 0 + + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) diff --git a/modelscope/models/cv/super_resolution/rrdbnet_arch.py b/modelscope/models/cv/super_resolution/rrdbnet_arch.py new file mode 100644 index 00000000..8c84f796 --- /dev/null +++ b/modelscope/models/cv/super_resolution/rrdbnet_arch.py @@ -0,0 +1,131 @@ +# The implementation is adopted from BasicSR, made public available under the Apache 2.0 License +# at https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/rrdbnet_arch.py +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, + 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, + 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights( + [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, + num_in_ch, + num_out_ch, + scale=4, + num_feat=64, + num_block=23, + num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer( + RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu( + self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu( + self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out diff --git a/modelscope/models/cv/text_driven_segmentation/__init__.py b/modelscope/models/cv/text_driven_segmentation/__init__.py new file mode 100644 index 00000000..aefaa698 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .lseg_base import TextDrivenSegmentation diff --git a/modelscope/models/cv/text_driven_segmentation/clip.py b/modelscope/models/cv/text_driven_segmentation/clip.py new file mode 100644 index 00000000..1cec5f39 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/clip.py @@ -0,0 +1,169 @@ +# CLIP +# Adapted from https://github.com/openai/CLIP. +# Originally MIT License, Copyright (c) 2021 OpenAI. + +import hashlib +import os +import urllib +import warnings +from typing import Any, List, Union + +import torch +from PIL import Image +from pkg_resources import packaging +from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, + ToTensor) +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + +if packaging.version.parse( + torch.__version__) < packaging.version.parse('1.7.1'): + warnings.warn('PyTorch version 1.7.1 or higher is recommended') +__all__ = ['load', 'tokenize'] + + +def _convert_image_to_rgb(image): + return image.convert('RGB') + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def load(name: str, + device: Union[str, torch.device] = 'cuda' + if torch.cuda.is_available() else 'cpu', + jit: bool = False, + root: str = None): + + if not jit: + model = build_model().to(device) + if str(device) == 'cpu': + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [ + n for n in device_holder.graph.findAllNodes('prim::Constant') + if 'Device' in repr(n) + ][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, 'graph') else [] + except RuntimeError: + graphs = [] + + if hasattr(module, 'forward1'): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes('prim::Constant'): + if 'value' in node.attributeNames() and str( + node['value']).startswith('cuda'): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == 'cpu': + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode('aten::to').inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, 'graph') else [] + except RuntimeError: + graphs = [] + + if hasattr(module, 'forward1'): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes('aten::to'): + inputs = list(node.inputs()) + for i in [ + 1, 2 + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()['value'] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize( + _tokenizer, + texts: Union[str, List[str]], + context_length: int = 77, + truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder['<|startoftext|>'] + eot_token = _tokenizer.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] + for text in texts] + if packaging.version.parse( + torch.__version__) < packaging.version.parse('1.8.0'): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f'Input {texts[i]} is too long for context length {context_length}' + ) + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_base.py b/modelscope/models/cv/text_driven_segmentation/lseg_base.py new file mode 100644 index 00000000..c79861a7 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_base.py @@ -0,0 +1,26 @@ +# Adapted from https://github.com/isl-org/lang-seg. +# Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. + +import torch +import torch.nn as nn + +from .lseg_net import LSeg + + +class TextDrivenSegmentation(nn.Module): + + def __init__(self, model_dir): + super(TextDrivenSegmentation, self).__init__() + self.net = LSeg(model_dir=model_dir) + self.model_dir = model_dir + + def forward(self, img, txt_list): + b = img.size()[0] + batch_name_list = txt_list + xout_list = [] + for i in range(b): + labelset = ['others', batch_name_list[i]] + xout = self.net(img[i:i + 1], labelset=labelset) + xout_list.append(xout) + score_map = torch.cat(xout_list, dim=0) + return score_map diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_blocks.py b/modelscope/models/cv/text_driven_segmentation/lseg_blocks.py new file mode 100644 index 00000000..56d4a65d --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_blocks.py @@ -0,0 +1,332 @@ +# Adapted from https://github.com/isl-org/lang-seg. +# Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. + +import torch +import torch.nn as nn + +from .lseg_vit import _make_pretrained_clip_vitl16_384, forward_vit + + +def _make_encoder( + backbone, + features, + use_pretrained=True, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout='ignore', + enable_attention_hooks=False, +): + if backbone == 'clip_vitl16_384': + clip_pretrained, pretrained = _make_pretrained_clip_vitl16_384( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch([256, 512, 1024, 1024], + features, + groups=groups, + expand=expand) + else: + raise NotImplementedError(f"Backbone '{backbone}' not implemented") + + return clip_pretrained, pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand is True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + return scratch + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode='bilinear', align_corners=True) + + return output + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn is True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn is True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn is True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand is True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners) + + output = self.out_conv(output) + return output diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_model.py b/modelscope/models/cv/text_driven_segmentation/lseg_model.py new file mode 100644 index 00000000..ec381356 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_model.py @@ -0,0 +1,109 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.text_driven_segmentation import \ + TextDrivenSegmentation +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() +__all__ = ['TextDrivenSeg'] + + +@MODELS.register_module( + Tasks.text_driven_segmentation, + module_name=Models.text_driven_segmentation) +class TextDrivenSeg(TorchModel): + """ text driven segmentation model. + """ + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.model = TextDrivenSegmentation(model_dir=model_dir) + pretrained_params = torch.load('{}/{}'.format( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) + self.model.load_state_dict(pretrained_params) + self.model.eval() + if device_id >= 0 and torch.cuda.is_available(): + self.model.to('cuda:{}'.format(device_id)) + logger.info('Use GPU: {}'.format(device_id)) + else: + device_id = -1 + logger.info('Use CPU for inference') + self.device_id = device_id + + def preprocess(self, img, size=640): + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + h, w, c = img.shape + max_hw = max(h, w) + ratio = 1.0 * size / max_hw + crop_h, crop_w = int(ratio * h), int(ratio * w) + pil_img = Image.fromarray(img) + pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR) + np_img = np.array(pil_img, dtype=np.float32) / 255. + for j in range(3): + np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j] + img_pad = np.zeros((size, size, 3), dtype=np.float32) + img_pad[:crop_h, :crop_w] = np_img + img_pad = torch.from_numpy(img_pad).permute(2, 0, + 1).unsqueeze(0).float() + return img_pad, h, w, crop_h, crop_w + + def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w): + output = np.clip(tensors * 255., a_min=0, a_max=255.) + crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8) + pil_output = Image.fromarray(crop_output) + pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR) + np_output = np.array(pil_output, dtype=np.uint8) + np_output[np_output < 128] = 0 + np_output[np_output >= 128] = 255 + np_output = np.uint8(np_output) + return np_output + + def forward(self, image, text): + """ + image should be numpy array, dtype=np.uint8, shape: height*width*3 + """ + image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess( + image, size=640) + pred = self.inference(image_tensor, text) + msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=640) + outputs = {OutputKeys.MASKS: msk} + return outputs + + def inference(self, image, text): + """ + image should be tensor, 1 * 3 * 640 * 640 + """ + with torch.no_grad(): + if self.device_id == -1: + output = self.model(image, [text]) + else: + device = torch.device('cuda', self.device_id) + output = self.model(image.to(device), [text]) + output = F.interpolate(output, size=(640, 640), mode='bilinear') + output = F.softmax(output, dim=1) + output = torch.argmax(output, dim=1) + output = output[0] + if self.device_id == -1: + pred = output.data.numpy() + else: + pred = output.data.cpu().numpy() + del output + return pred diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_net.py b/modelscope/models/cv/text_driven_segmentation/lseg_net.py new file mode 100644 index 00000000..541a4a38 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_net.py @@ -0,0 +1,195 @@ +# Adapted from https://github.com/isl-org/lang-seg. +# Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. + +import numpy as np +import torch +import torch.nn as nn + +from . import clip +from .lseg_blocks import (FeatureFusionBlock, FeatureFusionBlock_custom, + Interpolate, _make_encoder, forward_vit) +from .simple_tokenizer import SimpleTokenizer + + +class depthwise_clipseg_conv(nn.Module): + + def __init__(self): + super(depthwise_clipseg_conv, self).__init__() + self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1) + + def depthwise_clipseg(self, x, channels): + x = torch.cat( + [self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], + dim=1) + return x + + def forward(self, x): + channels = x.shape[1] + out = self.depthwise_clipseg(x, channels) + return out + + +class depthwise_conv(nn.Module): + + def __init__(self, kernel_size=3, stride=1, padding=1): + super(depthwise_conv, self).__init__() + self.depthwise = nn.Conv2d( + 1, 1, kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x): + # support for 4D tensor with NCHW + C, H, W = x.shape[1:] + x = x.reshape(-1, 1, H, W) + x = self.depthwise(x) + x = x.view(-1, C, H, W) + return x + + +class depthwise_block(nn.Module): + + def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): + super(depthwise_block, self).__init__() + self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'lrelu': + self.activation = nn.LeakyReLU() + elif activation == 'tanh': + self.activation = nn.Tanh() + + def forward(self, x, act=True): + x = self.depthwise(x) + if act: + x = self.activation(x) + return x + + +class bottleneck_block(nn.Module): + + def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): + super(bottleneck_block, self).__init__() + self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'lrelu': + self.activation = nn.LeakyReLU() + elif activation == 'tanh': + self.activation = nn.Tanh() + + def forward(self, x, act=True): + sum_layer = x.max(dim=1, keepdim=True)[0] + x = self.depthwise(x) + x = x + sum_layer + if act: + x = self.activation(x) + return x + + +class BaseModel(torch.nn.Module): + + def load(self, path): + """Load model from file. + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if 'optimizer' in parameters: + parameters = parameters['model'] + + self.load_state_dict(parameters) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + activation=nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class LSeg(BaseModel): + + def __init__( + self, + features=256, + backbone='clip_vitl16_384', + readout='project', + use_bn=True, + model_dir=None, + ): + super(LSeg, self).__init__() + hooks = { + 'clip_vitl16_384': [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.clip_pretrained, self.pretrained, self.scratch = _make_encoder( + backbone, + features, + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.logit_scale = nn.Parameter(torch.ones([]) + * np.log(1 / 0.07)).exp() + self.out_c = 512 + self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1) + + self.scratch.output_conv = nn.Sequential( + Interpolate(scale_factor=2, mode='bilinear', align_corners=True), ) + + self.tau = 0.07 + self.model_dir = model_dir + self.tokenizer = SimpleTokenizer(model_dir + + '/bpe_simple_vocab_16e6.txt.gz') + + def forward(self, x, labelset=''): + text = clip.tokenize(self.tokenizer, labelset) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + text = text.to(x.device) + text_features = self.clip_pretrained.encode_text(text) + + image_features = self.scratch.head1(path_1) + + imshape = image_features.shape + image_features = image_features.permute(0, 2, 3, + 1).reshape(-1, self.out_c) + + # normalized features + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + logits_per_image = image_features @ text_features.t() / self.tau + + out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], + -1).permute(0, 3, 1, 2) + + out = self.scratch.output_conv(out) + + return out diff --git a/modelscope/models/cv/text_driven_segmentation/lseg_vit.py b/modelscope/models/cv/text_driven_segmentation/lseg_vit.py new file mode 100644 index 00000000..5298832f --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/lseg_vit.py @@ -0,0 +1,541 @@ +# Adapted from https://github.com/isl-org/lang-seg. +# Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. + +import math +import types + +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from . import clip + +activations = {} + + +def get_activation(name): + + def hook(model, input, output): + activations[name] = output + + return hook + + +attention = {} + + +def get_attention(name): + + def hook(module, input, output): + x = input[0] + B, N, C = x.shape + qkv = ( + module.qkv(x).reshape(B, N, 3, module.num_heads, + C // module.num_heads).permute( + 2, 0, 3, 1, 4)) + q, k, _ = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * module.scale + + attn = attn.softmax(dim=-1) # [:,:,1,1:] + attention[name] = attn + + return hook + + +def get_mean_attention_map(attn, token, shape): + attn = attn[:, :, token, 1:] + attn = attn.unflatten(2, torch.Size([shape[2] // 16, + shape[3] // 16])).float() + attn = torch.nn.functional.interpolate( + attn, size=shape[2:], mode='bicubic', align_corners=False).squeeze(0) + + all_attn = torch.mean(attn, 0) + + return all_attn + + +class Slice(nn.Module): + + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index:] + + +class AddReadout(nn.Module): + + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential( + nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + # encoder + _ = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations['1'] + layer_2 = pretrained.activations['2'] + layer_3 = pretrained.activations['3'] + layer_4 = pretrained.activations['4'] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size([ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ]), + )) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)]( + layer_1) + layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)]( + layer_2) + layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)]( + layer_3) + layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)]( + layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, :self.start_index], + posemb[0, self.start_index:], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, + -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate( + posemb_grid, size=(gs_h, gs_w), mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], + w // self.patch_size[0]) + + B = x.shape[0] + + if hasattr(self.patch_embed, 'backbone'): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[ + -1] # last feature if backbone outputs list/tuple of features + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, 'dist_token', None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + gradient_checkpoint = False + for blk in self.blocks: + if gradient_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.norm(x) + + return x + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == 'ignore': + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == 'add': + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == 'project': + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float( + ) # Some weights are in torch.half, ensure it's float for sum on CPU + O, II, J, K = conv_weight.shape + if in_chans == 1: + if II > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, II // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if II != 3: + raise NotImplementedError( + 'Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, + 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +@torch.no_grad() +def _load_weights(model, checkpoint_path, prefix=''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_( + adapt_input_conv(stem.conv.weight.shape[1], + _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_( + _n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_( + _n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_( + _n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_( + _n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_( + _n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_( + _n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], + _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p( + w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_prefix_tokens', + 1), + model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance( + model.head, nn.Linear + ) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_( + torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T + for n in ('query', 'key', 'value') + ])) + block.attn.qkv.bias.copy_( + torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) + for n in ('query', 'key', 'value') + ])) + block.attn.proj.weight.copy_( + _n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_( + _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_( + _n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + ntok_new = posemb_new.shape[1] + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[ + 0, num_prefix_tokens:] + ntok_new -= num_prefix_tokens + else: + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, + -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate( + posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, + 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) + return posemb + + +def _make_pretrained_clip_vitl16_384(pretrained, + use_readout='ignore', + hooks=None, + enable_attention_hooks=False): + clip_pretrained, _ = clip.load('ViT-B/32', device='cpu', jit=False) + + # model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + model = timm.create_model('vit_large_patch16_384', pretrained=False) + hooks = [5, 11, 17, 23] if hooks is None else hooks + pretrained = _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + return clip_pretrained, pretrained + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout='ignore', + start_index=1, + enable_attention_hooks=False, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook( + get_activation('1')) + pretrained.model.blocks[hooks[1]].register_forward_hook( + get_activation('2')) + pretrained.model.blocks[hooks[2]].register_forward_hook( + get_activation('3')) + pretrained.model.blocks[hooks[3]].register_forward_hook( + get_activation('4')) + + pretrained.activations = activations + + if enable_attention_hooks: + pretrained.model.blocks[hooks[0]].attn.register_forward_hook( + get_attention('attn_1')) + pretrained.model.blocks[hooks[1]].attn.register_forward_hook( + get_attention('attn_2')) + pretrained.model.blocks[hooks[2]].attn.register_forward_hook( + get_attention('attn_3')) + pretrained.model.blocks[hooks[3]].attn.register_forward_hook( + get_attention('attn_4')) + pretrained.attention = attention + + readout_oper = get_readout_oper(vit_features, features, use_readout, + start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, + pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model) + + return pretrained diff --git a/modelscope/models/cv/text_driven_segmentation/model.py b/modelscope/models/cv/text_driven_segmentation/model.py new file mode 100644 index 00000000..f98d480d --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/model.py @@ -0,0 +1,456 @@ +# Adapted from https://github.com/isl-org/lang-seg. +# Originally MIT License, Copyright (c) 2021 Intelligent Systems Lab Org. + +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, width, layers, heads, attn_mask=None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x1 = self.class_embedding.to(x.dtype) + x2 = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([x1 + x2, x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, self.visual.layer2, self.visual.layer3, + self.visual.layer4 + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith('bn3.weight'): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1)] @ self.text_projection + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm( + dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(ll): + if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Linear)): + ll.weight.data = ll.weight.data.half() + if ll.bias is not None: + ll.bias.data = ll.bias.data.half() + + if isinstance(ll, nn.MultiheadAttention): + for attr in [ + *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], + 'in_proj_bias', 'bias_k', 'bias_v' + ]: + tensor = getattr(ll, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ['text_projection', 'proj']: + if hasattr(ll, name): + attr = getattr(ll, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(): + model = CLIP(512, 224, 12, 768, 32, 77, 49408, 512, 8, 12) + convert_weights(model) + return model.eval() diff --git a/modelscope/models/cv/text_driven_segmentation/simple_tokenizer.py b/modelscope/models/cv/text_driven_segmentation/simple_tokenizer.py new file mode 100644 index 00000000..361d67c6 --- /dev/null +++ b/modelscope/models/cv/text_driven_segmentation/simple_tokenizer.py @@ -0,0 +1,155 @@ +# CLIP +# Adapted from https://github.com/openai/CLIP. +# Originally MIT License, Copyright (c) 2021 OpenAI. + +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + error_list = [] + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception as err: + new_word.extend(word[i:]) + error_list.append(err) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text diff --git a/modelscope/models/cv/tinynas_classfication/__init__.py b/modelscope/models/cv/tinynas_classfication/__init__.py new file mode 100644 index 00000000..6c2f89ee --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model_zoo import get_zennet + +else: + _import_structure = { + 'model_zoo': ['get_zennet'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/tinynas_classfication/basic_blocks.py b/modelscope/models/cv/tinynas_classfication/basic_blocks.py new file mode 100644 index 00000000..50548dcc --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/basic_blocks.py @@ -0,0 +1,1309 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .global_utils import (create_netblock_list_from_str_inner, + get_right_parentheses_index) + + +class PlainNetBasicBlockClass(nn.Module): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=1, + no_create=False, + block_name=None, + **kwargs): + super(PlainNetBasicBlockClass, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.no_create = no_create + self.block_name = block_name + if self.block_name is None: + self.block_name = 'uuid{}'.format(uuid.uuid4().hex) + + def forward(self, x): + raise RuntimeError('Not implemented') + + def __str__(self): + return type(self).__name__ + '({},{},{})'.format( + self.in_channels, self.out_channels, self.stride) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride) + + def get_output_resolution(self, input_resolution): + raise RuntimeError('Not implemented') + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert PlainNetBasicBlockClass.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + @classmethod + def is_instance_from_str(cls, s): + if s.startswith(cls.__name__ + '(') and s[-1] == ')': + return True + else: + return False + + +class AdaptiveAvgPool(PlainNetBasicBlockClass): + + def __init__(self, out_channels, output_size, no_create=False, **kwargs): + super(AdaptiveAvgPool, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.output_size = output_size + self.no_create = no_create + if not no_create: + self.netblock = nn.AdaptiveAvgPool2d( + output_size=(self.output_size, self.output_size)) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return type(self).__name__ + '({},{})'.format( + self.out_channels // self.output_size**2, self.output_size) + + def __repr__(self): + return type(self).__name__ + '({}|{},{})'.format( + self.block_name, self.out_channels // self.output_size**2, + self.output_size) + + def get_output_resolution(self, input_resolution): + return self.output_size + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert AdaptiveAvgPool.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('AdaptiveAvgPool('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + out_channels = int(param_str_split[0]) + output_size = int(param_str_split[1]) + return AdaptiveAvgPool( + out_channels=out_channels, + output_size=output_size, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class BN(PlainNetBasicBlockClass): + + def __init__(self, + out_channels=None, + copy_from=None, + no_create=False, + **kwargs): + super(BN, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.BatchNorm2d) + self.in_channels = copy_from.weight.shape[0] + self.out_channels = copy_from.weight.shape[0] + assert out_channels is None or out_channels == self.out_channels + self.netblock = copy_from + + else: + self.in_channels = out_channels + self.out_channels = out_channels + if no_create: + return + else: + self.netblock = nn.BatchNorm2d(num_features=self.out_channels) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'BN({})'.format(self.out_channels) + + def __repr__(self): + return 'BN({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert BN.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('BN('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + out_channels = int(param_str) + return BN( + out_channels=out_channels, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class ConvKX(PlainNetBasicBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + groups=1, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKX, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.Conv2d) + self.in_channels = copy_from.in_channels + self.out_channels = copy_from.out_channels + self.kernel_size = copy_from.kernel_size[0] + self.stride = copy_from.stride[0] + self.groups = copy_from.groups + assert in_channels is None or in_channels == self.in_channels + assert out_channels is None or out_channels == self.out_channels + assert kernel_size is None or kernel_size == self.kernel_size + assert stride is None or stride == self.stride + self.netblock = copy_from + else: + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.groups = groups + self.kernel_size = kernel_size + self.padding = (self.kernel_size - 1) // 2 + if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \ + or self.stride == 0: + return + else: + self.netblock = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + bias=False, + groups=self.groups) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format( + self.in_channels, self.out_channels, self.kernel_size, self.stride) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{},{})'.format( + self.block_name, self.in_channels, self.out_channels, + self.kernel_size, self.stride) + + def get_output_resolution(self, input_resolution): + return input_resolution // self.stride + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + split_str = param_str.split(',') + in_channels = int(split_str[0]) + out_channels = int(split_str[1]) + kernel_size = int(split_str[2]) + stride = int(split_str[3]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class ConvDW(PlainNetBasicBlockClass): + + def __init__(self, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvDW, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.Conv2d) + self.in_channels = copy_from.in_channels + self.out_channels = copy_from.out_channels + self.kernel_size = copy_from.kernel_size[0] + self.stride = copy_from.stride[0] + assert self.in_channels == self.out_channels + assert out_channels is None or out_channels == self.out_channels + assert kernel_size is None or kernel_size == self.kernel_size + assert stride is None or stride == self.stride + + self.netblock = copy_from + else: + + self.in_channels = out_channels + self.out_channels = out_channels + self.stride = stride + self.kernel_size = kernel_size + + self.padding = (self.kernel_size - 1) // 2 + if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \ + or self.stride == 0: + return + else: + self.netblock = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + bias=False, + groups=self.in_channels) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'ConvDW({},{},{})'.format(self.out_channels, self.kernel_size, + self.stride) + + def __repr__(self): + return 'ConvDW({}|{},{},{})'.format(self.block_name, self.out_channels, + self.kernel_size, self.stride) + + def get_output_resolution(self, input_resolution): + return input_resolution // self.stride + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert ConvDW.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('ConvDW('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + split_str = param_str.split(',') + out_channels = int(split_str[0]) + kernel_size = int(split_str[1]) + stride = int(split_str[2]) + return ConvDW( + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class ConvKXG2(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG2, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=2, + **kwargs) + + +class ConvKXG4(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG4, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=4, + **kwargs) + + +class ConvKXG8(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG8, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=8, + **kwargs) + + +class ConvKXG16(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG16, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=16, + **kwargs) + + +class ConvKXG32(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG32, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=32, + **kwargs) + + +class Flatten(PlainNetBasicBlockClass): + + def __init__(self, out_channels, no_create=False, **kwargs): + super(Flatten, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.no_create = no_create + + def forward(self, x): + return torch.flatten(x, 1) + + def __str__(self): + return 'Flatten({})'.format(self.out_channels) + + def __repr__(self): + return 'Flatten({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return 1 + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Flatten.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('Flatten('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return Flatten( + out_channels=out_channels, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class Linear(PlainNetBasicBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + bias=True, + copy_from=None, + no_create=False, + **kwargs): + super(Linear, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.Linear) + self.in_channels = copy_from.weight.shape[1] + self.out_channels = copy_from.weight.shape[0] + self.use_bias = copy_from.bias is not None + assert in_channels is None or in_channels == self.in_channels + assert out_channels is None or out_channels == self.out_channels + + self.netblock = copy_from + else: + + self.in_channels = in_channels + self.out_channels = out_channels + self.use_bias = bias + if not no_create: + self.netblock = nn.Linear( + self.in_channels, self.out_channels, bias=self.use_bias) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'Linear({},{},{})'.format(self.in_channels, self.out_channels, + int(self.use_bias)) + + def __repr__(self): + return 'Linear({}|{},{},{})'.format(self.block_name, self.in_channels, + self.out_channels, + int(self.use_bias)) + + def get_output_resolution(self, input_resolution): + assert input_resolution == 1 + return 1 + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Linear.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('Linear('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + split_str = param_str.split(',') + in_channels = int(split_str[0]) + out_channels = int(split_str[1]) + use_bias = int(split_str[2]) + + return Linear( + in_channels=in_channels, + out_channels=out_channels, + bias=use_bias == 1, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class MaxPool(PlainNetBasicBlockClass): + + def __init__(self, + out_channels, + kernel_size, + stride, + no_create=False, + **kwargs): + super(MaxPool, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = (kernel_size - 1) // 2 + self.no_create = no_create + if not no_create: + self.netblock = nn.MaxPool2d( + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'MaxPool({},{},{})'.format(self.out_channels, self.kernel_size, + self.stride) + + def __repr__(self): + return 'MaxPool({}|{},{},{})'.format(self.block_name, + self.out_channels, + self.kernel_size, self.stride) + + def get_output_resolution(self, input_resolution): + return input_resolution // self.stride + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert MaxPool.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('MaxPool('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + out_channels = int(param_str_split[0]) + kernel_size = int(param_str_split[1]) + stride = int(param_str_split[2]) + return MaxPool( + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class Sequential(PlainNetBasicBlockClass): + + def __init__(self, block_list, no_create=False, **kwargs): + super(Sequential, self).__init__(**kwargs) + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + self.in_channels = block_list[0].in_channels + self.out_channels = block_list[-1].out_channels + self.no_create = no_create + res = 1024 + for block in self.block_list: + res = block.get_output_resolution(res) + self.stride = 1024 // res + + def forward(self, x): + output = x + for inner_block in self.block_list: + output = inner_block(output) + return output + + def __str__(self): + s = 'Sequential(' + for inner_block in self.block_list: + s += str(inner_block) + s += ')' + return s + + def __repr__(self): + return str(self) + + def get_output_resolution(self, input_resolution): + the_res = input_resolution + for the_block in self.block_list: + the_res = the_block.get_output_resolution(the_res) + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Sequential.is_instance_from_str(s) + the_right_paraen_idx = get_right_parentheses_index(s) + param_str = s[len('Sequential(') + 1:the_right_paraen_idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, netblocks_dict=bottom_basic_dict, no_create=no_create) + assert len(remaining_s) == 0 + if the_block_list is None or len(the_block_list) == 0: + return None, '' + return Sequential( + block_list=the_block_list, + no_create=no_create, + block_name=tmp_block_name), '' + + +class MultiSumBlock(PlainNetBasicBlockClass): + + def __init__(self, block_list, no_create=False, **kwargs): + super(MultiSumBlock, self).__init__(**kwargs) + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + self.in_channels = np.max([x.in_channels for x in block_list]) + self.out_channels = np.max([x.out_channels for x in block_list]) + self.no_create = no_create + + res = 1024 + res = self.block_list[0].get_output_resolution(res) + self.stride = 1024 // res + + def forward(self, x): + output = self.block_list[0](x) + for inner_block in self.block_list[1:]: + output2 = inner_block(x) + output = output + output2 + return output + + def __str__(self): + s = 'MultiSumBlock({}|'.format(self.block_name) + for inner_block in self.block_list: + s += str(inner_block) + ';' + s = s[:-1] + s += ')' + return s + + def __repr__(self): + return str(self) + + def get_output_resolution(self, input_resolution): + the_res = self.block_list[0].get_output_resolution(input_resolution) + for the_block in self.block_list: + assert the_res == the_block.get_output_resolution(input_resolution) + + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert MultiSumBlock.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('MultiSumBlock('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + the_s = param_str + + the_block_list = [] + while len(the_s) > 0: + tmp_block_list, remaining_s = create_netblock_list_from_str_inner( + the_s, netblocks_dict=bottom_basic_dict, no_create=no_create) + the_s = remaining_s + if tmp_block_list is None: + pass + elif len(tmp_block_list) == 1: + the_block_list.append(tmp_block_list[0]) + else: + the_block_list.append( + Sequential(block_list=tmp_block_list, no_create=no_create)) + pass + + if len(the_block_list) == 0: + return None, s[idx + 1:] + + return MultiSumBlock( + block_list=the_block_list, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class MultiCatBlock(PlainNetBasicBlockClass): + + def __init__(self, block_list, no_create=False, **kwargs): + super(MultiCatBlock, self).__init__(**kwargs) + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + self.in_channels = np.max([x.in_channels for x in block_list]) + self.out_channels = np.sum([x.out_channels for x in block_list]) + self.no_create = no_create + + res = 1024 + res = self.block_list[0].get_output_resolution(res) + self.stride = 1024 // res + + def forward(self, x): + output_list = [] + for inner_block in self.block_list: + output = inner_block(x) + output_list.append(output) + + return torch.cat(output_list, dim=1) + + def __str__(self): + s = 'MultiCatBlock({}|'.format(self.block_name) + for inner_block in self.block_list: + s += str(inner_block) + ';' + + s = s[:-1] + s += ')' + return s + + def __repr__(self): + return str(self) + + def get_output_resolution(self, input_resolution): + the_res = self.block_list[0].get_output_resolution(input_resolution) + for the_block in self.block_list: + assert the_res == the_block.get_output_resolution(input_resolution) + + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert MultiCatBlock.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('MultiCatBlock('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + the_s = param_str + + the_block_list = [] + while len(the_s) > 0: + tmp_block_list, remaining_s = create_netblock_list_from_str_inner( + the_s, netblocks_dict=bottom_basic_dict, no_create=no_create) + the_s = remaining_s + if tmp_block_list is None: + pass + elif len(tmp_block_list) == 1: + the_block_list.append(tmp_block_list[0]) + else: + the_block_list.append( + Sequential(block_list=tmp_block_list, no_create=no_create)) + + if len(the_block_list) == 0: + return None, s[idx + 1:] + + return MultiCatBlock( + block_list=the_block_list, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class RELU(PlainNetBasicBlockClass): + + def __init__(self, out_channels, no_create=False, **kwargs): + super(RELU, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.no_create = no_create + + def forward(self, x): + return F.relu(x) + + def __str__(self): + return 'RELU({})'.format(self.out_channels) + + def __repr__(self): + return 'RELU({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert RELU.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('RELU('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return RELU( + out_channels=out_channels, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class ResBlock(PlainNetBasicBlockClass): + """ + ResBlock(in_channles, inner_blocks_str). If in_channels is missing, use block_list[0].in_channels as in_channels + """ + + def __init__(self, + block_list, + in_channels=None, + stride=None, + no_create=False, + **kwargs): + super(ResBlock, self).__init__(**kwargs) + self.block_list = block_list + self.stride = stride + self.no_create = no_create + if not no_create: + self.module_list = nn.ModuleList(block_list) + + if in_channels is None: + self.in_channels = block_list[0].in_channels + else: + self.in_channels = in_channels + self.out_channels = block_list[-1].out_channels + + if self.stride is None: + tmp_input_res = 1024 + tmp_output_res = self.get_output_resolution(tmp_input_res) + self.stride = tmp_input_res // tmp_output_res + + self.proj = None + if self.stride > 1 or self.in_channels != self.out_channels: + self.proj = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, 1, self.stride), + nn.BatchNorm2d(self.out_channels), + ) + + def forward(self, x): + if len(self.block_list) == 0: + return x + + output = x + for inner_block in self.block_list: + output = inner_block(output) + + if self.proj is not None: + output = output + self.proj(x) + else: + output = output + x + + return output + + def __str__(self): + s = 'ResBlock({},{},'.format(self.in_channels, self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def __repr__(self): + s = 'ResBlock({}|{},{},'.format(self.block_name, self.in_channels, + self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def get_output_resolution(self, input_resolution): + the_res = input_resolution + for the_block in self.block_list: + the_res = the_block.get_output_resolution(the_res) + + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert ResBlock.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + the_stride = None + param_str = s[len('ResBlock('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + first_comma_index = param_str.find(',') + if first_comma_index < 0 or not param_str[0:first_comma_index].isdigit( + ): + in_channels = None + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + else: + in_channels = int(param_str[0:first_comma_index]) + param_str = param_str[first_comma_index + 1:] + second_comma_index = param_str.find(',') + if second_comma_index < 0 or not param_str[ + 0:second_comma_index].isdigit(): + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + else: + the_stride = int(param_str[0:second_comma_index]) + param_str = param_str[second_comma_index + 1:] + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + pass + pass + + assert len(remaining_s) == 0 + if the_block_list is None or len(the_block_list) == 0: + return None, s[idx + 1:] + return ResBlock( + block_list=the_block_list, + in_channels=in_channels, + stride=the_stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class ResBlockProj(PlainNetBasicBlockClass): + """ + ResBlockProj(in_channles, inner_blocks_str). If in_channels is missing, use block_list[0].in_channels as in_channels + """ + + def __init__(self, + block_list, + in_channels=None, + stride=None, + no_create=False, + **kwargs): + super(ResBlockProj, self).__init__(**kwargs) + self.block_list = block_list + self.stride = stride + self.no_create = no_create + if not no_create: + self.module_list = nn.ModuleList(block_list) + + if in_channels is None: + self.in_channels = block_list[0].in_channels + else: + self.in_channels = in_channels + self.out_channels = block_list[-1].out_channels + + if self.stride is None: + tmp_input_res = 1024 + tmp_output_res = self.get_output_resolution(tmp_input_res) + self.stride = tmp_input_res // tmp_output_res + + self.proj = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, 1, self.stride), + nn.BatchNorm2d(self.out_channels), + ) + + def forward(self, x): + if len(self.block_list) == 0: + return x + + output = x + for inner_block in self.block_list: + output = inner_block(output) + output = output + self.proj(x) + return output + + def __str__(self): + s = 'ResBlockProj({},{},'.format(self.in_channels, self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def __repr__(self): + s = 'ResBlockProj({}|{},{},'.format(self.block_name, self.in_channels, + self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def get_output_resolution(self, input_resolution): + the_res = input_resolution + for the_block in self.block_list: + the_res = the_block.get_output_resolution(the_res) + + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert ResBlockProj.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + the_stride = None + param_str = s[len('ResBlockProj('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + first_comma_index = param_str.find(',') + if first_comma_index < 0 or not param_str[0:first_comma_index].isdigit( + ): + in_channels = None + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + else: + in_channels = int(param_str[0:first_comma_index]) + param_str = param_str[first_comma_index + 1:] + second_comma_index = param_str.find(',') + if second_comma_index < 0 or not param_str[ + 0:second_comma_index].isdigit(): + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + else: + the_stride = int(param_str[0:second_comma_index]) + param_str = param_str[second_comma_index + 1:] + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + pass + pass + + assert len(remaining_s) == 0 + if the_block_list is None or len(the_block_list) == 0: + return None, s[idx + 1:] + return ResBlockProj( + block_list=the_block_list, + in_channels=in_channels, + stride=the_stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class SE(PlainNetBasicBlockClass): + + def __init__(self, + out_channels=None, + copy_from=None, + no_create=False, + **kwargs): + super(SE, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + raise RuntimeError('Not implemented') + else: + self.in_channels = out_channels + self.out_channels = out_channels + self.se_ratio = 0.25 + self.se_channels = max( + 1, int(round(self.out_channels * self.se_ratio))) + if no_create or self.out_channels == 0: + return + else: + self.netblock = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d( + in_channels=self.out_channels, + out_channels=self.se_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), nn.BatchNorm2d(self.se_channels), + nn.ReLU(), + nn.Conv2d( + in_channels=self.se_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), nn.BatchNorm2d(self.out_channels), + nn.Sigmoid()) + + def forward(self, x): + se_x = self.netblock(x) + return se_x * x + + def __str__(self): + return 'SE({})'.format(self.out_channels) + + def __repr__(self): + return 'SE({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert SE.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('SE('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return SE( + out_channels=out_channels, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class SwishImplementation(torch.autograd.Function): + + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class Swish(PlainNetBasicBlockClass): + + def __init__(self, + out_channels=None, + copy_from=None, + no_create=False, + **kwargs): + super(Swish, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + raise RuntimeError('Not implemented') + else: + self.in_channels = out_channels + self.out_channels = out_channels + + def forward(self, x): + return SwishImplementation.apply(x) + + def __str__(self): + return 'Swish({})'.format(self.out_channels) + + def __repr__(self): + return 'Swish({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Swish.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('Swish('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return Swish( + out_channels=out_channels, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +bottom_basic_dict = { + 'AdaptiveAvgPool': AdaptiveAvgPool, + 'BN': BN, + 'ConvDW': ConvDW, + 'ConvKX': ConvKX, + 'ConvKXG2': ConvKXG2, + 'ConvKXG4': ConvKXG4, + 'ConvKXG8': ConvKXG8, + 'ConvKXG16': ConvKXG16, + 'ConvKXG32': ConvKXG32, + 'Flatten': Flatten, + 'Linear': Linear, + 'MaxPool': MaxPool, + 'PlainNetBasicBlockClass': PlainNetBasicBlockClass, + 'RELU': RELU, + 'SE': SE, + 'Swish': Swish, +} + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'MultiSumBlock': MultiSumBlock, + 'MultiCatBlock': MultiCatBlock, + 'ResBlock': ResBlock, + 'ResBlockProj': ResBlockProj, + 'Sequential': Sequential, + } + netblocks_dict.update(this_py_file_netblocks_dict) + netblocks_dict.update(bottom_basic_dict) + return netblocks_dict diff --git a/modelscope/models/cv/tinynas_classfication/global_utils.py b/modelscope/models/cv/tinynas_classfication/global_utils.py new file mode 100644 index 00000000..022c61a0 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/global_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + + +def smart_round(x, base=None): + if base is None: + if x > 32 * 8: + round_base = 32 + elif x > 16 * 8: + round_base = 16 + else: + round_base = 8 + else: + round_base = base + + return max(round_base, round(x / float(round_base)) * round_base) + + +def get_right_parentheses_index(s): + left_paren_count = 0 + for index, x in enumerate(s): + + if x == '(': + left_paren_count += 1 + elif x == ')': + left_paren_count -= 1 + if left_paren_count == 0: + return index + else: + pass + return None + + +def create_netblock_list_from_str_inner(s, + no_create=False, + netblocks_dict=None, + **kwargs): + block_list = [] + while len(s) > 0: + is_found_block_class = False + for the_block_class_name in netblocks_dict.keys(): + tmp_idx = s.find('(') + if tmp_idx > 0 and s[0:tmp_idx] == the_block_class_name: + is_found_block_class = True + the_block_class = netblocks_dict[the_block_class_name] + the_block, remaining_s = the_block_class.create_from_str( + s, no_create=no_create, **kwargs) + if the_block is not None: + block_list.append(the_block) + s = remaining_s + if len(s) > 0 and s[0] == ';': + return block_list, s[1:] + break + assert is_found_block_class + return block_list, '' + + +def create_netblock_list_from_str(s, + no_create=False, + netblocks_dict=None, + **kwargs): + the_list, remaining_s = create_netblock_list_from_str_inner( + s, no_create=no_create, netblocks_dict=netblocks_dict, **kwargs) + assert len(remaining_s) == 0 + return the_list diff --git a/modelscope/models/cv/tinynas_classfication/master_net.py b/modelscope/models/cv/tinynas_classfication/master_net.py new file mode 100644 index 00000000..e2bc47e0 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/master_net.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import torch +import torch.nn.functional as F +from torch import nn + +from . import basic_blocks, plain_net_utils + + +class PlainNet(plain_net_utils.PlainNet): + + def __init__(self, + argv=None, + opt=None, + num_classes=None, + plainnet_struct=None, + no_create=False, + no_reslink=None, + no_BN=None, + use_se=None, + dropout=None, + **kwargs): + + module_opt = None + + if no_BN is None: + if module_opt is not None: + no_BN = module_opt.no_BN + else: + no_BN = False + + if no_reslink is None: + if module_opt is not None: + no_reslink = module_opt.no_reslink + else: + no_reslink = False + + if use_se is None: + if module_opt is not None: + use_se = module_opt.use_se + else: + use_se = False + + if dropout is None: + if module_opt is not None: + self.dropout = module_opt.dropout + else: + self.dropout = None + else: + self.dropout = dropout + + super(PlainNet, self).__init__( + argv=argv, + opt=opt, + num_classes=num_classes, + plainnet_struct=plainnet_struct, + no_create=no_create, + no_reslink=no_reslink, + no_BN=no_BN, + use_se=use_se, + **kwargs) + self.last_channels = self.block_list[-1].out_channels + self.fc_linear = basic_blocks.Linear( + in_channels=self.last_channels, + out_channels=self.num_classes, + no_create=no_create) + + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + + for layer in self.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.eps = 1e-3 + + def forward(self, x): + output = x + for block_id, the_block in enumerate(self.block_list): + output = the_block(output) + if self.dropout is not None: + dropout_p = float(block_id) / len( + self.block_list) * self.dropout + output = F.dropout( + output, dropout_p, training=self.training, inplace=True) + + output = F.adaptive_avg_pool2d(output, output_size=1) + if self.dropout is not None: + output = F.dropout( + output, self.dropout, training=self.training, inplace=True) + output = torch.flatten(output, 1) + output = self.fc_linear(output) + return output diff --git a/modelscope/models/cv/tinynas_classfication/model_zoo.py b/modelscope/models/cv/tinynas_classfication/model_zoo.py new file mode 100644 index 00000000..a49b053b --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/model_zoo.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +from . import master_net + + +def get_zennet(): + model_plainnet_str = ( + 'SuperConvK3BNRELU(3,32,2,1)' + 'SuperResK1K5K1(32,80,2,32,1)SuperResK1K7K1(80,432,2,128,5)' + 'SuperResK1K7K1(432,640,2,192,3)SuperResK1K7K1(640,1008,1,160,5)' + 'SuperResK1K7K1(1008,976,1,160,4)SuperResK1K5K1(976,2304,2,384,5)' + 'SuperResK1K5K1(2304,2496,1,384,5)SuperConvK1BNRELU(2496,3072,1,1)') + use_SE = False + num_classes = 1000 + + model = master_net.PlainNet( + num_classes=num_classes, + plainnet_struct=model_plainnet_str, + use_se=use_SE) + + return model diff --git a/modelscope/models/cv/tinynas_classfication/plain_net_utils.py b/modelscope/models/cv/tinynas_classfication/plain_net_utils.py new file mode 100644 index 00000000..844535ed --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/plain_net_utils.py @@ -0,0 +1,89 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +from torch import nn + +from . import (basic_blocks, super_blocks, super_res_idwexkx, super_res_k1kxk1, + super_res_kxkx) +from .global_utils import create_netblock_list_from_str_inner + + +class PlainNet(nn.Module): + + def __init__(self, + argv=None, + opt=None, + num_classes=None, + plainnet_struct=None, + no_create=False, + **kwargs): + super(PlainNet, self).__init__() + self.argv = argv + self.opt = opt + self.num_classes = num_classes + self.plainnet_struct = plainnet_struct + + self.module_opt = None + + if self.num_classes is None: + self.num_classes = self.module_opt.num_classes + + if self.plainnet_struct is None and self.module_opt.plainnet_struct is not None: + self.plainnet_struct = self.module_opt.plainnet_struct + + if self.plainnet_struct is None: + if hasattr(opt, 'plainnet_struct_txt' + ) and opt.plainnet_struct_txt is not None: + plainnet_struct_txt = opt.plainnet_struct_txt + else: + plainnet_struct_txt = self.module_opt.plainnet_struct_txt + + if plainnet_struct_txt is not None: + with open(plainnet_struct_txt, 'r') as fid: + the_line = fid.readlines()[0].strip() + self.plainnet_struct = the_line + pass + + if self.plainnet_struct is None: + return + + the_s = self.plainnet_struct + + block_list, remaining_s = create_netblock_list_from_str_inner( + the_s, + netblocks_dict=_all_netblocks_dict_, + no_create=no_create, + **kwargs) + assert len(remaining_s) == 0 + + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + + def forward(self, x): + output = x + for the_block in self.block_list: + output = the_block(output) + return output + + def __str__(self): + s = '' + for the_block in self.block_list: + s += str(the_block) + return s + + def __repr__(self): + return str(self) + + +_all_netblocks_dict_ = {} +_all_netblocks_dict_ = basic_blocks.register_netblocks_dict( + _all_netblocks_dict_) +_all_netblocks_dict_ = super_blocks.register_netblocks_dict( + _all_netblocks_dict_) +_all_netblocks_dict_ = super_res_kxkx.register_netblocks_dict( + _all_netblocks_dict_) +_all_netblocks_dict_ = super_res_k1kxk1.register_netblocks_dict( + _all_netblocks_dict_) +_all_netblocks_dict_ = super_res_idwexkx.register_netblocks_dict( + _all_netblocks_dict_) diff --git a/modelscope/models/cv/tinynas_classfication/super_blocks.py b/modelscope/models/cv/tinynas_classfication/super_blocks.py new file mode 100644 index 00000000..25862255 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/super_blocks.py @@ -0,0 +1,228 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +from torch import nn + +from . import basic_blocks, global_utils +from .global_utils import get_right_parentheses_index + + +class PlainNetSuperBlockClass(basic_blocks.PlainNetBasicBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(PlainNetSuperBlockClass, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.sub_layers = sub_layers + self.no_create = no_create + self.block_list = None + self.module_list = None + + def forward(self, x): + output = x + for block in self.block_list: + output = block(output) + return output + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{},{})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.sub_layers) + + def get_output_resolution(self, input_resolution): + resolution = input_resolution + for block in self.block_list: + resolution = block.get_output_resolution(resolution) + return resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + sub_layers = int(param_str_split[3]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + block_name=tmp_block_name, + no_create=no_create, + **kwargs), s[idx + 1:] + + +class SuperConvKXBNRELU(PlainNetSuperBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + kernel_size=None, + no_create=False, + no_reslink=False, + no_BN=False, + **kwargs): + super(SuperConvKXBNRELU, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + if not self.no_BN: + inner_str = 'ConvKX({},{},{},{})BN({})RELU({})'.format( + last_channels, self.out_channels, self.kernel_size, + current_stride, self.out_channels, self.out_channels) + else: + inner_str = 'ConvKX({},{},{},{})RELU({})'.format( + last_channels, self.out_channels, self.kernel_size, + current_stride, self.out_channels) + full_str += inner_str + + last_channels = out_channels + current_stride = 1 + pass + + netblocks_dict = basic_blocks.register_netblocks_dict({}) + self.block_list = global_utils.create_netblock_list_from_str( + full_str, + no_create=no_create, + netblocks_dict=netblocks_dict, + no_reslink=no_reslink, + no_BN=no_BN) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, self.sub_layers) + + def __repr__(self): + return type( + self + ).__name__ + '({}|in={},out={},stride={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.sub_layers, self.kernel_size) + + +class SuperConvK1BNRELU(SuperConvKXBNRELU): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperConvK1BNRELU, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + kernel_size=1, + no_create=no_create, + **kwargs) + + +class SuperConvK3BNRELU(SuperConvKXBNRELU): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperConvK3BNRELU, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, + **kwargs) + + +class SuperConvK5BNRELU(SuperConvKXBNRELU): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperConvK5BNRELU, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, + **kwargs) + + +class SuperConvK7BNRELU(SuperConvKXBNRELU): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperConvK7BNRELU, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, + **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperConvK1BNRELU': SuperConvK1BNRELU, + 'SuperConvK3BNRELU': SuperConvK3BNRELU, + 'SuperConvK5BNRELU': SuperConvK5BNRELU, + 'SuperConvK7BNRELU': SuperConvK7BNRELU, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/modelscope/models/cv/tinynas_classfication/super_res_idwexkx.py b/modelscope/models/cv/tinynas_classfication/super_res_idwexkx.py new file mode 100644 index 00000000..7d005069 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/super_res_idwexkx.py @@ -0,0 +1,451 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +from torch import nn + +from . import basic_blocks, global_utils +from .global_utils import get_right_parentheses_index +from .super_blocks import PlainNetSuperBlockClass + + +class SuperResIDWEXKX(PlainNetSuperBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + kernel_size=None, + expension=None, + no_create=False, + no_reslink=False, + no_BN=False, + use_se=False, + **kwargs): + super(SuperResIDWEXKX, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.expension = expension + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + + self.use_se = use_se + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + dw_channels = global_utils.smart_round( + self.bottleneck_channels * self.expension, base=8) + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, + dw_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + + inner_str += 'ConvDW({},{},{})'.format(dw_channels, + self.kernel_size, + current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + if self.use_se: + inner_str += 'SE({})'.format(dw_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, + bottleneck_channels, 1, + 1) + if not self.no_BN: + inner_str += 'BN({})'.format(bottleneck_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format( + inner_str, self.out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, self.out_channels) + + else: + res_str = '{}RELU({})'.format(inner_str, self.out_channels) + + full_str += res_str + + inner_str = '' + dw_channels = global_utils.smart_round( + self.out_channels * self.expension, base=8) + inner_str += 'ConvKX({},{},{},{})'.format(bottleneck_channels, + dw_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + + inner_str += 'ConvDW({},{},{})'.format(dw_channels, + self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + if self.use_se: + inner_str += 'SE({})'.format(dw_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, + self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, self.out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, self.out_channels) + + full_str += res_str + last_channels = out_channels + current_stride = 1 + pass + + netblocks_dict = basic_blocks.register_netblocks_dict({}) + self.block_list = global_utils.create_netblock_list_from_str( + full_str, + netblocks_dict=netblocks_dict, + no_create=no_create, + no_reslink=no_reslink, + no_BN=no_BN, + **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type( + self + ).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers, self.kernel_size) + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + block_name=tmp_block_name, + **kwargs), s[idx + 1:] + + +class SuperResIDWE1K3(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE1K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + expension=1.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE2K3(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE2K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + expension=2.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE4K3(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE4K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + expension=4.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE6K3(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE6K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + expension=6.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE1K5(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE1K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + expension=1.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE2K5(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE2K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + expension=2.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE4K5(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE4K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + expension=4.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE6K5(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE6K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + expension=6.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE1K7(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE1K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + expension=1.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE2K7(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE2K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + expension=2.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE4K7(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE4K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + expension=4.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE6K7(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE6K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + expension=6.0, + no_create=no_create, + **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperResIDWE1K3': SuperResIDWE1K3, + 'SuperResIDWE2K3': SuperResIDWE2K3, + 'SuperResIDWE4K3': SuperResIDWE4K3, + 'SuperResIDWE6K3': SuperResIDWE6K3, + 'SuperResIDWE1K5': SuperResIDWE1K5, + 'SuperResIDWE2K5': SuperResIDWE2K5, + 'SuperResIDWE4K5': SuperResIDWE4K5, + 'SuperResIDWE6K5': SuperResIDWE6K5, + 'SuperResIDWE1K7': SuperResIDWE1K7, + 'SuperResIDWE2K7': SuperResIDWE2K7, + 'SuperResIDWE4K7': SuperResIDWE4K7, + 'SuperResIDWE6K7': SuperResIDWE6K7, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/modelscope/models/cv/tinynas_classfication/super_res_k1kxk1.py b/modelscope/models/cv/tinynas_classfication/super_res_k1kxk1.py new file mode 100644 index 00000000..3ca68742 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/super_res_k1kxk1.py @@ -0,0 +1,238 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +from torch import nn + +from . import basic_blocks, global_utils +from .global_utils import get_right_parentheses_index +from .super_blocks import PlainNetSuperBlockClass + + +class SuperResK1KXK1(PlainNetSuperBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + kernel_size=None, + no_create=False, + no_reslink=False, + no_BN=False, + use_se=False, + **kwargs): + super(SuperResK1KXK1, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, + self.bottleneck_channels, + 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.bottleneck_channels, + self.kernel_size, + current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + inner_str = '' + inner_str += 'ConvKX({},{},{},{})'.format(self.out_channels, + self.bottleneck_channels, + 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.bottleneck_channels, + self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + last_channels = out_channels + current_stride = 1 + pass + + netblocks_dict = basic_blocks.register_netblocks_dict({}) + self.block_list = global_utils.create_netblock_list_from_str( + full_str, + netblocks_dict=netblocks_dict, + no_create=no_create, + no_reslink=no_reslink, + no_BN=no_BN, + **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type( + self + ).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers, self.kernel_size) + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + block_name=tmp_block_name, + **kwargs), s[idx + 1:] + + +class SuperResK1K3K1(SuperResK1KXK1): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK1K3K1, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, + **kwargs) + + +class SuperResK1K5K1(SuperResK1KXK1): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK1K5K1, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, + **kwargs) + + +class SuperResK1K7K1(SuperResK1KXK1): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK1K7K1, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, + **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperResK1K3K1': SuperResK1K3K1, + 'SuperResK1K5K1': SuperResK1K5K1, + 'SuperResK1K7K1': SuperResK1K7K1, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/modelscope/models/cv/tinynas_classfication/super_res_kxkx.py b/modelscope/models/cv/tinynas_classfication/super_res_kxkx.py new file mode 100644 index 00000000..a694fdbe --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/super_res_kxkx.py @@ -0,0 +1,202 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +from torch import nn + +from . import basic_blocks, global_utils +from .global_utils import get_right_parentheses_index +from .super_blocks import PlainNetSuperBlockClass + + +class SuperResKXKX(PlainNetSuperBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + kernel_size=None, + no_create=False, + no_reslink=False, + no_BN=False, + use_se=False, + **kwargs): + super(SuperResKXKX, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, + self.bottleneck_channels, + self.kernel_size, + current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.out_channels, + self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + last_channels = out_channels + current_stride = 1 + pass + + netblocks_dict = basic_blocks.register_netblocks_dict({}) + self.block_list = global_utils.create_netblock_list_from_str( + full_str, + netblocks_dict=netblocks_dict, + no_create=no_create, + no_reslink=no_reslink, + no_BN=no_BN, + **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type( + self + ).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers, self.kernel_size) + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + block_name=tmp_block_name, + **kwargs), s[idx + 1:] + + +class SuperResK3K3(SuperResKXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK3K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, + **kwargs) + + +class SuperResK5K5(SuperResKXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK5K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, + **kwargs) + + +class SuperResK7K7(SuperResKXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK7K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, + **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperResK3K3': SuperResK3K3, + 'SuperResK5K5': SuperResK5K5, + 'SuperResK7K7': SuperResK7K7, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/modelscope/models/cv/tinynas_detection/__init__.py b/modelscope/models/cv/tinynas_detection/__init__.py new file mode 100644 index 00000000..6d696ac4 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .tinynas_detector import Tinynas_detector + from .tinynas_damoyolo import DamoYolo + +else: + _import_structure = { + 'tinynas_detector': ['TinynasDetector'], + 'tinynas_damoyolo': ['DamoYolo'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/tinynas_detection/backbone/__init__.py b/modelscope/models/cv/tinynas_detection/backbone/__init__.py new file mode 100644 index 00000000..186d06a3 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/backbone/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import copy + +from .darknet import CSPDarknet +from .tinynas import load_tinynas_net + + +def build_backbone(cfg): + backbone_cfg = copy.deepcopy(cfg) + name = backbone_cfg.pop('name') + if name == 'CSPDarknet': + return CSPDarknet(**backbone_cfg) + elif name == 'TinyNAS': + return load_tinynas_net(backbone_cfg) diff --git a/modelscope/models/cv/tinynas_detection/backbone/darknet.py b/modelscope/models/cv/tinynas_detection/backbone/darknet.py new file mode 100644 index 00000000..d3294f0d --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/backbone/darknet.py @@ -0,0 +1,126 @@ +# Copyright (c) Megvii Inc. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import torch +from torch import nn + +from ..core.base_ops import (BaseConv, CSPLayer, DWConv, Focus, ResLayer, + SPPBottleneck) + + +class CSPDarknet(nn.Module): + + def __init__( + self, + dep_mul, + wid_mul, + out_features=('dark3', 'dark4', 'dark5'), + depthwise=False, + act='silu', + reparam=False, + ): + super(CSPDarknet, self).__init__() + assert out_features, 'please provide output features of Darknet' + self.out_features = out_features + Conv = DWConv if depthwise else BaseConv + + base_channels = int(wid_mul * 64) # 64 + base_depth = max(round(dep_mul * 3), 1) # 3 + + # stem + # self.stem = Focus(3, base_channels, ksize=3, act=act) + self.stem = Focus(3, base_channels, 3, act=act) + + # dark2 + self.dark2 = nn.Sequential( + Conv(base_channels, base_channels * 2, 3, 2, act=act), + CSPLayer( + base_channels * 2, + base_channels * 2, + n=base_depth, + depthwise=depthwise, + act=act, + reparam=reparam, + ), + ) + + # dark3 + self.dark3 = nn.Sequential( + Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), + CSPLayer( + base_channels * 4, + base_channels * 4, + n=base_depth * 3, + depthwise=depthwise, + act=act, + reparam=reparam, + ), + ) + + # dark4 + self.dark4 = nn.Sequential( + Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), + CSPLayer( + base_channels * 8, + base_channels * 8, + n=base_depth * 3, + depthwise=depthwise, + act=act, + reparam=reparam, + ), + ) + + # dark5 + self.dark5 = nn.Sequential( + Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), + SPPBottleneck( + base_channels * 16, base_channels * 16, activation=act), + CSPLayer( + base_channels * 16, + base_channels * 16, + n=base_depth, + shortcut=False, + depthwise=depthwise, + act=act, + reparam=reparam, + ), + ) + + def init_weights(self, pretrain=None): + + if pretrain is None: + return + else: + pretrained_dict = torch.load( + pretrain, map_location='cpu')['state_dict'] + new_params = self.state_dict().copy() + for k, v in pretrained_dict.items(): + ks = k.split('.') + if ks[0] == 'fc' or ks[-1] == 'total_ops' or ks[ + -1] == 'total_params': + continue + else: + new_params[k] = v + + self.load_state_dict(new_params) + print(f' load pretrain backbone from {pretrain}') + + def forward(self, x): + outputs = {} + x = self.stem(x) + outputs['stem'] = x + x = self.dark2(x) + outputs['dark2'] = x + x = self.dark3(x) + outputs['dark3'] = x + x = self.dark4(x) + outputs['dark4'] = x + x = self.dark5(x) + outputs['dark5'] = x + features_out = [ + outputs['stem'], outputs['dark2'], outputs['dark3'], + outputs['dark4'], outputs['dark5'] + ] + + return features_out diff --git a/modelscope/models/cv/tinynas_detection/backbone/tinynas.py b/modelscope/models/cv/tinynas_detection/backbone/tinynas.py new file mode 100755 index 00000000..202bdd55 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/backbone/tinynas.py @@ -0,0 +1,359 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import torch +import torch.nn as nn + +from modelscope.utils.file_utils import read_file +from ..core.base_ops import Focus, SPPBottleneck, get_activation +from ..core.repvgg_block import RepVggBlock + + +class ConvKXBN(nn.Module): + + def __init__(self, in_c, out_c, kernel_size, stride): + super(ConvKXBN, self).__init__() + self.conv1 = nn.Conv2d( + in_c, + out_c, + kernel_size, + stride, (kernel_size - 1) // 2, + groups=1, + bias=False) + self.bn1 = nn.BatchNorm2d(out_c) + + def forward(self, x): + return self.bn1(self.conv1(x)) + + +class ConvKXBNRELU(nn.Module): + + def __init__(self, in_c, out_c, kernel_size, stride, act='silu'): + super(ConvKXBNRELU, self).__init__() + self.conv = ConvKXBN(in_c, out_c, kernel_size, stride) + if act is None: + self.activation_function = torch.relu + else: + self.activation_function = get_activation(act) + + def forward(self, x): + output = self.conv(x) + return self.activation_function(output) + + +class ResConvK1KX(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + force_resproj=False, + act='silu', + reparam=False): + super(ResConvK1KX, self).__init__() + self.stride = stride + self.conv1 = ConvKXBN(in_c, btn_c, 1, 1) + if not reparam: + self.conv2 = ConvKXBN(btn_c, out_c, 3, stride) + else: + self.conv2 = RepVggBlock( + btn_c, out_c, kernel_size, stride, act='identity') + + if act is None: + self.activation_function = torch.relu + else: + self.activation_function = get_activation(act) + + if stride == 2: + self.residual_downsample = nn.AvgPool2d(kernel_size=2, stride=2) + else: + self.residual_downsample = nn.Identity() + + if in_c != out_c or force_resproj: + self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) + else: + self.residual_proj = nn.Identity() + + def forward(self, x): + if self.stride != 2: + reslink = self.residual_downsample(x) + reslink = self.residual_proj(reslink) + + output = x + output = self.conv1(output) + output = self.activation_function(output) + output = self.conv2(output) + if self.stride != 2: + output = output + reslink + output = self.activation_function(output) + + return output + + +class SuperResConvK1KX(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + num_blocks, + with_spp=False, + act='silu', + reparam=False): + super(SuperResConvK1KX, self).__init__() + if act is None: + self.act = torch.relu + else: + self.act = get_activation(act) + self.block_list = nn.ModuleList() + for block_id in range(num_blocks): + if block_id == 0: + in_channels = in_c + out_channels = out_c + this_stride = stride + force_resproj = False # as a part of CSPLayer, DO NOT need this flag + this_kernel_size = kernel_size + else: + in_channels = out_c + out_channels = out_c + this_stride = 1 + force_resproj = False + this_kernel_size = kernel_size + the_block = ResConvK1KX( + in_channels, + out_channels, + btn_c, + this_kernel_size, + this_stride, + force_resproj, + act=act, + reparam=reparam) + self.block_list.append(the_block) + if block_id == 0 and with_spp: + self.block_list.append( + SPPBottleneck(out_channels, out_channels)) + + def forward(self, x): + output = x + for block in self.block_list: + output = block(output) + return output + + +class ResConvKXKX(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + force_resproj=False, + act='silu'): + super(ResConvKXKX, self).__init__() + self.stride = stride + if self.stride == 2: + self.downsampler = ConvKXBNRELU(in_c, out_c, 3, 2, act=act) + else: + self.conv1 = ConvKXBN(in_c, btn_c, kernel_size, 1) + self.conv2 = RepVggBlock( + btn_c, out_c, kernel_size, stride, act='identity') + + if act is None: + self.activation_function = torch.relu + else: + self.activation_function = get_activation(act) + + if stride == 2: + self.residual_downsample = nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.residual_downsample = nn.Identity() + + if in_c != out_c or force_resproj: + self.residual_proj = ConvKXBN(in_c, out_c, 1, 1) + else: + self.residual_proj = nn.Identity() + + def forward(self, x): + if self.stride == 2: + return self.downsampler(x) + reslink = self.residual_downsample(x) + reslink = self.residual_proj(reslink) + + output = x + output = self.conv1(output) + output = self.activation_function(output) + output = self.conv2(output) + + output = output + reslink + output = self.activation_function(output) + + return output + + +class SuperResConvKXKX(nn.Module): + + def __init__(self, + in_c, + out_c, + btn_c, + kernel_size, + stride, + num_blocks, + with_spp=False, + act='silu'): + super(SuperResConvKXKX, self).__init__() + if act is None: + self.act = torch.relu + else: + self.act = get_activation(act) + self.block_list = nn.ModuleList() + for block_id in range(num_blocks): + if block_id == 0: + in_channels = in_c + out_channels = out_c + this_stride = stride + force_resproj = False # as a part of CSPLayer, DO NOT need this flag + this_kernel_size = kernel_size + else: + in_channels = out_c + out_channels = out_c + this_stride = 1 + force_resproj = False + this_kernel_size = kernel_size + the_block = ResConvKXKX( + in_channels, + out_channels, + btn_c, + this_kernel_size, + this_stride, + force_resproj, + act=act) + self.block_list.append(the_block) + if block_id == 0 and with_spp: + self.block_list.append( + SPPBottleneck(out_channels, out_channels)) + + def forward(self, x): + output = x + for block in self.block_list: + output = block(output) + return output + + +class TinyNAS(nn.Module): + + def __init__(self, + structure_info=None, + out_indices=[0, 1, 2, 4, 5], + out_channels=[None, None, 128, 256, 512], + with_spp=False, + use_focus=False, + need_conv1=True, + act='silu', + reparam=False): + super(TinyNAS, self).__init__() + assert len(out_indices) == len(out_channels) + self.out_indices = out_indices + self.need_conv1 = need_conv1 + + self.block_list = nn.ModuleList() + if need_conv1: + self.conv1_list = nn.ModuleList() + for idx, block_info in enumerate(structure_info): + the_block_class = block_info['class'] + if the_block_class == 'ConvKXBNRELU': + if use_focus: + the_block = Focus( + block_info['in'], + block_info['out'], + block_info['k'], + act=act) + else: + the_block = ConvKXBNRELU( + block_info['in'], + block_info['out'], + block_info['k'], + block_info['s'], + act=act) + self.block_list.append(the_block) + elif the_block_class == 'SuperResConvK1KX': + spp = with_spp if idx == len(structure_info) - 1 else False + the_block = SuperResConvK1KX( + block_info['in'], + block_info['out'], + block_info['btn'], + block_info['k'], + block_info['s'], + block_info['L'], + spp, + act=act, + reparam=reparam) + self.block_list.append(the_block) + elif the_block_class == 'SuperResConvKXKX': + spp = with_spp if idx == len(structure_info) - 1 else False + the_block = SuperResConvKXKX( + block_info['in'], + block_info['out'], + block_info['btn'], + block_info['k'], + block_info['s'], + block_info['L'], + spp, + act=act) + self.block_list.append(the_block) + if need_conv1: + if idx in self.out_indices and out_channels[ + self.out_indices.index(idx)] is not None: + self.conv1_list.append( + nn.Conv2d(block_info['out'], + out_channels[self.out_indices.index(idx)], + 1)) + else: + self.conv1_list.append(None) + + def init_weights(self, pretrain=None): + pass + + def forward(self, x): + output = x + stage_feature_list = [] + for idx, block in enumerate(self.block_list): + output = block(output) + if idx in self.out_indices: + if self.need_conv1 and self.conv1_list[idx] is not None: + true_out = self.conv1_list[idx](output) + stage_feature_list.append(true_out) + else: + stage_feature_list.append(output) + return stage_feature_list + + +def load_tinynas_net(backbone_cfg): + # load masternet model to path + import ast + net_structure_str = read_file(backbone_cfg.structure_file) + struct_str = ''.join([x.strip() for x in net_structure_str]) + struct_info = ast.literal_eval(struct_str) + for layer in struct_info: + if 'nbitsA' in layer: + del layer['nbitsA'] + if 'nbitsW' in layer: + del layer['nbitsW'] + + model = TinyNAS( + structure_info=struct_info, + out_indices=backbone_cfg.out_indices, + out_channels=backbone_cfg.out_channels, + with_spp=backbone_cfg.with_spp, + use_focus=backbone_cfg.use_focus, + act=backbone_cfg.act, + need_conv1=backbone_cfg.need_conv1, + reparam=backbone_cfg.reparam) + + return model diff --git a/modelscope/models/cv/tinynas_detection/core/__init__.py b/modelscope/models/cv/tinynas_detection/core/__init__.py new file mode 100644 index 00000000..3dad5e72 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. diff --git a/modelscope/models/cv/tinynas_detection/core/base_ops.py b/modelscope/models/cv/tinynas_detection/core/base_ops.py new file mode 100644 index 00000000..62729ca2 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/base_ops.py @@ -0,0 +1,474 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .repvgg_block import RepVggBlock + + +class SiLU(nn.Module): + """export-friendly version of nn.SiLU()""" + + @staticmethod + def forward(x): + return x * torch.sigmoid(x) + + +def get_activation(name='silu', inplace=True): + if name == 'silu': + module = nn.SiLU(inplace=inplace) + elif name == 'relu': + module = nn.ReLU(inplace=inplace) + elif name == 'lrelu': + module = nn.LeakyReLU(0.1, inplace=inplace) + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + return module + + +def get_norm(name, out_channels, inplace=True): + if name == 'bn': + module = nn.BatchNorm2d(out_channels) + elif name == 'gn': + module = nn.GroupNorm(num_channels=out_channels, num_groups=32) + return module + + +class BaseConv(nn.Module): + """A Conv2d -> Batchnorm -> silu/leaky relu block""" + + def __init__(self, + in_channels, + out_channels, + ksize, + stride=1, + groups=1, + bias=False, + act='silu', + norm='bn'): + super().__init__() + # same padding + pad = (ksize - 1) // 2 + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=ksize, + stride=stride, + padding=pad, + groups=groups, + bias=bias, + ) + if norm is not None: + self.bn = get_norm(norm, out_channels, inplace=True) + if act is not None: + self.act = get_activation(act, inplace=True) + self.with_norm = norm is not None + self.with_act = act is not None + + def forward(self, x): + x = self.conv(x) + if self.with_norm: + # x = self.norm(x) + x = self.bn(x) + if self.with_act: + x = self.act(x) + return x + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class DepthWiseConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + ksize, + stride=1, + groups=None, + bias=False, + act='silu', + norm='bn'): + super().__init__() + padding = (ksize - 1) // 2 + self.depthwise = nn.Conv2d( + in_channels, + in_channels, + kernel_size=ksize, + stride=stride, + padding=padding, + groups=in_channels, + bias=bias, + ) + + self.pointwise = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias) + if norm is not None: + self.dwnorm = get_norm(norm, in_channels, inplace=True) + self.pwnorm = get_norm(norm, out_channels, inplace=True) + if act is not None: + self.act = get_activation(act, inplace=True) + + self.with_norm = norm is not None + self.with_act = act is not None + self.order = ['depthwise', 'dwnorm', 'pointwise', 'act'] + + def forward(self, x): + + for layer_name in self.order: + layer = self.__getattr__(layer_name) + if layer is not None: + x = layer(x) + return x + + +class DWConv(nn.Module): + """Depthwise Conv + Conv""" + + def __init__(self, in_channels, out_channels, ksize, stride=1, act='silu'): + super().__init__() + self.dconv = BaseConv( + in_channels, + in_channels, + ksize=ksize, + stride=stride, + groups=in_channels, + act=act, + ) + self.pconv = BaseConv( + in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) + + def forward(self, x): + x = self.dconv(x) + return self.pconv(x) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__( + self, + in_channels, + out_channels, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + reparam=False, + ): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + k_conv1 = 3 if reparam else 1 + self.conv1 = BaseConv( + in_channels, hidden_channels, k_conv1, stride=1, act=act) + if reparam: + self.conv2 = RepVggBlock( + hidden_channels, out_channels, 3, stride=1, act=act) + else: + self.conv2 = Conv( + hidden_channels, out_channels, 3, stride=1, act=act) + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + + +class ResLayer(nn.Module): + 'Residual layer with `in_channels` inputs.' + + def __init__(self, in_channels: int): + super().__init__() + mid_channels = in_channels // 2 + self.layer1 = BaseConv( + in_channels, mid_channels, ksize=1, stride=1, act='lrelu') + self.layer2 = BaseConv( + mid_channels, in_channels, ksize=3, stride=1, act='lrelu') + + def forward(self, x): + out = self.layer2(self.layer1(x)) + return x + out + + +class SPPBottleneck(nn.Module): + """Spatial pyramid pooling layer used in YOLOv3-SPP""" + + def __init__(self, + in_channels, + out_channels, + kernel_sizes=(5, 9, 13), + activation='silu'): + super().__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=activation) + self.m = nn.ModuleList([ + nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ]) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = BaseConv( + conv2_channels, out_channels, 1, stride=1, act=activation) + + def forward(self, x): + x = self.conv1(x) + x = torch.cat([x] + [m(x) for m in self.m], dim=1) + x = self.conv2(x) + return x + + +class CSPLayer(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=1, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + reparam=False, + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv3 = BaseConv( + 2 * hidden_channels, out_channels, 1, stride=1, act=act) + module_list = [ + Bottleneck( + hidden_channels, + hidden_channels, + shortcut, + 1.0, + depthwise, + act=act, + reparam=reparam) for _ in range(n) + ] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + x_2 = self.conv2(x) + x_1 = self.m(x_1) + x = torch.cat((x_1, x_2), dim=1) + return self.conv3(x) + + +class Focus(nn.Module): + """Focus width and height information into channel space.""" + + def __init__(self, + in_channels, + out_channels, + ksize=1, + stride=1, + act='silu'): + super().__init__() + self.conv = BaseConv( + in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) + patch_top_left = x[..., ::2, ::2] + patch_top_right = x[..., ::2, 1::2] + patch_bot_left = x[..., 1::2, ::2] + patch_bot_right = x[..., 1::2, 1::2] + x = torch.cat( + ( + patch_top_left, + patch_bot_left, + patch_top_right, + patch_bot_right, + ), + dim=1, + ) + return self.conv(x) + + +class fast_Focus(nn.Module): + + def __init__(self, + in_channels, + out_channels, + ksize=1, + stride=1, + act='silu'): + super(Focus, self).__init__() + self.conv1 = self.focus_conv(w1=1.0) + self.conv2 = self.focus_conv(w3=1.0) + self.conv3 = self.focus_conv(w2=1.0) + self.conv4 = self.focus_conv(w4=1.0) + + self.conv = BaseConv( + in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + return self.conv( + torch.cat( + [self.conv1(x), + self.conv2(x), + self.conv3(x), + self.conv4(x)], 1)) + + def focus_conv(self, w1=0.0, w2=0.0, w3=0.0, w4=0.0): + conv = nn.Conv2d(3, 3, 2, 2, groups=3, bias=False) + conv.weight = self.init_weights_constant(w1, w2, w3, w4) + conv.weight.requires_grad = False + return conv + + def init_weights_constant(self, w1=0.0, w2=0.0, w3=0.0, w4=0.0): + return nn.Parameter( + torch.tensor([[[[w1, w2], [w3, w4]]], [[[w1, w2], [w3, w4]]], + [[[w1, w2], [w3, w4]]]])) + + +# shufflenet block +def channel_shuffle(x, groups=2): + bat_size, channels, w, h = x.shape + group_c = channels // groups + x = x.view(bat_size, groups, group_c, w, h) + x = torch.transpose(x, 1, 2).contiguous() + x = x.view(bat_size, -1, w, h) + return x + + +def conv_1x1_bn(in_c, out_c, stride=1): + return nn.Sequential( + nn.Conv2d(in_c, out_c, 1, stride, 0, bias=False), + nn.BatchNorm2d(out_c), nn.ReLU(True)) + + +def conv_bn(in_c, out_c, stride=2): + return nn.Sequential( + nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False), + nn.BatchNorm2d(out_c), nn.ReLU(True)) + + +class ShuffleBlock(nn.Module): + + def __init__(self, in_c, out_c, downsample=False): + super(ShuffleBlock, self).__init__() + self.downsample = downsample + half_c = out_c // 2 + if downsample: + self.branch1 = nn.Sequential( + # 3*3 dw conv, stride = 2 + # nn.Conv2d(in_c, in_c, 3, 2, 1, groups=in_c, bias=False), + nn.Conv2d(in_c, in_c, 3, 1, 1, groups=in_c, bias=False), + nn.BatchNorm2d(in_c), + # 1*1 pw conv + nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True)) + + self.branch2 = nn.Sequential( + # 1*1 pw conv + nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True), + # 3*3 dw conv, stride = 2 + # nn.Conv2d(half_c, half_c, 3, 2, 1, groups=half_c, bias=False), + nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False), + nn.BatchNorm2d(half_c), + # 1*1 pw conv + nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True)) + else: + # in_c = out_c + assert in_c == out_c + + self.branch2 = nn.Sequential( + # 1*1 pw conv + nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True), + # 3*3 dw conv, stride = 1 + nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False), + nn.BatchNorm2d(half_c), + # 1*1 pw conv + nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False), + nn.BatchNorm2d(half_c), + nn.ReLU(True)) + + def forward(self, x): + out = None + if self.downsample: + # if it is downsampling, we don't need to do channel split + out = torch.cat((self.branch1(x), self.branch2(x)), 1) + else: + # channel split + channels = x.shape[1] + c = channels // 2 + x1 = x[:, :c, :, :] + x2 = x[:, c:, :, :] + out = torch.cat((x1, self.branch2(x2)), 1) + return channel_shuffle(out, 2) + + +class ShuffleCSPLayer(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=1, + shortcut=True, + expansion=0.5, + depthwise=False, + act='silu', + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = BaseConv( + in_channels, hidden_channels, 1, stride=1, act=act) + module_list = [ + Bottleneck( + hidden_channels, + hidden_channels, + shortcut, + 1.0, + depthwise, + act=act) for _ in range(n) + ] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + x_2 = self.conv2(x) + x_1 = self.m(x_1) + x = torch.cat((x_1, x_2), dim=1) + # add channel shuffle + return channel_shuffle(x, 2) diff --git a/modelscope/models/cv/tinynas_detection/core/neck_ops.py b/modelscope/models/cv/tinynas_detection/core/neck_ops.py new file mode 100644 index 00000000..7f481665 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/neck_ops.py @@ -0,0 +1,324 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Swish(nn.Module): + + def __init__(self, inplace=True): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + if self.inplace: + x.mul_(F.sigmoid(x)) + return x + else: + return x * F.sigmoid(x) + + +def get_activation(name='silu', inplace=True): + if name is None: + return nn.Identity() + + if isinstance(name, str): + if name == 'silu': + module = nn.SiLU(inplace=inplace) + elif name == 'relu': + module = nn.ReLU(inplace=inplace) + elif name == 'lrelu': + module = nn.LeakyReLU(0.1, inplace=inplace) + elif name == 'swish': + module = Swish(inplace=inplace) + elif name == 'hardsigmoid': + module = nn.Hardsigmoid(inplace=inplace) + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + return module + elif isinstance(name, nn.Module): + return name + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + + +class ConvBNLayer(nn.Module): + + def __init__(self, + ch_in, + ch_out, + filter_size=3, + stride=1, + groups=1, + padding=0, + act=None): + super(ConvBNLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + bias=False) + self.bn = nn.BatchNorm2d(ch_out, ) + self.act = get_activation(act, inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + + return x + + +class RepVGGBlock(nn.Module): + + def __init__(self, ch_in, ch_out, act='relu', deploy=False): + super(RepVGGBlock, self).__init__() + self.ch_in = ch_in + self.ch_out = ch_out + self.deploy = deploy + self.in_channels = ch_in + self.groups = 1 + if self.deploy is False: + self.rbr_dense = ConvBNLayer( + ch_in, ch_out, 3, stride=1, padding=1, act=None) + self.rbr_1x1 = ConvBNLayer( + ch_in, ch_out, 1, stride=1, padding=0, act=None) + # self.rbr_identity = nn.BatchNorm2d(num_features=ch_in) if ch_out == ch_in else None + self.rbr_identity = None + else: + self.rbr_reparam = nn.Conv2d( + in_channels=self.ch_in, + out_channels=self.ch_out, + kernel_size=3, + stride=1, + padding=1, + groups=1) + self.act = get_activation(act) if act is None or isinstance( + act, (str, dict)) else act + + def forward(self, x): + if self.deploy: + print('----------deploy----------') + y = self.rbr_reparam(x) + else: + if self.rbr_identity is None: + y = self.rbr_dense(x) + self.rbr_1x1(x) + else: + y = self.rbr_dense(x) + self.rbr_1x1(x) + self.rbr_identity(x) + + y = self.act(y) + return y + + def switch_to_deploy(self): + print('switch') + if not hasattr(self, 'rbr_reparam'): + # return + self.rbr_reparam = nn.Conv2d( + in_channels=self.ch_in, + out_channels=self.ch_out, + kernel_size=3, + stride=1, + padding=1, + groups=1) + print('switch') + kernel, bias = self.get_equivalent_kernel_bias() + self.rbr_reparam.weight.data = kernel + self.rbr_reparam.bias.data = bias + for para in self.parameters(): + para.detach_() + # self.__delattr__(self.rbr_dense) + # self.__delattr__(self.rbr_1x1) + self.__delattr__('rbr_dense') + self.__delattr__('rbr_1x1') + if hasattr(self, 'rbr_identity'): + self.__delattr__('rbr_identity') + if hasattr(self, 'id_tensor'): + self.__delattr__('id_tensor') + self.deploy = True + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) + kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) + return kernel3x3 + self._pad_1x1_to_3x3_tensor( + kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch): + if branch is None: + return 0, 0 + # if isinstance(branch, nn.Sequential): + if isinstance(branch, ConvBNLayer): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), + dtype=np.float32) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, 1, 1] = 1 + self.id_tensor = torch.from_numpy(kernel_value).to( + branch.weight.device) + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class BasicBlock(nn.Module): + + def __init__(self, ch_in, ch_out, act='relu', shortcut=True): + super(BasicBlock, self).__init__() + assert ch_in == ch_out + # self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act) + # self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act) + self.conv2 = RepVGGBlock(ch_in, ch_out, act=act) + self.shortcut = shortcut + + def forward(self, x): + # y = self.conv1(x) + y = self.conv2(x) + if self.shortcut: + return x + y + else: + return y + + +class BasicBlock_3x3(nn.Module): + + def __init__(self, ch_in, ch_out, act='relu', shortcut=True): + super(BasicBlock_3x3, self).__init__() + assert ch_in == ch_out + self.conv1 = ConvBNLayer( + ch_in, ch_out, 3, stride=1, padding=1, act=act) + # self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act) + self.conv2 = RepVGGBlock(ch_in, ch_out, act=act) + self.shortcut = shortcut + + def forward(self, x): + y = self.conv1(x) + y = self.conv2(y) + if self.shortcut: + return x + y + else: + return y + + +class BasicBlock_3x3_Reverse(nn.Module): + + def __init__(self, ch_in, ch_out, act='relu', shortcut=True): + super(BasicBlock_3x3_Reverse, self).__init__() + assert ch_in == ch_out + self.conv1 = ConvBNLayer( + ch_in, ch_out, 3, stride=1, padding=1, act=act) + # self.conv1 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=act) + self.conv2 = RepVGGBlock(ch_in, ch_out, act=act) + self.shortcut = shortcut + + def forward(self, x): + y = self.conv2(x) + y = self.conv1(y) + if self.shortcut: + return x + y + else: + return y + + +class SPP(nn.Module): + + def __init__( + self, + ch_in, + ch_out, + k, + pool_size, + act='swish', + ): + super(SPP, self).__init__() + self.pool = [] + for i, size in enumerate(pool_size): + pool = nn.MaxPool2d( + kernel_size=size, stride=1, padding=size // 2, ceil_mode=False) + self.add_module('pool{}'.format(i), pool) + self.pool.append(pool) + self.conv = ConvBNLayer(ch_in, ch_out, k, padding=k // 2, act=act) + + def forward(self, x): + outs = [x] + + for pool in self.pool: + outs.append(pool(x)) + y = torch.cat(outs, axis=1) + + y = self.conv(y) + return y + + +class CSPStage(nn.Module): + + def __init__(self, block_fn, ch_in, ch_out, n, act='swish', spp=False): + super(CSPStage, self).__init__() + + ch_mid = int(ch_out // 2) + self.conv1 = ConvBNLayer(ch_in, ch_mid, 1, act=act) + self.conv2 = ConvBNLayer(ch_in, ch_mid, 1, act=act) + # self.conv2 = ConvBNLayer(ch_in, ch_mid, 3, stride=1, padding=1, act=act) + self.convs = nn.Sequential() + + next_ch_in = ch_mid + for i in range(n): + if block_fn == 'BasicBlock': + self.convs.add_module( + str(i), + BasicBlock(next_ch_in, ch_mid, act=act, shortcut=False)) + elif block_fn == 'BasicBlock_3x3': + self.convs.add_module( + str(i), + BasicBlock_3x3(next_ch_in, ch_mid, act=act, shortcut=True)) + elif block_fn == 'BasicBlock_3x3_Reverse': + self.convs.add_module( + str(i), + BasicBlock_3x3_Reverse( + next_ch_in, ch_mid, act=act, shortcut=True)) + else: + raise NotImplementedError + if i == (n - 1) // 2 and spp: + self.convs.add_module( + 'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act)) + next_ch_in = ch_mid + # self.convs = nn.Sequential(*convs) + self.conv3 = ConvBNLayer(ch_mid * (n + 1), ch_out, 1, act=act) + + def forward(self, x): + y1 = self.conv1(x) + y2 = self.conv2(x) + + mid_out = [y1] + for conv in self.convs: + y2 = conv(y2) + mid_out.append(y2) + y = torch.cat(mid_out, axis=1) + y = self.conv3(y) + return y diff --git a/modelscope/models/cv/tinynas_detection/core/repvgg_block.py b/modelscope/models/cv/tinynas_detection/core/repvgg_block.py new file mode 100644 index 00000000..06966a4e --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/repvgg_block.py @@ -0,0 +1,205 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + + +def get_activation(name='silu', inplace=True): + if name == 'silu': + module = nn.SiLU(inplace=inplace) + elif name == 'relu': + module = nn.ReLU(inplace=inplace) + elif name == 'lrelu': + module = nn.LeakyReLU(0.1, inplace=inplace) + elif name == 'identity': + module = nn.Identity() + else: + raise AttributeError('Unsupported act type: {}'.format(name)) + return module + + +def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): + '''Basic cell for rep-style block, including conv and bn''' + result = nn.Sequential() + result.add_module( + 'conv', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False)) + result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) + return result + + +class RepVggBlock(nn.Module): + '''RepVggBlock is a basic rep-style block, including training and deploy status + This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py + ''' + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + padding_mode='zeros', + deploy=False, + use_se=False, + act='relu', + norm=None): + super(RepVggBlock, self).__init__() + """ Initialization of the class. + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 1 + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + padding_mode (string, optional): Default: 'zeros' + deploy: Whether to be deploy status or training status. Default: False + use_se: Whether to use se. Default: False + """ + self.deploy = deploy + self.groups = groups + self.in_channels = in_channels + self.out_channels = out_channels + + assert kernel_size == 3 + assert padding == 1 + + padding_11 = padding - kernel_size // 2 + + if isinstance(act, str): + self.nonlinearity = get_activation(act) + else: + self.nonlinearity = act + + if use_se: + raise NotImplementedError('se block not supported yet') + else: + self.se = nn.Identity() + + if deploy: + self.rbr_reparam = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=True, + padding_mode=padding_mode) + + else: + self.rbr_identity = None + self.rbr_dense = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups) + self.rbr_1x1 = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=padding_11, + groups=groups) + + def forward(self, inputs): + '''Forward process''' + if hasattr(self, 'rbr_reparam'): + return self.nonlinearity(self.se(self.rbr_reparam(inputs))) + + if self.rbr_identity is None: + id_out = 0 + else: + id_out = self.rbr_identity(inputs) + + return self.nonlinearity( + self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)) + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) + kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) + return kernel3x3 + self._pad_1x1_to_3x3_tensor( + kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch): + if branch is None: + return 0, 0 + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), + dtype=np.float32) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, 1, 1] = 1 + self.id_tensor = torch.from_numpy(kernel_value).to( + branch.weight.device) + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def switch_to_deploy(self): + if hasattr(self, 'rbr_reparam'): + return + kernel, bias = self.get_equivalent_kernel_bias() + self.rbr_reparam = nn.Conv2d( + in_channels=self.rbr_dense.conv.in_channels, + out_channels=self.rbr_dense.conv.out_channels, + kernel_size=self.rbr_dense.conv.kernel_size, + stride=self.rbr_dense.conv.stride, + padding=self.rbr_dense.conv.padding, + dilation=self.rbr_dense.conv.dilation, + groups=self.rbr_dense.conv.groups, + bias=True) + self.rbr_reparam.weight.data = kernel + self.rbr_reparam.bias.data = bias + for para in self.parameters(): + para.detach_() + self.__delattr__('rbr_dense') + self.__delattr__('rbr_1x1') + if hasattr(self, 'rbr_identity'): + self.__delattr__('rbr_identity') + if hasattr(self, 'id_tensor'): + self.__delattr__('id_tensor') + self.deploy = True diff --git a/modelscope/models/cv/tinynas_detection/core/utils.py b/modelscope/models/cv/tinynas_detection/core/utils.py new file mode 100644 index 00000000..482f12fb --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/core/utils.py @@ -0,0 +1,196 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import numpy as np +import torch +import torchvision + +__all__ = [ + 'filter_box', + 'postprocess_airdet', + 'bboxes_iou', + 'matrix_iou', + 'adjust_box_anns', + 'xyxy2xywh', + 'xyxy2cxcywh', +] + + +def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, + iou_thr, + max_num=100, + score_factors=None): + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_thr (float): NMS IoU threshold + max_num (int): if there are more than max_num bboxes after NMS, + only top max_num will be kept. + score_factors (Tensor): The factors multiplied to scores before + applying NMS + + Returns: + tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ + are 0-based. + """ + num_classes = multi_scores.size(1) + # exclude background category + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) + scores = multi_scores + # filter out boxes with low scores + valid_mask = scores > score_thr # 1000 * 80 bool + + # We use masked_select for ONNX exporting purpose, + # which is equivalent to bboxes = bboxes[valid_mask] + # (TODO): as ONNX does not support repeat now, + # we have to use this ugly code + # bboxes -> 1000, 4 + bboxes = torch.masked_select( + bboxes, + torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), + -1)).view(-1, 4) # mask-> 1000*80*4, 80000*4 + if score_factors is not None: + scores = scores * score_factors[:, None] + scores = torch.masked_select(scores, valid_mask) + labels = valid_mask.nonzero(as_tuple=False)[:, 1] + + if bboxes.numel() == 0: + bboxes = multi_bboxes.new_zeros((0, 5)) + labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) + scores = multi_bboxes.new_zeros((0, )) + + return bboxes, scores, labels + + keep = torchvision.ops.batched_nms(bboxes, scores, labels, iou_thr) + + if max_num > 0: + keep = keep[:max_num] + + return bboxes[keep], scores[keep], labels[keep] + + +def filter_box(output, scale_range): + """ + output: (N, 5+class) shape + """ + min_scale, max_scale = scale_range + w = output[:, 2] - output[:, 0] + h = output[:, 3] - output[:, 1] + keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale) + return output[keep] + + +def filter_results(boxlist, num_classes, nms_thre): + boxes = boxlist.bbox + scores = boxlist.get_field('scores') + cls = boxlist.get_field('labels') + nms_out_index = torchvision.ops.batched_nms( + boxes, + scores, + cls, + nms_thre, + ) + boxlist = boxlist[nms_out_index] + + return boxlist + + +def postprocess_airdet(prediction, + num_classes, + conf_thre=0.7, + nms_thre=0.45, + imgs=None): + box_corner = prediction.new(prediction.shape) + box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = box_corner[:, :, :4] + output = [None for _ in range(len(prediction))] + for i, image_pred in enumerate(prediction): + # If none are remaining => process next image + if not image_pred.size(0): + continue + multi_bboxes = image_pred[:, :4] + multi_scores = image_pred[:, 5:] + detections, scores, labels = multiclass_nms(multi_bboxes, multi_scores, + conf_thre, nms_thre, 500) + detections = torch.cat( + (detections, scores[:, None], scores[:, None], labels[:, None]), + dim=1) + + if output[i] is None: + output[i] = detections + else: + output[i] = torch.cat((output[i], detections)) + return output + + +def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): + if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: + raise IndexError + + if xyxy: + tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) + br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) + area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) + area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) + else: + tl = torch.max( + (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), + ) + br = torch.min( + (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), + ) + + area_a = torch.prod(bboxes_a[:, 2:], 1) + area_b = torch.prod(bboxes_b[:, 2:], 1) + en = (tl < br).type(tl.type()).prod(dim=2) + area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) + return area_i / (area_a[:, None] + area_b - area_i) + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i + 1e-12) + + +def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max): + bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max) + bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max) + return bbox + + +def xyxy2xywh(bboxes): + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + return bboxes + + +def xyxy2cxcywh(bboxes): + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] + bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] + bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 + bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 + return bboxes diff --git a/modelscope/models/cv/tinynas_detection/detector.py b/modelscope/models/cv/tinynas_detection/detector.py new file mode 100644 index 00000000..7aff2167 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/detector.py @@ -0,0 +1,192 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import os.path as osp +import pickle + +import cv2 +import torch +import torch.nn as nn +import torchvision + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .backbone import build_backbone +from .head import build_head +from .neck import build_neck +from .utils import parse_config + + +class SingleStageDetector(TorchModel): + """ + The base class of single stage detector. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + """ + init model by cfg + """ + super().__init__(model_dir, *args, **kwargs) + + config_path = osp.join(model_dir, self.config_name) + config = parse_config(config_path) + self.cfg = config + model_path = osp.join(model_dir, config.model.name) + label_map = osp.join(model_dir, config.model.class_map) + self.label_map = pickle.load(open(label_map, 'rb')) + self.size_divisible = config.dataset.size_divisibility + self.num_classes = config.model.head.num_classes + self.conf_thre = config.model.head.nms_conf_thre + self.nms_thre = config.model.head.nms_iou_thre + + if self.cfg.model.backbone.name == 'TinyNAS': + self.cfg.model.backbone.structure_file = osp.join( + model_dir, self.cfg.model.backbone.structure_file) + self.backbone = build_backbone(self.cfg.model.backbone) + self.neck = build_neck(self.cfg.model.neck) + self.head = build_head(self.cfg.model.head) + self.apply(self.init_bn) + + self.load_pretrain_model(model_path) + + def load_pretrain_model(self, pretrain_model): + + state_dict = torch.load(pretrain_model, map_location='cpu')['model'] + new_state_dict = {} + for k, v in state_dict.items(): + k = k.replace('module.', '') + new_state_dict[k] = v + self.load_state_dict(new_state_dict, strict=True) + + def init_bn(self, M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + def inference(self, x): + + if self.training: + return self.forward_train(x) + else: + return self.forward_eval(x) + + def forward_train(self, x): + + pass + + def forward_eval(self, x): + + x = self.backbone(x) + x = self.neck(x) + prediction = self.head(x) + + return prediction + + def preprocess(self, image): + image = torch.from_numpy(image).type(torch.float32) + image = image.permute(2, 0, 1) + shape = image.shape # c, h, w + if self.size_divisible > 0: + import math + stride = self.size_divisible + shape = list(shape) + shape[1] = int(math.ceil(shape[1] / stride) * stride) + shape[2] = int(math.ceil(shape[2] / stride) * stride) + shape = tuple(shape) + pad_img = image.new(*shape).zero_() + pad_img[:, :image.shape[1], :image.shape[2]].copy_(image) + pad_img = pad_img.unsqueeze(0) + + return pad_img + + def postprocess(self, preds): + bboxes, scores, labels_idx = postprocess_gfocal( + preds, self.num_classes, self.conf_thre, self.nms_thre) + bboxes = bboxes.cpu().numpy() + scores = scores.cpu().numpy() + labels_idx = labels_idx.cpu().numpy() + labels = [self.label_map[idx + 1][0]['name'] for idx in labels_idx] + + return (bboxes, scores, labels) + + +def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, + iou_thr, + max_num=100, + score_factors=None): + """NMS for multi-class bboxes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + nms_thr (float): NMS IoU threshold + max_num (int): if there are more than max_num bboxes after NMS, + only top max_num will be kept. + score_factors (Tensor): The factors multiplied to scores before + applying NMS + + Returns: + tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ + are 0-based. + """ + num_classes = multi_scores.size(1) + # exclude background category + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4) + scores = multi_scores + # filter out boxes with low scores + valid_mask = scores > score_thr # 1000 * 80 bool + + # We use masked_select for ONNX exporting purpose, + # which is equivalent to bboxes = bboxes[valid_mask] + # (TODO): as ONNX does not support repeat now, + # we have to use this ugly code + # bboxes -> 1000, 4 + bboxes = torch.masked_select( + bboxes, + torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), + -1)).view(-1, 4) # mask-> 1000*80*4, 80000*4 + if score_factors is not None: + scores = scores * score_factors[:, None] + scores = torch.masked_select(scores, valid_mask) + labels = valid_mask.nonzero(as_tuple=False)[:, 1] + + if bboxes.numel() == 0: + bboxes = multi_bboxes.new_zeros((0, 5)) + labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) + scores = multi_bboxes.new_zeros((0, )) + + return bboxes, scores, labels + + keep = torchvision.ops.batched_nms(bboxes, scores, labels, iou_thr) + + if max_num > 0: + keep = keep[:max_num] + + return bboxes[keep], scores[keep], labels[keep] + + +def postprocess_gfocal(prediction, num_classes, conf_thre=0.05, nms_thre=0.7): + assert prediction.shape[0] == 1 + for i, image_pred in enumerate(prediction): + # If none are remaining => process next image + if not image_pred.size(0): + continue + multi_bboxes = image_pred[:, :4] + multi_scores = image_pred[:, 4:] + detections, scores, labels = multiclass_nms(multi_bboxes, multi_scores, + conf_thre, nms_thre, 500) + + return detections, scores, labels diff --git a/modelscope/models/cv/tinynas_detection/head/__init__.py b/modelscope/models/cv/tinynas_detection/head/__init__.py new file mode 100644 index 00000000..f870fae1 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/head/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import copy + +from .gfocal_v2_tiny import GFocalHead_Tiny + + +def build_head(cfg): + + head_cfg = copy.deepcopy(cfg) + name = head_cfg.pop('name') + if name == 'GFocalV2': + return GFocalHead_Tiny(**head_cfg) + else: + raise NotImplementedError diff --git a/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py b/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py new file mode 100644 index 00000000..66904ed1 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/head/gfocal_v2_tiny.py @@ -0,0 +1,374 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import functools +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..core.base_ops import BaseConv, DWConv + + +class Scale(nn.Module): + + def __init__(self, scale=1.0): + super(Scale, self).__init__() + self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) + + def forward(self, x): + return x * self.scale + + +def multi_apply(func, *args, **kwargs): + + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def xyxy2CxCywh(xyxy, size=None): + x1 = xyxy[..., 0] + y1 = xyxy[..., 1] + x2 = xyxy[..., 2] + y2 = xyxy[..., 3] + + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + + w = x2 - x1 + h = y2 - y1 + if size is not None: + w = w.clamp(min=0, max=size[1]) + h = h.clamp(min=0, max=size[0]) + return torch.stack([cx, cy, w, h], axis=-1) + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + """ + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return torch.stack([x1, y1, x2, y2], -1) + + +def bbox2distance(points, bbox, max_dis=None, eps=0.1): + """Decode bounding box based on distances. + """ + left = points[:, 0] - bbox[:, 0] + top = points[:, 1] - bbox[:, 1] + right = bbox[:, 2] - points[:, 0] + bottom = bbox[:, 3] - points[:, 1] + if max_dis is not None: + left = left.clamp(min=0, max=max_dis - eps) + top = top.clamp(min=0, max=max_dis - eps) + right = right.clamp(min=0, max=max_dis - eps) + bottom = bottom.clamp(min=0, max=max_dis - eps) + return torch.stack([left, top, right, bottom], -1) + + +class Integral(nn.Module): + """A fixed layer for calculating integral result from distribution. + """ + + def __init__(self, reg_max=16): + super(Integral, self).__init__() + self.reg_max = reg_max + self.register_buffer('project', + torch.linspace(0, self.reg_max, self.reg_max + 1)) + + def forward(self, x): + """Forward feature from the regression head to get integral result of + bounding box location. + """ + shape = x.size() + x = F.softmax(x.reshape(*shape[:-1], 4, self.reg_max + 1), dim=-1) + b, nb, ne, _ = x.size() + x = x.reshape(b * nb * ne, self.reg_max + 1) + y = self.project.type_as(x).unsqueeze(1) + x = torch.matmul(x, y).reshape(b, nb, 4) + return x + + +class GFocalHead_Tiny(nn.Module): + """Ref to Generalized Focal Loss V2: Learning Reliable Localization Quality + Estimation for Dense Object Detection. + """ + + def __init__( + self, + num_classes, + in_channels, + stacked_convs=4, # 4 + feat_channels=256, + reg_max=12, + reg_topk=4, + reg_channels=64, + strides=[8, 16, 32], + add_mean=True, + norm='gn', + act='relu', + start_kernel_size=3, + conv_groups=1, + conv_type='BaseConv', + simOTA_cls_weight=1.0, + simOTA_iou_weight=3.0, + octbase=8, + simlqe=False, + use_lqe=True, + **kwargs): + self.simlqe = simlqe + self.num_classes = num_classes + self.in_channels = in_channels + self.strides = strides + self.use_lqe = use_lqe + self.feat_channels = feat_channels if isinstance(feat_channels, list) \ + else [feat_channels] * len(self.strides) + + self.cls_out_channels = num_classes + 1 # add 1 for keep consistance with former models + # and will be deprecated in future. + self.stacked_convs = stacked_convs + self.conv_groups = conv_groups + self.reg_max = reg_max + self.reg_topk = reg_topk + self.reg_channels = reg_channels + self.add_mean = add_mean + self.total_dim = reg_topk + self.start_kernel_size = start_kernel_size + + self.norm = norm + self.act = act + self.conv_module = DWConv if conv_type == 'DWConv' else BaseConv + + if add_mean: + self.total_dim += 1 + + super(GFocalHead_Tiny, self).__init__() + self.integral = Integral(self.reg_max) + + self._init_layers() + + def _build_not_shared_convs(self, in_channel, feat_channels): + self.relu = nn.ReLU(inplace=True) + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + + for i in range(self.stacked_convs): + chn = feat_channels if i > 0 else in_channel + kernel_size = 3 if i > 0 else self.start_kernel_size + cls_convs.append( + self.conv_module( + chn, + feat_channels, + kernel_size, + stride=1, + groups=self.conv_groups, + norm=self.norm, + act=self.act)) + reg_convs.append( + self.conv_module( + chn, + feat_channels, + kernel_size, + stride=1, + groups=self.conv_groups, + norm=self.norm, + act=self.act)) + if self.use_lqe: + if not self.simlqe: + conf_vector = [ + nn.Conv2d(4 * self.total_dim, self.reg_channels, 1) + ] + else: + conf_vector = [ + nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1) + ] + conf_vector += [self.relu] + conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()] + reg_conf = nn.Sequential(*conf_vector) + else: + reg_conf = None + + return cls_convs, reg_convs, reg_conf + + def _init_layers(self): + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.reg_confs = nn.ModuleList() + + for i in range(len(self.strides)): + cls_convs, reg_convs, reg_conf = self._build_not_shared_convs( + self.in_channels[i], self.feat_channels[i]) + self.cls_convs.append(cls_convs) + self.reg_convs.append(reg_convs) + self.reg_confs.append(reg_conf) + + self.gfl_cls = nn.ModuleList([ + nn.Conv2d( + self.feat_channels[i], self.cls_out_channels, 3, padding=1) + for i in range(len(self.strides)) + ]) + + self.gfl_reg = nn.ModuleList([ + nn.Conv2d( + self.feat_channels[i], 4 * (self.reg_max + 1), 3, padding=1) + for i in range(len(self.strides)) + ]) + + self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) + + def forward(self, + xin, + labels=None, + imgs=None, + conf_thre=0.05, + nms_thre=0.7): + + # prepare labels during training + b, c, h, w = xin[0].shape + if labels is not None: + gt_bbox_list = [] + gt_cls_list = [] + for label in labels: + gt_bbox_list.append(label.bbox) + gt_cls_list.append((label.get_field('labels') + - 1).long()) # labels starts from 1 + + # prepare priors for label assignment and bbox decode + mlvl_priors_list = [ + self.get_single_level_center_priors( + xin[i].shape[0], + xin[i].shape[-2:], + stride, + dtype=torch.float32, + device=xin[0].device) for i, stride in enumerate(self.strides) + ] + mlvl_priors = torch.cat(mlvl_priors_list, dim=1) + + # forward for bboxes and classification prediction + cls_scores, bbox_preds = multi_apply( + self.forward_single, + xin, + self.cls_convs, + self.reg_convs, + self.gfl_cls, + self.gfl_reg, + self.reg_confs, + self.scales, + ) + flatten_cls_scores = torch.cat(cls_scores, dim=1) + flatten_bbox_preds = torch.cat(bbox_preds, dim=1) + + # calculating losses or bboxes decoded + if self.training: + loss = self.loss(flatten_cls_scores, flatten_bbox_preds, + gt_bbox_list, gt_cls_list, mlvl_priors) + return loss + else: + output = self.get_bboxes(flatten_cls_scores, flatten_bbox_preds, + mlvl_priors) + return output + + def forward_single(self, x, cls_convs, reg_convs, gfl_cls, gfl_reg, + reg_conf, scale): + """Forward feature of a single scale level. + + """ + cls_feat = x + reg_feat = x + + for cls_conv in cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in reg_convs: + reg_feat = reg_conv(reg_feat) + + bbox_pred = scale(gfl_reg(reg_feat)).float() + N, C, H, W = bbox_pred.size() + prob = F.softmax( + bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2) + if self.use_lqe: + if not self.simlqe: + prob_topk, _ = prob.topk(self.reg_topk, dim=2) + + if self.add_mean: + stat = torch.cat( + [prob_topk, + prob_topk.mean(dim=2, keepdim=True)], + dim=2) + else: + stat = prob_topk + + quality_score = reg_conf( + stat.reshape(N, 4 * self.total_dim, H, W)) + else: + quality_score = reg_conf( + bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W)) + + cls_score = gfl_cls(cls_feat).sigmoid() * quality_score + else: + cls_score = gfl_cls(cls_feat).sigmoid() + + flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2) + flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2) + return flatten_cls_score, flatten_bbox_pred + + def get_single_level_center_priors(self, batch_size, featmap_size, stride, + dtype, device): + + h, w = featmap_size + x_range = (torch.arange(0, int(w), dtype=dtype, + device=device)) * stride + y_range = (torch.arange(0, int(h), dtype=dtype, + device=device)) * stride + + x = x_range.repeat(h, 1) + y = y_range.unsqueeze(-1).repeat(1, w) + + y = y.flatten() + x = x.flatten() + strides = x.new_full((x.shape[0], ), stride) + priors = torch.stack([x, y, strides, strides], dim=-1) + + return priors.unsqueeze(0).repeat(batch_size, 1, 1) + + def sample(self, assign_result, gt_bboxes): + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert pos_assigned_gt_inds.numel() == 0 + pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] + + return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds + + def get_bboxes(self, + cls_preds, + reg_preds, + mlvl_center_priors, + img_meta=None): + + dis_preds = self.integral(reg_preds) * mlvl_center_priors[..., 2, None] + bboxes = distance2bbox(mlvl_center_priors[..., :2], dis_preds) + + res = torch.cat([bboxes, cls_preds[..., 0:self.num_classes]], dim=-1) + + return res diff --git a/modelscope/models/cv/tinynas_detection/neck/__init__.py b/modelscope/models/cv/tinynas_detection/neck/__init__.py new file mode 100644 index 00000000..3c418c29 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/neck/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import copy + +from .giraffe_fpn import GiraffeNeck +from .giraffe_fpn_v2 import GiraffeNeckV2 + + +def build_neck(cfg): + neck_cfg = copy.deepcopy(cfg) + name = neck_cfg.pop('name') + if name == 'GiraffeNeck': + return GiraffeNeck(**neck_cfg) + elif name == 'GiraffeNeckV2': + return GiraffeNeckV2(**neck_cfg) diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py new file mode 100644 index 00000000..289fdfd2 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_config.py @@ -0,0 +1,235 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import collections +import itertools +import os + +import networkx as nx +from omegaconf import OmegaConf + +Node = collections.namedtuple('Node', ['id', 'inputs', 'type']) + + +def get_graph_info(graph): + input_nodes = [] + output_nodes = [] + Nodes = [] + for node in range(graph.number_of_nodes()): + tmp = list(graph.neighbors(node)) + tmp.sort() + type = -1 + if node < tmp[0]: + input_nodes.append(node) + type = 0 + if node > tmp[-1]: + output_nodes.append(node) + type = 1 + Nodes.append(Node(node, [n for n in tmp if n < node], type)) + return Nodes, input_nodes, output_nodes + + +def nodeid_trans(id, cur_level, num_levels): + if id % 2 == 1: + gap = int(((id + 1) // 2) * num_levels * 2) + else: + a = (num_levels - cur_level) * 2 - 1 + b = ((id + 1) // 2) * num_levels * 2 + gap = int(a + b) + return cur_level + gap + + +def gen_log2n_graph_file(log2n_graph_file, depth_multiplier): + f = open(log2n_graph_file, 'w') + for i in range(depth_multiplier): + for j in [1, 2, 4, 8, 16, 32]: + if i - j < 0: + break + else: + f.write('%d,%d\n' % (i - j, i)) + f.close() + + +def get_log2n_graph(depth_multiplier): + nodes = [] + connnections = [] + + for i in range(depth_multiplier): + nodes.append(i) + for j in [1, 2, 4, 8, 16, 32]: + if i - j < 0: + break + else: + connnections.append((i - j, i)) + return nodes, connnections + + +def get_dense_graph(depth_multiplier): + nodes = [] + connections = [] + + for i in range(depth_multiplier): + nodes.append(i) + for j in range(i): + connections.append((j, i)) + return nodes, connections + + +def giraffeneck_config(min_level, + max_level, + weight_method=None, + depth_multiplier=5, + with_backslash=False, + with_slash=False, + with_skip_connect=False, + skip_connect_type='dense'): + """Graph config with log2n merge and panet""" + if skip_connect_type == 'dense': + nodes, connections = get_dense_graph(depth_multiplier) + elif skip_connect_type == 'log2n': + nodes, connections = get_log2n_graph(depth_multiplier) + graph = nx.Graph() + graph.add_nodes_from(nodes) + graph.add_edges_from(connections) + + drop_node = [] + nodes, input_nodes, output_nodes = get_graph_info(graph) + + weight_method = weight_method or 'fastattn' + + num_levels = max_level - min_level + 1 + node_ids = {min_level + i: [i] for i in range(num_levels)} + node_ids_per_layer = {} + + pnodes = {} + + def update_drop_node(new_id, input_offsets): + if new_id not in drop_node: + new_id = new_id + else: + while new_id in drop_node: + if new_id in pnodes: + for n in pnodes[new_id]['inputs_offsets']: + if n not in input_offsets and n not in drop_node: + input_offsets.append(n) + new_id = new_id - 1 + if new_id not in input_offsets: + input_offsets.append(new_id) + + # top-down layer + for i in range(max_level, min_level - 1, -1): + node_ids_per_layer[i] = [] + for id, node in enumerate(nodes): + input_offsets = [] + if id in input_nodes: + input_offsets.append(node_ids[i][0]) + else: + if with_skip_connect: + for input_id in node.inputs: + new_id = nodeid_trans(input_id, i - min_level, + num_levels) + update_drop_node(new_id, input_offsets) + + # add top2down + new_id = nodeid_trans(id, i - min_level, num_levels) + + # add backslash node + def cal_backslash_node(id): + ind = id // num_levels + mod = id % num_levels + if ind % 2 == 0: # even + if mod == (num_levels - 1): + last = -1 + else: + last = (ind - 1) * num_levels + ( + num_levels - 1 - mod - 1) + else: # odd + if mod == 0: + last = -1 + else: + last = (ind - 1) * num_levels + ( + num_levels - 1 - mod + 1) + + return last + + # add slash node + def cal_slash_node(id): + ind = id // num_levels + mod = id % num_levels + if ind % 2 == 1: # odd + if mod == (num_levels - 1): + last = -1 + else: + last = (ind - 1) * num_levels + ( + num_levels - 1 - mod - 1) + else: # even + if mod == 0: + last = -1 + else: + last = (ind - 1) * num_levels + ( + num_levels - 1 - mod + 1) + + return last + + # add last node + last = new_id - 1 + update_drop_node(last, input_offsets) + + if with_backslash: + backslash = cal_backslash_node(new_id) + if backslash != -1 and backslash not in input_offsets: + input_offsets.append(backslash) + + if with_slash: + slash = cal_slash_node(new_id) + if slash != -1 and slash not in input_offsets: + input_offsets.append(slash) + + if new_id in drop_node: + input_offsets = [] + + pnodes[new_id] = { + 'reduction': 1 << i, + 'inputs_offsets': input_offsets, + 'weight_method': weight_method, + 'is_out': 0, + } + + input_offsets = [] + for out_id in output_nodes: + new_id = nodeid_trans(out_id, i - min_level, num_levels) + input_offsets.append(new_id) + + pnodes[node_ids[i][0] + num_levels * (len(nodes) + 1)] = { + 'reduction': 1 << i, + 'inputs_offsets': input_offsets, + 'weight_method': weight_method, + 'is_out': 1, + } + + pnodes = dict(sorted(pnodes.items(), key=lambda x: x[0])) + return pnodes + + +def get_graph_config(fpn_name, + min_level=3, + max_level=7, + weight_method='concat', + depth_multiplier=5, + with_backslash=False, + with_slash=False, + with_skip_connect=False, + skip_connect_type='dense'): + name_to_config = { + 'giraffeneck': + giraffeneck_config( + min_level=min_level, + max_level=max_level, + weight_method=weight_method, + depth_multiplier=depth_multiplier, + with_backslash=with_backslash, + with_slash=with_slash, + with_skip_connect=with_skip_connect, + skip_connect_type=skip_connect_type), + } + return name_to_config[fpn_name] diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py new file mode 100644 index 00000000..b7087779 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn.py @@ -0,0 +1,661 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import logging +import math +from collections import OrderedDict +from functools import partial +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm import create_model +from timm.models.layers import (Swish, create_conv2d, create_pool2d, + get_act_layer) + +from ..core.base_ops import CSPLayer, ShuffleBlock, ShuffleCSPLayer +from .giraffe_config import get_graph_config + +_ACT_LAYER = Swish + + +class SequentialList(nn.Sequential): + """ This module exists to work around torchscript typing issues list -> list""" + + def __init__(self, *args): + super(SequentialList, self).__init__(*args) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + for module in self: + x = module(x) + return x + + +class ConvBnAct2d(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + padding='', + bias=False, + norm_layer=nn.BatchNorm2d, + act_layer=_ACT_LAYER): + super(ConvBnAct2d, self).__init__() + + self.conv = create_conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias) + self.bn = None if norm_layer is None else norm_layer(out_channels) + self.act = None if act_layer is None else act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + + +class SeparableConv2d(nn.Module): + """ Separable Conv + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + padding='', + bias=False, + channel_multiplier=1.0, + pw_kernel_size=1, + norm_layer=nn.BatchNorm2d, + act_layer=_ACT_LAYER): + super(SeparableConv2d, self).__init__() + self.conv_dw = create_conv2d( + in_channels, + int(in_channels * channel_multiplier), + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), + out_channels, + pw_kernel_size, + padding=padding, + bias=bias) + + self.bn = None if norm_layer is None else norm_layer(out_channels) + self.act = None if act_layer is None else act_layer(inplace=True) + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + if self.bn is not None: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + + +def _init_weight( + m, + n='', +): + """ Weight initialization as per Tensorflow official implementations. + """ + + def _fan_in_out(w, groups=1): + dimensions = w.dim() + if dimensions < 2: + raise ValueError( + 'Fan in and fan out can not be computed for tensor with fewer than 2 dimensions' + ) + num_input_fmaps = w.size(1) + num_output_fmaps = w.size(0) + receptive_field_size = 1 + if w.dim() > 2: + receptive_field_size = w[0][0].numel() + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + fan_out //= groups + return fan_in, fan_out + + def _glorot_uniform(w, gain=1, groups=1): + fan_in, fan_out = _fan_in_out(w, groups) + gain /= max(1., (fan_in + fan_out) / 2.) # fan avg + limit = math.sqrt(3.0 * gain) + w.data.uniform_(-limit, limit) + + def _variance_scaling(w, gain=1, groups=1): + fan_in, fan_out = _fan_in_out(w, groups) + gain /= max(1., fan_in) # fan in + std = math.sqrt(gain) + w.data.normal_(std=std) + + if isinstance(m, SeparableConv2d): + if 'box_net' in n or 'class_net' in n: + _variance_scaling(m.conv_dw.weight, groups=m.conv_dw.groups) + _variance_scaling(m.conv_pw.weight) + if m.conv_pw.bias is not None: + if 'class_net.predict' in n: + m.conv_pw.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) + else: + m.conv_pw.bias.data.zero_() + else: + _glorot_uniform(m.conv_dw.weight, groups=m.conv_dw.groups) + _glorot_uniform(m.conv_pw.weight) + if m.conv_pw.bias is not None: + m.conv_pw.bias.data.zero_() + elif isinstance(m, ConvBnAct2d): + if 'box_net' in n or 'class_net' in n: + m.conv.weight.data.normal_(std=.01) + if m.conv.bias is not None: + if 'class_net.predict' in n: + m.conv.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) + else: + m.conv.bias.data.zero_() + else: + _glorot_uniform(m.conv.weight) + if m.conv.bias is not None: + m.conv.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + +def _init_weight_alt( + m, + n='', +): + """ Weight initialization alternative, based on EfficientNet bacbkone init w/ class bias addition + NOTE: this will likely be removed after some experimentation + """ + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + if 'class_net.predict' in n: + m.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) + else: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + +class Interpolate2d(nn.Module): + r"""Resamples a 2d Image + + The input data is assumed to be of the form + `minibatch x channels x [optional depth] x [optional height] x width`. + Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. + + The algorithms available for upsampling are nearest neighbor and linear, + bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor, + respectively. + + One can either give a :attr:`scale_factor` or the target output :attr:`size` to + calculate the output size. (You cannot give both, as it is ambiguous) + + Args: + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): + output spatial sizes + scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): + multiplier for spatial size. Has to match input size if it is a tuple. + mode (str, optional): the upsampling algorithm: one of ``'nearest'``, + ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. + Default: ``'nearest'`` + align_corners (bool, optional): if ``True``, the corner pixels of the input + and output tensors are aligned, and thus preserving the values at + those pixels. This only has effect when :attr:`mode` is + ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False`` + """ + __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name'] + name: str + size: Optional[Union[int, Tuple[int, int]]] + scale_factor: Optional[Union[float, Tuple[float, float]]] + mode: str + align_corners: Optional[bool] + + def __init__(self, + size: Optional[Union[int, Tuple[int, int]]] = None, + scale_factor: Optional[Union[float, Tuple[float, + float]]] = None, + mode: str = 'nearest', + align_corners: bool = False) -> None: + super(Interpolate2d, self).__init__() + self.name = type(self).__name__ + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = None if mode == 'nearest' else align_corners + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.interpolate( + input, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + recompute_scale_factor=False) + + +class ResampleFeatureMap(nn.Sequential): + + def __init__(self, + in_channels, + out_channels, + reduction_ratio=1., + pad_type='', + downsample=None, + upsample=None, + norm_layer=nn.BatchNorm2d, + apply_bn=False, + conv_after_downsample=False, + redundant_bias=False): + super(ResampleFeatureMap, self).__init__() + downsample = downsample or 'max' + upsample = upsample or 'nearest' + self.in_channels = in_channels + self.out_channels = out_channels + self.reduction_ratio = reduction_ratio + self.conv_after_downsample = conv_after_downsample + + conv = None + if in_channels != out_channels: + conv = ConvBnAct2d( + in_channels, + out_channels, + kernel_size=1, + padding=pad_type, + norm_layer=norm_layer if apply_bn else None, + bias=not apply_bn or redundant_bias, + act_layer=None) + + if reduction_ratio > 1: + if conv is not None and not self.conv_after_downsample: + self.add_module('conv', conv) + if downsample in ('max', 'avg'): + stride_size = int(reduction_ratio) + downsample = create_pool2d( + downsample, + kernel_size=stride_size + 1, + stride=stride_size, + padding=pad_type) + else: + downsample = Interpolate2d( + scale_factor=1. / reduction_ratio, mode=downsample) + self.add_module('downsample', downsample) + if conv is not None and self.conv_after_downsample: + self.add_module('conv', conv) + else: + if conv is not None: + self.add_module('conv', conv) + if reduction_ratio < 1: + scale = int(1 // reduction_ratio) + self.add_module( + 'upsample', + Interpolate2d(scale_factor=scale, mode=upsample)) + + +class GiraffeCombine(nn.Module): + + def __init__(self, + feature_info, + fpn_config, + fpn_channels, + inputs_offsets, + target_reduction, + pad_type='', + downsample=None, + upsample=None, + norm_layer=nn.BatchNorm2d, + apply_resample_bn=False, + conv_after_downsample=False, + redundant_bias=False, + weight_method='attn'): + super(GiraffeCombine, self).__init__() + self.inputs_offsets = inputs_offsets + self.weight_method = weight_method + + self.resample = nn.ModuleDict() + reduction_base = feature_info[0]['reduction'] + + target_channels_idx = int( + math.log(target_reduction // reduction_base, 2)) + for idx, offset in enumerate(inputs_offsets): + if offset < len(feature_info): + in_channels = feature_info[offset]['num_chs'] + input_reduction = feature_info[offset]['reduction'] + else: + node_idx = offset + input_reduction = fpn_config[node_idx]['reduction'] + # in_channels = fpn_config[node_idx]['num_chs'] + input_channels_idx = int( + math.log(input_reduction // reduction_base, 2)) + in_channels = feature_info[input_channels_idx]['num_chs'] + + reduction_ratio = target_reduction / input_reduction + if weight_method == 'concat': + self.resample[str(offset)] = ResampleFeatureMap( + in_channels, + in_channels, + reduction_ratio=reduction_ratio, + pad_type=pad_type, + downsample=downsample, + upsample=upsample, + norm_layer=norm_layer, + apply_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias) + else: + self.resample[str(offset)] = ResampleFeatureMap( + in_channels, + fpn_channels[target_channels_idx], + reduction_ratio=reduction_ratio, + pad_type=pad_type, + downsample=downsample, + upsample=upsample, + norm_layer=norm_layer, + apply_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias) + + if weight_method == 'attn' or weight_method == 'fastattn': + self.edge_weights = nn.Parameter( + torch.ones(len(inputs_offsets)), requires_grad=True) # WSM + else: + self.edge_weights = None + + def forward(self, x: List[torch.Tensor]): + dtype = x[0].dtype + nodes = [] + if len(self.inputs_offsets) == 0: + return None + for offset, resample in zip(self.inputs_offsets, + self.resample.values()): + input_node = x[offset] + input_node = resample(input_node) + nodes.append(input_node) + + if self.weight_method == 'attn': + normalized_weights = torch.softmax( + self.edge_weights.to(dtype=dtype), dim=0) + out = torch.stack(nodes, dim=-1) * normalized_weights + out = torch.sum(out, dim=-1) + elif self.weight_method == 'fastattn': + edge_weights = nn.functional.relu( + self.edge_weights.to(dtype=dtype)) + weights_sum = torch.sum(edge_weights) + weights_norm = weights_sum + 0.0001 + out = torch.stack([(nodes[i] * edge_weights[i]) / weights_norm + for i in range(len(nodes))], + dim=-1) + + out = torch.sum(out, dim=-1) + elif self.weight_method == 'sum': + out = torch.stack(nodes, dim=-1) + out = torch.sum(out, dim=-1) + elif self.weight_method == 'concat': + out = torch.cat(nodes, dim=1) + else: + raise ValueError('unknown weight_method {}'.format( + self.weight_method)) + return out + + +class GiraffeNode(nn.Module): + """ A simple wrapper used in place of nn.Sequential for torchscript typing + Handles input type List[Tensor] -> output type Tensor + """ + + def __init__(self, combine: nn.Module, after_combine: nn.Module): + super(GiraffeNode, self).__init__() + self.combine = combine + self.after_combine = after_combine + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + combine_feat = self.combine(x) + if combine_feat is None: + return None + else: + return self.after_combine(combine_feat) + + +class GiraffeLayer(nn.Module): + + def __init__(self, + feature_info, + fpn_config, + inner_fpn_channels, + outer_fpn_channels, + num_levels=5, + pad_type='', + downsample=None, + upsample=None, + norm_layer=nn.BatchNorm2d, + act_layer=_ACT_LAYER, + apply_resample_bn=False, + conv_after_downsample=True, + conv_bn_relu_pattern=False, + separable_conv=True, + redundant_bias=False, + merge_type='conv'): + super(GiraffeLayer, self).__init__() + self.num_levels = num_levels + self.conv_bn_relu_pattern = False + + self.feature_info = {} + for idx, feat in enumerate(feature_info): + self.feature_info[idx] = feat + + self.fnode = nn.ModuleList() + reduction_base = feature_info[0]['reduction'] + for i, fnode_cfg in fpn_config.items(): + logging.debug('fnode {} : {}'.format(i, fnode_cfg)) + + if fnode_cfg['is_out'] == 1: + fpn_channels = outer_fpn_channels + else: + fpn_channels = inner_fpn_channels + + reduction = fnode_cfg['reduction'] + fpn_channels_idx = int(math.log(reduction // reduction_base, 2)) + combine = GiraffeCombine( + self.feature_info, + fpn_config, + fpn_channels, + tuple(fnode_cfg['inputs_offsets']), + target_reduction=reduction, + pad_type=pad_type, + downsample=downsample, + upsample=upsample, + norm_layer=norm_layer, + apply_resample_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias, + weight_method=fnode_cfg['weight_method']) + + after_combine = nn.Sequential() + + in_channels = 0 + out_channels = 0 + for input_offset in fnode_cfg['inputs_offsets']: + in_channels += self.feature_info[input_offset]['num_chs'] + + out_channels = fpn_channels[fpn_channels_idx] + + if merge_type == 'csp': + after_combine.add_module( + 'CspLayer', + CSPLayer( + in_channels, + out_channels, + 2, + shortcut=True, + depthwise=False, + act='silu')) + elif merge_type == 'shuffle': + after_combine.add_module( + 'shuffleBlock', ShuffleBlock(in_channels, in_channels)) + after_combine.add_module( + 'conv1x1', + create_conv2d(in_channels, out_channels, kernel_size=1)) + elif merge_type == 'conv': + after_combine.add_module( + 'conv1x1', + create_conv2d(in_channels, out_channels, kernel_size=1)) + conv_kwargs = dict( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=pad_type, + bias=False, + norm_layer=norm_layer, + act_layer=act_layer) + if not conv_bn_relu_pattern: + conv_kwargs['bias'] = redundant_bias + conv_kwargs['act_layer'] = None + after_combine.add_module('act', act_layer(inplace=True)) + after_combine.add_module( + 'conv', + SeparableConv2d(**conv_kwargs) + if separable_conv else ConvBnAct2d(**conv_kwargs)) + + self.fnode.append( + GiraffeNode(combine=combine, after_combine=after_combine)) + self.feature_info[i] = dict( + num_chs=fpn_channels[fpn_channels_idx], reduction=reduction) + + self.out_feature_info = [] + out_node = list(self.feature_info.keys())[-num_levels::] + for i in out_node: + self.out_feature_info.append(self.feature_info[i]) + + self.feature_info = self.out_feature_info + + def forward(self, x: List[torch.Tensor]): + for fn in self.fnode: + x.append(fn(x)) + return x[-self.num_levels::] + + +class GiraffeNeck(nn.Module): + + def __init__(self, min_level, max_level, num_levels, norm_layer, + norm_kwargs, act_type, fpn_config, fpn_name, fpn_channels, + out_fpn_channels, weight_method, depth_multiplier, + width_multiplier, with_backslash, with_slash, + with_skip_connect, skip_connect_type, separable_conv, + feature_info, merge_type, pad_type, downsample_type, + upsample_type, apply_resample_bn, conv_after_downsample, + redundant_bias, conv_bn_relu_pattern, alternate_init): + super(GiraffeNeck, self).__init__() + + self.num_levels = num_levels + self.min_level = min_level + self.in_features = [0, 1, 2, 3, 4, 5, + 6][self.min_level - 1:self.min_level - 1 + + num_levels] + self.alternate_init = alternate_init + norm_layer = norm_layer or nn.BatchNorm2d + if norm_kwargs: + norm_layer = partial(norm_layer, **norm_kwargs) + act_layer = get_act_layer(act_type) or _ACT_LAYER + fpn_config = fpn_config or get_graph_config( + fpn_name, + min_level=min_level, + max_level=max_level, + weight_method=weight_method, + depth_multiplier=depth_multiplier, + with_backslash=with_backslash, + with_slash=with_slash, + with_skip_connect=with_skip_connect, + skip_connect_type=skip_connect_type) + + # width scale + for i in range(len(fpn_channels)): + fpn_channels[i] = int(fpn_channels[i] * width_multiplier) + + self.resample = nn.ModuleDict() + for level in range(num_levels): + if level < len(feature_info): + in_chs = feature_info[level]['num_chs'] + reduction = feature_info[level]['reduction'] + else: + # Adds a coarser level by downsampling the last feature map + reduction_ratio = 2 + self.resample[str(level)] = ResampleFeatureMap( + in_channels=in_chs, + out_channels=feature_info[level - 1]['num_chs'], + pad_type=pad_type, + downsample=downsample_type, + upsample=upsample_type, + norm_layer=norm_layer, + reduction_ratio=reduction_ratio, + apply_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias, + ) + in_chs = feature_info[level - 1]['num_chs'] + reduction = int(reduction * reduction_ratio) + feature_info.append(dict(num_chs=in_chs, reduction=reduction)) + + self.cell = SequentialList() + logging.debug('building giraffeNeck') + giraffe_layer = GiraffeLayer( + feature_info=feature_info, + fpn_config=fpn_config, + inner_fpn_channels=fpn_channels, + outer_fpn_channels=out_fpn_channels, + num_levels=num_levels, + pad_type=pad_type, + downsample=downsample_type, + upsample=upsample_type, + norm_layer=norm_layer, + act_layer=act_layer, + separable_conv=separable_conv, + apply_resample_bn=apply_resample_bn, + conv_after_downsample=conv_after_downsample, + conv_bn_relu_pattern=conv_bn_relu_pattern, + redundant_bias=redundant_bias, + merge_type=merge_type) + self.cell.add_module('giraffeNeck', giraffe_layer) + feature_info = giraffe_layer.feature_info + + def init_weights(self, pretrained=False): + for n, m in self.named_modules(): + if 'backbone' not in n: + if self.alternate_init: + _init_weight_alt(m, n) + else: + _init_weight(m, n) + + def forward(self, x: List[torch.Tensor]): + if type(x) is tuple: + x = list(x) + x = [x[f] for f in self.in_features] + for resample in self.resample.values(): + x.append(resample(x[-1])) + x = self.cell(x) + return x diff --git a/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py new file mode 100644 index 00000000..b88c39f2 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/neck/giraffe_fpn_v2.py @@ -0,0 +1,200 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import torch +import torch.nn as nn + +from ..core.base_ops import BaseConv, CSPLayer, DWConv +from ..core.neck_ops import CSPStage + + +class GiraffeNeckV2(nn.Module): + + def __init__( + self, + depth=1.0, + width=1.0, + in_channels=[256, 512, 1024], + out_channels=[256, 512, 1024], + depthwise=False, + act='silu', + spp=True, + reparam_mode=True, + block_name='BasicBlock', + ): + super().__init__() + self.in_channels = in_channels + Conv = DWConv if depthwise else BaseConv + + reparam_mode = reparam_mode + + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + # node x3: input x0, x1 + self.bu_conv13 = Conv( + int(in_channels[1] * width), + int(in_channels[1] * width), + 3, + 2, + act=act) + if reparam_mode: + self.merge_3 = CSPStage( + block_name, + int((in_channels[1] + in_channels[2]) * width), + int(in_channels[2] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_3 = CSPLayer( + int((in_channels[1] + in_channels[2]) * width), + int(in_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + # node x4: input x1, x2, x3 + self.bu_conv24 = Conv( + int(in_channels[0] * width), + int(in_channels[0] * width), + 3, + 2, + act=act) + if reparam_mode: + self.merge_4 = CSPStage( + block_name, + int((in_channels[0] + in_channels[1] + in_channels[2]) + * width), + int(in_channels[1] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_4 = CSPLayer( + int((in_channels[0] + in_channels[1] + in_channels[2]) + * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + # node x5: input x2, x4 + if reparam_mode: + self.merge_5 = CSPStage( + block_name, + int((in_channels[1] + in_channels[0]) * width), + int(out_channels[0] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_5 = CSPLayer( + int((in_channels[1] + in_channels[0]) * width), + int(out_channels[0] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + # node x7: input x4, x5 + self.bu_conv57 = Conv( + int(out_channels[0] * width), + int(out_channels[0] * width), + 3, + 2, + act=act) + if reparam_mode: + self.merge_7 = CSPStage( + block_name, + int((out_channels[0] + in_channels[1]) * width), + int(out_channels[1] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_7 = CSPLayer( + int((out_channels[0] + in_channels[1]) * width), + int(out_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + # node x6: input x3, x4, x7 + self.bu_conv46 = Conv( + int(in_channels[1] * width), + int(in_channels[1] * width), + 3, + 2, + act=act) + self.bu_conv76 = Conv( + int(out_channels[1] * width), + int(out_channels[1] * width), + 3, + 2, + act=act) + if reparam_mode: + self.merge_6 = CSPStage( + block_name, + int((in_channels[1] + out_channels[1] + in_channels[2]) + * width), + int(out_channels[2] * width), + round(3 * depth), + act=act, + spp=spp) + else: + self.merge_6 = CSPLayer( + int((in_channels[1] + out_channels[1] + in_channels[2]) + * width), + int(out_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act) + + def init_weights(self): + pass + + def forward(self, out_features): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + [x2, x1, x0] = out_features + + # node x3 + x13 = self.bu_conv13(x1) + x3 = torch.cat([x0, x13], 1) + x3 = self.merge_3(x3) + + # node x4 + x34 = self.upsample(x3) + x24 = self.bu_conv24(x2) + x4 = torch.cat([x1, x24, x34], 1) + x4 = self.merge_4(x4) + + # node x5 + x45 = self.upsample(x4) + x5 = torch.cat([x2, x45], 1) + x5 = self.merge_5(x5) + + # node x7 + x57 = self.bu_conv57(x5) + x7 = torch.cat([x4, x57], 1) + x7 = self.merge_7(x7) + + # node x6 + x46 = self.bu_conv46(x4) + x76 = self.bu_conv76(x7) + x6 = torch.cat([x3, x46, x76], 1) + x6 = self.merge_6(x6) + + outputs = (x5, x7, x6) + return outputs diff --git a/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py b/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py new file mode 100644 index 00000000..9effad3a --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py @@ -0,0 +1,15 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from .detector import SingleStageDetector + + +@MODELS.register_module( + Tasks.image_object_detection, module_name=Models.tinynas_damoyolo) +class DamoYolo(SingleStageDetector): + + def __init__(self, model_dir, *args, **kwargs): + self.config_name = 'damoyolo_s.py' + super(DamoYolo, self).__init__(model_dir, *args, **kwargs) diff --git a/modelscope/models/cv/tinynas_detection/tinynas_detector.py b/modelscope/models/cv/tinynas_detection/tinynas_detector.py new file mode 100644 index 00000000..92acf3fa --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/tinynas_detector.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from .detector import SingleStageDetector + + +@MODELS.register_module( + Tasks.image_object_detection, module_name=Models.tinynas_detection) +class TinynasDetector(SingleStageDetector): + + def __init__(self, model_dir, *args, **kwargs): + self.config_name = 'airdet_s.py' + super(TinynasDetector, self).__init__(model_dir, *args, **kwargs) diff --git a/modelscope/models/cv/tinynas_detection/utils.py b/modelscope/models/cv/tinynas_detection/utils.py new file mode 100644 index 00000000..d67d3a36 --- /dev/null +++ b/modelscope/models/cv/tinynas_detection/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The AIRDet implementation is also open-sourced by the authors, and available at https://github.com/tinyvision/AIRDet. + +import importlib +import os +import sys +from os.path import dirname, join + + +def get_config_by_file(config_file): + try: + sys.path.append(os.path.dirname(config_file)) + current_config = importlib.import_module( + os.path.basename(config_file).split('.')[0]) + exp = current_config.Config() + except Exception: + raise ImportError( + "{} doesn't contains class named 'Config'".format(config_file)) + return exp + + +def parse_config(config_file): + """ + get config object by file. + Args: + config_file (str): file path of config. + """ + assert (config_file is not None), 'plz provide config file' + if config_file is not None: + return get_config_by_file(config_file) diff --git a/modelscope/models/cv/video_inpainting/__init__.py b/modelscope/models/cv/video_inpainting/__init__.py new file mode 100644 index 00000000..f5489da9 --- /dev/null +++ b/modelscope/models/cv/video_inpainting/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .inpainting_model import VideoInpainting + +else: + _import_structure = {'inpainting_model': ['VideoInpainting']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/video_inpainting/inpainting.py b/modelscope/models/cv/video_inpainting/inpainting.py new file mode 100644 index 00000000..e2af2ad0 --- /dev/null +++ b/modelscope/models/cv/video_inpainting/inpainting.py @@ -0,0 +1,299 @@ +""" VideoInpaintingProcess +The implementation here is modified based on STTN, +originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN +""" + +import os +import time + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +torch.backends.cudnn.enabled = False + +w, h = 192, 96 +ref_length = 300 +neighbor_stride = 20 +default_fps = 24 +MAX_frame = 300 + + +def video_process(video_input_path): + video_input = cv2.VideoCapture(video_input_path) + success, frame = video_input.read() + if success is False: + decode_error = 'decode_error' + w, h, fps = 0, 0, 0 + else: + decode_error = None + h, w = frame.shape[0:2] + fps = video_input.get(cv2.CAP_PROP_FPS) + video_input.release() + + return decode_error, fps, w, h + + +class Stack(object): + + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_group): + mode = img_group[0].mode + if mode == '1': + img_group = [img.convert('L') for img in img_group] + mode = 'L' + if mode == 'L': + return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) + elif mode == 'RGB': + if self.roll: + return np.stack([np.array(x)[:, :, ::-1] for x in img_group], + axis=2) + else: + return np.stack(img_group, axis=2) + else: + raise NotImplementedError(f'Image mode {mode}') + + +class ToTorchFormatTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + + def __init__(self, div=True): + self.div = div + + def __call__(self, pic): + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255) if self.div else img.float() + return img + + +_to_tensors = transforms.Compose([Stack(), ToTorchFormatTensor()]) + + +def get_crop_mask_v1(mask): + orig_h, orig_w, _ = mask.shape + if (mask == 255).all(): + return mask, (0, int(orig_h), 0, + int(orig_w)), [0, int(orig_h), 0, + int(orig_w) + ], [0, int(orig_h), 0, + int(orig_w)] + + hs = np.min(np.where(mask == 0)[0]) + he = np.max(np.where(mask == 0)[0]) + ws = np.min(np.where(mask == 0)[1]) + we = np.max(np.where(mask == 0)[1]) + crop_box = [ws, hs, we, he] + + mask_h = round(int(orig_h / 2) / 4) * 4 + mask_w = round(int(orig_w / 2) / 4) * 4 + + if (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we < mask_w): + crop_mask = mask[:mask_h, :mask_w, :] + res_pix = (0, mask_h, 0, mask_w) + elif (hs < mask_h) and (he < mask_h) and (ws > mask_w) and (we > mask_w): + crop_mask = mask[:mask_h, orig_w - mask_w:orig_w, :] + res_pix = (0, mask_h, orig_w - mask_w, int(orig_w)) + elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w): + crop_mask = mask[orig_h - mask_h:orig_h, :mask_w, :] + res_pix = (orig_h - mask_h, int(orig_h), 0, mask_w) + elif (hs > mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w): + crop_mask = mask[orig_h - mask_h:orig_h, orig_w - mask_w:orig_w, :] + res_pix = (orig_h - mask_h, int(orig_h), orig_w - mask_w, int(orig_w)) + + elif (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we > mask_w): + crop_mask = mask[:mask_h, :, :] + res_pix = (0, mask_h, 0, int(orig_w)) + elif (hs < mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w): + crop_mask = mask[:, :mask_w, :] + res_pix = (0, int(orig_h), 0, mask_w) + elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we > mask_w): + crop_mask = mask[orig_h - mask_h:orig_h, :, :] + res_pix = (orig_h - mask_h, int(orig_h), 0, int(orig_w)) + elif (hs < mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w): + crop_mask = mask[:, orig_w - mask_w:orig_w, :] + res_pix = (0, int(orig_h), orig_w - mask_w, int(orig_w)) + else: + crop_mask = mask + res_pix = (0, int(orig_h), 0, int(orig_w)) + a = ws - res_pix[2] + b = hs - res_pix[0] + c = we - res_pix[2] + d = he - res_pix[0] + return crop_mask, res_pix, crop_box, [a, b, c, d] + + +def get_ref_index(neighbor_ids, length): + ref_index = [] + for i in range(0, length, ref_length): + if i not in neighbor_ids: + ref_index.append(i) + return ref_index + + +def read_mask_oneImage(mpath): + masks = [] + print('mask_path: {}'.format(mpath)) + start = int(mpath.split('/')[-1].split('mask_')[1].split('_')[0]) + end = int( + mpath.split('/')[-1].split('mask_')[1].split('_')[1].split('.')[0]) + m = Image.open(mpath) + m = np.array(m.convert('L')) + m = np.array(m > 0).astype(np.uint8) + m = 1 - m + for i in range(start - 1, end + 1): + masks.append(Image.fromarray(m * 255)) + return masks + + +def check_size(h, w): + is_resize = False + if h != 240: + h = 240 + is_resize = True + if w != 432: + w = 432 + is_resize = True + return is_resize + + +def get_mask_list(mask_path): + mask_names = os.listdir(mask_path) + mask_names.sort() + + abs_mask_path = [] + mask_list = [] + begin_list = [] + end_list = [] + + for mask_name in mask_names: + mask_name_tmp = mask_name.split('mask_')[1] + begin_list.append(int(mask_name_tmp.split('_')[0])) + end_list.append(int(mask_name_tmp.split('_')[1].split('.')[0])) + abs_mask_path.append(os.path.join(mask_path, mask_name)) + mask = cv2.imread(os.path.join(mask_path, mask_name)) + mask_list.append(mask) + return mask_list, begin_list, end_list, abs_mask_path + + +def inpainting_by_model_balance(model, video_inputPath, mask_path, + video_savePath, fps, w_ori, h_ori): + + video_ori = cv2.VideoCapture(video_inputPath) + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_save = cv2.VideoWriter(video_savePath, fourcc, fps, (w_ori, h_ori)) + + mask_list, begin_list, end_list, abs_mask_path = get_mask_list(mask_path) + + img_npy = [] + + for index, mask in enumerate(mask_list): + + masks = read_mask_oneImage(abs_mask_path[index]) + + mask, res_pix, crop_for_oriimg, crop_for_inpimg = get_crop_mask_v1( + mask) + mask_h, mask_w = mask.shape[0:2] + is_resize = check_size(mask.shape[0], mask.shape[1]) + + begin = begin_list[index] + end = end_list[index] + print('begin: {}'.format(begin)) + print('end: {}'.format(end)) + + for i in range(begin, end + 1, MAX_frame): + begin_time = time.time() + if i + MAX_frame <= end: + video_length = MAX_frame + else: + video_length = end - i + 1 + + for frame_count in range(video_length): + _, frame = video_ori.read() + img_npy.append(frame) + frames_temp = [] + for f in img_npy: + f = Image.fromarray(f) + i_temp = f.crop( + (res_pix[2], res_pix[0], res_pix[3], res_pix[1])) + a = i_temp.resize((w, h), Image.NEAREST) + frames_temp.append(a) + feats_temp = _to_tensors(frames_temp).unsqueeze(0) * 2 - 1 + frames_temp = [np.array(f).astype(np.uint8) for f in frames_temp] + masks_temp = [] + for m in masks[i - begin:i + video_length - begin]: + + m_temp = m.crop( + (res_pix[2], res_pix[0], res_pix[3], res_pix[1])) + b = m_temp.resize((w, h), Image.NEAREST) + masks_temp.append(b) + binary_masks_temp = [ + np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) + for m in masks_temp + ] + masks_temp = _to_tensors(masks_temp).unsqueeze(0) + if torch.cuda.is_available(): + feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda() + comp_frames = [None] * video_length + model.eval() + with torch.no_grad(): + feats_out = feats_temp * (1 - masks_temp).float() + feats_out = feats_out.view(video_length, 3, h, w) + feats_out = model.model.encoder(feats_out) + _, c, feat_h, feat_w = feats_out.size() + feats_out = feats_out.view(1, video_length, c, feat_h, feat_w) + + for f in range(0, video_length, neighbor_stride): + neighbor_ids = [ + i for i in range( + max(0, f - neighbor_stride), + min(video_length, f + neighbor_stride + 1)) + ] + ref_ids = get_ref_index(neighbor_ids, video_length) + with torch.no_grad(): + pred_feat = model.model.infer( + feats_out[0, neighbor_ids + ref_ids, :, :, :], + masks_temp[0, neighbor_ids + ref_ids, :, :, :]) + pred_img = torch.tanh( + model.model.decoder( + pred_feat[:len(neighbor_ids), :, :, :])).detach() + pred_img = (pred_img + 1) / 2 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 + for j in range(len(neighbor_ids)): + idx = neighbor_ids[j] + img = np.array(pred_img[j]).astype( + np.uint8) * binary_masks_temp[idx] + frames_temp[ + idx] * (1 - binary_masks_temp[idx]) + if comp_frames[idx] is None: + comp_frames[idx] = img + else: + comp_frames[idx] = comp_frames[idx].astype( + np.float32) * 0.5 + img.astype( + np.float32) * 0.5 + print('inpainting time:', time.time() - begin_time) + for f in range(video_length): + comp = np.array(comp_frames[f]).astype( + np.uint8) * binary_masks_temp[f] + frames_temp[f] * ( + 1 - binary_masks_temp[f]) + if is_resize: + comp = cv2.resize(comp, (mask_w, mask_h)) + complete_frame = img_npy[f] + a1, b1, c1, d1 = crop_for_oriimg + a2, b2, c2, d2 = crop_for_inpimg + complete_frame[b1:d1, a1:c1] = comp[b2:d2, a2:c2] + video_save.write(complete_frame) + + img_npy = [] + + video_ori.release() diff --git a/modelscope/models/cv/video_inpainting/inpainting_model.py b/modelscope/models/cv/video_inpainting/inpainting_model.py new file mode 100644 index 00000000..ffecde67 --- /dev/null +++ b/modelscope/models/cv/video_inpainting/inpainting_model.py @@ -0,0 +1,381 @@ +""" VideoInpaintingProcess +The implementation here is modified based on STTN, + originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN +""" + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +from modelscope.metainfo import Models +from modelscope.models import Model +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class BaseNetwork(nn.Module): + + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print( + 'Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' % + (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': + m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' + % init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + +@MODELS.register_module( + Tasks.video_inpainting, module_name=Models.video_inpainting) +class VideoInpainting(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.model = InpaintGenerator() + if torch.cuda.is_available(): + device = 'cuda' + else: + device = 'cpu' + pretrained_params = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device) + self.model.load_state_dict(pretrained_params['netG']) + self.model.eval() + self.device_id = device_id + if self.device_id >= 0 and torch.cuda.is_available(): + self.model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device_id = -1 + logger.info('Use CPU for inference') + + +class InpaintGenerator(BaseNetwork): + + def __init__(self, init_weights=True): + super(InpaintGenerator, self).__init__() + channel = 256 + stack_num = 6 + patchsize = [(48, 24), (16, 8), (8, 4), (4, 2)] + blocks = [] + for _ in range(stack_num): + blocks.append(TransformerBlock(patchsize, hidden=channel)) + self.transformer = nn.Sequential(*blocks) + + self.encoder = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + self.decoder = nn.Sequential( + deconv(channel, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) + + if init_weights: + self.init_weights() + + def forward(self, masked_frames, masks): + b, t, c, h, w = masked_frames.size() + masks = masks.view(b * t, 1, h, w) + enc_feat = self.encoder(masked_frames.view(b * t, c, h, w)) + _, c, h, w = enc_feat.size() + masks = F.interpolate(masks, scale_factor=1.0 / 4) + enc_feat = self.transformer({ + 'x': enc_feat, + 'm': masks, + 'b': b, + 'c': c + })['x'] + output = self.decoder(enc_feat) + output = torch.tanh(output) + return output + + def infer(self, feat, masks): + t, c, h, w = masks.size() + masks = masks.view(t, c, h, w) + masks = F.interpolate(masks, scale_factor=1.0 / 4) + t, c, _, _ = feat.size() + enc_feat = self.transformer({ + 'x': feat, + 'm': masks, + 'b': 1, + 'c': c + })['x'] + return enc_feat + + +class deconv(nn.Module): + + def __init__(self, + input_channel, + output_channel, + kernel_size=3, + padding=0): + super().__init__() + self.conv = nn.Conv2d( + input_channel, + output_channel, + kernel_size=kernel_size, + stride=1, + padding=padding) + + def forward(self, x): + x = F.interpolate( + x, scale_factor=2, mode='bilinear', align_corners=True) + x = self.conv(x) + return x + + +class Attention(nn.Module): + """ + Compute 'Scaled Dot Product Attention + """ + + def forward(self, query, key, value, m): + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( + query.size(-1)) + scores.masked_fill(m, -1e9) + p_attn = F.softmax(scores, dim=-1) + p_val = torch.matmul(p_attn, value) + return p_val, p_attn + + +class MultiHeadedAttention(nn.Module): + """ + Take in model size and number of heads. + """ + + def __init__(self, patchsize, d_model): + super().__init__() + self.patchsize = patchsize + self.query_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.value_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.key_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.output_linear = nn.Sequential( + nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True)) + self.attention = Attention() + + def forward(self, x, m, b, c): + bt, _, h, w = x.size() + t = bt // b + d_k = c // len(self.patchsize) + output = [] + _query = self.query_embedding(x) + _key = self.key_embedding(x) + _value = self.value_embedding(x) + for (width, height), query, key, value in zip( + self.patchsize, + torch.chunk(_query, len(self.patchsize), dim=1), + torch.chunk(_key, len(self.patchsize), dim=1), + torch.chunk(_value, len(self.patchsize), dim=1)): + out_w, out_h = w // width, h // height + mm = m.view(b, t, 1, out_h, height, out_w, width) + mm = mm.permute(0, 1, 3, 5, 2, 4, + 6).contiguous().view(b, t * out_h * out_w, + height * width) + mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat( + 1, t * out_h * out_w, 1) + query = query.view(b, t, d_k, out_h, height, out_w, width) + query = query.permute(0, 1, 3, 5, 2, 4, + 6).contiguous().view(b, t * out_h * out_w, + d_k * height * width) + key = key.view(b, t, d_k, out_h, height, out_w, width) + key = key.permute(0, 1, 3, 5, 2, 4, + 6).contiguous().view(b, t * out_h * out_w, + d_k * height * width) + value = value.view(b, t, d_k, out_h, height, out_w, width) + value = value.permute(0, 1, 3, 5, 2, 4, + 6).contiguous().view(b, t * out_h * out_w, + d_k * height * width) + y, _ = self.attention(query, key, value, mm) + y = y.view(b, t, out_h, out_w, d_k, height, width) + y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w) + output.append(y) + output = torch.cat(output, 1) + x = self.output_linear(output) + return x + + +class FeedForward(nn.Module): + + def __init__(self, d_model): + super(FeedForward, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True)) + + def forward(self, x): + x = self.conv(x) + return x + + +class TransformerBlock(nn.Module): + """ + Transformer = MultiHead_Attention + Feed_Forward with sublayer connection + """ + + def __init__(self, patchsize, hidden=128): # hidden=128 + super().__init__() + self.attention = MultiHeadedAttention(patchsize, d_model=hidden) + self.feed_forward = FeedForward(hidden) + + def forward(self, x): + x, m, b, c = x['x'], x['m'], x['b'], x['c'] + x = x + self.attention(x, m, b, c) + x = x + self.feed_forward(x) + return {'x': x, 'm': m, 'b': b, 'c': c} + + +class Discriminator(BaseNetwork): + + def __init__(self, + in_channels=3, + use_sigmoid=False, + use_spectral_norm=True, + init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 64 + + self.conv = nn.Sequential( + spectral_norm( + nn.Conv3d( + in_channels=in_channels, + out_channels=nf * 1, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=1, + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d( + nf * 1, + nf * 2, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d( + nf * 2, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d( + nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d( + nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d( + nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2))) + + if init_weights: + self.init_weights() + + def forward(self, xs): + xs_t = torch.transpose(xs, 0, 1) + xs_t = xs_t.unsqueeze(0) + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) + return out + + +def spectral_norm(module, mode=True): + if mode: + return _spectral_norm(module) + return module diff --git a/modelscope/models/cv/video_single_object_tracking/__init__.py b/modelscope/models/cv/video_single_object_tracking/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_single_object_tracking/config/__init__.py b/modelscope/models/cv/video_single_object_tracking/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_single_object_tracking/config/ostrack.py b/modelscope/models/cv/video_single_object_tracking/config/ostrack.py new file mode 100644 index 00000000..6805c503 --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/config/ostrack.py @@ -0,0 +1,39 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +from easydict import EasyDict as edict + +cfg = edict() + +# MODEL +cfg.MODEL = edict() + +# MODEL.BACKBONE +cfg.MODEL.BACKBONE = edict() +cfg.MODEL.BACKBONE.TYPE = 'vit_base_patch16_224_ce' +cfg.MODEL.BACKBONE.STRIDE = 16 +cfg.MODEL.BACKBONE.CAT_MODE = 'direct' +cfg.MODEL.BACKBONE.DROP_PATH_RATE = 0.1 +cfg.MODEL.BACKBONE.CE_LOC = [3, 6, 9] +cfg.MODEL.BACKBONE.CE_KEEP_RATIO = [0.7, 0.7, 0.7] +cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE = 'CTR_POINT' + +# MODEL.HEAD +cfg.MODEL.HEAD = edict() +cfg.MODEL.HEAD.TYPE = 'CENTER' +cfg.MODEL.HEAD.NUM_CHANNELS = 256 + +# DATA +cfg.DATA = edict() +cfg.DATA.MEAN = [0.485, 0.456, 0.406] +cfg.DATA.STD = [0.229, 0.224, 0.225] +cfg.DATA.SEARCH = edict() +cfg.DATA.SEARCH.SIZE = 384 +cfg.DATA.TEMPLATE = edict() +cfg.DATA.TEMPLATE.SIZE = 192 + +# TEST +cfg.TEST = edict() +cfg.TEST.TEMPLATE_FACTOR = 2.0 +cfg.TEST.TEMPLATE_SIZE = 192 +cfg.TEST.SEARCH_FACTOR = 5.0 +cfg.TEST.SEARCH_SIZE = 384 diff --git a/modelscope/models/cv/video_single_object_tracking/models/__init__.py b/modelscope/models/cv/video_single_object_tracking/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_single_object_tracking/models/layers/__init__.py b/modelscope/models/cv/video_single_object_tracking/models/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_single_object_tracking/models/layers/attn.py b/modelscope/models/cv/video_single_object_tracking/models/layers/attn.py new file mode 100644 index 00000000..e245c821 --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/layers/attn.py @@ -0,0 +1,54 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +import torch.nn as nn + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0., + rpe=False, + z_size=7, + x_size=14): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, mask=None, return_attention=False): + # x: B, N, C + # mask: [B, N, ] torch.bool + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind( + 0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if mask is not None: + attn = attn.masked_fill( + mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + if return_attention: + return x, attn + else: + return x diff --git a/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py b/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py new file mode 100644 index 00000000..702c84f1 --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py @@ -0,0 +1,129 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +import math + +import torch +import torch.nn as nn +from timm.models.layers import DropPath, Mlp + +from .attn import Attention + + +def candidate_elimination(attn: torch.Tensor, tokens: torch.Tensor, + lens_t: int, keep_ratio: float, + global_index: torch.Tensor, + box_mask_z: torch.Tensor): + """ + Eliminate potential background candidates for computation reduction and noise cancellation. + Args: + attn (torch.Tensor): [B, num_heads, L_t + L_s, L_t + L_s], attention weights + tokens (torch.Tensor): [B, L_t + L_s, C], template and search region tokens + lens_t (int): length of template + keep_ratio (float): keep ratio of search region tokens (candidates) + global_index (torch.Tensor): global index of search region tokens + box_mask_z (torch.Tensor): template mask used to accumulate attention weights + + Returns: + tokens_new (torch.Tensor): tokens after candidate elimination + keep_index (torch.Tensor): indices of kept search region tokens + removed_index (torch.Tensor): indices of removed search region tokens + """ + lens_s = attn.shape[-1] - lens_t + bs, hn, _, _ = attn.shape + + lens_keep = math.ceil(keep_ratio * lens_s) + if lens_keep == lens_s: + return tokens, global_index, None + + attn_t = attn[:, :, :lens_t, lens_t:] + + if box_mask_z is not None: + box_mask_z = box_mask_z.unsqueeze(1).unsqueeze(-1).expand( + -1, attn_t.shape[1], -1, attn_t.shape[-1]) + attn_t = attn_t[box_mask_z] + attn_t = attn_t.view(bs, hn, -1, lens_s) + attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s + else: + attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s + + # use sort instead of topk, due to the speed issue + # https://github.com/pytorch/pytorch/issues/22812 + sorted_attn, indices = torch.sort(attn_t, dim=1, descending=True) + + _, topk_idx = sorted_attn[:, :lens_keep], indices[:, :lens_keep] + _, non_topk_idx = sorted_attn[:, lens_keep:], indices[:, lens_keep:] + keep_index = global_index.gather(dim=1, index=topk_idx) + removed_index = global_index.gather(dim=1, index=non_topk_idx) + + # separate template and search tokens + tokens_t = tokens[:, :lens_t] + tokens_s = tokens[:, lens_t:] + + # obtain the attentive and inattentive tokens + B, L, C = tokens_s.shape + attentive_tokens = tokens_s.gather( + dim=1, index=topk_idx.unsqueeze(-1).expand(B, -1, C)) + + # concatenate these tokens + tokens_new = torch.cat([tokens_t, attentive_tokens], dim=1) + + return tokens_new, keep_index, removed_index + + +class CEBlock(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + keep_ratio_search=1.0, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + self.keep_ratio_search = keep_ratio_search + + def forward(self, + x, + global_index_template, + global_index_search, + mask=None, + ce_template_mask=None, + keep_ratio_search=None): + x_attn, attn = self.attn(self.norm1(x), mask, True) + x = x + self.drop_path(x_attn) + lens_t = global_index_template.shape[1] + + removed_index_search = None + if self.keep_ratio_search < 1 and (keep_ratio_search is None + or keep_ratio_search < 1): + keep_ratio_search = self.keep_ratio_search if keep_ratio_search is None else keep_ratio_search + x, global_index_search, removed_index_search = candidate_elimination( + attn, x, lens_t, keep_ratio_search, global_index_search, + ce_template_mask) + + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x, global_index_template, global_index_search, removed_index_search, attn diff --git a/modelscope/models/cv/video_single_object_tracking/models/layers/head.py b/modelscope/models/cv/video_single_object_tracking/models/layers/head.py new file mode 100644 index 00000000..e0dc7b59 --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/layers/head.py @@ -0,0 +1,141 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +import torch +import torch.nn as nn + + +def conv(in_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True), nn.BatchNorm2d(out_planes), nn.ReLU(inplace=True)) + + +class CenterPredictor( + nn.Module, ): + + def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16): + super(CenterPredictor, self).__init__() + self.feat_sz = feat_sz + self.stride = stride + self.img_sz = self.feat_sz * self.stride + + # corner predict + self.conv1_ctr = conv(inplanes, channel) + self.conv2_ctr = conv(channel, channel // 2) + self.conv3_ctr = conv(channel // 2, channel // 4) + self.conv4_ctr = conv(channel // 4, channel // 8) + self.conv5_ctr = nn.Conv2d(channel // 8, 1, kernel_size=1) + + # offset regress + self.conv1_offset = conv(inplanes, channel) + self.conv2_offset = conv(channel, channel // 2) + self.conv3_offset = conv(channel // 2, channel // 4) + self.conv4_offset = conv(channel // 4, channel // 8) + self.conv5_offset = nn.Conv2d(channel // 8, 2, kernel_size=1) + + # size regress + self.conv1_size = conv(inplanes, channel) + self.conv2_size = conv(channel, channel // 2) + self.conv3_size = conv(channel // 2, channel // 4) + self.conv4_size = conv(channel // 4, channel // 8) + self.conv5_size = nn.Conv2d(channel // 8, 2, kernel_size=1) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, x, gt_score_map=None): + """ Forward pass with input x. """ + score_map_ctr, size_map, offset_map = self.get_score_map(x) + + # assert gt_score_map is None + if gt_score_map is None: + bbox = self.cal_bbox(score_map_ctr, size_map, offset_map) + else: + bbox = self.cal_bbox( + gt_score_map.unsqueeze(1), size_map, offset_map) + + return score_map_ctr, bbox, size_map, offset_map + + def cal_bbox(self, + score_map_ctr, + size_map, + offset_map, + return_score=False): + max_score, idx = torch.max( + score_map_ctr.flatten(1), dim=1, keepdim=True) + idx_y = idx // self.feat_sz + idx_x = idx % self.feat_sz + + idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1) + size = size_map.flatten(2).gather(dim=2, index=idx) + offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) + + # cx, cy, w, h + bbox = torch.cat( + [(idx_x.to(torch.float) + offset[:, :1]) / self.feat_sz, + (idx_y.to(torch.float) + offset[:, 1:]) / self.feat_sz, + size.squeeze(-1)], + dim=1) + + if return_score: + return bbox, max_score + return bbox + + def get_score_map(self, x): + + def _sigmoid(x): + y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) + return y + + # ctr branch + x_ctr1 = self.conv1_ctr(x) + x_ctr2 = self.conv2_ctr(x_ctr1) + x_ctr3 = self.conv3_ctr(x_ctr2) + x_ctr4 = self.conv4_ctr(x_ctr3) + score_map_ctr = self.conv5_ctr(x_ctr4) + + # offset branch + x_offset1 = self.conv1_offset(x) + x_offset2 = self.conv2_offset(x_offset1) + x_offset3 = self.conv3_offset(x_offset2) + x_offset4 = self.conv4_offset(x_offset3) + score_map_offset = self.conv5_offset(x_offset4) + + # size branch + x_size1 = self.conv1_size(x) + x_size2 = self.conv2_size(x_size1) + x_size3 = self.conv3_size(x_size2) + x_size4 = self.conv4_size(x_size3) + score_map_size = self.conv5_size(x_size4) + return _sigmoid(score_map_ctr), _sigmoid( + score_map_size), score_map_offset + + +def build_box_head(cfg, hidden_dim): + stride = cfg.MODEL.BACKBONE.STRIDE + + if cfg.MODEL.HEAD.TYPE == 'CENTER': + in_channel = hidden_dim + out_channel = cfg.MODEL.HEAD.NUM_CHANNELS + feat_sz = int(cfg.DATA.SEARCH.SIZE / stride) + center_head = CenterPredictor( + inplanes=in_channel, + channel=out_channel, + feat_sz=feat_sz, + stride=stride) + return center_head + else: + raise ValueError('HEAD TYPE %s is not supported.' + % cfg.MODEL.HEAD_TYPE) diff --git a/modelscope/models/cv/video_single_object_tracking/models/layers/patch_embed.py b/modelscope/models/cv/video_single_object_tracking/models/layers/patch_embed.py new file mode 100644 index 00000000..c001663f --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/layers/patch_embed.py @@ -0,0 +1,37 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +import torch.nn as nn +from timm.models.layers import to_2tuple + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x diff --git a/modelscope/models/cv/video_single_object_tracking/models/ostrack/__init__.py b/modelscope/models/cv/video_single_object_tracking/models/ostrack/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_single_object_tracking/models/ostrack/base_backbone.py b/modelscope/models/cv/video_single_object_tracking/models/ostrack/base_backbone.py new file mode 100644 index 00000000..20d73422 --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/ostrack/base_backbone.py @@ -0,0 +1,93 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +import torch.nn as nn +from timm.models.layers import to_2tuple + +from modelscope.models.cv.video_single_object_tracking.models.layers.patch_embed import \ + PatchEmbed + + +class BaseBackbone(nn.Module): + + def __init__(self): + super().__init__() + + # for original ViT + self.pos_embed = None + self.img_size = [224, 224] + self.patch_size = 16 + self.embed_dim = 384 + + self.cat_mode = 'direct' + + self.pos_embed_z = None + self.pos_embed_x = None + + self.template_segment_pos_embed = None + self.search_segment_pos_embed = None + + self.return_stage = [2, 5, 8, 11] + + def finetune_track(self, cfg, patch_start_index=1): + + search_size = to_2tuple(cfg.DATA.SEARCH.SIZE) + template_size = to_2tuple(cfg.DATA.TEMPLATE.SIZE) + new_patch_size = cfg.MODEL.BACKBONE.STRIDE + + self.cat_mode = cfg.MODEL.BACKBONE.CAT_MODE + + # resize patch embedding + if new_patch_size != self.patch_size: + print( + 'Inconsistent Patch Size With The Pretrained Weights, Interpolate The Weight!' + ) + old_patch_embed = {} + for name, param in self.patch_embed.named_parameters(): + if 'weight' in name: + param = nn.functional.interpolate( + param, + size=(new_patch_size, new_patch_size), + mode='bicubic', + align_corners=False) + param = nn.Parameter(param) + old_patch_embed[name] = param + self.patch_embed = PatchEmbed( + img_size=self.img_size, + patch_size=new_patch_size, + in_chans=3, + embed_dim=self.embed_dim) + self.patch_embed.proj.bias = old_patch_embed['proj.bias'] + self.patch_embed.proj.weight = old_patch_embed['proj.weight'] + + # for patch embedding + patch_pos_embed = self.pos_embed[:, patch_start_index:, :] + patch_pos_embed = patch_pos_embed.transpose(1, 2) + B, E, Q = patch_pos_embed.shape + P_H, P_W = self.img_size[0] // self.patch_size, self.img_size[ + 1] // self.patch_size + patch_pos_embed = patch_pos_embed.view(B, E, P_H, P_W) + + # for search region + H, W = search_size + new_P_H, new_P_W = H // new_patch_size, W // new_patch_size + search_patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_P_H, new_P_W), + mode='bicubic', + align_corners=False) + search_patch_pos_embed = search_patch_pos_embed.flatten(2).transpose( + 1, 2) + + # for template region + H, W = template_size + new_P_H, new_P_W = H // new_patch_size, W // new_patch_size + template_patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_P_H, new_P_W), + mode='bicubic', + align_corners=False) + template_patch_pos_embed = template_patch_pos_embed.flatten( + 2).transpose(1, 2) + + self.pos_embed_z = nn.Parameter(template_patch_pos_embed) + self.pos_embed_x = nn.Parameter(search_patch_pos_embed) diff --git a/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py b/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py new file mode 100644 index 00000000..52704a6c --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py @@ -0,0 +1,109 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +import torch +from torch import nn + +from modelscope.models.cv.video_single_object_tracking.models.layers.head import \ + build_box_head +from .vit_ce import vit_base_patch16_224_ce + + +class OSTrack(nn.Module): + """ This is the base class for OSTrack """ + + def __init__(self, + transformer, + box_head, + aux_loss=False, + head_type='CORNER'): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.backbone = transformer + self.box_head = box_head + + self.aux_loss = aux_loss + self.head_type = head_type + if head_type == 'CORNER' or head_type == 'CENTER': + self.feat_sz_s = int(box_head.feat_sz) + self.feat_len_s = int(box_head.feat_sz**2) + + def forward( + self, + template: torch.Tensor, + search: torch.Tensor, + ce_template_mask=None, + ce_keep_rate=None, + ): + x, aux_dict = self.backbone( + z=template, + x=search, + ce_template_mask=ce_template_mask, + ce_keep_rate=ce_keep_rate, + ) + + # Forward head + feat_last = x + if isinstance(x, list): + feat_last = x[-1] + out = self.forward_head(feat_last, None) + + out.update(aux_dict) + out['backbone_feat'] = x + return out + + def forward_head(self, cat_feature, gt_score_map=None): + """ + cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) + """ + enc_opt = cat_feature[:, -self. + feat_len_s:] # encoder output for the search region (B, HW, C) + opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous() + bs, Nq, C, HW = opt.size() + opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s) + + if self.head_type == 'CENTER': + # run the center head + score_map_ctr, bbox, size_map, offset_map = self.box_head( + opt_feat, gt_score_map) + outputs_coord = bbox + outputs_coord_new = outputs_coord.view(bs, Nq, 4) + out = { + 'pred_boxes': outputs_coord_new, + 'score_map': score_map_ctr, + 'size_map': size_map, + 'offset_map': offset_map + } + return out + else: + raise NotImplementedError + + +def build_ostrack(cfg): + if cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224_ce': + backbone = vit_base_patch16_224_ce( + False, + drop_path_rate=cfg.MODEL.BACKBONE.DROP_PATH_RATE, + ce_loc=cfg.MODEL.BACKBONE.CE_LOC, + ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO, + ) + hidden_dim = backbone.embed_dim + patch_start_index = 1 + else: + raise NotImplementedError + + backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index) + + box_head = build_box_head(cfg, hidden_dim) + + model = OSTrack( + backbone, + box_head, + aux_loss=False, + head_type=cfg.MODEL.HEAD.TYPE, + ) + + return model diff --git a/modelscope/models/cv/video_single_object_tracking/models/ostrack/utils.py b/modelscope/models/cv/video_single_object_tracking/models/ostrack/utils.py new file mode 100644 index 00000000..46e7c18a --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/ostrack/utils.py @@ -0,0 +1,24 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +import torch + + +def combine_tokens(template_tokens, + search_tokens, + mode='direct', + return_res=False): + if mode == 'direct': + merged_feature = torch.cat((template_tokens, search_tokens), dim=1) + else: + raise NotImplementedError + + return merged_feature + + +def recover_tokens(merged_tokens, mode='direct'): + if mode == 'direct': + recovered_tokens = merged_tokens + else: + raise NotImplementedError + + return recovered_tokens diff --git a/modelscope/models/cv/video_single_object_tracking/models/ostrack/vit_ce.py b/modelscope/models/cv/video_single_object_tracking/models/ostrack/vit_ce.py new file mode 100644 index 00000000..f186cf89 --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/ostrack/vit_ce.py @@ -0,0 +1,343 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +from functools import partial + +import torch +import torch.nn as nn +from timm.models.layers import DropPath, Mlp, to_2tuple + +from modelscope.models.cv.video_single_object_tracking.models.layers.attn_blocks import \ + CEBlock +from modelscope.models.cv.video_single_object_tracking.models.layers.patch_embed import \ + PatchEmbed +from .base_backbone import BaseBackbone +from .utils import combine_tokens, recover_tokens + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + +class VisionTransformer(BaseBackbone): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + distilled=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = None + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer) for i in range(depth) + ]) + self.norm = norm_layer(embed_dim) + + +class VisionTransformerCE(VisionTransformer): + """ Vision Transformer with candidate elimination (CE) module + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + distilled=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + ce_loc=None, + ce_keep_ratio=None): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + if isinstance(img_size, tuple): + self.img_size = img_size + else: + self.img_size = to_2tuple(img_size) + self.patch_size = patch_size + self.in_chans = in_chans + + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros( + 1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + blocks = [] + ce_index = 0 + self.ce_loc = ce_loc + for i in range(depth): + ce_keep_ratio_i = 1.0 + if ce_loc is not None and i in ce_loc: + ce_keep_ratio_i = ce_keep_ratio[ce_index] + ce_index += 1 + + blocks.append( + CEBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + keep_ratio_search=ce_keep_ratio_i)) + + self.blocks = nn.Sequential(*blocks) + self.norm = norm_layer(embed_dim) + + def forward_features( + self, + z, + x, + mask_x=None, + ce_template_mask=None, + ce_keep_rate=None, + ): + B = x.shape[0] + + x = self.patch_embed(x) + z = self.patch_embed(z) + + z += self.pos_embed_z + x += self.pos_embed_x + + x = combine_tokens(z, x, mode=self.cat_mode) + + x = self.pos_drop(x) + + lens_z = self.pos_embed_z.shape[1] + lens_x = self.pos_embed_x.shape[1] + + global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device) + global_index_t = global_index_t.repeat(B, 1) + + global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device) + global_index_s = global_index_s.repeat(B, 1) + removed_indexes_s = [] + for i, blk in enumerate(self.blocks): + x, global_index_t, global_index_s, removed_index_s, attn = \ + blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate) + + if self.ce_loc is not None and i in self.ce_loc: + removed_indexes_s.append(removed_index_s) + + x = self.norm(x) + lens_x_new = global_index_s.shape[1] + lens_z_new = global_index_t.shape[1] + + z = x[:, :lens_z_new] + x = x[:, lens_z_new:] + + if removed_indexes_s and removed_indexes_s[0] is not None: + removed_indexes_cat = torch.cat(removed_indexes_s, dim=1) + + pruned_lens_x = lens_x - lens_x_new + pad_x = torch.zeros([B, pruned_lens_x, x.shape[2]], + device=x.device) + x = torch.cat([x, pad_x], dim=1) + index_all = torch.cat([global_index_s, removed_indexes_cat], dim=1) + # recover original token order + C = x.shape[-1] + x = torch.zeros_like(x).scatter_( + dim=1, + index=index_all.unsqueeze(-1).expand(B, -1, C).to(torch.int64), + src=x) + + x = recover_tokens(x, mode=self.cat_mode) + + # re-concatenate with the template, which may be further used by other modules + x = torch.cat([z, x], dim=1) + + aux_dict = { + 'attn': attn, + 'removed_indexes_s': removed_indexes_s, # used for visualization + } + + return x, aux_dict + + def forward(self, z, x, ce_template_mask=None, ce_keep_rate=None): + + x, aux_dict = self.forward_features( + z, + x, + ce_template_mask=ce_template_mask, + ce_keep_rate=ce_keep_rate, + ) + + return x, aux_dict + + +def _create_vision_transformer(pretrained=False, **kwargs): + model = VisionTransformerCE(**kwargs) + return model + + +def vit_base_patch16_224_ce(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) + return model diff --git a/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py b/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_single_object_tracking/tracker/ostrack.py b/modelscope/models/cv/video_single_object_tracking/tracker/ostrack.py new file mode 100644 index 00000000..5093a72d --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/tracker/ostrack.py @@ -0,0 +1,139 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +import torch + +from modelscope.models.cv.video_single_object_tracking.config.ostrack import \ + cfg +from modelscope.models.cv.video_single_object_tracking.models.ostrack.ostrack import \ + build_ostrack +from modelscope.models.cv.video_single_object_tracking.utils.utils import ( + Preprocessor, clip_box, generate_mask_cond, hann2d, sample_target, + transform_image_to_crop) + + +class OSTrack(): + + def __init__(self, ckpt_path, device): + network = build_ostrack(cfg) + network.load_state_dict( + torch.load(ckpt_path, map_location='cpu')['net'], strict=True) + self.cfg = cfg + if device.type == 'cuda': + self.network = network.to(device) + else: + self.network = network + self.network.eval() + self.preprocessor = Preprocessor(device) + self.state = None + + self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE + # motion constrain + if device.type == 'cuda': + self.output_window = hann2d( + torch.tensor([self.feat_sz, self.feat_sz]).long(), + centered=True).to(device) + else: + self.output_window = hann2d( + torch.tensor([self.feat_sz, self.feat_sz]).long(), + centered=True) + self.frame_id = 0 + # for save boxes from all queries + self.z_dict1 = {} + + def initialize(self, image, info: dict): + # forward the template once + z_patch_arr, resize_factor, z_amask_arr = sample_target( + image, + info['init_bbox'], + self.cfg.TEST.TEMPLATE_FACTOR, + output_sz=self.cfg.TEST.TEMPLATE_SIZE) + self.z_patch_arr = z_patch_arr + template = self.preprocessor.process(z_patch_arr, z_amask_arr) + with torch.no_grad(): + self.z_dict1 = template + + self.box_mask_z = None + if self.cfg.MODEL.BACKBONE.CE_LOC: + template_bbox = self.transform_bbox_to_crop( + info['init_bbox'], resize_factor, + template.tensors.device).squeeze(1) + self.box_mask_z = generate_mask_cond(self.cfg, 1, + template.tensors.device, + template_bbox) + + # save states + self.state = info['init_bbox'] + self.frame_id = 0 + + def track(self, image, info: dict = None): + H, W, _ = image.shape + self.frame_id += 1 + x_patch_arr, resize_factor, x_amask_arr = sample_target( + image, + self.state, + self.cfg.TEST.SEARCH_FACTOR, + output_sz=self.cfg.TEST.SEARCH_SIZE) # (x1, y1, w, h) + search = self.preprocessor.process(x_patch_arr, x_amask_arr) + + with torch.no_grad(): + x_dict = search + # merge the template and the search + # run the transformer + out_dict = self.network.forward( + template=self.z_dict1.tensors, + search=x_dict.tensors, + ce_template_mask=self.box_mask_z) + + # add hann windows + pred_score_map = out_dict['score_map'] + response = self.output_window * pred_score_map + pred_boxes = self.network.box_head.cal_bbox(response, + out_dict['size_map'], + out_dict['offset_map']) + pred_boxes = pred_boxes.view(-1, 4) + # Baseline: Take the mean of all pred boxes as the final result + pred_box = (pred_boxes.mean(dim=0) * self.cfg.TEST.SEARCH_SIZE + / resize_factor).tolist() # (cx, cy, w, h) [0,1] + # get the final box result + self.state = clip_box( + self.map_box_back(pred_box, resize_factor), H, W, margin=10) + + x1, y1, w, h = self.state + x2 = x1 + w + y2 = y1 + h + return {'target_bbox': [x1, y1, x2, y2]} + + def map_box_back(self, pred_box: list, resize_factor: float): + cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[ + 1] + 0.5 * self.state[3] + cx, cy, w, h = pred_box + half_side = 0.5 * self.cfg.TEST.SEARCH_SIZE / resize_factor + cx_real = cx + (cx_prev - half_side) + cy_real = cy + (cy_prev - half_side) + return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h] + + def transform_bbox_to_crop(self, + box_in, + resize_factor, + device, + box_extract=None, + crop_type='template'): + if crop_type == 'template': + crop_sz = torch.Tensor( + [self.cfg.TEST.TEMPLATE_SIZE, self.cfg.TEST.TEMPLATE_SIZE]) + elif crop_type == 'search': + crop_sz = torch.Tensor( + [self.cfg.TEST.SEARCH_SIZE, self.cfg.TEST.SEARCH_SIZE]) + else: + raise NotImplementedError + + box_in = torch.tensor(box_in) + if box_extract is None: + box_extract = box_in + else: + box_extract = torch.tensor(box_extract) + template_bbox = transform_image_to_crop( + box_in, box_extract, resize_factor, crop_sz, normalize=True) + template_bbox = template_bbox.view(1, 1, 4).to(device) + + return template_bbox diff --git a/modelscope/models/cv/video_single_object_tracking/utils/__init__.py b/modelscope/models/cv/video_single_object_tracking/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_single_object_tracking/utils/utils.py b/modelscope/models/cv/video_single_object_tracking/utils/utils.py new file mode 100644 index 00000000..90513a2a --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/utils/utils.py @@ -0,0 +1,247 @@ +# The implementation is adopted from OSTrack, +# made publicly available under the MIT License at https://github.com/botaoye/OSTrack/ +import math +from typing import Optional + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + + +def hann1d(sz: int, centered=True) -> torch.Tensor: + """1D cosine window.""" + if centered: + return 0.5 * (1 - torch.cos( + (2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float())) + w = 0.5 * (1 + torch.cos( + (2 * math.pi / (sz + 2)) * torch.arange(0, sz // 2 + 1).float())) + return torch.cat([w, w[1:sz - sz // 2].flip((0, ))]) + + +def hann2d(sz: torch.Tensor, centered=True) -> torch.Tensor: + """2D cosine window.""" + return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d( + sz[1].item(), centered).reshape(1, 1, 1, -1) + + +class NestedTensor(object): + + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + +class Preprocessor(object): + + def __init__(self, device: str): + self.device = device + self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)) + self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)) + if 'cuda' == self.device.type: + self.mean = self.mean.to(self.device) + self.std = self.std.to(self.device) + + def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): + # Deal with the image patch + if 'cuda' == self.device.type: + img_tensor = torch.tensor(img_arr).to(self.device).float().permute( + (2, 0, 1)).unsqueeze(dim=0) + else: + img_tensor = torch.tensor(img_arr).float().permute( + (2, 0, 1)).unsqueeze(dim=0) + img_tensor_norm = ( + (img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W) + + # Deal with the attention mask + if 'cuda' == self.device.type: + amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).to( + self.device).unsqueeze(dim=0) # (1,H,W) + else: + amask_tensor = torch.from_numpy(amask_arr).to( + torch.bool).unsqueeze(dim=0) # (1,H,W) + return NestedTensor(img_tensor_norm, amask_tensor) + + +def clip_box(box: list, H, W, margin=0): + x1, y1, w, h = box + x2, y2 = x1 + w, y1 + h + x1 = min(max(0, x1), W - margin) + x2 = min(max(margin, x2), W) + y1 = min(max(0, y1), H - margin) + y2 = min(max(margin, y2), H) + w = max(margin, x2 - x1) + h = max(margin, y2 - y1) + if isinstance(x1, torch.Tensor): + x1 = x1.item() + y1 = y1.item() + w = w.item() + h = h.item() + return [x1, y1, w, h] + + +def generate_mask_cond(cfg, bs, device, gt_bbox): + template_size = cfg.DATA.TEMPLATE.SIZE + stride = cfg.MODEL.BACKBONE.STRIDE + template_feat_size = template_size // stride + + if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT': + if template_feat_size == 8: + index = slice(3, 4) + elif template_feat_size == 12: + index = slice(5, 6) + elif template_feat_size == 7: + index = slice(3, 4) + elif template_feat_size == 14: + index = slice(6, 7) + else: + raise NotImplementedError + box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], + device=device) + box_mask_z[:, index, index] = 1 + box_mask_z = box_mask_z.flatten(1).to(torch.bool) + else: + raise NotImplementedError + + return box_mask_z + + +def sample_target(im, + target_bb, + search_area_factor, + output_sz=None, + mask=None): + """ Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area + + args: + im - cv image + target_bb - target box [x, y, w, h] + search_area_factor - Ratio of crop size to target size + output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done. + + returns: + cv image - extracted crop + float - the factor by which the crop has been resized to make the crop size equal output_size + """ + if not isinstance(target_bb, list): + x, y, w, h = target_bb.tolist() + else: + x, y, w, h = target_bb + # Crop image + crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor) + + if crop_sz < 1: + raise Exception('Too small bounding box.') + + x1 = round(x + 0.5 * w - crop_sz * 0.5) + x2 = x1 + crop_sz + + y1 = round(y + 0.5 * h - crop_sz * 0.5) + y2 = y1 + crop_sz + + x1_pad = max(0, -x1) + x2_pad = max(x2 - im.shape[1] + 1, 0) + + y1_pad = max(0, -y1) + y2_pad = max(y2 - im.shape[0] + 1, 0) + + # Crop target + im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :] + if mask is not None: + mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad] + + # Pad + im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, + x2_pad, cv2.BORDER_CONSTANT) + # deal with attention mask + H, W, _ = im_crop_padded.shape + att_mask = np.ones((H, W)) + end_x, end_y = -x2_pad, -y2_pad + if y2_pad == 0: + end_y = None + if x2_pad == 0: + end_x = None + att_mask[y1_pad:end_y, x1_pad:end_x] = 0 + if mask is not None: + mask_crop_padded = F.pad( + mask_crop, + pad=(x1_pad, x2_pad, y1_pad, y2_pad), + mode='constant', + value=0) + + if output_sz is not None: + resize_factor = output_sz / crop_sz + im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz)) + att_mask = cv2.resize(att_mask, + (output_sz, output_sz)).astype(np.bool_) + if mask is None: + return im_crop_padded, resize_factor, att_mask + mask_crop_padded = \ + F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), + mode='bilinear', align_corners=False)[0, 0] + return im_crop_padded, resize_factor, att_mask, mask_crop_padded + + else: + if mask is None: + return im_crop_padded, att_mask.astype(np.bool_), 1.0 + return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded + + +def transform_image_to_crop(box_in: torch.Tensor, + box_extract: torch.Tensor, + resize_factor: float, + crop_sz: torch.Tensor, + normalize=False) -> torch.Tensor: + """ Transform the box co-ordinates from the original image co-ordinates to the co-ordinates of the cropped image + args: + box_in - the box for which the co-ordinates are to be transformed + box_extract - the box about which the image crop has been extracted. + resize_factor - the ratio between the original image scale and the scale of the image crop + crop_sz - size of the cropped image + + returns: + torch.Tensor - transformed co-ordinates of box_in + """ + box_extract_center = box_extract[0:2] + 0.5 * box_extract[2:4] + + box_in_center = box_in[0:2] + 0.5 * box_in[2:4] + + box_out_center = (crop_sz - 1) / 2 + (box_in_center + - box_extract_center) * resize_factor + box_out_wh = box_in[2:4] * resize_factor + + box_out = torch.cat((box_out_center - 0.5 * box_out_wh, box_out_wh)) + if normalize: + return box_out / crop_sz[0] + else: + return box_out + + +def check_box(box: list, image_height, image_width) -> bool: + """ To check whether the box is within the image range or not + args: + box - the bounding box in the form of [x1, y1, x2, y2] + image_height - the height of the image + image_width - the width of the image + + returns: + bool - if box is valid, return True. Otherwise, return False + """ + assert len(box) == 4, 'box must be in the form of: [x1, y1, x2, y2]' + if box[0] < 0 or box[0] >= image_width: + return False + if box[2] < 0 or box[2] >= image_width: + return False + if box[1] < 0 or box[1] >= image_height: + return False + if box[3] < 0 or box[3] >= image_height: + return False + return True + + +def timestamp_format(seconds): + m, s = divmod(seconds, 60) + h, m = divmod(m, 60) + time = '%02d:%02d:%06.3f' % (h, m, s) + return time diff --git a/modelscope/models/cv/video_summarization/__init__.py b/modelscope/models/cv/video_summarization/__init__.py new file mode 100644 index 00000000..15ad61b4 --- /dev/null +++ b/modelscope/models/cv/video_summarization/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .summarizer import (PGLVideoSummarization, summary_format) + +else: + _import_structure = { + 'summarizer': ['PGLVideoSummarization', 'summary_format'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/video_summarization/base_model.py b/modelscope/models/cv/video_summarization/base_model.py new file mode 100644 index 00000000..912ba68d --- /dev/null +++ b/modelscope/models/cv/video_summarization/base_model.py @@ -0,0 +1,119 @@ +# Part of the implementation is borrowed and modified from pytorch-caffe-models, +# publicly available at https://github.com/crowsonkb/pytorch-caffe-models + +import cv2 +import numpy as np +import torch +import torch.nn as nn + + +class Inception(nn.Module): + + def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, + pool_proj): + super().__init__() + self.conv_1x1 = nn.Conv2d(in_channels, ch1x1, 1) + self.relu_1x1 = nn.ReLU(inplace=True) + self.conv_3x3_reduce = nn.Conv2d(in_channels, ch3x3red, 1) + self.relu_3x3_reduce = nn.ReLU(inplace=True) + self.conv_3x3 = nn.Conv2d(ch3x3red, ch3x3, 3, padding=1) + self.relu_3x3 = nn.ReLU(inplace=True) + self.conv_5x5_reduce = nn.Conv2d(in_channels, ch5x5red, 1) + self.relu_5x5_reduce = nn.ReLU(inplace=True) + self.conv_5x5 = nn.Conv2d(ch5x5red, ch5x5, 5, padding=2) + self.relu_5x5 = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(3, stride=1, padding=1) + self.pool_proj = nn.Conv2d(in_channels, pool_proj, 1) + self.relu_pool_proj = nn.ReLU(inplace=True) + + def forward(self, x): + branch_1 = self.relu_1x1(self.conv_1x1(x)) + branch_2 = self.relu_3x3_reduce(self.conv_3x3_reduce(x)) + branch_2 = self.relu_3x3(self.conv_3x3(branch_2)) + branch_3 = self.relu_5x5_reduce(self.conv_5x5_reduce(x)) + branch_3 = self.relu_5x5(self.conv_5x5(branch_3)) + branch_4 = self.pool(x) + branch_4 = self.relu_pool_proj(self.pool_proj(branch_4)) + return torch.cat([branch_1, branch_2, branch_3, branch_4], dim=1) + + +class GoogLeNet(nn.Sequential): + + def __init__(self, num_classes=1000): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.norm1 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) + self.conv2_reduce = nn.Conv2d(64, 64, kernel_size=1) + self.relu2_reduce = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1) + self.relu2 = nn.ReLU(inplace=True) + self.norm2 = nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75) + self.pool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.inception_3a = Inception(192, 64, 96, 128, 16, 32, 32) + self.inception_3b = Inception(256, 128, 128, 192, 32, 96, 64) + self.pool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.inception_4a = Inception(480, 192, 96, 208, 16, 48, 64) + self.inception_4b = Inception(512, 160, 112, 224, 24, 64, 64) + self.inception_4c = Inception(512, 128, 128, 256, 24, 64, 64) + self.inception_4d = Inception(512, 112, 144, 288, 32, 64, 64) + self.inception_4e = Inception(528, 256, 160, 320, 32, 128, 128) + self.pool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.inception_5a = Inception(832, 256, 160, 320, 32, 128, 128) + self.inception_5b = Inception(832, 384, 192, 384, 48, 128, 128) + self.pool5 = nn.AdaptiveAvgPool2d((1, 1)) + self.loss3_classifier = nn.Linear(1024, num_classes) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.pool1(x) + x = self.norm1(x) + x = self.relu2_reduce(self.conv2_reduce(x)) + x = self.relu2(self.conv2(x)) + x = self.norm2(x) + x = self.pool2(x) + x = self.inception_3a(x) + x = self.inception_3b(x) + x = self.pool3(x) + x = self.inception_4a(x) + x = self.inception_4b(x) + x = self.inception_4c(x) + x = self.inception_4d(x) + x = self.inception_4e(x) + x = self.pool4(x) + x = self.inception_5a(x) + x = self.inception_5b(x) + x = self.pool5(x).flatten(1) + return x + + +class bvlc_googlenet(nn.Module): + + def __init__(self, input_size=224): + """model for the BVLC GoogLeNet, trained on ImageNet. + URL: https://github.com/BVLC/caffe/tree/master/models/bvlc_googlenet""" + super(bvlc_googlenet, self).__init__() + + self.model = GoogLeNet(num_classes=1000) + + self.input_size = input_size + self.input_mean = (104.0, 117.0, 123.0) + + def forward(self, frame): + x = cv2.resize(frame, + (self.input_size, self.input_size)).astype(np.float32) + x = (x - self.input_mean).astype(np.float32) + x = np.transpose(x, [2, 0, 1]) + + x = np.expand_dims(x, 0) + x = torch.from_numpy(x) + if not next(self.model.parameters()).device.type == 'cpu': + x = x.cuda() + with torch.no_grad(): + frame_feat = self.model(x) + if not frame_feat.device.type == 'cpu': + frame_feat = frame_feat.cpu() + frame_feat = frame_feat.numpy() + frame_feat = frame_feat / np.linalg.norm(frame_feat) + return frame_feat.reshape(-1) diff --git a/modelscope/models/cv/video_summarization/kts/__init__.py b/modelscope/models/cv/video_summarization/kts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_summarization/kts/cpd_auto.py b/modelscope/models/cv/video_summarization/kts/cpd_auto.py new file mode 100644 index 00000000..58281df8 --- /dev/null +++ b/modelscope/models/cv/video_summarization/kts/cpd_auto.py @@ -0,0 +1,36 @@ +# Part of the implementation is borrowed and modified from KTS, +# publicly available at https://github.com/TatsuyaShirakawa/KTS + +import numpy as np + +from .cpd_nonlin import cpd_nonlin + + +def cpd_auto(K, ncp, vmax, desc_rate=1, **kwargs): + """Detect change points automatically selecting their number + + :param K: Kernel between each pair of frames in video + :param ncp: Maximum number of change points + :param vmax: Special parameter + :param desc_rate: Rate of descriptor sampling, vmax always corresponds to 1x + :param kwargs: Extra parameters for ``cpd_nonlin`` + :return: Tuple (cps, costs) + - cps - best selected change-points + - costs - costs for 0,1,2,...,m change-points + """ + m = ncp + _, scores = cpd_nonlin(K, m, backtrack=False, **kwargs) + + N = K.shape[0] + N2 = N * desc_rate # length of the video before down-sampling + + penalties = np.zeros(m + 1) + # Prevent division by zero (in case of 0 changes) + ncp = np.arange(1, m + 1) + penalties[1:] = (vmax * ncp / (2.0 * N2)) * (np.log(float(N2) / ncp) + 1) + + costs = scores / float(N) + penalties + m_best = np.argmin(costs) + cps, scores2 = cpd_nonlin(K, m_best, **kwargs) + + return cps, scores2 diff --git a/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py b/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py new file mode 100644 index 00000000..55e279e9 --- /dev/null +++ b/modelscope/models/cv/video_summarization/kts/cpd_nonlin.py @@ -0,0 +1,103 @@ +# Part of the implementation is borrowed and modified from KTS, +# publicly available at https://github.com/TatsuyaShirakawa/KTS + +import numpy as np + + +def calc_scatters(K): + """Calculate scatter matrix: scatters[i,j] = {scatter of the sequence with + starting frame i and ending frame j} + """ + n = K.shape[0] + K1 = np.cumsum([0] + list(np.diag(K))) + K2 = np.zeros((n + 1, n + 1)) + # TODO: use the fact that K - symmetric + K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1) + + diagK2 = np.diag(K2) + + i = np.arange(n).reshape((-1, 1)) + j = np.arange(n).reshape((1, -1)) + + ij_f32 = ((j - i + 1).astype(np.float32) + (j == i - 1).astype(np.float32)) + diagK2_K2 = ( + diagK2[1:].reshape((1, -1)) + diagK2[:-1].reshape( + (-1, 1)) - K2[1:, :-1].T - K2[:-1, 1:]) + scatters = ( + K1[1:].reshape((1, -1)) - K1[:-1].reshape( + (-1, 1)) - diagK2_K2 / ij_f32) + + scatters[j < i] = 0 + + return scatters + + +def cpd_nonlin(K, + ncp, + lmin=1, + lmax=100000, + backtrack=True, + verbose=True, + out_scatters=None): + """Change point detection with dynamic programming + + :param K: Square kernel matrix + :param ncp: Number of change points to detect (ncp >= 0) + :param lmin: Minimal length of a segment + :param lmax: Maximal length of a segment + :param backtrack: If False - only evaluate objective scores (to save memory) + :param verbose: If true, print verbose message + :param out_scatters: Output scatters + :return: Tuple (cps, obj_vals) + - cps - detected array of change points: mean is thought to be constant + on [ cps[i], cps[i+1] ) + - obj_vals - values of the objective function for 0..m changepoints + """ + m = int(ncp) # prevent numpy.int64 + + n, n1 = K.shape + assert n == n1, 'Kernel matrix awaited.' + assert (m + 1) * lmin <= n <= (m + 1) * lmax + assert 1 <= lmin <= lmax + + if verbose: + print('Precomputing scatters...') + J = calc_scatters(K) + + if out_scatters is not None: + out_scatters[0] = J + + if verbose: + print('Inferring best change points...') + # Iden[k, l] - value of the objective for k change-points and l first frames + Iden = 1e101 * np.ones((m + 1, n + 1)) + Iden[0, lmin:lmax] = J[0, lmin - 1:lmax - 1] + + if backtrack: + # p[k, l] --- 'previous change' --- best t[k] when t[k+1] equals l + p = np.zeros((m + 1, n + 1), dtype=int) + else: + p = np.zeros((1, 1), dtype=int) + + for k in range(1, m + 1): + for l_frame in range((k + 1) * lmin, n + 1): + tmin = max(k * lmin, l_frame - lmax) + tmax = l_frame - lmin + 1 + c = J[tmin:tmax, l_frame - 1].reshape(-1) + \ + Iden[k - 1, tmin:tmax].reshape(-1) + Iden[k, l_frame] = np.min(c) + if backtrack: + p[k, l_frame] = np.argmin(c) + tmin + + # Collect change points + cps = np.zeros(m, dtype=int) + + if backtrack: + cur = n + for k in range(m, 0, -1): + cps[k - 1] = p[k, cur] + cur = cps[k - 1] + + scores = Iden[:, n].copy() + scores[scores > 1e99] = np.inf + return cps, scores diff --git a/modelscope/models/cv/video_summarization/pgl_sum.py b/modelscope/models/cv/video_summarization/pgl_sum.py new file mode 100644 index 00000000..2d27501d --- /dev/null +++ b/modelscope/models/cv/video_summarization/pgl_sum.py @@ -0,0 +1,312 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SelfAttention(nn.Module): + + def __init__(self, + input_size=1024, + output_size=1024, + freq=10000, + heads=1, + pos_enc=None): + """ The basic (multi-head) Attention 'cell' containing the learnable parameters of Q, K and V + + :param int input_size: Feature input size of Q, K, V. + :param int output_size: Feature -hidden- size of Q, K, V. + :param int freq: The frequency of the sinusoidal positional encoding. + :param int heads: Number of heads for the attention module. + :param str | None pos_enc: The type of the positional encoding [supported: Absolute, Relative]. + """ + super(SelfAttention, self).__init__() + + self.permitted_encodings = ['absolute', 'relative'] + if pos_enc is not None: + pos_enc = pos_enc.lower() + assert pos_enc in self.permitted_encodings, f'Supported encodings: {*self.permitted_encodings,}' + + self.input_size = input_size + self.output_size = output_size + self.heads = heads + self.pos_enc = pos_enc + self.freq = freq + self.Wk, self.Wq, self.Wv = nn.ModuleList(), nn.ModuleList( + ), nn.ModuleList() + for _ in range(self.heads): + self.Wk.append( + nn.Linear( + in_features=input_size, + out_features=output_size // heads, + bias=False)) + self.Wq.append( + nn.Linear( + in_features=input_size, + out_features=output_size // heads, + bias=False)) + self.Wv.append( + nn.Linear( + in_features=input_size, + out_features=output_size // heads, + bias=False)) + self.out = nn.Linear( + in_features=output_size, out_features=input_size, bias=False) + + self.softmax = nn.Softmax(dim=-1) + self.drop = nn.Dropout(p=0.5) + + def getAbsolutePosition(self, T): + """Calculate the sinusoidal positional encoding based on the absolute position of each considered frame. + Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762) + + :param int T: Number of frames contained in Q, K and V + :return: Tensor with shape [T, T] + """ + freq = self.freq + d = self.input_size + + pos = torch.tensor([k for k in range(T)], + device=self.out.weight.device) + i = torch.tensor([k for k in range(T // 2)], + device=self.out.weight.device) + + # Reshape tensors each pos_k for each i indices + pos = pos.reshape(pos.shape[0], 1) + pos = pos.repeat_interleave(i.shape[0], dim=1) + i = i.repeat(pos.shape[0], 1) + + AP = torch.zeros(T, T, device=self.out.weight.device) + AP[pos, 2 * i] = torch.sin(pos / freq**((2 * i) / d)) + AP[pos, 2 * i + 1] = torch.cos(pos / freq**((2 * i) / d)) + return AP + + def getRelativePosition(self, T): + """Calculate the sinusoidal positional encoding based on the relative position of each considered frame. + r_pos calculations as here: https://theaisummer.com/positional-embeddings/ + + :param int T: Number of frames contained in Q, K and V + :return: Tensor with shape [T, T] + """ + freq = self.freq + d = 2 * T + min_rpos = -(T - 1) + + i = torch.tensor([k for k in range(T)], device=self.out.weight.device) + j = torch.tensor([k for k in range(T)], device=self.out.weight.device) + + # Reshape tensors each i for each j indices + i = i.reshape(i.shape[0], 1) + i = i.repeat_interleave(i.shape[0], dim=1) + j = j.repeat(i.shape[0], 1) + + # Calculate the relative positions + r_pos = j - i - min_rpos + + RP = torch.zeros(T, T, device=self.out.weight.device) + idx = torch.tensor([k for k in range(T // 2)], + device=self.out.weight.device) + RP[:, 2 * idx] = torch.sin( + r_pos[:, 2 * idx] / freq**((i[:, 2 * idx] + j[:, 2 * idx]) / d)) + RP[:, 2 * idx + 1] = torch.cos( + r_pos[:, 2 * idx + 1] + / freq**((i[:, 2 * idx + 1] + j[:, 2 * idx + 1]) / d)) + return RP + + def forward(self, x): + """ Compute the weighted frame features, based on either the global or local (multi-head) attention mechanism. + + :param torch.tensor x: Frame features with shape [T, input_size] + :return: A tuple of: + y: Weighted features based on the attention weights, with shape [T, input_size] + att_weights : The attention weights (before dropout), with shape [T, T] + """ + outputs = [] + for head in range(self.heads): + K = self.Wk[head](x) + Q = self.Wq[head](x) + V = self.Wv[head](x) + + # Q *= 0.06 # scale factor VASNet + # Q /= np.sqrt(self.output_size) # scale factor (i.e 1 / sqrt(d_k) ) + energies = torch.matmul(Q, K.transpose(1, 0)) + if self.pos_enc is not None: + if self.pos_enc == 'absolute': + AP = self.getAbsolutePosition(T=energies.shape[0]) + energies = energies + AP + elif self.pos_enc == 'relative': + RP = self.getRelativePosition(T=energies.shape[0]) + energies = energies + RP + + att_weights = self.softmax(energies) + _att_weights = self.drop(att_weights) + y = torch.matmul(_att_weights, V) + + # Save the current head output + outputs.append(y) + y = self.out(torch.cat(outputs, dim=1)) + return y, att_weights.clone( + ) # for now we don't deal with the weights (probably max or avg pooling) + + +class MultiAttention(nn.Module): + + def __init__(self, + input_size=1024, + output_size=1024, + freq=10000, + pos_enc=None, + num_segments=None, + heads=1, + fusion=None): + """ Class wrapping the MultiAttention part of PGL-SUM; its key modules and parameters. + + :param int input_size: The expected input feature size. + :param int output_size: The hidden feature size of the attention mechanisms. + :param int freq: The frequency of the sinusoidal positional encoding. + :param None | str pos_enc: The selected positional encoding [absolute, relative]. + :param None | int num_segments: The selected number of segments to split the videos. + :param int heads: The selected number of global heads. + :param None | str fusion: The selected type of feature fusion. + """ + super(MultiAttention, self).__init__() + + # Global Attention, considering differences among all frames + self.attention = SelfAttention( + input_size=input_size, + output_size=output_size, + freq=freq, + pos_enc=pos_enc, + heads=heads) + + self.num_segments = num_segments + if self.num_segments is not None: + assert self.num_segments >= 2, 'num_segments must be None or 2+' + self.local_attention = nn.ModuleList() + for _ in range(self.num_segments): + # Local Attention, considering differences among the same segment with reduce hidden size + self.local_attention.append( + SelfAttention( + input_size=input_size, + output_size=output_size // num_segments, + freq=freq, + pos_enc=pos_enc, + heads=4)) + self.permitted_fusions = ['add', 'mult', 'avg', 'max'] + self.fusion = fusion + if self.fusion is not None: + self.fusion = self.fusion.lower() + assert self.fusion in self.permitted_fusions, f'Fusion method must be: {*self.permitted_fusions,}' + + def forward(self, x): + """ Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms. + + :param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features. + :return: A tuple of: + weighted_value: Tensor with shape [T, input_size] containing the weighted frame features. + attn_weights: Tensor with shape [T, T] containing the attention weights. + """ + weighted_value, attn_weights = self.attention(x) # global attention + + if self.num_segments is not None and self.fusion is not None: + segment_size = math.ceil(x.shape[0] / self.num_segments) + for segment in range(self.num_segments): + left_pos = segment * segment_size + right_pos = (segment + 1) * segment_size + local_x = x[left_pos:right_pos] + weighted_local_value, attn_local_weights = self.local_attention[ + segment](local_x) # local attentions + + # Normalize the features vectors + weighted_value[left_pos:right_pos] = F.normalize( + weighted_value[left_pos:right_pos].clone(), p=2, dim=1) + weighted_local_value = F.normalize( + weighted_local_value, p=2, dim=1) + if self.fusion == 'add': + weighted_value[left_pos:right_pos] += weighted_local_value + elif self.fusion == 'mult': + weighted_value[left_pos:right_pos] *= weighted_local_value + elif self.fusion == 'avg': + weighted_value[left_pos:right_pos] += weighted_local_value + weighted_value[left_pos:right_pos] /= 2 + elif self.fusion == 'max': + weighted_value[left_pos:right_pos] = torch.max( + weighted_value[left_pos:right_pos].clone(), + weighted_local_value) + + return weighted_value, attn_weights + + +class PGL_SUM(nn.Module): + + def __init__(self, + input_size=1024, + output_size=1024, + freq=10000, + pos_enc=None, + num_segments=None, + heads=1, + fusion=None): + """ Class wrapping the PGL-SUM model; its key modules and parameters. + + :param int input_size: The expected input feature size. + :param int output_size: The hidden feature size of the attention mechanisms. + :param int freq: The frequency of the sinusoidal positional encoding. + :param None | str pos_enc: The selected positional encoding [absolute, relative]. + :param None | int num_segments: The selected number of segments to split the videos. + :param int heads: The selected number of global heads. + :param None | str fusion: The selected type of feature fusion. + """ + super(PGL_SUM, self).__init__() + + self.attention = MultiAttention( + input_size=input_size, + output_size=output_size, + freq=freq, + pos_enc=pos_enc, + num_segments=num_segments, + heads=heads, + fusion=fusion) + self.linear_1 = nn.Linear( + in_features=input_size, out_features=input_size) + self.linear_2 = nn.Linear( + in_features=self.linear_1.out_features, out_features=1) + + self.drop = nn.Dropout(p=0.5) + self.norm_y = nn.LayerNorm(normalized_shape=input_size, eps=1e-6) + self.norm_linear = nn.LayerNorm( + normalized_shape=self.linear_1.out_features, eps=1e-6) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, frame_features): + """ Produce frames importance scores from the frame features, using the PGL-SUM model. + + :param torch.Tensor frame_features: Tensor of shape [T, input_size] containing the frame features produced by + using the pool5 layer of GoogleNet. + :return: A tuple of: + y: Tensor with shape [1, T] containing the frames importance scores in [0, 1]. + attn_weights: Tensor with shape [T, T] containing the attention weights. + """ + frame_features = frame_features.reshape(-1, frame_features.shape[-1]) + residual = frame_features + weighted_value, attn_weights = self.attention(frame_features) + y = weighted_value + residual + y = self.drop(y) + y = self.norm_y(y) + + # 2-layer NN (Regressor Network) + y = self.linear_1(y) + y = self.relu(y) + y = self.drop(y) + y = self.norm_linear(y) + + y = self.linear_2(y) + y = self.sigmoid(y) + y = y.view(1, -1) + + return y, attn_weights diff --git a/modelscope/models/cv/video_summarization/summarizer.py b/modelscope/models/cv/video_summarization/summarizer.py new file mode 100644 index 00000000..c9987670 --- /dev/null +++ b/modelscope/models/cv/video_summarization/summarizer.py @@ -0,0 +1,266 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM + +import os.path as osp +from copy import deepcopy +from typing import Dict, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.video_summarization.kts.cpd_auto import cpd_auto +from modelscope.models.cv.video_summarization.pgl_sum import PGL_SUM +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def get_change_points(video_feat, n_frame): + video_feat = np.array(video_feat, np.float32) + K = np.dot(video_feat, video_feat.T) + change_points, _ = cpd_auto( + K, ncp=min(K.shape[0] - 1, 120), vmax=2.2 / 4.0, lmin=1) + change_points = change_points * 15 + change_points = np.concatenate(([0], change_points, [n_frame - 1])) + + temp_change_points = [] + for idx in range(len(change_points) - 1): + segment = [change_points[idx], change_points[idx + 1] - 1] + if idx == len(change_points) - 2: + segment = [change_points[idx], change_points[idx + 1]] + + temp_change_points.append(segment) + change_points = np.array(list(temp_change_points)) + + temp_n_frame_per_seg = [] + for change_points_idx in range(len(change_points)): + n_frame = change_points[change_points_idx][1] - change_points[ + change_points_idx][0] + temp_n_frame_per_seg.append(n_frame) + n_frame_per_seg = np.array(list(temp_n_frame_per_seg)) + + return change_points, n_frame_per_seg + + +def knap_sack(W, wt, val, n): + """ Maximize the value that a knapsack of capacity W can hold. You can either put the item or discard it, there is + no concept of putting some part of item in the knapsack. + + :param int W: Maximum capacity -in frames- of the knapsack. + :param list[int] wt: The weights (lengths -in frames-) of each video shot. + :param list[float] val: The values (importance scores) of each video shot. + :param int n: The number of the shots. + :return: A list containing the indices of the selected shots. + """ + K = [[0 for _ in range(W + 1)] for _ in range(n + 1)] + + # Build table K[][] in bottom up manner + for i in range(n + 1): + for w in range(W + 1): + if i == 0 or w == 0: + K[i][w] = 0 + elif wt[i - 1] <= w: + K[i][w] = max(val[i - 1] + K[i - 1][w - wt[i - 1]], + K[i - 1][w]) + else: + K[i][w] = K[i - 1][w] + + selected = [] + w = W + for i in range(n, 0, -1): + if K[i][w] != K[i - 1][w]: + selected.insert(0, i - 1) + w -= wt[i - 1] + + return selected + + +def generate_summary(all_shot_bound, all_scores, all_nframes, all_positions): + """ Generate the automatic machine summary, based on the video shots; the frame importance scores; the number of + frames in the original video and the position of the sub-sampled frames of the original video. + + :param list[np.ndarray] all_shot_bound: The video shots for all the -original- testing videos. + :param list[np.ndarray] all_scores: The calculated frame importance scores for all the sub-sampled testing videos. + :param list[np.ndarray] all_nframes: The number of frames for all the -original- testing videos. + :param list[np.ndarray] all_positions: The position of the sub-sampled frames for all the -original- testing videos. + :return: A list containing the indices of the selected frames for all the -original- testing videos. + """ + all_summaries = [] + for video_index in range(len(all_scores)): + # Get shots' boundaries + shot_bound = all_shot_bound[video_index] # [number_of_shots, 2] + frame_init_scores = all_scores[video_index] + n_frames = all_nframes[video_index] + positions = all_positions[video_index] + + # Compute the importance scores for the initial frame sequence (not the sub-sampled one) + frame_scores = np.zeros(n_frames, dtype=np.float32) + if positions.dtype != int: + positions = positions.astype(np.int32) + if positions[-1] != n_frames: + positions = np.concatenate([positions, [n_frames]]) + for i in range(len(positions) - 1): + pos_left, pos_right = positions[i], positions[i + 1] + if i == len(frame_init_scores): + frame_scores[pos_left:pos_right] = 0 + else: + frame_scores[pos_left:pos_right] = frame_init_scores[i] + + # Compute shot-level importance scores by taking the average importance scores of all frames in the shot + shot_imp_scores = [] + shot_lengths = [] + for shot in shot_bound: + shot_lengths.append(shot[1] - shot[0] + 1) + shot_imp_scores.append( + (frame_scores[shot[0]:shot[1] + 1].mean()).item()) + + # Select the best shots using the knapsack implementation + final_shot = shot_bound[-1] + final_max_length = int((final_shot[1] + 1) * 0.15) + + selected = knap_sack(final_max_length, shot_lengths, shot_imp_scores, + len(shot_lengths)) + + # Select all frames from each selected shot (by setting their value in the summary vector to 1) + summary = np.zeros(final_shot[1] + 1, dtype=np.int8) + for shot in selected: + summary[shot_bound[shot][0]:shot_bound[shot][1] + 1] = 1 + + all_summaries.append(summary) + + return all_summaries + + +def transform_time(seconds): + m, s = divmod(seconds, 60) + h, m = divmod(m, 60) + time = '%02d:%02d:%06.3f' % (h, m, s) + return time + + +def summary_format(summary, fps): + frames_list = [] + start_frame = -1 + end_frame = -1 + is_summary_frame = False + for i, idx in enumerate(summary): + if idx: + if is_summary_frame is False: + start_frame = i + is_summary_frame = True + else: + if is_summary_frame: + end_frame = i - 1 + frames_list.append([start_frame, end_frame]) + is_summary_frame = False + + if is_summary_frame and summary[-1] == 1: + end_frame = len(summary) - 1 + frames_list.append([start_frame, end_frame]) + + output = [] + for seg in frames_list: + output.append({ + 'frame': + seg, + 'timestamps': [ + transform_time(seg[0] / float(fps)), + transform_time(seg[1] / float(fps)) + ] + }) + return output + + +@MODELS.register_module( + Tasks.video_summarization, module_name=Models.video_summarization) +class PGLVideoSummarization(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the video summarization model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + + self.loss = nn.MSELoss() + self.model = PGL_SUM( + input_size=1024, + output_size=1024, + num_segments=4, + heads=8, + fusion='add', + pos_enc='absolute') + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.model = self.model.to(self._device) + + self.model = self.load_pretrained(self.model, model_path) + + if self.training: + self.model.train() + else: + self.model.eval() + + def load_pretrained(self, net, load_path, strict=True, param_key='params'): + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + load_net = torch.load( + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info( + f'Loading: {param_key} does not exist, use params.') + if param_key in load_net: + load_net = load_net[param_key] + logger.info( + f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' + ) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + net.load_state_dict(load_net, strict=strict) + logger.info('load model done.') + return net + + def _train_forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + frame_features = input['frame_features'] + gtscore = input['gtscore'] + preds, attn_weights = self.model(frame_features) + return {'loss': self.loss(preds, gtscore)} + + def _inference_forward(self, input: Dict[str, + Tensor]) -> Dict[str, Tensor]: + frame_features = input['frame_features'] + y, attn_weights = self.model(frame_features) + return {'scores': y} + + def forward(self, input: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Union[list, Tensor]]: results + """ + for key, value in input.items(): + input[key] = input[key].to(self._device) + if self.training: + return self._train_forward(input) + else: + return self._inference_forward(input) diff --git a/modelscope/models/cv/virual_tryon/__init__.py b/modelscope/models/cv/virual_tryon/__init__.py new file mode 100644 index 00000000..10def17a --- /dev/null +++ b/modelscope/models/cv/virual_tryon/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .sdafnet import SDAFNet_Tryon + +else: + _import_structure = {'sdafnet': ['SDAFNet_Tryon']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/virual_tryon/sdafnet.py b/modelscope/models/cv/virual_tryon/sdafnet.py new file mode 100644 index 00000000..f98a5e7d --- /dev/null +++ b/modelscope/models/cv/virual_tryon/sdafnet.py @@ -0,0 +1,442 @@ +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models import MODELS +from modelscope.utils.constant import ModelFile, Tasks + + +def apply_offset(offset): + sizes = list(offset.size()[2:]) + grid_list = torch.meshgrid( + [torch.arange(size, device=offset.device) for size in sizes]) + grid_list = reversed(grid_list) + # apply offset + grid_list = [ + grid.float().unsqueeze(0) + offset[:, dim, ...] + for dim, grid in enumerate(grid_list) + ] + # normalize + grid_list = [ + grid / ((size - 1.0) / 2.0) - 1.0 + for grid, size in zip(grid_list, reversed(sizes)) + ] + + return torch.stack(grid_list, dim=-1) + + +# backbone +class ResBlock(nn.Module): + + def __init__(self, in_channels): + super(ResBlock, self).__init__() + self.block = nn.Sequential( + nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), + nn.Conv2d( + in_channels, in_channels, kernel_size=3, + padding=1, bias=False), nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + in_channels, in_channels, kernel_size=3, padding=1, + bias=False)) + + def forward(self, x): + return self.block(x) + x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, out_channels): + super(Downsample, self).__init__() + self.block = nn.Sequential( + nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False)) + + def forward(self, x): + return self.block(x) + + +class FeatureEncoder(nn.Module): + + def __init__(self, in_channels, chns=[64, 128, 256, 256, 256]): + # in_channels = 3 for images, and is larger (e.g., 17+1+1) for agnositc representation + super(FeatureEncoder, self).__init__() + self.encoders = [] + for i, out_chns in enumerate(chns): + if i == 0: + encoder = nn.Sequential( + Downsample(in_channels, out_chns), ResBlock(out_chns), + ResBlock(out_chns)) + else: + encoder = nn.Sequential( + Downsample(chns[i - 1], out_chns), ResBlock(out_chns), + ResBlock(out_chns)) + + self.encoders.append(encoder) + + self.encoders = nn.ModuleList(self.encoders) + + def forward(self, x): + encoder_features = [] + for encoder in self.encoders: + x = encoder(x) + encoder_features.append(x) + return encoder_features + + +class RefinePyramid(nn.Module): + + def __init__(self, chns=[64, 128, 256, 256, 256], fpn_dim=256): + super(RefinePyramid, self).__init__() + self.chns = chns + + # adaptive + self.adaptive = [] + for in_chns in list(reversed(chns)): + adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1) + self.adaptive.append(adaptive_layer) + self.adaptive = nn.ModuleList(self.adaptive) + # output conv + self.smooth = [] + for i in range(len(chns)): + smooth_layer = nn.Conv2d( + fpn_dim, fpn_dim, kernel_size=3, padding=1) + self.smooth.append(smooth_layer) + self.smooth = nn.ModuleList(self.smooth) + + def forward(self, x): + conv_ftr_list = x + + feature_list = [] + last_feature = None + for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))): + # adaptive + feature = self.adaptive[i](conv_ftr) + # fuse + if last_feature is not None: + feature = feature + F.interpolate( + last_feature, scale_factor=2, mode='nearest') + # smooth + feature = self.smooth[i](feature) + last_feature = feature + feature_list.append(feature) + + return tuple(reversed(feature_list)) + + +def DAWarp(feat, offsets, att_maps, sample_k, out_ch): + att_maps = torch.repeat_interleave(att_maps, out_ch, 1) + B, C, H, W = feat.size() + multi_feat = torch.repeat_interleave(feat, sample_k, 0) + multi_warp_feat = F.grid_sample( + multi_feat, + offsets.detach().permute(0, 2, 3, 1), + mode='bilinear', + padding_mode='border') + multi_att_warp_feat = multi_warp_feat.reshape(B, -1, H, W) * att_maps + att_warp_feat = sum(torch.split(multi_att_warp_feat, out_ch, 1)) + return att_warp_feat + + +class MFEBlock(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + num_filters=[128, 64, 32]): + super(MFEBlock, self).__init__() + layers = [] + for i in range(len(num_filters)): + if i == 0: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=num_filters[i], + kernel_size=3, + stride=1, + padding=1)) + else: + layers.append( + torch.nn.Conv2d( + in_channels=num_filters[i - 1], + out_channels=num_filters[i], + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2)) + layers.append( + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)) + layers.append( + torch.nn.Conv2d( + in_channels=num_filters[-1], + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2)) + self.layers = torch.nn.Sequential(*layers) + + def forward(self, input): + return self.layers(input) + + +class DAFlowNet(nn.Module): + + def __init__(self, num_pyramid, fpn_dim=256, head_nums=1): + super(DAFlowNet, self).__init__() + self.Self_MFEs = [] + + self.Cross_MFEs = [] + self.Refine_MFEs = [] + self.k = head_nums + self.out_ch = fpn_dim + for i in range(num_pyramid): + # self-MFE for model img 2k:flow 1k:att_map + Self_MFE_layer = MFEBlock( + in_channels=2 * fpn_dim, + out_channels=self.k * 3, + kernel_size=7) + # cross-MFE for cloth img + Cross_MFE_layer = MFEBlock( + in_channels=2 * fpn_dim, out_channels=self.k * 3) + # refine-MFE for cloth and model imgs + Refine_MFE_layer = MFEBlock( + in_channels=2 * fpn_dim, out_channels=self.k * 6) + self.Self_MFEs.append(Self_MFE_layer) + self.Cross_MFEs.append(Cross_MFE_layer) + self.Refine_MFEs.append(Refine_MFE_layer) + + self.Self_MFEs = nn.ModuleList(self.Self_MFEs) + self.Cross_MFEs = nn.ModuleList(self.Cross_MFEs) + self.Refine_MFEs = nn.ModuleList(self.Refine_MFEs) + + self.lights_decoder = torch.nn.Sequential( + torch.nn.Conv2d(64, out_channels=32, kernel_size=1, stride=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=3, + kernel_size=3, + stride=1, + padding=1)) + self.lights_encoder = torch.nn.Sequential( + torch.nn.Conv2d( + 3, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, out_channels=64, kernel_size=1, stride=1)) + + def forward(self, + source_image, + reference_image, + source_feats, + reference_feats, + return_all=False, + warp_feature=True, + use_light_en_de=True): + r""" + Args: + source_image: cloth rgb image for tryon + reference_image: model rgb image for try on + source_feats: cloth FPN features + reference_feats: model and pose features + return_all: bool return all intermediate try-on results in training phase + warp_feature: use DAFlow for both features and images + use_light_en_de: use shallow encoder and decoder to project the images from RGB to high dimensional space + + """ + + # reference branch inputs model img using self-DAFlow + last_multi_self_offsets = None + # source branch inputs cloth img using cross-DAFlow + last_multi_cross_offsets = None + + if return_all: + results_all = [] + + for i in range(len(source_feats)): + + feat_source = source_feats[len(source_feats) - 1 - i] + feat_ref = reference_feats[len(reference_feats) - 1 - i] + B, C, H, W = feat_source.size() + + # Pre-DAWarp for Pyramid feature + if last_multi_cross_offsets is not None and warp_feature: + att_source_feat = DAWarp(feat_source, last_multi_cross_offsets, + cross_att_maps, self.k, self.out_ch) + att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets, + self_att_maps, self.k, self.out_ch) + else: + att_source_feat = feat_source + att_reference_feat = feat_ref + # Cross-MFE + input_feat = torch.cat([att_source_feat, feat_ref], 1) + offsets_att = self.Cross_MFEs[i](input_feat) + cross_att_maps = F.softmax( + offsets_att[:, self.k * 2:, :, :], dim=1) + offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape( + -1, 2, H, W)) + if last_multi_cross_offsets is not None: + offsets = F.grid_sample( + last_multi_cross_offsets, + offsets, + mode='bilinear', + padding_mode='border') + else: + offsets = offsets.permute(0, 3, 1, 2) + last_multi_cross_offsets = offsets + att_source_feat = DAWarp(feat_source, last_multi_cross_offsets, + cross_att_maps, self.k, self.out_ch) + + # Self-MFE + input_feat = torch.cat([att_source_feat, att_reference_feat], 1) + offsets_att = self.Self_MFEs[i](input_feat) + self_att_maps = F.softmax(offsets_att[:, self.k * 2:, :, :], dim=1) + offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape( + -1, 2, H, W)) + if last_multi_self_offsets is not None: + offsets = F.grid_sample( + last_multi_self_offsets, + offsets, + mode='bilinear', + padding_mode='border') + else: + offsets = offsets.permute(0, 3, 1, 2) + last_multi_self_offsets = offsets + att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets, + self_att_maps, self.k, self.out_ch) + + # Refine-MFE + input_feat = torch.cat([att_source_feat, att_reference_feat], 1) + offsets_att = self.Refine_MFEs[i](input_feat) + att_maps = F.softmax(offsets_att[:, self.k * 4:, :, :], dim=1) + cross_offsets = apply_offset( + offsets_att[:, :self.k * 2, :, :].reshape(-1, 2, H, W)) + self_offsets = apply_offset( + offsets_att[:, + self.k * 2:self.k * 4, :, :].reshape(-1, 2, H, W)) + last_multi_cross_offsets = F.grid_sample( + last_multi_cross_offsets, + cross_offsets, + mode='bilinear', + padding_mode='border') + last_multi_self_offsets = F.grid_sample( + last_multi_self_offsets, + self_offsets, + mode='bilinear', + padding_mode='border') + + # Upsampling + last_multi_cross_offsets = F.interpolate( + last_multi_cross_offsets, scale_factor=2, mode='bilinear') + last_multi_self_offsets = F.interpolate( + last_multi_self_offsets, scale_factor=2, mode='bilinear') + self_att_maps = F.interpolate( + att_maps[:, :self.k, :, :], scale_factor=2, mode='bilinear') + cross_att_maps = F.interpolate( + att_maps[:, self.k:, :, :], scale_factor=2, mode='bilinear') + + # Post-DAWarp for source and reference images + if return_all: + cur_source_image = F.interpolate( + source_image, (H * 2, W * 2), mode='bilinear') + cur_reference_image = F.interpolate( + reference_image, (H * 2, W * 2), mode='bilinear') + if use_light_en_de: + cur_source_image = self.lights_encoder(cur_source_image) + cur_reference_image = self.lights_encoder( + cur_reference_image) + # the feat dim in light encoder is 64 + warp_att_source_image = DAWarp(cur_source_image, + last_multi_cross_offsets, + cross_att_maps, self.k, 64) + warp_att_reference_image = DAWarp(cur_reference_image, + last_multi_self_offsets, + self_att_maps, self.k, + 64) + result_tryon = self.lights_decoder( + warp_att_source_image + warp_att_reference_image) + else: + warp_att_source_image = DAWarp(cur_source_image, + last_multi_cross_offsets, + cross_att_maps, self.k, 3) + warp_att_reference_image = DAWarp(cur_reference_image, + last_multi_self_offsets, + self_att_maps, self.k, 3) + result_tryon = warp_att_source_image + warp_att_reference_image + results_all.append(result_tryon) + + last_multi_self_offsets = F.interpolate( + last_multi_self_offsets, + reference_image.size()[2:], + mode='bilinear') + last_multi_cross_offsets = F.interpolate( + last_multi_cross_offsets, source_image.size()[2:], mode='bilinear') + self_att_maps = F.interpolate( + self_att_maps, reference_image.size()[2:], mode='bilinear') + cross_att_maps = F.interpolate( + cross_att_maps, source_image.size()[2:], mode='bilinear') + if use_light_en_de: + source_image = self.lights_encoder(source_image) + reference_image = self.lights_encoder(reference_image) + warp_att_source_image = DAWarp(source_image, + last_multi_cross_offsets, + cross_att_maps, self.k, 64) + warp_att_reference_image = DAWarp(reference_image, + last_multi_self_offsets, + self_att_maps, self.k, 64) + result_tryon = self.lights_decoder(warp_att_source_image + + warp_att_reference_image) + else: + warp_att_source_image = DAWarp(source_image, + last_multi_cross_offsets, + cross_att_maps, self.k, 3) + warp_att_reference_image = DAWarp(reference_image, + last_multi_self_offsets, + self_att_maps, self.k, 3) + result_tryon = warp_att_source_image + warp_att_reference_image + + if return_all: + return result_tryon, return_all + return result_tryon + + +class SDAFNet_Tryon(nn.Module): + + def __init__(self, ref_in_channel, source_in_channel=3, head_nums=6): + super(SDAFNet_Tryon, self).__init__() + num_filters = [64, 128, 256, 256, 256] + self.source_features = FeatureEncoder(source_in_channel, num_filters) + self.reference_features = FeatureEncoder(ref_in_channel, num_filters) + self.source_FPN = RefinePyramid(num_filters) + self.reference_FPN = RefinePyramid(num_filters) + self.dafnet = DAFlowNet(len(num_filters), head_nums=head_nums) + + def forward(self, + ref_input, + source_image, + ref_image, + use_light_en_de=True, + return_all=False, + warp_feature=True): + reference_feats = self.reference_FPN( + self.reference_features(ref_input)) + source_feats = self.source_FPN(self.source_features(source_image)) + result = self.dafnet( + source_image, + ref_image, + source_feats, + reference_feats, + use_light_en_de=use_light_en_de, + return_all=return_all, + warp_feature=warp_feature) + return result diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py new file mode 100644 index 00000000..0053da43 --- /dev/null +++ b/modelscope/models/multi_modal/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .clip import CLIPForMultiModalEmbedding + from .gemm import GEMMForMultiModalEmbedding + from .team import TEAMForMultiModalSimilarity + from .diffusion import DiffusionForTextToImageSynthesis + from .mmr import VideoCLIPForMultiModalEmbedding + from .mplug_for_all_tasks import MPlugForAllTasks + from .ofa_for_all_tasks import OfaForAllTasks + from .ofa_for_text_to_image_synthesis_model import \ + OfaForTextToImageSynthesis + from .multi_stage_diffusion import \ + MultiStageDiffusionForTextToImageSynthesis + +else: + _import_structure = { + 'clip': ['CLIPForMultiModalEmbedding'], + 'diffusion': ['DiffusionForTextToImageSynthesis'], + 'gemm': ['GEMMForMultiModalEmbedding'], + 'team': ['TEAMForMultiModalSimilarity'], + 'mmr': ['VideoCLIPForMultiModalEmbedding'], + 'mplug_for_all_tasks': ['MPlugForAllTasks'], + 'ofa_for_all_tasks': ['OfaForAllTasks'], + 'ofa_for_text_to_image_synthesis_model': + ['OfaForTextToImageSynthesis'], + 'multi_stage_diffusion': + ['MultiStageDiffusionForTextToImageSynthesis'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/clip/__init__.py b/modelscope/models/multi_modal/clip/__init__.py new file mode 100644 index 00000000..e2e925ce --- /dev/null +++ b/modelscope/models/multi_modal/clip/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .model import CLIPForMultiModalEmbedding diff --git a/modelscope/models/multi_modal/clip/bert_tokenizer.py b/modelscope/models/multi_modal/clip/bert_tokenizer.py new file mode 100644 index 00000000..8d356f42 --- /dev/null +++ b/modelscope/models/multi_modal/clip/bert_tokenizer.py @@ -0,0 +1,422 @@ +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import, division, print_function +import collections +import os +import re +import unicodedata + +import six + + +def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match('^.*?([A-Za-z0-9_-]+)/bert_model.ckpt', init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + 'uncased_L-24_H-1024_A-16', 'uncased_L-12_H-768_A-12', + 'multilingual_L-12_H-768_A-12', 'chinese_L-12_H-768_A-12' + ] + + cased_models = [ + 'cased_L-12_H-768_A-12', 'cased_L-24_H-1024_A-16', + 'multi_cased_L-12_H-768_A-12' + ] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = 'False' + case_name = 'lowercased' + opposite_flag = 'True' + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = 'True' + case_name = 'cased' + opposite_flag = 'False' + + if is_bad_config: + raise ValueError( + 'You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. ' + 'However, `%s` seems to be a %s model, so you ' + 'should pass in `--do_lower_case=%s` so that the fine-tuning matches ' + 'how the model was pre-training. If this error is wrong, please ' + 'just comment out this check.' % + (actual_flag, init_checkpoint, model_name, case_name, + opposite_flag)) + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode('utf-8', 'ignore') + elif isinstance(text, unicode): + return text + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode('utf-8') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, 'r') as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """ Converts a sequence of tokens (string) in a single string. """ + + def clean_up_tokenization(out_string): + """ Clean up a list of simple English tokenization artifacts + like spaces before punctuations and abreviated forms. + """ + out_string = ( + out_string.replace(' .', '.').replace(' ?', '?').replace( + ' !', '!').replace(' ,', ',').replace(" ' ", "'").replace( + " n't", "n't").replace(" 'm", "'m").replace( + " 's", "'s").replace(" 've", + "'ve").replace(" 're", "'re")) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(' '.join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize('NFD', text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == 'Mn': + continue + output.append(char) + return ''.join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return [''.join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(' ') + output.append(char) + output.append(' ') + else: + output.append(char) + return ''.join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F)): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(' ') + else: + output.append(char) + return ''.join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenziation.""" + + def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = ''.join(chars[start:end]) + if start > 0: + substr = '##' + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == ' ' or char == '\t' or char == '\n' or char == '\r': + return True + cat = unicodedata.category(char) + if cat == 'Zs': + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == '\t' or char == '\n' or char == '\r': + return False + cat = unicodedata.category(char) + if cat in ('Cc', 'Cf'): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) + or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith('P'): + return True + return False diff --git a/modelscope/models/multi_modal/clip/configuration_bert.py b/modelscope/models/multi_modal/clip/configuration_bert.py new file mode 100644 index 00000000..b75f5db8 --- /dev/null +++ b/modelscope/models/multi_modal/clip/configuration_bert.py @@ -0,0 +1,82 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BERT model configuration """ + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import logging + +logger = logging.getLogger(__name__) + + +class BertConfig(object): + r""" + :class:`~transformers.BertConfig` is the configuration class to store the configuration of a + `BertModel`. + + + Arguments: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_eps: The epsilon used by LayerNorm. + """ + + def __init__(self, + vocab_size_or_config_json_file=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + output_attentions=False, + output_hidden_states=False): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states diff --git a/modelscope/models/multi_modal/clip/model.py b/modelscope/models/multi_modal/clip/model.py new file mode 100644 index 00000000..9b82e4a1 --- /dev/null +++ b/modelscope/models/multi_modal/clip/model.py @@ -0,0 +1,606 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import OrderedDict +from typing import Any, Dict, Tuple, Union + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer +from modelscope.models.multi_modal.clip.configuration_bert import BertConfig +from modelscope.models.multi_modal.clip.modeling_bert import BertModel +from modelscope.utils.constant import ModeKeys, ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['CLIPForMultiModalEmbedding'] + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisualTransformer(nn.Module): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ # noqa + self.class_embedding.to(x.dtype) + torch.zeros( # noqa + x.shape[0], + 1, + x.shape[-1], + dtype=x.dtype, + device=x.device), + x # noqa + ], + dim=1) # noqa shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + vocab_size: int, + text_attention_probs_dropout_prob: float, + text_hidden_act: str, + text_hidden_dropout_prob: float, + text_hidden_size: int, + text_initializer_range: float, + text_intermediate_size: int, + text_max_position_embeddings: int, + text_num_attention_heads: int, + text_num_hidden_layers: int, + text_type_vocab_size: int, + tokenizer: FullTokenizer, + # vision_head_width, added this param for ViT-H + vision_head_width: int = 64, + ): + super().__init__() + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // vision_head_width + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // vision_head_width + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim) + + self.bert_config = BertConfig( + vocab_size_or_config_json_file=vocab_size, + hidden_size=text_hidden_size, + num_hidden_layers=text_num_hidden_layers, + num_attention_heads=text_num_attention_heads, + intermediate_size=text_intermediate_size, + hidden_act=text_hidden_act, + hidden_dropout_prob=text_hidden_dropout_prob, + attention_probs_dropout_prob=text_attention_probs_dropout_prob, + max_position_embeddings=text_max_position_embeddings, + type_vocab_size=text_type_vocab_size, + initializer_range=text_initializer_range, + layer_norm_eps=1e-12, + ) + self.bert = BertModel(self.bert_config) + + self.text_projection = nn.Parameter( + torch.empty(text_hidden_size, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.tokenizer = tokenizer + + self.initialize_parameters() + + def initialize_parameters(self): + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, self.visual.layer2, self.visual.layer3, + self.visual.layer4 + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith('bn3.weight'): + nn.init.zeros_(param) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.bert_config.hidden_size**-0.5) + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + pad_index = self.tokenizer.vocab['[PAD]'] + attn_mask = text.ne(pad_index).type(self.dtype) + x = self.bert( + text, attention_mask=attn_mask)[0].type( + self.dtype) # [batch_size, seq_length, hidden_size] + return x[:, 0, :] @ self.text_projection + + def forward(self, image, text): + assert image is not None or text is not None, 'text and image cannot both be None!' + + if image is None: + return self.encode_text(text) + elif text is None: + return self.encode_image(image) + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + return image_features, text_features, self.logit_scale.exp() + + def get_similarity(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm( + dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_models_to_fp32(model): + for p in model.parameters(): + p.data = p.data.float() + if p.grad: + p.grad.data = p.grad.data.float() + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(module): + if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Linear)): + module.weight.data = module.weight.data.half() + if module.bias is not None: + module.bias.data = module.bias.data.half() + + if isinstance(module, nn.MultiheadAttention): + for attr in [ + *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], + 'in_proj_bias', 'bias_k', 'bias_v' + ]: + tensor = getattr(module, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + if isinstance(module, BertModel): + module.to(torch.half) + + for name in ['text_projection', 'proj']: + if hasattr(module, name): + attr = getattr(module, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +@MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) +class CLIPForMultiModalEmbedding(TorchModel): + + def __init__(self, model_dir, device_id=-1): + super().__init__(model_dir=model_dir, device_id=device_id) + + # Initialize the model. + vision_model_config_file = '{}/vision_model_config.json'.format( + model_dir) + logger.info( + f'Loading vision model config from {vision_model_config_file}') + assert os.path.exists(vision_model_config_file) + + text_model_config_file = '{}/text_model_config.json'.format(model_dir) + logger.info(f'Loading text model config from {text_model_config_file}') + assert os.path.exists(text_model_config_file) + + with open(vision_model_config_file, + 'r') as fv, open(text_model_config_file, 'r') as ft: + self.model_info = json.load(fv) + for k, v in json.load(ft).items(): + self.model_info[k] = v + + vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' + self.tokenizer = FullTokenizer(vocab_file=vocab_file) + + # initialize the model + self.clip_model = CLIP(**self.model_info, tokenizer=self.tokenizer) + convert_weights(self.clip_model) + + # restore the pretrained weight + checkpoint = torch.load( + f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu') + sd = checkpoint[ + 'state_dict'] if 'state_dict' in checkpoint else checkpoint + if next(iter(sd.items()))[0].startswith('module'): + sd = {k[len('module.'):]: v for k, v in sd.items()} + # support the finetuned model + if next(iter(sd.items()))[0].startswith('clip_model'): + sd = {k[len('clip_model.'):]: v for k, v in sd.items()} + self.clip_model.load_state_dict(sd) + self.clip_model.eval() + + # place the model + self.device = 'cuda:{}'.format(int(os.environ.get( + 'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu' + if torch.cuda.is_available(): + self.clip_model.to(self.device) + logger.info('Use GPU {} for finetuning & inference'.format( + int(os.environ.get('LOCAL_RANK', 0)))) + else: + self.clip_model.float() + logger.info('Use CPU for finetuning & inference') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + from modelscope.outputs import OutputKeys + output = { + OutputKeys.IMG_EMBEDDING: None, + OutputKeys.TEXT_EMBEDDING: None + } + mode = input.get('mode', ModeKeys.INFERENCE) + + # encode the image + if 'img' in input and isinstance(input['img'], torch.Tensor): + image_tensor = input['img'].to(self.device) + if image_tensor.dim() == 5 and image_tensor.shape[1] == 1: + image_tensor = image_tensor.squeeze(1) + + with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): + image_features = self.clip_model.encode_image(image_tensor) + image_features /= image_features.norm( + dim=-1, keepdim=True) # l2-normalize + + output[OutputKeys.IMG_EMBEDDING] = image_features + + if 'text' in input and isinstance(input['text'], torch.Tensor): + text_tensor = input['text'].to(self.device) + if text_tensor.dim() == 3 and text_tensor.shape[1] == 1: + text_tensor = text_tensor.squeeze(1) + + with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN): + text_features = self.clip_model.encode_text(text_tensor) + text_features /= text_features.norm( + dim=-1, keepdim=True) # l2-normalize + output[OutputKeys.TEXT_EMBEDDING] = text_features + + if mode == ModeKeys.TRAIN: + output['logit_scale'] = (self.clip_model.logit_scale + * 1.0).exp().mean() + + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + @property + def temperature(self): + return 1.0 / self.clip_model.logit_scale.exp() diff --git a/modelscope/models/multi_modal/clip/modeling_bert.py b/modelscope/models/multi_modal/clip/modeling_bert.py new file mode 100644 index 00000000..b5f104ce --- /dev/null +++ b/modelscope/models/multi_modal/clip/modeling_bert.py @@ -0,0 +1,507 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model. """ + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import logging +import math +import os +import sys +from io import open + +import json +import torch +from torch import nn + +from .configuration_bert import BertConfig + +logger = logging.getLogger(__name__) + + +def gelu(x): + """ Original Implementation of the gelu activation function in Google Bert repo when initially created. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def gelu_new(x): + """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). + Also see https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1 + torch.tanh( + math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = { + 'gelu': gelu, + 'relu': torch.nn.functional.relu, + 'swish': swish, + 'gelu_new': gelu_new +} + +BertLayerNorm = torch.nn.LayerNorm + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if self.output_attentions else ( + context_layer, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def forward(self, input_tensor, attention_mask=None, head_mask=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, + str) or (sys.version_info[0] == 2 + and isinstance(config.hidden_act, unicode)): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + attention_outputs = self.attention(hidden_states, attention_mask, + head_mask) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output, ) + attention_outputs[ + 1:] # add attentions if we output them + return outputs + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super(BertEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None, head_mask=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_outputs = layer_module(hidden_states, attention_mask, + head_mask[i]) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + outputs = (hidden_states, ) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if self.output_attentions: + outputs = outputs + (all_attentions, ) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, + str) or (sys.version_info[0] == 2 + and isinstance(config.hidden_act, unicode)): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(nn.Module): + config_class = BertConfig + base_model_prefix = 'bert' + + def __init__(self, config): + super(BertPreTrainedModel, self).__init__() + self.config = config + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the output of the last layer of the model. + **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during Bert pretraining. This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) + of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, + used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertModel.from_pretrained('bert-base-uncased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + + def __init__(self, config): + super(BertModel, self).__init__(config) + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + + self.apply(self._init_weights) + + def forward(self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze( + -1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, + -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters( + )).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) + encoder_outputs = self.encoder( + embedding_output, extended_attention_mask, head_mask=head_mask) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = ( + sequence_output, + pooled_output, + ) + encoder_outputs[ + 1:] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) diff --git a/modelscope/models/multi_modal/diffusion/__init__.py b/modelscope/models/multi_modal/diffusion/__init__.py new file mode 100644 index 00000000..e7e374b6 --- /dev/null +++ b/modelscope/models/multi_modal/diffusion/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from .model import DiffusionForTextToImageSynthesis diff --git a/modelscope/models/multi_modal/diffusion/diffusion.py b/modelscope/models/multi_modal/diffusion/diffusion.py new file mode 100644 index 00000000..bfe7baf7 --- /dev/null +++ b/modelscope/models/multi_modal/diffusion/diffusion.py @@ -0,0 +1,598 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def kl_divergence(mu1, logvar1, mu2, logvar2): + a = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + b = ((mu1 - mu2)**2) * torch.exp(-logvar2) + return 0.5 * (a + b) + + +def standard_normal_cdf(x): + return 0.5 * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs + + +def _i(tensor, t, x): + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def cosine_fn(u): + return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + betas.append(min(1.0 - cosine_fn(t2) / cosine_fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i( + self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + shape = (-1, ) + ((1, ) * (xt.ndim - 1)) + mask = t.ne(0).float().view(*shape) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + assert self.mean_type == 'eps' + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + a = u_out[:, :3] + b = guide_scale * (y_out[:, :3] - u_out[:, :3]) + c = y_out[:, 3:] + out = torch.cat([a + b, c], dim=1) + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + a = (1 - alphas_prev) / (1 - alphas) + b = (1 - alphas / alphas_prev) + sigmas = eta * torch.sqrt(a * b) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b, c, h, w = x0.size() + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 + ).abs().flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # prepare input and output + b, c, h, w = x0.size() + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/multi_modal/diffusion/model.py b/modelscope/models/multi_modal/diffusion/model.py new file mode 100644 index 00000000..4229391f --- /dev/null +++ b/modelscope/models/multi_modal/diffusion/model.py @@ -0,0 +1,256 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models import Model +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.diffusion.diffusion import ( + GaussianDiffusion, beta_schedule) +from modelscope.models.multi_modal.diffusion.structbert import (BertConfig, + BertModel) +from modelscope.models.multi_modal.diffusion.tokenizer import FullTokenizer +from modelscope.models.multi_modal.diffusion.unet_generator import \ + DiffusionGenerator +from modelscope.models.multi_modal.diffusion.unet_upsampler_256 import \ + SuperResUNet256 +from modelscope.models.multi_modal.diffusion.unet_upsampler_1024 import \ + SuperResUNet1024 +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['DiffusionForTextToImageSynthesis'] + + +def make_diffusion(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None, + var_type='fixed_small'): + betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta) + diffusion = GaussianDiffusion(betas, var_type=var_type) + return diffusion + + +class Tokenizer(object): + + def __init__(self, vocab_file, seq_len=64): + self.vocab_file = vocab_file + self.seq_len = seq_len + self.tokenizer = FullTokenizer( + vocab_file=vocab_file, do_lower_case=True) + + def __call__(self, text): + # tokenization + tokens = self.tokenizer.tokenize(text) + tokens = ['[CLS]'] + tokens[:self.seq_len - 2] + ['[SEP]'] + input_ids = self.tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + segment_ids = [0] * len(input_ids) + + # padding + input_ids += [0] * (self.seq_len - len(input_ids)) + input_mask += [0] * (self.seq_len - len(input_mask)) + segment_ids += [0] * (self.seq_len - len(segment_ids)) + assert len(input_ids) == len(input_mask) == len( + segment_ids) == self.seq_len + + # convert to tensors + input_ids = torch.LongTensor(input_ids) + input_mask = torch.LongTensor(input_mask) + segment_ids = torch.LongTensor(segment_ids) + return input_ids, segment_ids, input_mask + + +class DiffusionModel(nn.Module): + + def __init__(self, model_dir): + super(DiffusionModel, self).__init__() + # including text and generator config + model_config = json.load( + open('{}/model_config.json'.format(model_dir))) + + # text encoder + text_config = model_config['text_config'] + self.text_encoder = BertModel(BertConfig.from_dict(text_config)) + + # generator (64x64) + generator_config = model_config['generator_config'] + self.unet_generator = DiffusionGenerator(**generator_config) + + # upsampler (256x256) + upsampler_256_config = model_config['upsampler_256_config'] + self.unet_upsampler_256 = SuperResUNet256(**upsampler_256_config) + + # upsampler (1024x1024) + upsampler_1024_config = model_config['upsampler_1024_config'] + self.unet_upsampler_1024 = SuperResUNet1024(**upsampler_1024_config) + + def forward(self, noise, timesteps, input_ids, token_type_ids, + attention_mask): + context, y = self.text_encoder( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) + context = context[-1] + x = self.unet_generator(noise, timesteps, y, context, attention_mask) + x = self.unet_upsampler_256(noise, timesteps, x, + torch.zeros_like(timesteps), y, context, + attention_mask) + x = self.unet_upsampler_1024(x, t, x) + return x + + +@MODELS.register_module( + Tasks.text_to_image_synthesis, module_name=Models.diffusion) +class DiffusionForTextToImageSynthesis(Model): + + def __init__(self, model_dir, device_id=-1): + super().__init__(model_dir=model_dir, device_id=device_id) + diffusion_model = DiffusionModel(model_dir=model_dir) + pretrained_params = torch.load( + osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu') + diffusion_model.load_state_dict(pretrained_params) + diffusion_model.eval() + + self.device_id = device_id + if self.device_id >= 0: + self.device = torch.device(f'cuda:{self.device_id}') + diffusion_model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device = torch.device('cpu') + logger.info('Use CPU for inference') + + # modules + self.text_encoder = diffusion_model.text_encoder + self.unet_generator = diffusion_model.unet_generator + self.unet_upsampler_256 = diffusion_model.unet_upsampler_256 + self.unet_upsampler_1024 = diffusion_model.unet_upsampler_1024 + + # text tokenizer + vocab_path = f'{model_dir}/{ModelFile.VOCAB_FILE}' + self.tokenizer = Tokenizer(vocab_file=vocab_path, seq_len=64) + + # diffusion process + diffusion_params = json.load( + open('{}/diffusion_config.json'.format(model_dir))) + self.diffusion_generator = make_diffusion( + **diffusion_params['generator_config']) + self.diffusion_upsampler_256 = make_diffusion( + **diffusion_params['upsampler_256_config']) + self.diffusion_upsampler_1024 = make_diffusion( + **diffusion_params['upsampler_1024_config']) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if not all([key in input for key in ('text', 'noise', 'timesteps')]): + raise ValueError( + f'input should contains "text", "noise", and "timesteps", but got {input.keys()}' + ) + input_ids, token_type_ids, attention_mask = self.tokenizer( + input['text']) + input_ids = input_ids.to(self.device).unsqueeze(0) + token_type_ids = token_type_ids.to(self.device).unsqueeze(0) + attention_mask = attention_mask.to(self.device).unsqueeze(0) + context, y = self.text_encoder( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) + context = context[-1] + x = self.unet_generator(noise, timesteps, y, context, attention_mask) + x = self.unet_upsampler_256(noise, timesteps, x, + torch.zeros_like(timesteps), y, context, + attention_mask) + x = self.unet_upsampler_1024(x, t, x) + img = x.clamp(-1, 1).add(1).mul(127.5) + img = img.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8) + return img + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + @torch.no_grad() + def generate(self, input: Dict[str, Any]) -> Dict[str, Any]: + if 'text' not in input: + raise ValueError( + f'input should contain "text", but got {input.keys()}') + + # encode text + input_ids, token_type_ids, attention_mask = self.tokenizer( + input['text']) + input_ids = input_ids.to(self.device).unsqueeze(0) + token_type_ids = token_type_ids.to(self.device).unsqueeze(0) + attention_mask = attention_mask.to(self.device).unsqueeze(0) + context, y = self.text_encoder( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) + context = context[-1] + + # generation + img = self.diffusion_generator.ddim_sample_loop( + noise=torch.randn(1, 3, 64, 64).to(self.device), + model=self.unet_generator, + model_kwargs=[{ + 'y': y, + 'context': context, + 'mask': attention_mask + }, { + 'y': torch.zeros_like(y), + 'context': torch.zeros_like(context), + 'mask': attention_mask + }], + percentile=input.get('generator_percentile', 0.995), + guide_scale=input.get('generator_guide_scale', 5.0), + ddim_timesteps=input.get('generator_ddim_timesteps', 250), + eta=input.get('generator_ddim_eta', 0.0)) + + # upsampling (64->256) + if not input.get('debug', False): + img = F.interpolate( + img, scale_factor=4.0, mode='bilinear', align_corners=False) + img = self.diffusion_upsampler_256.ddim_sample_loop( + noise=torch.randn_like(img), + model=self.unet_upsampler_256, + model_kwargs=[{ + 'lx': img, + 'lt': torch.zeros(1).to(self.device), + 'y': y, + 'context': context, + 'mask': attention_mask + }, { + 'lx': img, + 'lt': torch.zeros(1).to(self.device), + 'y': torch.zeros_like(y), + 'context': torch.zeros_like(context), + 'mask': torch.zeros_like(attention_mask) + }], + percentile=input.get('upsampler_256_percentile', 0.995), + guide_scale=input.get('upsampler_256_guide_scale', 5.0), + ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50), + eta=input.get('upsampler_256_ddim_eta', 0.0)) + + # upsampling (256->1024) + if not input.get('debug', False): + img = F.interpolate( + img, scale_factor=4.0, mode='bilinear', align_corners=False) + img = self.diffusion_upsampler_1024.ddim_sample_loop( + noise=torch.randn_like(img), + model=self.unet_upsampler_1024, + model_kwargs={'concat': img}, + percentile=input.get('upsampler_1024_percentile', 0.995), + ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20), + eta=input.get('upsampler_1024_ddim_eta', 0.0)) + + # output + img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute( + 1, 2, 0).cpu().numpy().astype(np.uint8) + return img diff --git a/modelscope/models/multi_modal/diffusion/structbert.py b/modelscope/models/multi_modal/diffusion/structbert.py new file mode 100644 index 00000000..d5d678ed --- /dev/null +++ b/modelscope/models/multi_modal/diffusion/structbert.py @@ -0,0 +1,936 @@ +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team and Alibaba inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import absolute_import, division, print_function +import copy +import math + +import json +import numpy as np +import six +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + + def __init__(self, + vocab_size, + hidden_size=768, + emb_size=-1, + num_hidden_layers=12, + transformer_type='original', + transition_function='linear', + weighted_transformer=0, + num_rolled_layers=3, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + attention_type='self', + rezero=False, + pre_ln=False, + squeeze_excitation=False, + transfer_matrix=False, + dim_dropout=False, + roberta_style=False, + set_mask_zero=False, + init_scale=False, + safer_fp16=False, + grad_checkpoint=False): + """Constructs BertConfig. + + Args: + vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.emb_size = emb_size + self.num_hidden_layers = num_hidden_layers + self.transformer_type = transformer_type + self.transition_function = transition_function + self.weighted_transformer = weighted_transformer + self.num_rolled_layers = num_rolled_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.attention_type = attention_type + self.rezero = rezero + self.pre_ln = pre_ln + self.squeeze_excitation = squeeze_excitation + self.transfer_matrix = transfer_matrix + self.dim_dropout = dim_dropout + self.set_mask_zero = set_mask_zero + self.roberta_style = roberta_style + self.init_scale = init_scale + self.safer_fp16 = safer_fp16 + self.grad_checkpoint = grad_checkpoint + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size=None) + for (key, value) in six.iteritems(json_object): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, 'r') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' + + +class BERTLayerNorm(nn.Module): + + def __init__(self, config, variance_epsilon=1e-12, special_size=None): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BERTLayerNorm, self).__init__() + self.config = config + hidden_size = special_size if special_size is not None else config.hidden_size + self.gamma = nn.Parameter(torch.ones(hidden_size)) + self.beta = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = variance_epsilon if not config.roberta_style else 1e-5 + + def forward(self, x): + previous_type = x.type() + if self.config.safer_fp16: + x = x.float() + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + if self.config.safer_fp16: + return (self.gamma * x + self.beta).type(previous_type) + else: + return self.gamma * x + self.beta + + +class BERTEmbeddings(nn.Module): + + def __init__(self, config): + super(BERTEmbeddings, self).__init__() + """Construct the embedding module from word, position and token_type embeddings. + """ + hidden_size = config.hidden_size if config.emb_size < 0 else config.emb_size + self.word_embeddings = nn.Embedding( + config.vocab_size, + hidden_size, + padding_idx=1 if config.roberta_style else None) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + hidden_size, + padding_idx=1 if config.roberta_style else None) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + hidden_size) + self.config = config + self.proj = None if config.emb_size < 0 else nn.Linear( + config.emb_size, config.hidden_size) + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BERTLayerNorm(config, special_size=hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, adv_embedding=None): + seq_length = input_ids.size(1) + if not self.config.roberta_style: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + else: + mask = input_ids.ne(1).int() + position_ids = (torch.cumsum(mask, dim=1).type_as(mask) + * mask).long() + 1 + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings( + input_ids) if adv_embedding is None else adv_embedding + if self.config.set_mask_zero: + words_embeddings[input_ids == 103] = 0. + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + if not self.config.roberta_style: + embeddings = words_embeddings + position_embeddings + token_type_embeddings + else: + embeddings = words_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + if self.proj is not None: + embeddings = self.proj(embeddings) + embeddings = self.dropout(embeddings) + else: + return embeddings, words_embeddings + + +class BERTFactorizedAttention(nn.Module): + + def __init__(self, config): + super(BERTFactorizedAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, *size): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(size) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer, 0, 2, 3, 1) + key_layer = self.transpose_for_scores(mixed_key_layer, 0, 2, 1, 3) + value_layer = self.transpose_for_scores(mixed_value_layer, 0, 2, 1, 3) + + s_attention_scores = query_layer + attention_mask + s_attention_probs = nn.Softmax(dim=-1)(s_attention_scores) + s_attention_probs = self.dropout(s_attention_probs) + + c_attention_probs = nn.Softmax(dim=-1)(key_layer) + s_context_layer = torch.matmul(s_attention_probs, value_layer) + context_layer = torch.matmul(c_attention_probs, s_context_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +def dim_dropout(x, p=0, dim=-1, training=False): + if not training or p == 0: + return x + a = (1 - p) + b = (x.data.new(x.size()).zero_() + 1) + dropout_mask = torch.bernoulli(a * b) + return dropout_mask * (dropout_mask.size(dim) / torch.sum( + dropout_mask, dim=dim, keepdim=True)) * x + + +class BERTSelfAttention(nn.Module): + + def __init__(self, config): + super(BERTSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.config = config + if config.pre_ln: + self.LayerNorm = BERTLayerNorm(config) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, head_mask=None): + if self.config.pre_ln: + hidden_states = self.LayerNorm(hidden_states) + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + if head_mask is not None and not self.training: + for i, mask in enumerate(head_mask): + if head_mask[i] == 1: + attention_scores[:, i, :, :] = 0. + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + if not self.config.dim_dropout: + attention_probs = self.dropout(attention_probs) + else: + attention_probs = dim_dropout( + attention_probs, + p=self.config.attention_probs_dropout_prob, + dim=-1, + training=self.training) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BERTSelfOutput(nn.Module): + + def __init__(self, config): + super(BERTSelfOutput, self).__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if not config.pre_ln and not config.rezero: + self.LayerNorm = BERTLayerNorm(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.rezero: + self.res_factor = nn.Parameter( + torch.Tensor(1).fill_(0.99).to( + dtype=next(self.parameters()).dtype)) + self.factor = nn.Parameter( + torch.ones(1).to(dtype=next(self.parameters()).dtype)) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + if not self.config.rezero and not self.config.pre_ln: + hidden_states = self.LayerNorm(hidden_states + input_tensor) + elif self.config.rezero: + hidden_states = hidden_states + self.factor * input_tensor + else: + pass + return hidden_states + + +class BERTAttention(nn.Module): + + def __init__(self, config): + super(BERTAttention, self).__init__() + if config.attention_type.lower() == 'self': + self.self = BERTSelfAttention(config) + elif config.attention_type.lower() == 'factorized': + self.self = BERTFactorizedAttention(config) + else: + raise ValueError( + 'Attention type must in [self, factorized], but got {}'.format( + config.attention_type)) + self.output = BERTSelfOutput(config) + + def forward(self, input_tensor, attention_mask, head_mask=None): + self_output = self.self(input_tensor, attention_mask, head_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class DepthwiseSeparableConv1d(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=False): + super(DepthwiseSeparableConv1d, self).__init__() + padding = (kernel_size - 1) // 2 + self.depthwise = nn.Conv1d( + in_channels, + in_channels, + kernel_size, + stride, + padding, + dilation, + groups=in_channels, + bias=bias) + self.pointwise = nn.Conv1d( + in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + return x + + +class BERTIntermediate(nn.Module): + + def __init__(self, config): + super(BERTIntermediate, self).__init__() + self.config = config + if self.config.pre_ln: + self.LayerNorm = BERTLayerNorm(config) + self.intermediate_act_fn = gelu + if config.transition_function.lower() == 'linear': + self.dense = nn.Linear(config.hidden_size, + config.intermediate_size) + elif config.transition_function.lower() == 'cnn': + self.cnn = DepthwiseSeparableConv1d( + config.hidden_size, 4 * config.hidden_size, kernel_size=7) + elif config.config.hidden_size.lower() == 'rnn': + raise NotImplementedError( + 'rnn transition function is not implemented yet') + else: + raise ValueError('Only support linear/cnn/rnn') + + def forward(self, hidden_states): + if self.config.pre_ln: + hidden_states = self.LayerNorm(hidden_states) + if self.config.transition_function.lower() == 'linear': + hidden_states = self.dense(hidden_states) + elif self.config.transition_function.lower() == 'cnn': + hidden_states = self.cnn(hidden_states.transpose(-1, + -2)).transpose( + -1, -2) + else: + pass + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SqueezeExcitationBlock(nn.Module): + + def __init__(self, config): + super(SqueezeExcitationBlock, self).__init__() + self.down_sampling = nn.Linear(config.hidden_size, + config.hidden_size // 4) + self.up_sampling = nn.Linear(config.hidden_size // 4, + config.hidden_size) + + def forward(self, hidden_states): + squeeze = torch.mean(hidden_states, 1, keepdim=True) + excitation = torch.sigmoid( + self.up_sampling(gelu(self.down_sampling(squeeze)))) + return hidden_states * excitation + + +class BERTOutput(nn.Module): + + def __init__(self, config): + super(BERTOutput, self).__init__() + self.config = config + if config.transition_function.lower() == 'linear': + self.dense = nn.Linear(config.intermediate_size, + config.hidden_size) + elif config.transition_function.lower() == 'cnn': + self.cnn = DepthwiseSeparableConv1d( + 4 * config.hidden_size, config.hidden_size, kernel_size=7) + elif config.config.hidden_size.lower() == 'rnn': + raise NotImplementedError( + 'rnn transition function is not implemented yet') + else: + raise ValueError('Only support linear/cnn/rnn') + if not config.pre_ln and not config.rezero: + self.LayerNorm = BERTLayerNorm(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.squeeze_excitation: + self.SEblock = SqueezeExcitationBlock(config) + if config.rezero: + self.res_factor = nn.Parameter( + torch.Tensor(1).fill_(0.99).to( + dtype=next(self.parameters()).dtype)) + self.factor = nn.Parameter( + torch.ones(1).to(dtype=next(self.parameters()).dtype)) + + def forward(self, hidden_states, input_tensor): + if self.config.transition_function.lower() == 'linear': + hidden_states = self.dense(hidden_states) + elif self.config.transition_function.lower() == 'cnn': + hidden_states = self.cnn(hidden_states.transpose(-1, + -2)).transpose( + -1, -2) + else: + pass + hidden_states = self.dropout(hidden_states) + if self.config.squeeze_excitation: + hidden_states = self.SEblock(hidden_states) + if not self.config.rezero and not self.config.pre_ln: + hidden_states = self.LayerNorm(hidden_states + input_tensor) + elif self.config.rezero: + hidden_states = hidden_states + self.factor * input_tensor + else: + pass + return hidden_states + + +class BERTLayer(nn.Module): + + def __init__(self, config): + super(BERTLayer, self).__init__() + self.attention = BERTAttention(config) + self.intermediate = BERTIntermediate(config) + self.output = BERTOutput(config) + + def forward(self, hidden_states, attention_mask, head_mask=None): + attention_output = self.attention(hidden_states, attention_mask, + head_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return attention_output, layer_output + + +class BERTWeightedLayer(nn.Module): + + def __init__(self, config): + super(BERTWeightedLayer, self).__init__() + self.config = config + self.self = BERTSelfAttention(config) + self.attention_head_size = self.self.attention_head_size + + self.w_o = nn.ModuleList([ + nn.Linear(self.attention_head_size, config.hidden_size) + for _ in range(config.num_attention_heads) + ]) + self.w_kp = torch.rand(config.num_attention_heads) + self.w_kp = nn.Parameter(self.w_kp / self.w_kp.sum()) + self.w_a = torch.rand(config.num_attention_heads) + self.w_a = nn.Parameter(self.w_a / self.w_a.sum()) + + self.intermediate = BERTIntermediate(config) + self.output = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BERTLayerNorm(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, attention_mask): + self_output = self.self(hidden_states, attention_mask) + self_outputs = self_output.split(self.self.attention_head_size, dim=-1) + self_outputs = [ + self.w_o[i](self_outputs[i]) for i in range(len(self_outputs)) + ] + self_outputs = [ + self.dropout(self_outputs[i]) for i in range(len(self_outputs)) + ] + self_outputs = [ + kappa * output for kappa, output in zip(self.w_kp, self_outputs) + ] + self_outputs = [ + self.intermediate(self_outputs[i]) + for i in range(len(self_outputs)) + ] + self_outputs = [ + self.output(self_outputs[i]) for i in range(len(self_outputs)) + ] + self_outputs = [ + self.dropout(self_outputs[i]) for i in range(len(self_outputs)) + ] + self_outputs = [ + alpha * output for alpha, output in zip(self.w_a, self_outputs) + ] + output = sum(self_outputs) + return self.LayerNorm(hidden_states + output) + + +class BERTEncoder(nn.Module): + + def __init__(self, config): + super(BERTEncoder, self).__init__() + self.layer = nn.ModuleList() + for _ in range(config.num_hidden_layers): + if config.weighted_transformer: + self.layer.append(BERTWeightedLayer(config)) + else: + self.layer.append(BERTLayer(config)) + if config.rezero: + for index, layer in enumerate(self.layer): + layer.output.res_factor = nn.Parameter( + torch.Tensor(1).fill_(1.).to( + dtype=next(self.parameters()).dtype)) + layer.output.factor = nn.Parameter( + torch.Tensor(1).fill_(1).to( + dtype=next(self.parameters()).dtype)) + layer.attention.output.res_factor = layer.output.res_factor + layer.attention.output.factor = layer.output.factor + self.config = config + + def forward(self, + hidden_states, + attention_mask, + epoch_id=-1, + head_masks=None): + all_encoder_layers = [hidden_states] + if epoch_id != -1: + detach_index = int(len(self.layer) / 3) * (2 - epoch_id) - 1 + else: + detach_index = -1 + for index, layer_module in enumerate(self.layer): + if head_masks is None: + if not self.config.grad_checkpoint: + self_out, hidden_states = layer_module( + hidden_states, attention_mask, None) + else: + self_out, hidden_states = torch.utils.checkpoint.checkpoint( + layer_module, hidden_states, attention_mask, None) + else: + self_out, hidden_states = layer_module(hidden_states, + attention_mask, + head_masks[index]) + if detach_index == index: + hidden_states.detach_() + all_encoder_layers.append(self_out) + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BERTEncoderRolled(nn.Module): + + def __init__(self, config): + super(BERTEncoderRolled, self).__init__() + layer = BERTLayer(config) + self.config = config + self.layer = nn.ModuleList( + [copy.deepcopy(layer) for _ in range(config.num_rolled_layers)]) + + def forward(self, + hidden_states, + attention_mask, + epoch_id=-1, + head_masks=None): + all_encoder_layers = [hidden_states] + for i in range(self.config.num_hidden_layers): + if self.config.transformer_type.lower() == 'universal': + hidden_states = self.layer[i % self.config.num_rolled_layers]( + hidden_states, attention_mask) + elif self.config.transformer_type.lower() == 'albert': + a = i // ( + self.config.num_hidden_layers + // self.config.num_rolled_layers) + hidden_states = self.layer[a](hidden_states, attention_mask) + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BERTEncoderACT(nn.Module): + + def __init__(self, config): + super(BERTEncoderACT, self).__init__() + self.layer = BERTLayer(config) + p = nn.Linear(config.hidden_size, 1) + self.p = nn.ModuleList( + [copy.deepcopy(p) for _ in range(config.num_hidden_layers)]) + # Following act paper, set bias init ones + for module in self.p: + module.bias.data.fill_(1.) + self.config = config + self.act_max_steps = config.num_hidden_layers + self.threshold = 0.99 + + def should_continue(self, halting_probability, n_updates): + return (halting_probability.lt(self.threshold).__and__( + n_updates.lt(self.act_max_steps))).any() + + def forward(self, hidden_states, attention_mask): + all_encoder_layers = [hidden_states] + batch_size, seq_len, hdim = hidden_states.size() + halting_probability = torch.zeros(batch_size, seq_len).cuda() + remainders = torch.zeros(batch_size, seq_len).cuda() + n_updates = torch.zeros(batch_size, seq_len).cuda() + for i in range(self.act_max_steps): + p = torch.sigmoid(self.p[i](hidden_states).squeeze(2)) + still_running = halting_probability.lt(1.0).float() + new_halted = (halting_probability + p * still_running).gt( + self.threshold).float() * still_running + still_running = (halting_probability + p * still_running).le( + self.threshold).float() * still_running + halting_probability = halting_probability + p * still_running + remainders = remainders + new_halted * (1 - halting_probability) + halting_probability = halting_probability + new_halted * remainders + n_updates = n_updates + still_running + new_halted + update_weights = (p * still_running + + new_halted * remainders).unsqueeze(2) + transformed_states = self.layer(hidden_states, attention_mask) + hidden_states = transformed_states * update_weights + hidden_states * ( + 1 - update_weights) + all_encoder_layers.append(hidden_states) + if not self.should_continue(halting_probability, n_updates): + break + return all_encoder_layers, torch.mean(n_updates + remainders) + + +class BERTPooler(nn.Module): + + def __init__(self, config): + super(BERTPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = modeling.BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config: BertConfig): + """Constructor for BertModel. + + Args: + config: `BertConfig` instance. + """ + super(BertModel, self).__init__() + self.config = config + self.embeddings = BERTEmbeddings(config) + if config.transformer_type.lower() == 'original': + self.encoder = BERTEncoder(config) + elif config.transformer_type.lower() == 'universal': + self.encoder = BERTEncoderRolled(config) + elif config.transformer_type.lower() == 'albert': + self.encoder = BERTEncoderRolled(config) + elif config.transformer_type.lower() == 'act': + self.encoder = BERTEncoderACT(config) + elif config.transformer_type.lower() == 'textnas': + from textnas_final import input_dict, op_dict, skip_dict + self.encoder = TextNASEncoder(config, op_dict, input_dict, + skip_dict) + else: + raise ValueError('Not support transformer type: {}'.format( + config.transformer_type.lower())) + self.pooler = BERTPooler(config) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + epoch_id=-1, + head_masks=None, + adv_embedding=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output, word_embeddings = self.embeddings( + input_ids, token_type_ids, adv_embedding) + if self.config.transformer_type.lower() == 'act': + all_encoder_layers, act_loss = self.encoder( + embedding_output, extended_attention_mask) + elif self.config.transformer_type.lower() == 'reformer': + sequence_output = self.encoder(embedding_output) + all_encoder_layers = [sequence_output, sequence_output] + else: + all_encoder_layers = self.encoder(embedding_output, + extended_attention_mask, + epoch_id, head_masks) + all_encoder_layers.insert(0, word_embeddings) + sequence_output = all_encoder_layers[-1] + if not self.config.safer_fp16: + pooled_output = self.pooler(sequence_output) + else: + pooled_output = sequence_output[:, 0] + return all_encoder_layers, pooled_output + + +class BertForSequenceClassificationMultiTask(nn.Module): + """BERT model for classification. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + num_labels = 2 + + model = BertForSequenceClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config, label_list, core_encoder): + super(BertForSequenceClassificationMultiTask, self).__init__() + if core_encoder.lower() == 'bert': + self.bert = BertModel(config) + elif core_encoder.lower() == 'lstm': + self.bert = LSTMModel(config) + else: + raise ValueError( + 'Only support lstm or bert, but got {}'.format(core_encoder)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.ModuleList() + for label in label_list: + self.classifier.append(nn.Linear(config.hidden_size, len(label))) + self.label_list = label_list + + def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=config.initializer_range) + elif isinstance(module, BERTLayerNorm): + module.beta.data.normal_( + mean=0.0, std=config.initializer_range) + module.gamma.data.normal_( + mean=0.0, std=config.initializer_range) + if isinstance(module, nn.Linear): + module.bias.data.zero_() + + self.apply(init_weights) + + def forward(self, + input_ids, + token_type_ids, + attention_mask, + labels=None, + labels_index=None, + epoch_id=-1, + head_masks=None, + adv_embedding=None, + return_embedding=False, + loss_weight=None): + all_encoder_layers, pooled_output = self.bert(input_ids, + token_type_ids, + attention_mask, epoch_id, + head_masks, + adv_embedding) + pooled_output = self.dropout(pooled_output) + logits = [classifier(pooled_output) for classifier in self.classifier] + if labels is not None: + loss_fct = CrossEntropyLoss(reduction='none') + regression_loss_fct = nn.MSELoss(reduction='none') + labels_lst = torch.unbind(labels, 1) + loss_lst = [] + for index, (label, logit) in enumerate(zip(labels_lst, logits)): + if len(self.label_list[index]) != 1: + loss = loss_fct(logit, label.long()) + else: + loss = regression_loss_fct(logit.squeeze(-1), label) + labels_mask = (labels_index == index).to( + dtype=next(self.parameters()).dtype) + if loss_weight is not None: + loss = loss * loss_weight[index] + loss = torch.mean(loss * labels_mask) + loss_lst.append(loss) + if not return_embedding: + return sum(loss_lst), logits + else: + return sum(loss_lst), logits, all_encoder_layers[0] + else: + return logits diff --git a/modelscope/models/multi_modal/diffusion/tokenizer.py b/modelscope/models/multi_modal/diffusion/tokenizer.py new file mode 100644 index 00000000..82c09661 --- /dev/null +++ b/modelscope/models/multi_modal/diffusion/tokenizer.py @@ -0,0 +1,333 @@ +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team and Alibaba inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import, division, print_function +import collections +import unicodedata + +import six + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode('utf-8', 'ignore') + elif isinstance(text, unicode): + return text + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode('utf-8') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, 'r') as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_tokens_to_ids(vocab, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(vocab[token]) + return ids + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a peice of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_tokens_to_ids(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return [self.inv_vocab[i] for i in ids] + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(' '.join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize('NFD', text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == 'Mn': + continue + output.append(char) + return ''.join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return [''.join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(' ') + output.append(char) + output.append(' ') + else: + output.append(char) + return ''.join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F)): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(' ') + else: + output.append(char) + return ''.join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = ''.join(chars[start:end]) + if start > 0: + substr = '##' + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == ' ' or char == '\t' or char == '\n' or char == '\r': + return True + cat = unicodedata.category(char) + if cat == 'Zs': + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == '\t' or char == '\n' or char == '\r': + return False + cat = unicodedata.category(char) + if cat.startswith('C'): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) + or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith('P'): + return True + return False diff --git a/modelscope/models/multi_modal/diffusion/unet_generator.py b/modelscope/models/multi_modal/diffusion/unet_generator.py new file mode 100644 index 00000000..539d3996 --- /dev/null +++ b/modelscope/models/multi_modal/diffusion/unet_generator.py @@ -0,0 +1,322 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['DiffusionGenerator'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample(x) + x = self.layer1[-1](self.resample(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None, mask=None): + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + if mask is not None: + assert context is not None + full_mask = x.new_ones((b, 1, q.size(-1), k.size(-1))) + full_mask[:, 0, :, :-q.size(-1)] = mask.unsqueeze(1) + attn = attn.masked_fill(full_mask == 0, float('-inf')) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class DiffusionGenerator(nn.Module): + + def __init__(self, + in_dim=3, + dim=512, + text_dim=1024, + context_dim=512, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.0): + embed_dim = dim * 4 + super(DiffusionGenerator, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.text_dim = text_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.pool_embedding = nn.Sequential( + nn.LayerNorm(text_dim), nn.Linear(text_dim, embed_dim)) + self.text_embedding = nn.Sequential( + nn.LayerNorm(text_dim), nn.Linear(text_dim, context_dim), + nn.SiLU(), nn.Linear(context_dim, context_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList( + [ResidualBlock(in_dim, embed_dim, out_dim, dropout)]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + AttentionBlock(out_dim, context_dim, num_heads, head_dim), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y, context, mask=None): + # embeddings + e = self.time_embedding(sinusoidal_embedding( + t, self.dim)) + self.pool_embedding(y) + context = self.text_embedding(context) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e, context, mask) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e, context, mask) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e, context, mask) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e, context, mask): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, AttentionBlock): + x = module(x, context, mask) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, mask) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/diffusion/unet_upsampler_1024.py b/modelscope/models/multi_modal/diffusion/unet_upsampler_1024.py new file mode 100644 index 00000000..38cff6a2 --- /dev/null +++ b/modelscope/models/multi_modal/diffusion/unet_upsampler_1024.py @@ -0,0 +1,243 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['SuperResUNet1024'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample(x) + x = self.layer1[-1](self.resample(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class SuperResUNet1024(nn.Module): + + def __init__(self, + in_dim=6, + dim=192, + out_dim=3, + dim_mult=[1, 1, 2, 2, 4, 4], + num_res_blocks=2, + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.0): + embed_dim = dim * 4 + super(SuperResUNet1024, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embedding + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual block + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + shortcut_dims.append(out_dim) + in_dim = out_dim + self.encoder.append(block) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual block + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, concat): + # embedding + if concat is not None: + if concat.shape[-2:] != x.shape[-2:]: + concat = F.interpolate( + concat, x.shape[-2:], mode='bilinear', align_corners=False) + x = torch.cat([x, concat], dim=1) + e = self.time_embedding(sinusoidal_embedding(t, self.dim)) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/diffusion/unet_upsampler_256.py b/modelscope/models/multi_modal/diffusion/unet_upsampler_256.py new file mode 100644 index 00000000..ca5cd7d6 --- /dev/null +++ b/modelscope/models/multi_modal/diffusion/unet_upsampler_256.py @@ -0,0 +1,340 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['SuperResUNet256'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample_x = Resample(in_dim, in_dim, scale_factor) + self.resample_i = Resample(in_dim, in_dim, scale_factor) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample_i(x) + x = self.layer1[-1](self.resample_x(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None, mask=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + mask: [B, L] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([k, ck], dim=-1) + v = torch.cat([v, cv], dim=-1) + + # compute attention + attn = torch.einsum('bndi,bndj->bnij', q * self.scale, k * self.scale) + if mask is not None: + pad_mask = mask.new_ones((b, 1, 1, h * w)) + mask = torch.cat((pad_mask, mask.unsqueeze(1).unsqueeze(1)), + dim=-1) + attn = attn.masked_fill(mask == 0, float('-inf')) + + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.einsum('bnij,bndj->bndi', attn, v) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class SuperResUNet256(nn.Module): + + def __init__(self, + in_dim=6, + out_dim=3, + dim=256, + text_dim=1024, + context_dim=512, + dim_mult=[1, 2, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=2, + attn_scales=[1 / 16], + resblock_resample=True, + use_conv=True, + use_scale_shift_norm=True, + dropout=0.1): + embed_dim = dim * 4 + super(SuperResUNet256, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.resblock_resample = resblock_resample + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.noise_time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.pool_embedding = nn.Sequential( + nn.LayerNorm(text_dim), nn.Linear(text_dim, embed_dim)) + self.text_embedding = nn.Sequential( + nn.LayerNorm(text_dim), nn.Linear(text_dim, context_dim), + nn.SiLU(), nn.Linear(context_dim, context_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + shortcut_dims.append(out_dim) + in_dim = out_dim + self.encoder.append(block) + + # downsample + if i != len(dim_mult) - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample(out_dim, out_dim, 0.5, use_conv) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + AttentionBlock(out_dim, context_dim, num_heads, head_dim), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample(out_dim, out_dim, 2.0, use_conv) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, lx, lt, y, context, mask): + assert context.shape[:-1] == mask.shape + + # embeddings + t = self.time_embedding(sinusoidal_embedding(t, self.dim)) \ + + self.noise_time_embedding(sinusoidal_embedding(lt, self.dim)) \ + + self.pool_embedding(y) + + context = self.text_embedding(context) + + if lx.shape[-2:] != x.shape[-2:]: + lx = F.interpolate( + lx, x.shape[-2:], mode='bilinear', align_corners=False) + x = torch.cat([x, lx], dim=1) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, t, context, mask) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, t, context, mask) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, t, context, mask) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, t, context, mask): + if isinstance(module, ResidualBlock): + x = module(x, t) + elif isinstance(module, AttentionBlock): + x = module(x, context, mask) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, t, context, mask) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/gemm/__init__.py b/modelscope/models/multi_modal/gemm/__init__.py new file mode 100644 index 00000000..fe5df1fe --- /dev/null +++ b/modelscope/models/multi_modal/gemm/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from .gemm_model import GEMMForMultiModalEmbedding diff --git a/modelscope/models/multi_modal/gemm/gemm_base.py b/modelscope/models/multi_modal/gemm/gemm_base.py new file mode 100644 index 00000000..806c469c --- /dev/null +++ b/modelscope/models/multi_modal/gemm/gemm_base.py @@ -0,0 +1,556 @@ +# Copyright 2021 The OpenAI Team Authors. +# Copyright 2022 Phil Wang. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +# +# The implementation here is modified based on OpenAI CLIP, +# originally MIT License, Copyright (c) 2021 OpenAI, +# and publicly available at https://github.com/openai/CLIP/. +# The implementation here is modified based on Coca-pytorch, +# originally MIT License, Copyright (c) 2022 Phil Wang, +# and publicly available at https://github.com/lucidrains/CoCa-pytorch/, +""" Generative Multimodal Model Architecture.""" + +import os +from collections import OrderedDict +from typing import Tuple, Union + +import json +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import LayerNorm + +from modelscope.models.multi_modal.gemm.tokenizer import (SimpleTokenizer, + clip_tokenize) + + +class Bottleneck(nn.Module): + """ ResNet style bottleneck module + From https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + if stride > 1 or inplanes != planes * Bottleneck.expansion: + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + out = self.relu(out) + return out + + +class QuickGELU(nn.Module): + """ A quick version of GELU module + From https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + """ Multihead attention block with residual link + Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + attn_mask = self.attn_mask + if attn_mask is not None and attn_mask.shape[0] > x.shape[0]: + attn_mask = self.attn_mask[:x.shape[0], :x.shape[0]] + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + """ Transformer encoder module + Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + use_gc: bool = False): + super().__init__() + self.use_gc = use_gc + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class AttentionPool2d(nn.Module): + """ Pool layer with attention module + Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, 1) + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) + x = x + self.positional_embedding[:, None, :].to(x.dtype) + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + return x.permute(1, 0, 2).contiguous() + + +class CrossAttention(nn.Module): + """ Cross attention module with query and context as input + Adapted from https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py + """ + + def __init__(self, + dim, + *, + context_dim=None, + dim_head=64, + heads=8, + parallel_ff=False, + ff_mult=4, + norm_context=False): + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + inner_dim = heads * dim_head + context_dim = dim if context_dim is None else context_dim + self.norm = LayerNorm(dim) + self.context_norm = LayerNorm( + context_dim) if norm_context else nn.Identity() + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + ff_inner_dim = ff_mult * dim + self.ff = nn.Sequential( + nn.Linear(dim, ff_inner_dim * 2, bias=False), SwiGLU(), + nn.Linear(ff_inner_dim, dim, bias=False)) if parallel_ff else None + + def forward(self, x, context): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + x = self.norm(x) + context = self.context_norm(context) + + q = self.to_q(x) + q = q.view(q.shape[0], q.shape[1], self.heads, + -1).permute(0, 2, 1, 3).contiguous() + q = q * self.scale + k, v = self.to_kv(context).chunk(2, dim=-1) + sim = torch.einsum('b h i d, b j d -> b h i j', q, k) + sim = sim - sim.amax(dim=-1, keepdim=True) + attn = sim.softmax(dim=-1) + out = torch.einsum('b h i j, b j d -> b h i d', attn, v) + out = out.permute(0, 2, 1, + 3).contiguous().reshape(out.shape[0], out.shape[2], + -1) + out = self.to_out(out) + if self.ff is not None: + out = out + self.ff(x) + return out + + +class ModifiedResNet(nn.Module): + """ Modified ResNet backbone + From https://github.com/openai/CLIP/blob/main/clip/model.py + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + self._inplanes = width + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class VisualTransformer(nn.Module): + """ ViT transformer backbone + From https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int, use_gc: bool): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer(width, layers, heads, use_gc=use_gc) + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) + z = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([self.class_embedding.to(x.dtype) + z, x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + return x + + +class GEVL(nn.Module): + """ Generative vision-language model + Support learning from both generative and contrastive loss. + Given image and text input, it could output the features of + image and text respectively. Furthermore, caption could also + be produced when image input is available. + """ + + def __init__(self, embed_dim: int, image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], + int], vision_width: int, + vision_patch_size: int, context_length: int, vocab_size: int, + transformer_width: int, transformer_heads: int, + transformer_layers: int, use_gc: bool, tokenizer): + nn.Module.__init__(self) + self.context_length = context_length + self.vis_token_size = context_length + self.tokenizer = tokenizer + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + use_gc=use_gc) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + use_gc=use_gc) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.vis_token_projection = nn.Parameter( + torch.empty(embed_dim, transformer_width)) + nn.init.normal_( + self.vis_token_projection, std=self.transformer.width**-0.5) + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.decoder = Transformer( + width=transformer_width, + layers=4, + heads=transformer_heads, + attn_mask=self.build_attention_mask( + self.vis_token_size + self.context_length, + self.vis_token_size), + use_gc=use_gc) + self.to_logits = nn.Sequential( + LayerNorm(transformer_width), + nn.Linear(transformer_width, transformer_width), + nn.Linear(transformer_width, vocab_size, bias=False)) + self.gen_logit_scale = nn.Parameter( + torch.ones([]) * np.log(np.log(vocab_size))) + self.bias = nn.Parameter(torch.ones(vocab_size)) + self.to_logits[-1].weight = self.token_embedding.weight + self.to_logits[-1].bias = self.bias + self.img_queries = nn.Parameter( + torch.randn(self.vis_token_size, transformer_width)) + self.img_attn_pool = CrossAttention( + dim=transformer_width, norm_context=True) + self.img_attn_pool_norm = LayerNorm(transformer_width) + + def build_attention_mask(self, seq_length=None, prefix_length=0): + seq_length = self.context_length if seq_length is None else seq_length + mask = torch.empty(seq_length, seq_length) + mask.fill_(torch.tensor(torch.finfo(torch.float16).min)) + mask.triu_(1) + if prefix_length > 0: + mask[:prefix_length, :prefix_length] = 0 + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image, return_tokens=False): + image_outputs = self.visual(image) + image_features = image_outputs[:, 0, :] + image_features = image_features / image_features.norm( + dim=-1, p=2, keepdim=True) + if return_tokens: + image_tokens = image_outputs[:, 1:, :] @ self.vis_token_projection + return image_features, image_tokens + else: + return image_features + + def encode_text(self, text, return_tokens=False): + x = self.token_embedding(text) + x = x + self.positional_embedding[:x.shape[1], :] + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_final(x) + text_features = x[torch.arange(x.shape[0]), + text.argmax(dim=-1), ...] @ self.text_projection + text_features = text_features / text_features.norm( + dim=-1, p=2, keepdim=True) + if return_tokens: + text_tokens = x + return text_features, text_tokens + else: + return text_features + + def image_to_text(self, image): + image_features, image_tokens = self.encode_image( + image, return_tokens=True) + img_queries = self.img_queries.expand(image_tokens.shape[0], -1, -1) + img_token_features = self.img_attn_pool(img_queries, image_tokens) + img_token_features = self.img_attn_pool_norm(img_token_features) + sot_token = self.tokenizer.encoder['<|startoftext|>'] + eot_token = self.tokenizer.encoder['<|endoftext|>'] + text_input = image.new_ones( + image.shape[0], 1, dtype=torch.long) * sot_token + input_tokens = img_token_features + pred_tokens = [] + for text_idx in range(self.context_length): + text_features, text_tokens = self.encode_text( + text_input, return_tokens=True) + input_tokens = torch.cat([img_token_features, text_tokens], axis=1) + out_embs = self.decoder(input_tokens.permute(1, 0, 2).contiguous()) + gen_logits = self.to_logits(out_embs[-1:, ...]) + probs = F.softmax(self.gen_logit_scale.exp() * gen_logits, dim=-1) + pred = torch.argmax( + probs * (2.0 + torch.rand_like(probs)), axis=-1) + if int(pred) >= eot_token or int(pred) <= 0: + break + pred_tokens.append(pred) + text_input = torch.cat( + [text_input, pred.permute(1, 0).contiguous()], axis=1) + pred_text_tokens = torch.cat(pred_tokens, axis=0).permute(1, 0) + text_list = [] + for out_tokens in pred_text_tokens: + tokens = [] + for x in out_tokens: + tokens.append(int(x)) + out_text = self.tokenizer.decode(tokens) + out_text = out_text.strip() + text_list.append(out_text) + return image_features, text_list[0] + + +class GEMMModel(nn.Module): + """ Generative multi-modal model, wrapper of GEVL module. + It takes image or text or both of them as input, and output + features of input or caption when image input is available. + """ + + def __init__(self, model_dir): + super().__init__() + with open('{}/encoder_config.json'.format(model_dir), 'r') as f: + model_config = json.loads(f.read()) + model_name = list(model_config.keys())[0] + config_args = model_config[model_name] + bpe_path = os.path.join(model_dir, 'bpe_vocab_16e6.txt.gz') + self.tokenizer = SimpleTokenizer(bpe_path) + self.model = GEVL(*config_args, self.tokenizer) + + def tokenize(self, text_str): + text_tensor = clip_tokenize(self.tokenizer, [text_str])[0] + return text_tensor + + def parse_feat(self, feat): + out = feat.cpu().numpy() + return out + + @torch.no_grad() + def forward(self, image=None, text=None, captioning=True): + img_feature, text_feature, caption = None, None, None + if captioning and image is not None: + img_feature, caption = self.model.image_to_text(image) + img_feature = self.parse_feat(img_feature) + elif image is not None: + img_feature = self.parse_feat(self.model.encode_image(image)) + if text is not None: + text_feature = self.parse_feat(self.model.encode_text(text)) + out = { + 'image_feature': img_feature, + 'text_feature': text_feature, + 'caption': caption, + } + return out diff --git a/modelscope/models/multi_modal/gemm/gemm_model.py b/modelscope/models/multi_modal/gemm/gemm_model.py new file mode 100644 index 00000000..c90b35d4 --- /dev/null +++ b/modelscope/models/multi_modal/gemm/gemm_model.py @@ -0,0 +1,94 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +""" Generative Multimodal Model Wrapper.""" +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms as T + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.gemm.gemm_base import GEMMModel +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['GEMMForMultiModalEmbedding'] + + +@MODELS.register_module( + Tasks.generative_multi_modal_embedding, module_name=Models.gemm) +class GEMMForMultiModalEmbedding(TorchModel): + """ Generative multi-modal model for multi-modal embedding + Inputs could be image or text or both of them. + Outputs could be features of input image or text, + image caption could also be produced when image is available. + """ + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.gemm_model = GEMMModel(model_dir=model_dir) + pretrained_params = torch.load('{}/{}'.format( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) + self.gemm_model.load_state_dict(pretrained_params) + self.gemm_model.eval() + self.device_id = device_id + if self.device_id >= 0 and torch.cuda.is_available(): + self.gemm_model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device_id = -1 + logger.info('Use CPU for inference') + self.img_preprocessor = T.Compose([ + T.Resize(224), + T.CenterCrop(224), + T.ToTensor(), + T.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)) + ]) + + def parse_image(self, input_img): + if input_img is None: + return None + input_img = LoadImage.convert_to_img(input_img) + img_tensor = self.img_preprocessor(input_img)[None, ...] + if self.device_id >= 0: + img_tensor = img_tensor.to('cuda:{}'.format(self.device_id)) + return img_tensor + + def parse_text(self, text_str): + if text_str is None or len(text_str) == 0: + return None + if isinstance(text_str, str): + text_ids_tensor = self.gemm_model.tokenize(text_str) + else: + raise TypeError(f'text should be str, but got {type(text_str)}') + if self.device_id >= 0: + text_ids_tensor = text_ids_tensor.to('cuda:{}'.format( + self.device_id)) + return text_ids_tensor.view(1, -1) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + image_input = input.get('image', input.get('img', None)) + text_input = input.get('text', input.get('txt', None)) + captioning_input = input.get('captioning', None) + image = self.parse_image(image_input) + text = self.parse_text(text_input) + captioning = captioning_input is True or text_input == '' + out = self.gemm_model(image, text, captioning) + output = { + OutputKeys.IMG_EMBEDDING: out.get('image_feature', None), + OutputKeys.TEXT_EMBEDDING: out.get('text_feature', None), + OutputKeys.CAPTION: out.get('caption', None) + } + return output diff --git a/modelscope/models/multi_modal/gemm/tokenizer.py b/modelscope/models/multi_modal/gemm/tokenizer.py new file mode 100644 index 00000000..8b7cc094 --- /dev/null +++ b/modelscope/models/multi_modal/gemm/tokenizer.py @@ -0,0 +1,201 @@ +# Copyright 2021 The OpenAI Team Authors. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +# +# The implementation here is modified based on OpenAI CLIP, +# originally MIT License, Copyright (c) 2021 OpenAI, +# and publicly available at https://github.com/openai/CLIP/. +""" CLIP Tokenizer.""" + +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + error_list = [] + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception as err: + error_list.append(err) + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + if len(error_list) > 100: + print(error_list[-1]) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +def clip_tokenize(tokenizer, texts, context_length=77, truncate=True): + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = tokenizer.encoder['<|startoftext|>'] + eot_token = tokenizer.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] + for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f'Input {texts[i]} is too long for context length {context_length}' + ) + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/modelscope/models/multi_modal/mmr/__init__.py b/modelscope/models/multi_modal/mmr/__init__.py new file mode 100644 index 00000000..9dac8409 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from .models import VideoCLIPForMultiModalEmbedding diff --git a/modelscope/models/multi_modal/mmr/dataloaders/__init__.py b/modelscope/models/multi_modal/mmr/dataloaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/mmr/dataloaders/rawvideo_util.py b/modelscope/models/multi_modal/mmr/dataloaders/rawvideo_util.py new file mode 100644 index 00000000..c7ac3f94 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/dataloaders/rawvideo_util.py @@ -0,0 +1,117 @@ +# The implementation is adopted from Huaishao Luo, +# made pubicly available under the MIT License at https://github.com/ArrowLuo/CLIP4Clip + +import cv2 +import numpy as np +import torch as th +from PIL import Image +from torchvision.transforms import (CenterCrop, Compose, InterpolationMode, + Normalize, Resize, ToTensor) + +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class RawVideoExtractorCV2(): + + def __init__( + self, + centercrop=False, + size=224, + frame_rate=-1, + ): + self.centercrop = centercrop + self.size = size + self.framerate = frame_rate + self.transform = self._transform(self.size) + + def _transform(self, n_px): + return Compose([ + Resize(n_px, interpolation=InterpolationMode.BICUBIC), + CenterCrop(n_px), + lambda image: image.convert('RGB'), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + def video_to_tensor(self, + video_file, + preprocess, + sample_fp=0, + start_time=None, + end_time=None): + if start_time is not None or end_time is not None: + assert isinstance(start_time, int) and isinstance(end_time, int) \ + and start_time > -1 and end_time > start_time + assert sample_fp > -1 + + # Samples a frame sample_fp X frames. + cap = cv2.VideoCapture(video_file) + frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = int(cap.get(cv2.CAP_PROP_FPS)) + + if fps == 0: + logger.info(f'{video_file} with fps 0!!!') + total_duration = (frameCount + fps - 1) // fps + start_sec, end_sec = 0, total_duration + + if start_time is not None: + start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration + cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) + + interval = 1 + if sample_fp > 0: + interval = fps // sample_fp + else: + sample_fp = fps + if interval == 0: + interval = 1 + + inds = [ind for ind in np.arange(0, fps, interval)] + assert len(inds) >= sample_fp + inds = inds[:sample_fp] + + ret = True + images = [] + + for sec in np.arange(start_sec, end_sec + 1): + if not ret: + break + sec_base = int(sec * fps) + for ind in inds: + cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) + ret, frame = cap.read() + if not ret: + break + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + images.append( + preprocess(Image.fromarray(frame_rgb).convert('RGB'))) + + cap.release() + + if len(images) > 0: + video_data = th.tensor(np.stack(images)) + else: + video_data = th.zeros(1) + return {'video': video_data} + + def get_video_data(self, video_path, start_time=None, end_time=None): + image_input = self.video_to_tensor( + video_path, + self.transform, + sample_fp=self.framerate, + start_time=start_time, + end_time=end_time) + return image_input + + def process_raw_data(self, raw_video_data): + tensor_size = raw_video_data.size() + tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], + tensor_size[-1]) + return tensor + + +# An ordinary video frame extractor based CV2 +RawVideoExtractor = RawVideoExtractorCV2 diff --git a/modelscope/models/multi_modal/mmr/models/__init__.py b/modelscope/models/multi_modal/mmr/models/__init__.py new file mode 100644 index 00000000..da832719 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from .clip_for_mm_video_embedding import VideoCLIPForMultiModalEmbedding diff --git a/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py new file mode 100644 index 00000000..0cc040c6 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py @@ -0,0 +1,246 @@ +# The implementation is adopted from the CLIP4Clip implementation, +# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip + +import os +import random +import uuid +from os.path import exists +from tempfile import TemporaryDirectory +from typing import Any, Dict +from urllib.parse import urlparse + +import json +import numpy as np +import torch +from decord import VideoReader, cpu +from PIL import Image + +from modelscope.hub.file_download import http_get_file +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..dataloaders.rawvideo_util import RawVideoExtractor +from .modeling import CLIP4Clip +from .tokenization_clip import SimpleTokenizer as ClipTokenizer + +logger = get_logger() + + +@MODELS.register_module( + Tasks.video_multi_modal_embedding, module_name=Models.video_clip) +class VideoCLIPForMultiModalEmbedding(TorchModel): + + def __init__(self, model_dir, **kwargs): + super().__init__(model_dir=model_dir, **kwargs) + # model config parameters + with open(f'{model_dir}/{ModelFile.CONFIGURATION}', 'r') as json_file: + model_config = json.load(json_file) + model_config = model_config['paras'] + model_config['model_dir'] = model_dir + self.SPECIAL_TOKEN = { + 'CLS_TOKEN': '<|startoftext|>', + 'SEP_TOKEN': '<|endoftext|>', + 'MASK_TOKEN': '[MASK]', + 'UNK_TOKEN': '[UNK]', + 'PAD_TOKEN': '[PAD]' + } + self.max_words = model_config['max_words'] + self.max_frames = model_config['max_frames'] + self.feature_framerate = model_config['feature_framerate'] + self.image_resolution = 224 + if torch.cuda.is_available(): + self.device = model_config['device'] + else: + self.device = 'cpu' + self.init_model = f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}' + + self.tokenizer = ClipTokenizer(model_dir) + self.rawVideoExtractor = RawVideoExtractor( + frame_rate=self.feature_framerate, size=self.image_resolution) + self.local_transform = self.rawVideoExtractor.transform + + self.model = CLIP4Clip(model_config) + if hasattr(self.model, 'module'): + self.model = self.model.module.to(self.device) + else: + self.model = self.model.to(self.device) + if self.init_model: + assert exists(self.init_model) + model_state_dict = torch.load(self.init_model, map_location='cpu') + self.model.load_state_dict(model_state_dict, strict=False) + self.model.to(self.device) + + def _get_text(self, caption, tokenizer, enable_zh=False): + + if type(caption) is str: + _caption_text, s, e = caption, None, None + elif type(caption) is tuple: + if len(caption) == 3: + _caption_text, s, e = caption + elif len(caption) == 4: + _caption_text, s, e, pos = caption + else: + NotImplementedError + + if isinstance(_caption_text, list): + caption_text = random.choice(_caption_text) + else: + caption_text = _caption_text + if enable_zh: + _token = tokenizer.encode(caption_text) + input_ids = _token.ids + input_mask = _token.attention_mask + segment_ids = _token.type_ids + else: + words = tokenizer.tokenize(caption_text) + + words = [self.SPECIAL_TOKEN['CLS_TOKEN']] + words + total_length_with_CLS = self.max_words - 1 + if len(words) > total_length_with_CLS: + words = words[:total_length_with_CLS] + words = words + [self.SPECIAL_TOKEN['SEP_TOKEN']] + + input_ids = tokenizer.convert_tokens_to_ids(words) + input_mask = [1] * len(input_ids) + segment_ids = [0] * len(input_ids) + + while len(input_ids) < self.max_words: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == self.max_words + assert len(input_mask) == self.max_words + assert len(segment_ids) == self.max_words + + pairs_text = np.array(input_ids) + pairs_mask = np.array(input_mask) + pairs_segment = np.array(segment_ids) + + return pairs_text, pairs_mask, pairs_segment, s, e + + def _get_rawvideo_dec(self, + video_path, + rawVideoExtractor, + local_transform, + s=None, + e=None): + video_mask = np.zeros(self.max_frames, dtype=np.long) + max_video_length = 0 + + # T x 3 x H x W + video = np.zeros((self.max_frames, 3, rawVideoExtractor.size, + rawVideoExtractor.size), + dtype=np.float) + + if s is None: + start_time, end_time = None, None + else: + start_time = int(s) + end_time = int(e) + start_time = start_time if start_time >= 0. else 0. + end_time = end_time if end_time >= 0. else 0. + if start_time > end_time: + start_time, end_time = end_time, start_time + elif start_time == end_time: + end_time = end_time + 1 + + url_parsed = urlparse(video_path) + if url_parsed.scheme in ('file', '') and exists( + url_parsed.path): # Possibly a local file + vreader = VideoReader(video_path, ctx=cpu(0)) + else: + try: + with TemporaryDirectory() as temporary_cache_dir: + random_str = uuid.uuid4().hex + http_get_file( + url=video_path, + local_dir=temporary_cache_dir, + file_name=random_str, + cookies=None) + temp_file_path = os.path.join(temporary_cache_dir, + random_str) + vreader = VideoReader(temp_file_path, ctx=cpu(0)) + except Exception as ex: + logger.error('non video input, output is {}!!!'.format(ex)) + return video, video_mask + + fps = vreader.get_avg_fps() + f_start = 0 if start_time is None else int(start_time * fps) + f_end = int( + min(1000000000 if end_time is None else end_time * fps, + len(vreader) - 1)) + num_frames = f_end - f_start + 1 + if num_frames > 0: + # L x T x 3 x H x W + sample_fps = int(self.feature_framerate) + t_stride = int(round(float(fps) / sample_fps)) + + all_pos = list(range(f_start, f_end + 1, t_stride)) + if len(all_pos) > self.max_frames: + sample_pos = [ + all_pos[_] for _ in np.linspace( + 0, len(all_pos) - 1, num=self.max_frames, dtype=int) + ] + else: + sample_pos = all_pos + patch_images = [ + Image.fromarray(f) + for f in vreader.get_batch(sample_pos).asnumpy() + ] + patch_images = torch.stack( + [local_transform(img) for img in patch_images]) + slice_len = patch_images.shape[0] + max_video_length = max_video_length if max_video_length > slice_len else slice_len + if slice_len < 1: + pass + else: + video[:slice_len, ...] = patch_images + else: + logger.error('video path: {} error. video id: {}'.format( + video_path, video_id)) + + video_mask[:max_video_length] = [1] * max_video_length + + return video, video_mask + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + from modelscope.outputs import OutputKeys + output = {} + + if 'video' in input and input['video'] is not None: + video_path = input['video'] + video, video_mask = self._get_rawvideo_dec(video_path, + self.rawVideoExtractor, + self.local_transform) + video = torch.unsqueeze( + torch.from_numpy(video), dim=0).to(self.device) + video_mask = torch.unsqueeze( + torch.from_numpy(video_mask), dim=0).to(self.device) + + if 'text' in input and input['text'] is not None: + caption = input['text'] + pairs_text, pairs_mask, pairs_segment, s, e = self._get_text( + caption, self.tokenizer, enable_zh=False) + input_ids = torch.unsqueeze( + torch.from_numpy(pairs_text), dim=0).to(self.device) + input_mask = torch.unsqueeze( + torch.from_numpy(pairs_mask), dim=0).to(self.device) + segment_ids = torch.unsqueeze( + torch.from_numpy(pairs_segment), dim=0).to(self.device) + + sequence_output, visual_output = self.model.get_sequence_visual_output( + input_ids, segment_ids, input_mask, video, video_mask) + logger.info('text feature: {}'.format(sequence_output[0][0][0])) + logger.info('video feature: {}'.format(visual_output[0][0][0])) + + output[ + OutputKeys.VIDEO_EMBEDDING] = visual_output.cpu().detach().numpy() + output[OutputKeys.TEXT_EMBEDDING] = sequence_output.cpu().detach( + ).numpy() + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py b/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py new file mode 100644 index 00000000..c2d96275 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/dynamic_inverted_softmax.py @@ -0,0 +1,45 @@ +# The implementation is adopted from the CLIP4Clip implementation, +# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip + +import numpy as np + + +def get_retrieved_videos(sims, k): + """ + Returns list of retrieved top k videos based on the sims matrix + Args: + sims: similar matrix. + K: top k number of videos + """ + argm = np.argsort(-sims, axis=1) + topk = argm[:, :k].reshape(-1) + retrieved_videos = np.unique(topk) + return retrieved_videos + + +def get_index_to_normalize(sims, videos): + """ + Returns list of indices to normalize from sims based on videos + Args: + sims: similar matrix. + videos: video array. + """ + argm = np.argsort(-sims, axis=1)[:, 0] + result = np.array(list(map(lambda x: x in videos, argm))) + result = np.nonzero(result) + return result + + +def qb_norm(train_test, test_test, args): + k = args.get('k', 1) + beta = args.get('beta', 20) + retrieved_videos = get_retrieved_videos(train_test, k) + test_test_normalized = test_test + train_test = np.exp(train_test * beta) + test_test = np.exp(test_test * beta) + + normalizing_sum = np.sum(train_test, axis=0) + index_for_normalizing = get_index_to_normalize(test_test, retrieved_videos) + test_test_normalized[index_for_normalizing, :] = \ + np.divide(test_test[index_for_normalizing, :], normalizing_sum) + return test_test_normalized diff --git a/modelscope/models/multi_modal/mmr/models/modeling.py b/modelscope/models/multi_modal/mmr/models/modeling.py new file mode 100644 index 00000000..dc6510bf --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/modeling.py @@ -0,0 +1,507 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os +import platform +from collections import OrderedDict +from types import SimpleNamespace + +import torch +from torch import nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from modelscope.models.multi_modal.mmr.models.module_clip import ( + _PT_NAME, CLIP, QuickGELU, convert_weights) +from modelscope.models.multi_modal.mmr.models.module_cross import \ + Transformer as TransformerClip +from modelscope.models.multi_modal.mmr.models.until_module import (AllGather, + CrossEn, + LayerNorm) +from modelscope.utils.logger import get_logger + +allgather = AllGather.apply + +logger = get_logger() +__all__ = ['CLIP4Clip'] + + +class CLIP4Clip(nn.Module): + + def __init__(self, config): + super(CLIP4Clip, self).__init__() + + self.config = config + self.loose_type = config['loose_type'] + self.sim_header = config['sim_header'] + if self.sim_header in [ + 'tightTransf', 'tightFc1', 'tightFc2', 'tightFc3', 'tightFc4', + 'tightMean', 'tightFc5' + ]: + assert self.loose_type is False + + backbone = config['pretrained_clip_name'] + + # fix backbone without downlond + model_path = '{}/ViT-B-16.pt'.format(config['model_dir']) + if not os.path.exists(model_path): + logger.info('no model loaded!!!') + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location='cpu').eval() + state_dict = model.state_dict() + except RuntimeError: + state_dict = torch.load(model_path, map_location='cpu') + + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) + image_resolution = vision_patch_size * grid_size + + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) + + cut_top_layer = 0 + self.clip = CLIP( + embed_dim, + image_resolution, + vision_layers - cut_top_layer, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers - cut_top_layer, + linear_patch=config['linear_patch'], + use_gc=config['use_gc']).float() + + if backbone in ['ViT-B/32', 'ViT-B/16']: + cross_config = SimpleNamespace(**{ + 'hidden_size': 512, + 'max_position_embeddings': 128, + }) + elif backbone in ['ViT-L/14', 'ViT-B/14-336px']: + cross_config = SimpleNamespace(**{ + 'hidden_size': 768, + 'max_position_embeddings': 128, + }) + else: + raise ValueError + + cross_config.max_position_embeddings = context_length + self.cross_config = cross_config + + self.text_weight_fc = nn.Sequential( + nn.Linear(transformer_width, transformer_width), + nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) + self.video_weight_fc = nn.Sequential( + nn.Linear(transformer_width, transformer_width), + nn.ReLU(inplace=True), nn.Linear(transformer_width, 1)) + + if self.loose_type is False: + raise NotImplementedError + + if self.sim_header in ['seqLSTM', 'seqTransf', 'tightFc1']: + self.frame_position_embeddings = nn.Embedding( + cross_config.max_position_embeddings, cross_config.hidden_size) + if self.sim_header in ['seqTransf', 'tightFc1']: + self.transformerClip = TransformerClip( + width=transformer_width, + layers=config['cross_num_hidden_layers'], + heads=transformer_heads, + ) + if self.sim_header == 'seqLSTM': + self.lstm_visual = nn.LSTM( + input_size=cross_config.hidden_size, + hidden_size=cross_config.hidden_size, + batch_first=True, + bidirectional=False, + num_layers=1) + + self.loss_fct = CrossEn(config) + + self.apply(self.init_weights) + self.clip.load_state_dict(state_dict, strict=False) + + # ===> Initialization trick [HARD CODE] + if backbone not in _PT_NAME: + raise NotImplementedError + # reload + else: + if config['linear_patch'] == '3d': + raise NotImplementedError + + new_state_dict = OrderedDict() + if self.sim_header == 'tightTransf': + raise NotImplementedError + + if self.sim_header in ['seqLSTM', 'seqTransf', 'seqFc1']: + contain_frame_position = False + for key in state_dict.keys(): + if key.find('frame_position_embeddings') > -1: + contain_frame_position = True + break + if contain_frame_position is False: + for key, val in state_dict.items(): + if key == 'positional_embedding': + new_state_dict[ + 'frame_position_embeddings.weight'] = val.clone() + continue + if self.sim_header in [ + 'seqTransf', 'seqFc1' + ] and key.find('transformer.resblocks') == 0: + num_layer = int(key.split('.')[2]) + # cut from beginning + if num_layer < config['cross_num_hidden_layers']: + new_state_dict[key.replace( + 'transformer.', + 'transformerClip.')] = val.clone() + continue + # <=== End of initialization trick + + self.load_state_dict( + new_state_dict, strict=False + ) # only update new state (seqTransf/seqLSTM/tightTransf) + if self.sim_header == 'tightFc5': + raise ValueError + + def forward(self, + input_ids, + token_type_ids, + attention_mask, + video, + video_mask=None): + input_ids = input_ids.view(-1, input_ids.shape[-1]) + token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) + attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + + # B x T x 3 x H x W - > (B x T) x 3 x H x W + video = torch.as_tensor(video).float() + if len(video.shape) == 6: # image + b, bs, ts, channel, h, w = video.shape + b = b * bs + else: # video + b, ts, channel, h, w = video.shape + video = video.view(b * ts, channel, h, w) + + sequence_output, visual_output = self.get_sequence_visual_output( + input_ids, + token_type_ids, + attention_mask, + video, + video_mask, + shaped=True) + + if self.training: + loss = 0. + sim_matrix1, sim_matrix2, barlow_loss = self.get_similarity_logits( + sequence_output, + visual_output, + attention_mask, + video_mask, + shaped=True, + loose_type=self.loose_type) + sim_loss = (self.loss_fct(sim_matrix1) + + self.loss_fct(sim_matrix2)) / 2 + loss += sim_loss + barlow_loss * self.config.cdcr_lambda + + return loss + else: + return None + + def get_sequence_output(self, + input_ids, + token_type_ids, + attention_mask, + shaped=False): + if shaped is False: + input_ids = input_ids.view(-1, input_ids.shape[-1]) + token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) + attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) + + bs_pair = input_ids.size(0) + sequence_hidden = self.clip.encode_text( + input_ids, return_hidden=True, prompt=None)[1].float() + sequence_hidden = sequence_hidden.view(bs_pair, -1, + sequence_hidden.size(-1)) + + return sequence_hidden + + def get_visual_output(self, video, video_mask, shaped=False): + if shaped is False: + video_mask = video_mask.view(-1, video_mask.shape[-1]) + video = torch.as_tensor(video).float() + b, ts, channel, h, w = video.shape + video = video.view(b * ts, channel, h, w) + + bs_pair = video_mask.size(0) + visual_hidden = self.clip.encode_image(video).float() + visual_hidden = visual_hidden.float().view(bs_pair, -1, + visual_hidden.size(-1)) + + return visual_hidden + + def get_sequence_visual_output(self, + input_ids, + token_type_ids, + attention_mask, + video, + video_mask, + shaped=False): + if shaped is False: + input_ids = input_ids.view(-1, input_ids.shape[-1]) + token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) + attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + + video = torch.as_tensor(video).float() + if len(video.shape) == 6: # image + b, bs, ts, channel, h, w = video.shape + b = b * bs + else: # video + b, ts, channel, h, w = video.shape + video = video.view(b * ts, channel, h, w) + + sequence_output = self.get_sequence_output( + input_ids, token_type_ids, attention_mask, shaped=True) + visual_output = self.get_visual_output(video, video_mask, shaped=True) + + return sequence_output, visual_output + + def agg_video_feat(self, visual_output, video_mask, sim_header='meanP'): + if self.config.max_sum == 0: + raise ValueError + + if sim_header == 'meanP': + # Default: Parameter-free type + pass + elif sim_header == 'seqLSTM': + # Sequential type: LSTM + visual_output_original = visual_output + visual_output = pack_padded_sequence( + visual_output, + torch.sum(video_mask, dim=-1).cpu(), + batch_first=True, + enforce_sorted=False) + visual_output, _ = self.lstm_visual(visual_output) + if self.training: + self.lstm_visual.flatten_parameters() + visual_output, _ = pad_packed_sequence( + visual_output, batch_first=True) + visual_output = torch.cat( + (visual_output, visual_output_original[:, + visual_output.size(1):, + ...].contiguous()), + dim=1) + visual_output = visual_output + visual_output_original + elif sim_header == 'seqTransf': + # Sequential type: Transformer Encoder + visual_output_original = visual_output + seq_length = visual_output.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=visual_output.device) + position_ids = position_ids.unsqueeze(0).expand( + visual_output.size(0), -1) + frame_position_embeddings = self.frame_position_embeddings( + position_ids) + visual_output = visual_output + frame_position_embeddings + + extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 + extended_video_mask = extended_video_mask.expand( + -1, video_mask.size(1), -1) + visual_output = visual_output.permute(1, 0, 2) # NLD -> LND + visual_output = self.transformerClip(visual_output, + extended_video_mask) + visual_output = visual_output.permute(1, 0, 2) # LND -> NLD + visual_output = visual_output + visual_output_original + + return visual_output + + def wti_interaction(self, text_feat, video_feat, text_mask, video_mask): + text_weight = self.text_weight_fc(text_feat).squeeze( + 2) # B x N_t x D -> B x N_t + text_weight.masked_fill_( + torch.tensor((1 - text_mask), dtype=torch.bool), float('-inf')) + text_weight = torch.softmax(text_weight, dim=-1) # B x N_t + + video_weight = self.video_weight_fc(video_feat).squeeze( + 2) # B x N_v x D -> B x N_v + video_weight.masked_fill_( + torch.tensor((1 - video_mask), dtype=torch.bool), float('-inf')) + video_weight = torch.softmax(video_weight, dim=-1) # B x N_v + + text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True) + + retrieve_logits = torch.einsum('atd,bvd->abtv', + [text_feat, video_feat]) + retrieve_logits = torch.einsum('abtv,at->abtv', + [retrieve_logits, text_mask]) + retrieve_logits = torch.einsum('abtv,bv->abtv', + [retrieve_logits, video_mask]) + + t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt + t2v_logits = torch.einsum('abt,at->ab', [t2v_logits, text_weight]) + + v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv + v2t_logits = torch.einsum('abv,bv->ab', [v2t_logits, video_weight]) + retrieve_logits = (t2v_logits + v2t_logits) / 2.0 + + if self.training: + logit_scale = self.clip.logit_scale.exp() + retrieve_logits = logit_scale * retrieve_logits + + # selecet max + max_idx1 = max_idx1[torch.arange(max_idx1.shape[0]), + torch.arange(max_idx1.shape[1])] + max_idx2 = max_idx2[torch.arange(max_idx2.shape[0]), + torch.arange(max_idx2.shape[1])] + + max_t_feat = text_feat[torch.arange(max_idx2.shape[0]). + repeat_interleave(max_idx2.shape[1]), + max_idx2.flatten()].squeeze(1) + max_v_feat = video_feat[torch.arange(max_idx1.shape[0]). + repeat_interleave(max_idx1.shape[1]), + max_idx1.flatten()].squeeze(1) + + t_feat = text_feat.reshape(-1, text_feat.shape[-1]) + t_mask = text_mask.flatten().type(torch.bool) + v_feat = video_feat.reshape(-1, video_feat.shape[-1]) + v_mask = video_mask.flatten().type(torch.bool) + t_feat = t_feat[t_mask] + v_feat = v_feat[v_mask] + max_t_feat = max_t_feat[v_mask] + max_v_feat = max_v_feat[t_mask] + text_weight = text_weight.flatten()[t_mask] + video_weight = video_weight.flatten()[v_mask] + + z_a_norm = (t_feat - t_feat.mean(0)) / t_feat.std(0) # (BxN_t)xD + z_b_norm = (max_v_feat - max_v_feat.mean(0)) / max_v_feat.std( + 0) # (BxN_t)xD + + x_a_norm = (v_feat - v_feat.mean(0)) / v_feat.std(0) # (BxN_v)xD + x_b_norm = (max_t_feat - max_t_feat.mean(0)) / max_t_feat.std( + 0) # (BxN_v)xD + + # cross-correlation matrix + N, D = z_a_norm.shape + B = text_feat.shape[0] + c1 = torch.einsum('acd,a->cd', + torch.einsum('ac,ad->acd', z_a_norm, z_b_norm), + text_weight) / B # DxD + c2 = torch.einsum('acd,a->cd', + torch.einsum('ac,ad->acd', x_a_norm, x_b_norm), + video_weight) / B # DxD + c = (c1 + c2) / 2.0 + # loss + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = c.flatten()[1:].view(D - 1, D + 1)[:, :-1].pow_(2).sum() + cdcr_loss = ( + on_diag * self.config.cdcr_alpha1 + + off_diag * self.config.cdcr_alpha2) + return retrieve_logits, retrieve_logits.T, cdcr_loss + else: + return retrieve_logits, retrieve_logits.T + + def _loose_similarity(self, + sequence_output, + visual_output, + attention_mask, + video_mask, + sim_header='seqTransf'): + sequence_output, visual_output = sequence_output.contiguous( + ), visual_output.contiguous() + + visual_output = self.agg_video_feat(visual_output, video_mask, + sim_header) + + if self.training: # batch merge here + visual_output = allgather(visual_output, self.config) + attention_mask = allgather(attention_mask, self.config) + video_mask = allgather(video_mask, self.config) + sequence_output = allgather(sequence_output, self.config) + torch.distributed.barrier() # force sync + + return self.wti_interaction(sequence_output, visual_output, + attention_mask, video_mask) + + def get_similarity_logits(self, + sequence_output, + visual_output, + attention_mask, + video_mask, + shaped=False, + loose_type=False): + if shaped is False: + attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + + if loose_type: + assert self.sim_header in ['meanP', 'seqLSTM', 'seqTransf'] + + if self.training: + retrieve_logits1, retrieve_logits2, barlow_loss = self._loose_similarity( + sequence_output, + visual_output, + attention_mask, + video_mask, + sim_header=self.sim_header) + return retrieve_logits1, retrieve_logits2, barlow_loss + else: + retrieve_logits1, retrieve_logits2 = self._loose_similarity( + sequence_output, + visual_output, + attention_mask, + video_mask, + sim_header=self.sim_header) + return retrieve_logits1, retrieve_logits2 + else: + raise NotImplementedError + + @property + def dtype(self): + """ + :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + try: + return next(self.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + def find_tensor_attributes(module: nn.Module): + tuples = [(k, v) for k, v in module.__dict__.items() + if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + def init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, LayerNorm): + if 'beta' in dir(module) and 'gamma' in dir(module): + module.beta.data.zero_() + module.gamma.data.fill_(1.0) + else: + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() diff --git a/modelscope/models/multi_modal/mmr/models/module_clip.py b/modelscope/models/multi_modal/mmr/models/module_clip.py new file mode 100644 index 00000000..53501720 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/module_clip.py @@ -0,0 +1,527 @@ +# The implementation is adopated from the CLIP4Clip implementation, +# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip + +import hashlib +import os +import urllib +import warnings +from collections import OrderedDict +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch import nn +from tqdm import tqdm + +_MODELS = {} +_PT_NAME = {'ViT-B/16': 'ViT-B-16.pt'} + + +def available_models(): + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super(Bottleneck, self).__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super(AttentionPool2d, self).__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super(ModifiedResNet, self).__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, d_model: int, n_head: int, attn_mask=None): + super(ResidualAttentionBlock, self).__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + attn_mask_ = self.attn_mask + if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): + attn_mask_ = self.attn_mask(x.size(0)) # LND + + attn_mask_ = attn_mask_.to( + dtype=x.dtype, device=x.device) if attn_mask_ is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] + + def forward(self, x): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask=None, + use_gc=0): + super(Transformer, self).__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + self.use_gc = use_gc + + def forward(self, x: torch.Tensor): + if self.use_gc > 0: + for blk in self.resblocks: + x = checkpoint.checkpoint(blk, x) + return x + else: + return self.resblocks(x) + + +class VisualTransformer(nn.Module): + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + linear_patch: str = '2d', + use_gc: int = 0): + super(VisualTransformer, self).__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads, use_gc=use_gc) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + # For 3D + assert linear_patch in ['2d', '3d'] + self.linear_patch = linear_patch + if self.linear_patch == '3d': + self.conv2 = nn.Conv3d( + in_channels=3, + out_channels=width, + kernel_size=(3, patch_size, patch_size), + stride=(1, patch_size, patch_size), + padding=(1, 0, 0), + bias=False) + + def forward(self, x: torch.Tensor, video_frame=-1): + + if self.linear_patch == '3d': + assert video_frame != -1 + x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], + x.shape[-1]) + x_3d = x_3d.permute(0, 2, 1, 3, 4) + x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid] + x_3d = x_3d.permute(0, 2, 1, 3, + 4) # shape = [*, frame, width, grid, grid] + x = x_3d.reshape( + -1, x_3d.shape[-3], x_3d.shape[-2], + x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid] + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + _x = self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([_x, x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + return x + + +class CLIP(nn.Module): + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + # vision linear of patch + linear_patch: str = '2d', + use_gc: int = 0): + super(CLIP, self).__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + linear_patch=linear_patch, + use_gc=use_gc) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([])) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, self.visual.layer2, self.visual.layer3, + self.visual.layer4 + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith('bn3.weight'): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self, context_length): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.zeros(context_length, context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image, return_hidden=False): + hidden = self.visual(image.type(self.dtype)) + hidden = self.visual.ln_post(hidden) @ self.visual.proj + + x = hidden[:, 0, :] + + if return_hidden: + return x, hidden + + return x + + def encode_text(self, text, return_hidden=False, prompt=None): + x = self.token_embedding(text).type( + self.dtype) # [batch_size, n_ctx, d_model] + if prompt: + x = prompt(x) + + pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) + x = x + pos_emd + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + hidden = self.ln_final(x).type(self.dtype) @ self.text_projection + + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] + + if return_hidden: + return x, hidden + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(lay): + # l = lay + if isinstance(lay, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): + lay.weight.data = lay.weight.data.half() + if lay.bias is not None: + lay.bias.data = lay.bias.data.half() + + if isinstance(lay, nn.MultiheadAttention): + for attr in [ + *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], + 'in_proj_bias', 'bias_k', 'bias_v' + ]: + tensor = getattr(lay, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ['text_projection', 'proj']: + if hasattr(lay, name): + attr = getattr(lay, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) diff --git a/modelscope/models/multi_modal/mmr/models/module_cross.py b/modelscope/models/multi_modal/mmr/models/module_cross.py new file mode 100644 index 00000000..b958d5bc --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/module_cross.py @@ -0,0 +1,103 @@ +# The implementation is adopated from the CLIP4Clip implementation, +# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip + +from __future__ import absolute_import, division, print_function +import logging +from collections import OrderedDict + +import json +import torch +from torch import nn + +from .until_module import ACT2FN, LayerNorm + +logger = logging.getLogger(__name__) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, d_model: int, n_head: int): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.n_head = n_head + + def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): + attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] + + def forward(self, para_tuple: tuple): + # x: torch.Tensor, attn_mask: torch.Tensor + x, attn_mask = para_tuple + x = x + self.attention(self.ln_1(x), attn_mask) + x = x + self.mlp(self.ln_2(x)) + return (x, attn_mask) + + +class Transformer(nn.Module): + + def __init__(self, width: int, layers: int, heads: int): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential( + *[ResidualAttentionBlock(width, heads) for _ in range(layers)]) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): + return self.resblocks((x, attn_mask))[0] + + +class CrossEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super(CrossEmbeddings, self).__init__() + + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, concat_embeddings, concat_type=None): + + _, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=concat_embeddings.device) + position_ids = position_ids.unsqueeze(0).expand( + concat_embeddings.size(0), -1) + + position_embeddings = self.position_embeddings(position_ids) + + embeddings = concat_embeddings + position_embeddings # + token_type_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +class CrossPooler(nn.Module): + + def __init__(self, config): + super(CrossPooler, self).__init__() + self.ln_pool = LayerNorm(config.hidden_size) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = QuickGELU() + + def forward(self, hidden_states, hidden_mask): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + hidden_states = self.ln_pool(hidden_states) + pooled_output = hidden_states[:, 0] + pooled_output = self.dense(pooled_output) + pooled_output = self.activation(pooled_output) + return pooled_output diff --git a/modelscope/models/multi_modal/mmr/models/tokenization_clip.py b/modelscope/models/multi_modal/mmr/models/tokenization_clip.py new file mode 100644 index 00000000..97ee7156 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/tokenization_clip.py @@ -0,0 +1,161 @@ +# The implementation is adopted from the CLIP4Clip implementation, +# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip + +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, model_dir): + bpe_path = '{}/bpe_simple_vocab_16e6.txt.gz'.format(model_dir) + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + self.vocab = self.encoder + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + def tokenize(self, text): + tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + tokens.extend( + bpe_token for bpe_token in self.bpe(token).split(' ')) + return tokens + + def convert_tokens_to_ids(self, tokens): + return [self.encoder[bpe_token] for bpe_token in tokens] diff --git a/modelscope/models/multi_modal/mmr/models/until_module.py b/modelscope/models/multi_modal/mmr/models/until_module.py new file mode 100644 index 00000000..24e886b0 --- /dev/null +++ b/modelscope/models/multi_modal/mmr/models/until_module.py @@ -0,0 +1,120 @@ +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +import logging +import math + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +logger = logging.getLogger(__name__) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish} + + +class LayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class CrossEn(nn.Module): + + def __init__(self, config=None): + super(CrossEn, self).__init__() + + def forward(self, sim_matrix): + logpt = F.log_softmax(sim_matrix, dim=-1) + logpt = torch.diag(logpt) + nce_loss = -logpt + sim_loss = nce_loss.mean() + return sim_loss + + +class AllGather(torch.autograd.Function): + """An autograd function that performs allgather on a tensor.""" + + @staticmethod + def forward(ctx, tensor, args): + if args.world_size == 1: + ctx.rank = args.local_rank + ctx.batch_size = tensor.shape[0] + return tensor + else: + output = [torch.empty_like(tensor) for _ in range(args.world_size)] + torch.distributed.all_gather(output, tensor) + ctx.rank = args.local_rank + ctx.batch_size = tensor.shape[0] + return torch.cat(output, dim=0) + + @staticmethod + def backward(ctx, grad_output): + return ( + grad_output[ctx.batch_size * ctx.rank:ctx.batch_size + * (ctx.rank + 1)], + None, + ) + + +class AllGather2(torch.autograd.Function): + """An autograd function that performs allgather on a tensor.""" + # https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 + @staticmethod + def forward(ctx, tensor, args): + if args.world_size == 1: + ctx.rank = args.local_rank + ctx.batch_size = tensor.shape[0] + return tensor + else: + output = [torch.empty_like(tensor) for _ in range(args.world_size)] + torch.distributed.all_gather(output, tensor) + ctx.rank = args.local_rank + ctx.batch_size = tensor.shape[0] + return torch.cat(output, dim=0) + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + torch.distributed.all_reduce( + grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + return (grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1) + * ctx.batch_size], None) diff --git a/modelscope/models/multi_modal/mplug/__init__.py b/modelscope/models/multi_modal/mplug/__init__.py new file mode 100644 index 00000000..955c87e2 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_mplug import MPlugConfig +from .modeling_mplug import CONFIG_NAME, MPlug diff --git a/modelscope/models/multi_modal/mplug/clip/__init__.py b/modelscope/models/multi_modal/mplug/clip/__init__.py new file mode 100644 index 00000000..e6007a04 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/clip/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .clip import load_from_config diff --git a/modelscope/models/multi_modal/mplug/clip/clip.py b/modelscope/models/multi_modal/mplug/clip/clip.py new file mode 100644 index 00000000..aa56e39b --- /dev/null +++ b/modelscope/models/multi_modal/mplug/clip/clip.py @@ -0,0 +1,461 @@ +# Copyright 2021 The OpenAI CLIP Authors. All rights reserved. + +from collections import OrderedDict +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch import nn + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + use_grad_ckp: bool = True): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + self.use_grad_ckp = use_grad_ckp + + def forward(self, x: torch.Tensor): + if self.use_grad_ckp: + for each_block in self.resblocks: + x = checkpoint.checkpoint(each_block, x) + return x + else: + return self.resblocks(x) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, + 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + if self.training: + dropout = 0.1 + else: + dropout = 0.0 + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=dropout, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, + layers, + output_dim, + heads, + input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d( + width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, skip_last_layer=False): + + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + if not skip_last_layer: + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x) + return ret.type(orig_type) + + +class VisualTransformer(nn.Module): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.heads = heads + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, + x: torch.Tensor, + skip_last_layer=False, + text_embedding=None, + text_mask=None): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + cls_emb = self.class_embedding.to(x.dtype) + x_zeros = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([cls_emb + x_zeros, x], + dim=1) # shape = [*, grid ** 2 + 1, width] + + x = x + self.positional_embedding.to(x.dtype)[:x.size(1), :] + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + x = self.transformer(x) + + x = x.permute(1, 0, 2) # LND -> NLD + + if skip_last_layer: + x = self.ln_post(x) + # x = x @ self.proj + else: + x = x @ self.proj + return x + + +class CLIP(nn.Module): + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([])) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, self.visual.layer2, self.visual.layer3, + self.visual.layer4 + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith('bn3.weight'): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type( + self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def load_from_config(config): + return CLIP(config.clip_embed_dim, config.clip_image_resolution, + config.clip_vision_layers, config.clip_vision_width, + config.clip_vision_patch_size, config.clip_context_length, + config.clip_vocab_size, config.clip_transformer_width, + config.clip_transformer_heads, config.clip_transformer_layers) diff --git a/modelscope/models/multi_modal/mplug/configuration_mplug.py b/modelscope/models/multi_modal/mplug/configuration_mplug.py new file mode 100644 index 00000000..914678c5 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/configuration_mplug.py @@ -0,0 +1,116 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MPLUG model configuration """ +import os +from typing import Any, Dict, Union + +import yaml +from transformers import PretrainedConfig +from transformers.utils import logging + +from modelscope.utils.constant import Tasks + +logger = logging.get_logger(__name__) + + +class MPlugConfig(PretrainedConfig): + + model_type = 'mplug' + + def __init__( + self, + task=Tasks.visual_question_answering, + bert_config='config_bert.json', + image_res=504, + batch_size_train=128, + vision_width=1024, + distill=True, + clip_name='ViT-L-14', # ViT-B-16 | ViT-L-14 + batch_size_test=64, + k_test=128, + alpha=0.4, + warm_up=True, + eos='[SEP]', + optimizer=None, + schedular=None, + min_length=1, + max_length=10, + beam_size=5, + add_ocr=False, + add_object=False, + text_encoder='bert-base-uncased', + text_decoder='bert-base-uncased', + # clip + clip_embed_dim=768, + clip_image_resolution=224, + clip_vision_layers=24, + clip_vision_width=1024, + clip_vision_patch_size=14, + clip_context_length=77, + clip_vocab_size=49408, + clip_transformer_width=768, + clip_transformer_heads=12, + clip_transformer_layers=12, + # retrieval + queue_size=65536, + embed_dim=256, + temp=0.07, + **kwargs): + + super().__init__(**kwargs) + self.task = task + self.bert_config = bert_config + self.image_res = image_res + self.batch_size_train = batch_size_train + self.vision_width = vision_width + self.distill = distill + self.clip_name = clip_name + self.batch_size_test = batch_size_test + self.k_test = k_test + self.alpha = alpha + self.warm_up = warm_up + self.eos = eos + self.optimizer = optimizer + self.schedular = schedular + self.min_length = min_length + self.max_length = max_length + self.beam_size = beam_size + self.add_ocr = add_ocr + self.add_object = add_object + self.text_encoder = text_encoder + self.text_decoder = text_decoder + # clip + self.clip_embed_dim = clip_embed_dim + self.clip_image_resolution = clip_image_resolution + self.clip_vision_layers = clip_vision_layers + self.clip_vision_width = clip_vision_width + self.clip_vision_patch_size = clip_vision_patch_size + self.clip_context_length = clip_context_length + self.clip_vocab_size = clip_vocab_size + self.clip_transformer_width = clip_transformer_width + self.clip_transformer_heads = clip_transformer_heads + self.clip_transformer_layers = clip_transformer_layers + # retrieval + self.queue_size = queue_size + self.embed_dim = embed_dim + self.temp = temp + + @classmethod + def from_yaml_file(cls, yaml_file: Union[str, + os.PathLike]) -> Dict[str, Any]: + with open(yaml_file, 'r') as reader: + config_dict = yaml.load(reader, Loader=yaml.Loader) + return cls(**config_dict) diff --git a/modelscope/models/multi_modal/mplug/modeling_mplug.py b/modelscope/models/multi_modal/mplug/modeling_mplug.py new file mode 100755 index 00000000..ec491f1d --- /dev/null +++ b/modelscope/models/multi_modal/mplug/modeling_mplug.py @@ -0,0 +1,2483 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MPLUG model. """ + +import math +import os +from typing import Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from torch import Tensor, device, nn +from torch.nn import CrossEntropyLoss +from transformers import BertConfig, BertTokenizer +from transformers.activations import ACT2FN +from transformers.file_utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions) +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.utils import logging + +from modelscope.models.multi_modal.mplug.configuration_mplug import MPlugConfig +from modelscope.models.multi_modal.mplug.predictor import TextGenerator +from modelscope.utils.constant import ModelFile + +transformers.logging.set_verbosity_error() + +logger = logging.get_logger(__name__) + +CONFIG_NAME = 'config.yaml' + +_CONFIG_FOR_DOC = 'BertConfig' +_TOKENIZER_FOR_DOC = 'BertTokenizer' + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' + 'https://www.tensorflow.org/install/ for installation instructions.' + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info('Converting TensorFlow checkpoint from {}'.format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info('Loading TF weight {} with shape {}'.format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in [ + 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', + 'AdamWeightDecayOptimizer_1', 'global_step' + ] for n in name): + logger.info('Skipping {}'.format('/'.join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + scope_names = re.split(r'_(\d+)', m_name) + else: + scope_names = [m_name] + if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif scope_names[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'squad': + pointer = getattr(pointer, 'classifier') + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info('Skipping {}'.format('/'.join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info('Initialize PyTorch weight {}'.format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def clamp_inf(tensor): + if tensor.dtype == torch.float16 and torch.isinf(tensor).any(): + clamp_value = torch.finfo(tensor.dtype).max - 1000 + tensor = torch.clamp(tensor, min=-clamp_value, max=clamp_value) + return tensor + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, + past_key_values_length:seq_length + + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = clamp_inf(attention_scores) + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = clamp_inf(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = clamp_inf(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FusionLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.stride_layer = getattr(self.config, 'stride_layer', 100) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + + self.crossattention = BertAttention(config, is_cross_attention=True) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + layer_nums=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + if layer_nums == 0 or layer_nums % self.stride_layer != 0: + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + assert encoder_hidden_states is not None, 'encoder_hidden_states must be given for cross-attention layers' + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + elif layer_nums != 0 and layer_nums % self.stride_layer == 0: + self_attention_outputs = self.attention( + torch.cat([encoder_hidden_states, hidden_states], 1), + torch.cat([encoder_attention_mask, attention_mask], 3), + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value[0], present_key_value[1]) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + + self.has_cross_attention = getattr(self.config, 'add_cross_attention', + False) + if self.has_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=True) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if self.has_cross_attention: + assert encoder_hidden_states is not None, 'encoder_hidden_states must be given for cross-attention layers' + + if type(encoder_hidden_states) == list: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states[(self.layer_num + - self.config.fusion_layer) + % len(encoder_hidden_states)], + encoder_attention_mask[(self.layer_num + - self.config.fusion_layer) + % len(encoder_hidden_states)], + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + + else: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1: + -1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value[0], present_key_value[1]) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class FusionEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [FusionLayer(config, i) for i in range(config.num_hidden_layers)]) + self.start_layer = max(0, + config.num_hidden_layers - config.fusion_layers) + + def forward(self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + self.stride_layer = getattr(self.config, 'stride_layer', 100) + image_length = encoder_hidden_states.shape[1] + text_length = hidden_states.shape[1] + + for i in range(self.start_layer, len(self.layer)): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if getattr(self.config, 'gradient_checkpointing', + False) and self.training: + if use_cache: + logger.warn( + '`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting ' + '`use_cache=False`...') + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return tuple( + module(*inputs, past_key_value, output_attentions)) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + i - self.start_layer, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + i - self.start_layer, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if hidden_states.shape[1] == (image_length + text_length): + encoder_hidden_states_new, hidden_states = torch.split( + hidden_states, (image_length, text_length), 1) + encoder_hidden_states += encoder_hidden_states_new + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + return [encoder_hidden_states, hidden_states] + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + + def forward(self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(len(self.layer)): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if getattr(self.config, 'gradient_checkpointing', + False) and self.training: + if use_cache: + logger.warn( + '`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting ' + '`use_cache=False`...') + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return tuple( + module(*inputs, past_key_value, output_attentions)) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), hidden_states, + attention_mask, layer_head_mask, encoder_hidden_states, + encoder_attention_mask) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +BERT_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + Parameters: + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare Bert Model transformer outputting raw hidden-states without any specific head on top.', + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint='bert-base-uncased', + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat( + batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, + None, :, :] * attention_mask[:, + None, + None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states + (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values + (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape + :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds or encoder_embeds' + ) + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class FusionModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + self.encoder = FusionEncoder(config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + # tokenizer_class=_TOKENIZER_FOR_DOC, + processor_class=_TOKENIZER_FOR_DOC, + checkpoint='bert-base-uncased', + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat( + batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, + None, :, :] * attention_mask[:, + None, + None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False): + r""" + encoder_hidden_states + (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values + (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape + :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds or encoder_embeds' + ) + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + encoder_hidden_states, sequence_output = encoder_outputs + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return [encoder_hidden_states, sequence_output] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, + BERT_START_DOCSTRING) +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @replace_return_docstrings( + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=True, + reduction='mean', + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + encoder_hidden_states + (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values + (:obj:`tuple(tuple(torch.FloatTensor))` of length + :obj:`config.n_layers` with each tuple having 4 tensors of shape + :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, + dim=-1) + loss_distill = (loss_distill * (labels != -100)).sum(1) + lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + +class BertPrefixModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint='bert-base-uncased', + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=True, + reduction='mean', + soft_labels=None, + alpha=0, + return_logits=False, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, + dim=-1) + loss_distill = loss_distill[labels != -100].mean() + lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class MPlug(PreTrainedModel): + config_class = MPlugConfig + + def __init__(self, config): + super().__init__(config) + self.config = config + self.tokenizer = BertTokenizer.from_pretrained( + os.path.join(config.model_dir, ModelFile.VOCAB_FILE)) + self.module_setting(config) + self.visual_encoder = self._initialize_clip(config) + self.text_encoder = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder = FusionModel( + self.config_fusion, add_pooling_layer=False) + + @classmethod + def from_pretrained(cls, model_dir, load_checkpoint=True): + from modelscope.utils.constant import Tasks + + task_mapping = { + Tasks.visual_question_answering: MPlugForVisualQuestionAnswering, + Tasks.image_captioning: MPlugForImageCaption, + Tasks.image_text_retrieval: MPlugForImageTextRetrieval, + } + config = cls.config_class.from_yaml_file( + os.path.join(model_dir, CONFIG_NAME)) + config.model_dir = model_dir + model = task_mapping[config.task](config) + if load_checkpoint: + checkpoint_path = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if 'model' in checkpoint: + checkpoint = checkpoint['model'] + if 'module' in checkpoint: + checkpoint = checkpoint['module'] + checkpoint = { + k.replace('model.', ''): v + for k, v in checkpoint.items() + } + + msg = model.load_state_dict(checkpoint, strict=False) + print('load checkpoint from %s' % checkpoint_path) + print(msg) + return model + + @staticmethod + def _initialize_clip(config, num_patches=240): + + def resize_pos_embed(posemb, posemb_new): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + ntok_new = posemb_new.shape[1] + + posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + ntok_new -= 1 + + gs_old = int(math.sqrt(len(posemb_grid))) + gs_new = int(math.sqrt(ntok_new)) + # _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, + -1).permute(0, 3, 1, 2) + orig = posemb_grid.dtype + posemb_grid = F.interpolate( + posemb_grid.float(), size=(gs_new, gs_new), mode='bilinear') + posemb_grid = posemb_grid.to(orig) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape( + 1, gs_new * gs_new, -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + from .clip import clip + clip_model = clip.load_from_config(config) + if 'ViT-B-16' in config.clip_name: + num_patches = int(config.image_res * config.image_res / (16 * 16)) + pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768).float()) + else: + num_patches = int(config.image_res * config.image_res / (14 * 14)) + pos_embed = nn.Parameter( + torch.zeros(num_patches + 1, 1024).float()) + pos_embed.weight = resize_pos_embed( + clip_model.visual.positional_embedding.unsqueeze(0), + pos_embed.unsqueeze(0)) + clip_model.visual.positional_embedding = pos_embed + return clip_model + + def init_distill(self, config): + self.distill = config.distill + if self.distill: + self.visual_encoder_m = self._initialize_clip(config) + self.text_encoder_m = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder_m = FusionModel( + self.config_fusion, add_pooling_layer=False) + self.text_decoder_m = BertLMHeadModel(self.config_decoder) + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.text_decoder, self.text_decoder_m], + ] + if self.config_encoder.hidden_size != config.vision_width: + self.visn_fc_m = nn.Linear(config.vision_width, + self.config_encoder.hidden_size) + self.visn_layer_norm_m = nn.LayerNorm( + self.config_encoder.hidden_size, eps=1e-12) + self.dropout_m = nn.Dropout( + self.config_encoder.hidden_dropout_prob) + self.model_pairs.extend( + [[self.visn_fc, self.visn_fc_m], + [self.visn_layer_norm, self.visn_layer_norm_m]]) + self.copy_params() + self.momentum = 0.995 + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def module_setting(self, config): + bert_config_path = os.path.join(config.model_dir, config.bert_config) + self.config_encoder = BertConfig.from_json_file(bert_config_path) + self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers + self.config_fusion = BertConfig.from_json_file(bert_config_path) + self.config_decoder = BertConfig.from_json_file(bert_config_path) + self.config_decoder.add_cross_attention = True + self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers + self.large = False + if self.config_encoder.hidden_size != config.vision_width: + self.visn_fc = nn.Linear(config.vision_width, + self.config_encoder.hidden_size) + self.visn_layer_norm = nn.LayerNorm( + self.config_encoder.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(self.config_encoder.hidden_dropout_prob) + self.large = True + + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data = param_m.data * self.momentum + param.data * ( + 1. - self.momentum) + + def generation(self, question_states, question_atts, out_size=1): + encoder_inputs = [question_states, question_atts] + topk_ids, topk_scores = self.beam_generator.translate_batch( + encoder_inputs, out_size=out_size) + return topk_ids, topk_scores + + @staticmethod + def _tile(x, dim, n_tile): + import numpy as np + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate( + [init_dim * np.arange(n_tile) + i for i in range(init_dim)])) + return torch.index_select(x, dim, order_index.to(x.device)) + + +class MPlugForVisualQuestionAnswering(MPlug): + + def __init__(self, config): + super().__init__(config) + self.text_decoder = BertLMHeadModel(self.config_decoder) + self.beam_generator = TextGenerator(config, self.text_decoder) + self.init_distill(config) + + def forward(self, + image, + question, + answer=None, + alpha=0, + k=None, + weights=None, + train=True): + image = image.to(dtype=next(self.parameters()).dtype) + image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) + if self.large: + image_embeds = self.dropout( + self.visn_layer_norm(self.visn_fc(image_embeds))) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + if train: + ''' + k: number of answers for each question + weights: weight for each answer + ''' + answer_targets = answer.input_ids.masked_fill( + answer.input_ids == self.tokenizer.pad_token_id, -100) + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False) + + image_output, question_output = fusion_output + + question_output = torch.cat([image_output, question_output], 1) + merge_text_attention = torch.cat( + [image_atts, question.attention_mask], 1) + + if k is None: + k = [1] * question_output.shape[0] + question_states = [] + question_atts = [] + for b, n in enumerate(k): + question_states += [question_output[b]] * n + question_atts += [merge_text_attention[b]] * n + question_states = torch.stack(question_states, 0) + question_atts = torch.stack(question_atts, 0) + + if self.distill: + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m.visual( + image, skip_last_layer=True) + if self.large: + image_embeds_m = self.dropout_m( + self.visn_layer_norm_m( + self.visn_fc_m(image_embeds_m))) + text_output_m = self.text_encoder_m( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds_m = text_output_m.last_hidden_state + fusion_output_m = self.fusion_encoder_m( + encoder_embeds=text_embeds_m, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds_m, + encoder_attention_mask=image_atts, + return_dict=False) + + image_output_m, question_output_m = fusion_output_m + question_output_m = torch.cat( + [image_output_m, question_output_m], 1) + + question_states_m = [] + for b, n in enumerate(k): + question_states_m += [question_output_m[b]] * n + question_states_m = torch.stack(question_states_m, 0) + + logits_m = self.text_decoder_m( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_states_m, + encoder_attention_mask=question_atts, + return_logits=True, + ) + + answer_output = self.text_decoder( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=answer_targets, + return_dict=True, + soft_labels=F.softmax(logits_m, dim=-1), + reduction='none', + ) + else: + answer_output = self.text_decoder( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=answer_targets, + return_dict=True, + reduction='none', + ) + if weights is None: + weights = 1 + loss = weights * answer_output.loss + loss = loss.sum() / image.size(0) + + return loss + + else: + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False) + image_output, question_output = fusion_output + question_output = torch.cat([image_output, question_output], 1) + merge_text_attention = torch.cat( + [image_atts, question.attention_mask], 1) + topk_ids, topk_probs = self.generation(question_output, + merge_text_attention) + return topk_ids, topk_probs + + +class MPlugForImageCaption(MPlug): + + def __init__(self, config): + super().__init__(config) + self.text_decoder = BertPrefixModel(self.config_decoder) + self.beam_generator = TextGenerator(config, self.text_decoder) + + def beam_search(self, + image, + question, + answer=None, + train=True, + out_size=5): + image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) + if self.large: + image_embeds = self.dropout( + self.visn_layer_norm(self.visn_fc(image_embeds))) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image.device) + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False) + image_output, question_output = fusion_output + question_output = torch.cat([image_output, question_output], 1) + merge_text_attention = torch.cat([image_atts, question.attention_mask], + 1) + topk_ids, topk_probs = self.generation( + question_output, merge_text_attention, out_size=out_size) + return topk_ids, topk_probs + + def forward(self, + image, + question, + answer=None, + train=True, + out_size=5, + scst=False): + if (scst): + return self.beam_search( + image, question, answer, train=True, out_size=out_size) + image = image.to(dtype=next(self.parameters()).dtype) + image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) + if self.large: + image_embeds = self.dropout( + self.visn_layer_norm(self.visn_fc(image_embeds))) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + if train: + answer_targets = answer.input_ids.masked_fill( + answer.input_ids == self.tokenizer.pad_token_id, -100) + answer_output = self.text_decoder( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + labels=answer_targets, + return_dict=True, + reduction='none') + loss = answer_output.loss + + return loss + else: + topk_ids, topk_probs = self.generation(image_embeds, image_atts) + return topk_ids, topk_probs + + +class MPlugForImageTextRetrieval(MPlug): + + def __init__(self, config): + super().__init__(config) + self.embed_dim = config.embed_dim + self.temp = nn.Parameter(torch.ones([]) * config.temp) + self.queue_size = config.queue_size + self.momentum = config.momentum + self.alpha = config.alpha + + self.queue_size = config.queue_size + self.text_width = self.config_encoder.hidden_size + self.embed_dim = config.embed_dim + + self.vision_proj = nn.Linear(self.text_width, self.embed_dim) + self.text_proj = nn.Linear(self.text_width, self.embed_dim) + self.itm_head = nn.Linear(self.text_width, 2) + + self.register_buffer('image_queue', + torch.randn(self.embed_dim, self.queue_size)) + self.register_buffer('text_queue', + torch.randn(self.embed_dim, self.queue_size)) + self.register_buffer('idx_queue', torch.full((1, self.queue_size), + -100)) + self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) + + self.image_queue = F.normalize(self.image_queue, dim=0) + self.text_queue = F.normalize(self.text_queue, dim=0) + self.init_distill(config) + + def init_distill(self, config): + self.distill = config.distill + if self.distill: + self.visual_encoder_m = self._initialize_clip(config) + self.text_encoder_m = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder_m = FusionModel( + self.config_fusion, add_pooling_layer=False) + self.vision_proj_m = nn.Linear(self.text_width, self.embed_dim) + self.text_proj_m = nn.Linear(self.text_width, self.embed_dim) + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.text_proj, self.text_proj_m], + [self.vision_proj, self.vision_proj_m], + ] + if self.config_encoder.hidden_size != config.vision_width: + self.visn_fc_m = nn.Linear(config.vision_width, + self.config_encoder.hidden_size) + self.visn_layer_norm_m = nn.LayerNorm( + self.config_encoder.hidden_size, eps=1e-12) + self.dropout_m = nn.Dropout( + self.config_encoder.hidden_dropout_prob) + self.model_pairs.extend( + [[self.visn_fc, self.visn_fc_m], + [self.visn_layer_norm, self.visn_layer_norm_m]]) + self.copy_params() + self.momentum = 0.995 + + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idx): + + def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + if not torch.distributed.is_initialized(): + return tensor + tensors_gather = [ + torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather( + tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + idxs = concat_all_gather(idx) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + # assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr:ptr + batch_size] = image_feats.T + self.text_queue[:, ptr:ptr + batch_size] = text_feats.T + self.idx_queue[:, ptr:ptr + batch_size] = idxs.T + ptr = (ptr + batch_size) % self.queue_size # move pointer + + self.queue_ptr[0] = ptr + + def forward(self, image, text, idx=None, train=True): + if train: + image_embeds = self.visual_encoder.visual( + image, skip_last_layer=True) + if self.large: + image_embeds = self.dropout( + self.visn_layer_norm(self.visn_fc(image_embeds))) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + image_feat = F.normalize( + self.vision_proj(image_embeds[:, 0, :]), dim=-1) + text_output = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize( + self.text_proj(text_embeds[:, 0, :]), dim=-1) + + idx = idx.view(-1, 1) + idx_all = torch.cat( + [idx.t(), self.idx_queue.clone().detach()], dim=1) + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) + + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m.visual( + image, skip_last_layer=True) + if self.large: + image_embeds_m = self.dropout_m( + self.visn_layer_norm_m(self.visn_fc_m(image_embeds_m))) + image_feat_m = F.normalize( + self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1) + image_feat_all = torch.cat( + [image_feat_m.t(), + self.image_queue.clone().detach()], + dim=1) + text_output_m = self.text_encoder_m( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True) + text_feat_m = F.normalize( + self.text_proj_m(text_output_m.last_hidden_state[:, 0, :]), + dim=-1) + text_feat_all = torch.cat( + [text_feat_m.t(), + self.text_queue.clone().detach()], dim=1) + + if self.distill: + sim_i2t_m = image_feat_m @ text_feat_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_all / self.temp + + sim_i2t_targets = self.alpha * F.softmax( + sim_i2t_m, dim=1) + (1 - self.alpha) * sim_targets + sim_t2i_targets = self.alpha * F.softmax( + sim_t2i_m, dim=1) + (1 - self.alpha) * sim_targets + + sim_i2t = image_feat @ text_feat_all / self.temp + sim_t2i = text_feat @ image_feat_all / self.temp + + if self.distill: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, + dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, + dim=1).mean() + else: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean() + + loss_ita = (loss_i2t + loss_t2i) / 2 + + self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx) + + # forward the positve image-text pair + _, output_pos = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False, + ) + with torch.no_grad(): + bs = image.size(0) + weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1) + weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1) + + mask = torch.eq(idx, idx.T) + weights_i2t.masked_fill_(mask, 0) + weights_t2i.masked_fill_(mask, 0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_embeds_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_embeds_neg.append(text_embeds[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + text_embeds_neg = torch.stack(text_embeds_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg], + dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], + dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + _, output_neg = self.fusion_encoder( + encoder_embeds=text_embeds_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=False, + ) + + vl_embeddings = torch.cat( + [output_pos[:, 0, :], output_neg[:, 0, :]], dim=0) + vl_output = self.itm_head(vl_embeddings) + + ones_tmp = torch.ones(bs, dtype=torch.long) + zeros_tmp = torch.zeros(2 * bs, dtype=torch.long) + itm_labels = torch.cat([ones_tmp, zeros_tmp], + dim=0).to(image.device) + loss_itm = F.cross_entropy(vl_output, itm_labels) + + return loss_ita + loss_itm + else: + text_output = self.text_encoder( + text.input_ids, attention_mask=text.attention_mask) + text_feat = text_output.last_hidden_state + image_feat = self.visual_encoder.visual( + image, skip_last_layer=True) + image_feat = self.visn_layer_norm(self.visn_fc(image_feat)) + image_att = torch.ones( + image_feat.size()[:-1], + dtype=torch.long, + device=image_feat.device) + _, output = self.fusion_encoder( + encoder_embeds=text_feat, + attention_mask=text.attention_mask, + encoder_hidden_states=image_feat, + encoder_attention_mask=image_att, + return_dict=False, + ) + scores = self.itm_head(output[:, 0, :]) + scores = F.softmax(scores, dim=-1) + + return scores diff --git a/modelscope/models/multi_modal/mplug/predictor.py b/modelscope/models/multi_modal/mplug/predictor.py new file mode 100755 index 00000000..6375d1d7 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/predictor.py @@ -0,0 +1,551 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import torch +import torch.nn.functional as F + + +def build_predictor(args, tokenizer, symbols, model, logger=None): + scorer = None + + translator = TextGenerator( + args, model, tokenizer, symbols, global_scorer=scorer, logger=logger) + return translator + + +class TextGenerator(object): + """ + Uses a model to translate a batch of sentences. + + + Args: + model (:obj:`onmt.modules.NMTModel`): + NMT model to use for translation + fields (dict of Fields): data fields + beam_size (int): size of beam to use + n_best (int): number of translations produced + max_length (int): maximum length output to produce + global_scores (:obj:`GlobalScorer`): + object to rescore final translations + copy_attn (bool): use copy attention during translation + cuda (bool): use cuda + beam_trace (bool): trace beam search for debugging + logger(logging.Logger): logger. + """ + + def __init__(self, + args, + model, + vocab=None, + symbols=None, + global_scorer=None, + logger=None, + dump_beam=''): + self.alpha = 0.6 + + self.logger = logger + self.cuda = (torch.cuda.device_count() > 0) + + self.args = args + self.model = model + + self.vocab = vocab + self.symbols = symbols + self.start_token = 101 # ['[PAD]'] + self.end_token = 102 # ['[PAD]'] + + self.global_scorer = global_scorer + self.beam_size = args.beam_size + self.min_length = args.min_length + self.max_length = args.max_length + + self.dump_beam = dump_beam + + # for debugging + self.beam_trace = self.dump_beam != '' + self.beam_accum = None + + if self.beam_trace: + self.beam_accum = { + 'predicted_ids': [], + 'beam_parent_ids': [], + 'scores': [], + 'log_probs': [] + } + + def _build_target_tokens(self, pred): + tokens = [] + for tok in pred: + tok = int(tok) + tokens.append(tok) + if tokens[-1] == self.end_token: + tokens = tokens[:-1] + break + tokens = [t for t in tokens if t < len(self.vocab)] + tokens = self.vocab.DecodeIds(tokens).split(' ') + return tokens + + def translate_batch(self, encoder_inputs, do_sample=False, out_size=1): + """ + Translate a batch of sentences. + + Mostly a wrapper around :obj:`Beam`. + + Args: + batch (:obj:`Batch`): a batch from a dataset object + data (:obj:`Dataset`): the dataset object + fast (bool): enables fast beam search (may not support all features) + + Todo: + Shouldn't need the original dataset. + """ + if do_sample: + return self._fast_translate_batch( + encoder_inputs, + self.max_length, + min_length=self.min_length, + do_sample=do_sample, + out_size=out_size) + else: + with torch.no_grad(): + return self._fast_translate_batch( + encoder_inputs, + self.max_length, + min_length=self.min_length, + do_sample=do_sample, + out_size=out_size) + + def translate_batch_scst(self, + encoder_inputs, + do_sample=False, + out_size=1): + return self._fast_translate_batch( + encoder_inputs, + self.max_length, + min_length=self.min_length, + do_sample=do_sample, + out_size=out_size) + + def _fast_translate_batch(self, + encoder_inputs, + max_length, + min_length=0, + do_sample=False, + out_size=1): + + assert not self.dump_beam + if do_sample: + beam_size = 1 + else: + beam_size = self.beam_size + if len(encoder_inputs) == 3: + src_features, padding_mask, input_ids = encoder_inputs + elif len(encoder_inputs) == 2: + src_features, padding_mask = encoder_inputs + input_ids = None + + device = src_features.device + + # Tile states and memory beam_size times. + batch_size = src_features.size(0) + src_features = tile(src_features, beam_size, dim=0) + attention_mask = tile(padding_mask, beam_size, dim=0) + + batch_offset = torch.arange( + batch_size, dtype=torch.long, device=device) + beam_offset = torch.arange( + 0, + batch_size * beam_size, + step=beam_size, + dtype=torch.long, + device=device) + if input_ids is not None: + alive_seq = tile(input_ids, beam_size, dim=0) + else: + alive_seq = torch.full([batch_size * beam_size, 1], + self.start_token, + dtype=torch.long, + device=device) + + # Give full probability to the first beam on the first step. + topk_log_probs = ( + torch.tensor( + [0.0] + [float('-inf')] * (beam_size - 1), + device=device).repeat(batch_size)) + + # Structure that holds finished hypotheses. + hypotheses = [[] for _ in range(batch_size)] # noqa: F812 + + results = {} + results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812 + results['scores'] = [[] for _ in range(batch_size)] # noqa: F812 + results['gold_score'] = [0] * batch_size + results['batch'] = [] + + for step in range(max_length): + dec_feat_seq = self.model( + alive_seq, + encoder_hidden_states=src_features, + encoder_attention_mask=attention_mask, + return_dict=True, + reduction='none') + + dec_feat_seq = dec_feat_seq.logits[:, -1, :] + vocab_size = dec_feat_seq.size(-1) + log_probs = torch.log( + torch.softmax(dec_feat_seq.view(-1, vocab_size), dim=-1)) + if step < min_length: + log_probs[:, self.end_token] = -1e20 + alpha = self.alpha + if do_sample: + length_penalty = 1.0 + else: + length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha + + if do_sample: + _scores = log_probs / self.args.temperature + _scores = top_k_top_p_filtering( + _scores, + top_k=self.args.top_k, + top_p=self.args.top_p, + min_tokens_to_keep=1 + ) # (batch_size * num_beams, vocab_size) + # Sample 2 next words for each beam + # (so we have some spare tokens and match output of greedy beam search) + topk_ids = torch.multinomial( + F.softmax(_scores, dim=-1), + num_samples=1) # (batch_size * num_beams, 2) + # Compute next scores + _scores = F.log_softmax( + _scores, dim=1) # (batch_size * num_beams, vocab_size) + + _scores += topk_log_probs.view(-1).unsqueeze(1) + topk_scores = torch.gather( + _scores, -1, topk_ids) # (batch_size * num_beams, 2) + # log_probs += # (batch_size * num_beams, 2) + # Match shape of greedy beam search + topk_ids = topk_ids.view( + -1, beam_size) # (batch_size, 2 * num_beams) + topk_scores = topk_scores.view( + -1, beam_size) # (batch_size, 2 * num_beams) + else: + log_probs += topk_log_probs.view(-1).unsqueeze(1) + curr_scores = log_probs / length_penalty + + curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) + topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) + topk_log_probs = topk_scores * length_penalty + + # Resolve beam origin and true word ids. + # topk_beam_index = topk_ids.div(vocab_size) + topk_beam_index = torch.div( + topk_ids, vocab_size, rounding_mode='floor') + topk_ids = topk_ids.fmod(vocab_size) + + # Map beam_index to batch_index in the flat representation. + batch_index = ( + topk_beam_index + + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) + select_indices = batch_index.view(-1) + + # Append last prediction. + alive_seq = torch.cat([ + alive_seq.index_select(0, select_indices), + topk_ids.view(-1, 1) + ], -1) + + is_finished = topk_ids.eq(self.end_token) + if step + 1 == max_length: + is_finished.fill_(1) # self.end_token) + # End condition is top beam is finished. + end_condition = is_finished[:, 0].eq(1) # self.end_token) + # Save finished hypotheses. + if is_finished.any(): + predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) + for i in range(is_finished.size(0)): + b = batch_offset[i] + if end_condition[i]: + is_finished[i].fill_(1) # self.end_token) + finished_hyp = is_finished[i].nonzero().view(-1) + # Store finished hypotheses for this batch. + for j in finished_hyp: + hypotheses[b].append( + (topk_scores[i, j], predictions[i, j, 0:])) + # If the batch reached the end, save the n_best hypotheses. + if end_condition[i]: + best_hyp = sorted( + hypotheses[b], key=lambda x: x[0], reverse=True) + + for each in best_hyp[:beam_size]: + score, pred = each + results['scores'][b].append(score) + results['predictions'][b].append(pred) + non_finished = end_condition.eq(0).nonzero().view(-1) + # If all sentences are translated, no need to go further. + if len(non_finished) == 0: + break + # Remove finished batches for the next step. + topk_log_probs = topk_log_probs.index_select(0, non_finished) + batch_index = batch_index.index_select(0, non_finished) + batch_offset = batch_offset.index_select(0, non_finished) + alive_seq = predictions.index_select(0, non_finished) \ + .view(-1, alive_seq.size(-1)) + # Reorder states. + select_indices = batch_index.view(-1) + src_features = src_features.index_select(0, select_indices) + attention_mask = attention_mask.index_select(0, select_indices) + pred_ids = [] + scores = [] + # print (pred_ids, scores) + for each in results['scores']: + scores.append(each[:out_size]) + for each in results['predictions']: + pred_ids.append(each[:out_size]) + return pred_ids, scores + + def _generate_no_beam_search( + self, + input_ids, + cur_len, + max_length, + do_sample, + temperature, + top_k, + top_p, + repetition_penalty, + pad_token_id, + eos_token_ids, + batch_size, + ): + """ Generate sequences for each example without beam search (num_beams == 1). + All returned sequence are generated independantly. + """ + assert self.num_keep_best == 1, 'cannot generate >1 sentences in greedy search' + # current position / max lengths / length of generated sentences / unfinished sentences + unfinished_sents = [] + cur_unfinished = input_ids.new(batch_size).fill_(1) + + # log of scores for each sentence in the batch + logprobs = [] + + past = None + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation( + input_ids, past=past) + outputs = self(**model_inputs) + if cur_len == 1: + token_len = 2 + self.od_labels_len + next_token_idx = 1 + else: + assert cur_len > 1 + if not self._do_output_past(outputs): + token_len = cur_len + 1 + self.od_labels_len + next_token_idx = cur_len + else: + token_len = 2 + next_token_idx = 1 + assert outputs[0].shape[1] == token_len + + next_token_logits = outputs[0][:, next_token_idx, :] + + # if model has past, then set the past variable to speed up decoding + if self._do_output_past(outputs): + past = outputs[1] + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + for i in range(batch_size): + for previous_token in set(input_ids[i].tolist()): + # if score < 0 then repetition penalty has to multiplied + # to reduce the previous token probability + if next_token_logits[i, previous_token] < 0: + next_token_logits[ + i, previous_token] *= repetition_penalty + else: + next_token_logits[ + i, previous_token] /= repetition_penalty + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + # Top-p/top-k filtering + next_token_logits = top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p) + # Sample + next_token = torch.multinomial( + F.softmax(next_token_logits, dim=-1), + num_samples=1).squeeze(1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1) + + # Compute scores + _scores = F.log_softmax( + next_token_logits, dim=-1) # (batch_size, vocab_size) + _scores = torch.gather(_scores, -1, + next_token.unsqueeze(-1)) # (batch_size, 1) + logprobs.append(_scores) # (batch_size, 1) + unfinished_sents.append(cur_unfinished) + + # update generations and finished sentences + tokens_to_add = next_token * cur_unfinished + pad_token_id * ( + 1 - cur_unfinished) + input_ids = torch.cat( + [input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) + + for eos_token_id in eos_token_ids: + cur_unfinished = cur_unfinished.mul( + tokens_to_add.ne(eos_token_id).long()) + cur_len = cur_len + 1 + + # stop when there is a in each sentence, or if we exceed the maximul length + if cur_unfinished.max() == 0: + break + + # add eos_token_ids to unfinished sentences + if cur_len == max_length: + input_ids[:, -1].masked_fill_( + cur_unfinished.to(dtype=torch.bool), eos_token_ids[0]) + + logprobs = torch.cat(logprobs, dim=1) + unfinished_sents = torch.stack(unfinished_sents, dim=1).float() + sum_logprobs = (logprobs * unfinished_sents).sum(dim=1) + # return logprobs to keep consistent with beam search output + logprobs = sum_logprobs / unfinished_sents.sum(dim=1) + + # pad to the same length, otherwise DataParallel will give error + pad_len = max_length - input_ids.shape[1] + if pad_len > 0: + padding_ids = input_ids.new(batch_size, + pad_len).fill_(pad_token_id) + input_ids = torch.cat([input_ids, padding_ids], dim=1) + + # (batch_size, n_best, max_len), (batch_size, n_best) + return input_ids.unsqueeze(1), logprobs.unsqueeze(1) + + +def top_k_top_p_filtering(logits, + top_k=10, + top_p=1.0, + filter_value=-float('Inf'), + min_tokens_to_keep=1): + + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), + logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, + None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +class Translation(object): + """ + Container for a translated sentence. + + Attributes: + src (`LongTensor`): src word ids + src_raw ([str]): raw src words + + pred_sents ([[str]]): words from the n-best translations + pred_scores ([[float]]): log-probs of n-best translations + attns ([`FloatTensor`]) : attention dist for each translation + gold_sent ([str]): words from gold translation + gold_score ([float]): log-prob of gold translation + + """ + + def __init__(self, fname, src, src_raw, pred_sents, attn, pred_scores, + tgt_sent, gold_score): + self.fname = fname + self.src = src + self.src_raw = src_raw + self.pred_sents = pred_sents + self.attns = attn + self.pred_scores = pred_scores + self.gold_sent = tgt_sent + self.gold_score = gold_score + + def log(self, sent_number): + """ + Log translation. + """ + + output = '\nSENT {}: {}\n'.format(sent_number, self.src_raw) + + best_pred = self.pred_sents[0] + best_score = self.pred_scores[0] + pred_sent = ' '.join(best_pred) + output += 'PRED {}: {}\n'.format(sent_number, pred_sent) + output += 'PRED SCORE: {:.4f}\n'.format(best_score) + + if self.gold_sent is not None: + tgt_sent = ' '.join(self.gold_sent) + output += 'GOLD {}: {}\n'.format(sent_number, tgt_sent) + output += ('GOLD SCORE: {:.4f}\n'.format(self.gold_score)) + if len(self.pred_sents) > 1: + output += '\nBEST HYP:\n' + for score, sent in zip(self.pred_scores, self.pred_sents): + output += '[{:.4f}] {}\n'.format(score, sent) + + return output + + +def tile(x, count, dim=0): + """ + Tiles x on dimension dim count times. + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = x.view(batch, -1) \ + .transpose(0, 1) \ + .repeat(count, 1) \ + .transpose(0, 1) \ + .contiguous() \ + .view(*out_size) + if dim != 0: + x = x.permute(perm).contiguous() + return x diff --git a/modelscope/models/multi_modal/mplug_for_all_tasks.py b/modelscope/models/multi_modal/mplug_for_all_tasks.py new file mode 100644 index 00000000..7de8d291 --- /dev/null +++ b/modelscope/models/multi_modal/mplug_for_all_tasks.py @@ -0,0 +1,83 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Dict, List + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['MPlugForAllTasks'] + + +@MODELS.register_module( + Tasks.visual_question_answering, module_name=Models.mplug) +@MODELS.register_module(Tasks.image_captioning, module_name=Models.mplug) +@MODELS.register_module(Tasks.image_text_retrieval, module_name=Models.mplug) +class MPlugForAllTasks(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the mplug model from the `model_dir` path. + Args: + model_dir (str): the model path. + """ + + super().__init__(model_dir, *args, **kwargs) + from modelscope.models.multi_modal.mplug import MPlug + self.model = MPlug.from_pretrained(model_dir) + self.tokenizer = self.model.tokenizer + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), + } + """ + + # get task from config file + task = Config.from_file( + osp.join(self.model_dir, ModelFile.CONFIGURATION)).task + + # inference + if not self.training and 'question' in input: + output = self.model(input['image'], input['question'], train=False) + if task == Tasks.image_text_retrieval: + return {OutputKeys.SCORES: output[0].tolist()} + topk_ids, _ = output + pred_string: List[str] = \ + self.tokenizer.decode(topk_ids[0][0], skip_special_tokens=True) + output_key = OutputKeys.CAPTION \ + if task == Tasks.image_captioning else OutputKeys.TEXT + return {output_key: pred_string} + + # train and evaluate + import addict + image = input['image'] + answer = addict.Dict( + input_ids=input['answer_input_ids'], + attention_mask=input['answer_attention_mask']) + if 'index' not in input: + question = addict.Dict( + input_ids=input['question_input_ids'], + attention_mask=input['question_attention_mask']) + output = self.model(image, question, answer, train=self.training) + else: + index = input['index'] + output = self.model(image, answer, index, train=self.training) + if self.training: + return {OutputKeys.LOSS: output} + + # evaluate + topk_ids, _ = output + return {'sequences': [list_tensor[0] for list_tensor in topk_ids]} diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py b/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py new file mode 100644 index 00000000..1b3f445b --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from .model import MultiStageDiffusionForTextToImageSynthesis diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/clip.py b/modelscope/models/multi_modal/multi_stage_diffusion/clip.py new file mode 100644 index 00000000..98727066 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/clip.py @@ -0,0 +1,319 @@ +# Part of the implementation is borrowed and modified from CLIP, publicly avaialbe at https://github.com/openai/CLIP. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['CLIP'] + + +def to_fp16(m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + m.weight.data = m.weight.data.half() + if m.bias is not None: + m.bias.data = m.bias.data.half() + elif hasattr(m, 'head'): + p = getattr(m, 'head') + p.data = p.data.half() + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + r"""Subclass of nn.LayerNorm to handle fp16. + """ + + def forward(self, x): + return super(LayerNorm, self).forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.attn_dropout = nn.Dropout(attn_dropout) + self.proj = nn.Linear(dim, dim) + self.proj_dropout = nn.Dropout(proj_dropout) + + def forward(self, x, mask=None): + r"""x: [B, L, C]. + mask: [*, L, L]. + """ + b, l, _, n = *x.size(), self.num_heads + + # compute query, key, and value + q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) + q = q.reshape(l, b * n, -1).transpose(0, 1) + k = k.reshape(l, b * n, -1).transpose(0, 1) + v = v.reshape(l, b * n, -1).transpose(0, 1) + + # compute attention + attn = self.scale * torch.bmm(q, k.transpose(1, 2)) + if mask is not None: + attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = self.attn_dropout(attn) + + # gather context + x = torch.bmm(attn, v) + x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) + + # output + x = self.proj(x) + x = self.proj_dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), + nn.Dropout(proj_dropout)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.mlp(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + out_dim=512, + num_heads=12, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + assert image_size % patch_size == 0 + super(VisionTransformer, self).__init__() + self.image_size = image_size + self.patch_size = patch_size + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.num_patches = (image_size // patch_size)**2 + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain * torch.randn(1, self.num_patches + 1, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim) + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim) + + # head + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + def forward(self, x): + b, dtype = x.size(0), self.head.dtype + x = x.type(dtype) + + # patch-embedding + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] + x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], + dim=1) + x = self.dropout(x + self.pos_embedding.type(dtype)) + x = self.pre_norm(x) + + # transformer + x = self.transformer(x) + + # head + x = self.post_norm(x) + x = torch.mm(x[:, 0, :], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class TextTransformer(nn.Module): + + def __init__(self, + vocab_size, + text_len, + dim=512, + out_dim=512, + num_heads=8, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(TextTransformer, self).__init__() + self.vocab_size = vocab_size + self.text_len = text_len + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim) + self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.transformer = nn.ModuleList([ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.norm = LayerNorm(dim) + + # head + gain = 1.0 / math.sqrt(dim) + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + # causal attention mask + self.register_buffer('attn_mask', + torch.tril(torch.ones(1, text_len, text_len))) + + def forward(self, x): + eot, dtype = x.argmax(dim=-1), self.head.dtype + + # embeddings + x = self.dropout( + self.token_embedding(x).type(dtype) + + self.pos_embedding.type(dtype)) + + # transformer + for block in self.transformer: + x = block(x, self.attn_mask) + + # head + x = self.norm(x) + x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=49408, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(CLIP, self).__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_tokens): + r"""imgs: [B, C, H, W] of torch.float32. + txt_tokens: [B, T] of torch.long. + """ + xi = self.visual(imgs) + xt = self.textual(txt_tokens) + + # normalize features + xi = F.normalize(xi, p=2, dim=1) + xt = F.normalize(xt, p=2, dim=1) + + # logits + scale = self.log_scale.exp() + logits_i2t = scale * torch.mm(xi, xt.t()) + logits_t2i = scale * torch.mm(xt, xi.t()) + return logits_i2t, logits_t2i + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else 'textual' + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * transformer.num_layers)) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer.layers: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def fp16(self): + return self.apply(to_fp16) diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/decoder.py b/modelscope/models/multi_modal/multi_stage_diffusion/decoder.py new file mode 100644 index 00000000..eb52a48b --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/decoder.py @@ -0,0 +1,322 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['Decoder'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample(x) + x = self.layer1[-1](self.resample(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class Decoder(nn.Module): + + def __init__(self, + in_dim=3, + dim=512, + y_dim=512, + context_dim=512, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.1): + embed_dim = dim * 4 + super(Decoder, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.y_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.context_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, context_dim * 4)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + AttentionBlock(out_dim, context_dim, num_heads, head_dim), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y): + # embeddings + e = self.time_embedding(sinusoidal_embedding( + t, self.dim)) + self.y_embedding(y) + context = self.context_embedding(y).view(-1, 4, self.context_dim) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e, context) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e, context) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e, context) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e, context): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, AttentionBlock): + x = module(x, context) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py b/modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py new file mode 100644 index 00000000..9677d7c4 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py @@ -0,0 +1,642 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math + +import torch + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def kl_divergence(mu1, logvar1, mu2, logvar2): + u1 = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + u2 = ((mu1 - mu2)**2) * torch.exp(-logvar2) + return 0.5 * (u1 + u2) + + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + fn_t1 = math.cos((t1 + 0.008) / 1.008 * math.pi / 2)**2 + fn_t2 = math.cos((t2 + 0.008) / 1.008 * math.pi / 2)**2 + betas.append(min(1.0 - fn_t2 / fn_t1, 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + noise = torch.randn_like(x0) if noise is None else noise + u1 = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + u2 = _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + return u1 + u2 + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + shape = (-1, *((1, ) * (xt.ndim - 1))) + mask = t.ne(0).float().view(shape) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + cond = self.var_type.startswith('fixed') + dim = y_out.size(1) if cond else y_out.size(1) // 2 + u1 = u_out[:, :dim] + u2 = guide_scale * (y_out[:, :dim] - u_out[:, :dim]) + out = torch.cat([u1 + u2, y_out[:, dim:]], dim=1) + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + u1 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu + u2 = _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + xt) * xt + x0 = u1 - u2 + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + x0 = u1 - u2 + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + x0 = u1 - u2 + + # derive variables + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + u1 = (1 - alphas_prev) / (1 - alphas) + u2 = (1 - alphas / alphas_prev) + sigmas = eta * torch.sqrt(u1 * u2) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + x0 = u1 - u2 + + # derive eps + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + x0 = u1 - u2 + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + # mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None, input_x0=None): + noise = torch.randn_like(x0) if noise is None else noise + input_x0 = x0 if input_x0 is None else input_x0 + xt = self.q_sample(input_x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 + ).abs().flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/model.py b/modelscope/models/multi_modal/multi_stage_diffusion/model.py new file mode 100644 index 00000000..59bd837d --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/model.py @@ -0,0 +1,265 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.multi_stage_diffusion.clip import CLIP +from modelscope.models.multi_modal.multi_stage_diffusion.decoder import Decoder +from modelscope.models.multi_modal.multi_stage_diffusion.gaussian_diffusion import ( + GaussianDiffusion, beta_schedule) +from modelscope.models.multi_modal.multi_stage_diffusion.prior import Prior +from modelscope.models.multi_modal.multi_stage_diffusion.tokenizer import ( + CLIPTokenizer, XGLMTokenizer) +from modelscope.models.multi_modal.multi_stage_diffusion.upsampler import ( + Upsampler256, Upsampler1024) +from modelscope.models.multi_modal.multi_stage_diffusion.xglm import XGLM +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['MultiStageDiffusionForTextToImageSynthesis'] + + +def make_diffusion(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None, + mean_type='eps', + var_type='fixed_small'): + betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta) + diffusion = GaussianDiffusion( + betas, mean_type=mean_type, var_type=var_type) + return diffusion + + +class UnCLIP(nn.Module): + + def __init__(self, model_dir): + super(UnCLIP, self).__init__() + self.model_dir = model_dir + self.config = json.load(open(f'{model_dir}/{ModelFile.CONFIGURATION}')) + + # modules + self.clip = CLIP(**self.config['clip']).fp16() + self.xglm = XGLM(**self.config['xglm']) + self.prior = Prior(**self.config['prior']) + self.decoder = Decoder(**self.config['decoder']) + self.upsampler256 = Upsampler256(**self.config['upsampler256']) + self.upsampler1024 = Upsampler1024(**self.config['upsampler1024']) + + # diffusions + self.prior_diffusion = make_diffusion(**self.config['prior_diffusion']) + self.decoder_diffusion = make_diffusion( + **self.config['decoder_diffusion']) + self.upsampler256_diffusion = make_diffusion( + **self.config['upsampler256_diffusion']) + self.upsampler1024_diffusion = make_diffusion( + **self.config['upsampler1024_diffusion']) + + # tokenizers + self.clip_tokenizer = CLIPTokenizer( + bpe_path=f'{model_dir}/bpe_simple_vocab_16e6.txt.gz') + self.xglm_tokenizer = XGLMTokenizer(model_dir=model_dir) + + def forward(self, *args, **kwargs): + raise NotImplementedError( + '"forward" is not implemented. Use "synthesis" instead.') + + @torch.no_grad() + def synthesis(self, + text='A photo of a confused grizzly bear in calculus class.', + tokenizer='clip', + batch_size=4, + timesteps_prior=100, + timesteps_64=50, + timesteps_256=20, + timesteps_1024=20, + guide_prior=3.0, + guide_64=7.0, + guide_256=3.0, + guide_1024=3.0, + eta_prior=0.0, + eta_64=0.0, + eta_256=0.0, + eta_1024=0.0): + device = next(self.parameters()).device + + # check params + assert all([ + t > 0 and t <= 1000 for t in + [timesteps_prior, timesteps_64, timesteps_256, timesteps_1024] + ]) + assert all([ + g > 1 and g < 15 + for g in [guide_prior, guide_64, guide_256, guide_1024] + ]) + assert all([ + e >= 0 and e <= 1.0 + for e in [eta_prior, eta_64, eta_256, eta_1024] + ]) + assert batch_size >= 1 and batch_size <= 16 + + # tokenize the text + if tokenizer == 'clip': + y = F.normalize( + self.clip.textual(self.clip_tokenizer([text]).to(device)), + p=2, + dim=1) + zero_y = F.normalize( + self.clip.textual(self.clip_tokenizer(['']).to(device)), + p=2, + dim=1) + elif tokenizer == 'xglm': + y = F.normalize( + self.xglm(*to_device(self.xglm_tokenizer([text]), device)), + p=2, + dim=1) + zero_y = F.normalize( + self.xglm(*to_device(self.xglm_tokenizer(['']), device)), + p=2, + dim=1) + else: + raise ValueError( + f'Expected tokenizer to be one of "clip" or "xglm", but got {tokenizer}' + ) + y = math.sqrt(y.size(1)) * y.repeat(batch_size, 1) + zero_y = math.sqrt(zero_y.size(1)) * zero_y.repeat(batch_size, 1) + + # synthesis + with amp.autocast(enabled=True): + # prior + x0 = self.prior_diffusion.ddim_sample_loop( + noise=torch.randn_like(y), + model=self.prior, + model_kwargs=[{ + 'y': y + }, { + 'y': zero_y + }], + guide_scale=guide_prior, + ddim_timesteps=timesteps_prior, + eta=eta_prior) + + # decoder + imgs64 = self.decoder_diffusion.ddim_sample_loop( + noise=torch.randn(batch_size, 3, 64, 64).to(device), + model=self.decoder, + model_kwargs=[{ + 'y': x0 + }, { + 'y': torch.zeros_like(x0) + }], + guide_scale=guide_64, + percentile=0.995, + ddim_timesteps=timesteps_64, + eta=eta_64).clamp_(-1, 1) + + # upsampler256 + imgs256 = F.interpolate( + imgs64, scale_factor=4.0, mode='bilinear', align_corners=False) + imgs256 = self.upsampler256_diffusion.ddim_sample_loop( + noise=torch.randn_like(imgs256), + model=self.upsampler256, + model_kwargs=[{ + 'y': y, + 'concat': imgs256 + }, { + 'y': zero_y, + 'concat': imgs256 + }], + guide_scale=guide_256, + percentile=0.995, + ddim_timesteps=timesteps_256, + eta=eta_256).clamp_(-1, 1) + + # upsampler1024 + imgs1024 = F.interpolate( + imgs256, + scale_factor=4.0, + mode='bilinear', + align_corners=False) + imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop( + noise=torch.randn_like(imgs1024), + model=self.upsampler1024, + model_kwargs=[{ + 'y': y, + 'concat': imgs1024 + }, { + 'y': zero_y, + 'concat': imgs1024 + }], + guide_scale=guide_1024, + percentile=0.995, + ddim_timesteps=timesteps_1024, + eta=eta_1024).clamp_(-1, 1) + + # output ([B, C, H, W] within range [0, 1]) + imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu() + imgs1024 = [ + Image.fromarray(np.array(u, dtype=np.uint8)) for u in imgs1024 + ] + return imgs1024 + + +@MODELS.register_module( + Tasks.text_to_image_synthesis, module_name=Models.multi_stage_diffusion) +class MultiStageDiffusionForTextToImageSynthesis(TorchModel): + + def __init__(self, model_dir, device_id=-1): + super().__init__(model_dir=model_dir, device_id=device_id) + model = UnCLIP(model_dir=model_dir) + pretrained_params = torch.load( + osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu') + model.load_state_dict(pretrained_params) + model.eval() + + self.device_id = device_id + if self.device_id >= 0: + self.device = torch.device(f'cuda:{self.device_id}') + model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device = torch.device('cpu') + logger.info('Use CPU for inference') + self.model = model + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(input, dict): + raise ValueError( + f'Expected the input to be a dictionary, but got {type(input)}' + ) + if 'text' not in input: + raise ValueError('input should contain "text", but not found') + + # ddim sampling + imgs = self.model.synthesis( + text=input.get('text'), + tokenizer=input.get('tokenizer', 'clip'), + batch_size=input.get('batch_size', 4), + timesteps_prior=input.get('timesteps_prior', 100), + timesteps_64=input.get('timesteps_64', 50), + timesteps_256=input.get('timesteps_256', 20), + timesteps_1024=input.get('timesteps_1024', 20), + guide_prior=input.get('guide_prior', 3.0), + guide_64=input.get('guide_64', 7.0), + guide_256=input.get('guide_256', 3.0), + guide_1024=input.get('guide_1024', 3.0), + eta_prior=input.get('eta_prior', 0.0), + eta_64=input.get('eta_64', 0.0), + eta_256=input.get('eta_256', 0.0), + eta_1024=input.get('eta_1024', 0.0)) + imgs = [np.array(u)[..., ::-1] for u in imgs] + return imgs diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/prior.py b/modelscope/models/multi_modal/multi_stage_diffusion/prior.py new file mode 100644 index 00000000..9f4ef2d5 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/prior.py @@ -0,0 +1,170 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['Prior'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = math.pow(self.head_dim, -0.25) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, mask): + b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, l, n * 3, c).chunk(3, dim=2) + + # compute attention + attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale) + if mask is not None: + attn = attn.masked_fill(mask[:, :, :l, :l] == 0, float('-inf')) + attn = F.softmax(attn.float(), dim=-1).type(attn.dtype) + + # gather context + x = torch.einsum('bnij,bjnc->binc', attn, v) + x = x.reshape(b, l, -1) + + # output + x = self.proj(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = nn.LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads) + self.norm2 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class Prior(nn.Module): + + def __init__(self, dim=2048, clip_dim=768, num_heads=32, num_layers=24): + super(Prior, self).__init__() + self.dim = dim + self.clip_dim = clip_dim + self.num_heads = num_heads + self.num_layers = num_layers + + # embeddings + self.text_embedding = nn.Sequential( + nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_embedding = nn.Sequential( + nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.vision_embedding = nn.Sequential( + nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.eos_embedding = nn.Parameter(torch.zeros(1, 1, dim)) + self.pos_embedding = nn.Parameter(torch.zeros(1, 4, dim)) + + # transformer + self.blocks = nn.ModuleList( + [AttentionBlock(dim, num_heads) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(dim) + + # head + self.head = nn.Linear(dim, clip_dim) + + # causal attention mask + self.register_buffer('attn_mask', torch.tril(torch.ones(1, 1, 4, 4))) + + # initialize weights + self.init_weights() + + def forward(self, x, t, y): + r"""x: [B, C]. + t: [B]. + y: [B, C]. + """ + b = x.size(0) + + # embeddings of shape [B, L + 4, C] + u1 = sinusoidal_embedding(t, self.dim) + u2 = [ + self.text_embedding(y).unsqueeze(1), + self.time_embedding(u1).unsqueeze(1), + self.vision_embedding(x).unsqueeze(1), + self.eos_embedding.repeat(b, 1, 1) + ] + x = self.pos_embedding + torch.cat(u2, dim=1) + + # transformer + for block in self.blocks: + x = block(x, self.attn_mask) + x = self.norm(x) + + # head + x = self.head(x[:, -1]) + return x + + def init_weights(self): + std = 0.02 / math.sqrt(2.0 * self.num_layers) + for name, m in self.named_modules(): + if name.endswith('attn.proj') or name.endswith('ffn.2'): + # smaller std for output layers + nn.init.normal_(m.weight, std=std) + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.Linear, nn.Embedding)): + nn.init.normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': + 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py b/modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py new file mode 100644 index 00000000..59d6b304 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py @@ -0,0 +1,200 @@ +# Part of the implementation is borrowed and modified from CLIP, publicly avaialbe at https://github.com/openai/CLIP. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import gzip +import html +from functools import lru_cache + +import ftfy +import regex as re +import torch +from transformers import AutoTokenizer + +__all__ = ['CLIPTokenizer', 'XGLMTokenizer'] + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +class CLIPTokenizer(object): + r"""CLIP tokenizer, adapted from https://github.com/openai/CLIP. + """ + + def __init__(self, bpe_path, length=77): + self.bpe_path = bpe_path + self.length = length + + # init tokenizer + self.tokenizer = SimpleTokenizer(bpe_path=bpe_path) + self.sos_token = self.tokenizer.encoder['<|startoftext|>'] + self.eos_token = self.tokenizer.encoder['<|endoftext|>'] + self.vocab_size = len(self.tokenizer.encoder) + + def __call__(self, sequence): + if isinstance(sequence, str): + return torch.LongTensor(self._tokenizer(sequence)) + elif isinstance(sequence, list): + return torch.LongTensor([self._tokenizer(u) for u in sequence]) + else: + raise TypeError( + f'Expected the "sequence" to be a string or a list, but got {type(sequence)}' + ) + + def _tokenizer(self, text): + tokens = self.tokenizer.encode(text)[:self.length - 2] + tokens = [self.sos_token] + tokens + [self.eos_token] + tokens = tokens + [0] * (self.length - len(tokens)) + return tokens + + +class XGLMTokenizer(object): + r"""A wrapper of HuggingFace's XGLM tokenizer. + """ + + def __init__(self, model_dir, length=77, **kwargs): + self.length = length + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + _kwargs = { + 'return_tensors': 'pt', + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.length + } + _kwargs.update(**kwargs) + tokens = self.tokenizer(sequence, **_kwargs) + return tokens.input_ids, tokens.attention_mask diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py b/modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py new file mode 100644 index 00000000..a292edae --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py @@ -0,0 +1,466 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['Upsampler256', 'Upsampler1024'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample(x) + x = self.layer1[-1](self.resample(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class Upsampler256(nn.Module): + + def __init__(self, + in_dim=6, + dim=320, + y_dim=768, + context_dim=512, + out_dim=3, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 8], + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.1): + embed_dim = dim * 4 + super(Upsampler256, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.y_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.context_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, context_dim * 4)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + AttentionBlock(out_dim, context_dim, num_heads, head_dim), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y, concat): + # embeddings + x = torch.cat([x, concat], dim=1) + e = self.time_embedding(sinusoidal_embedding( + t, self.dim)) + self.y_embedding(y) + context = self.context_embedding(y).view(-1, 4, self.context_dim) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e, context) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e, context) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e, context) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e, context): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, AttentionBlock): + x = module(x, context) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context) + else: + x = module(x) + return x + + +class Upsampler1024(nn.Module): + + def __init__(self, + in_dim=6, + dim=192, + y_dim=768, + out_dim=3, + dim_mult=[1, 1, 2, 2, 4, 4], + num_res_blocks=2, + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.0): + embed_dim = dim * 4 + super(Upsampler1024, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embedding + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.y_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual block + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + shortcut_dims.append(out_dim) + in_dim = out_dim + self.encoder.append(block) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual block + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y, concat): + # embedding + x = torch.cat([x, concat], dim=1) + e = self.time_embedding(sinusoidal_embedding( + t, self.dim)) + self.y_embedding(y) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/xglm.py b/modelscope/models/multi_modal/multi_stage_diffusion/xglm.py new file mode 100644 index 00000000..133da50b --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/xglm.py @@ -0,0 +1,206 @@ +# Part of the implementation is borrowed and modified from HuggingFace XGLM, +# publicly avaialbe at https://github.com/huggingface/transformers. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['XGLM'] + + +def sinusoidal_embedding(seq_len, dim, pad_token=None): + half = dim // 2 + sinusoid = torch.outer( + torch.arange(seq_len, dtype=torch.float32), + torch.pow(10000, + -torch.arange(half, dtype=torch.float32).div(half - 1))) + x = torch.cat([torch.sin(sinusoid), torch.cos(sinusoid)], dim=1) + if dim % 2 == 1: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + if pad_token is not None: + x[pad_token, :] = 0 + return x + + +class SinusoidalEmbedding(nn.Module): + + def __init__(self, seq_len, dim, pad_token): + super(SinusoidalEmbedding, self).__init__() + self.seq_len = seq_len + self.dim = dim + self.pad_token = pad_token + self.register_buffer('weight', + sinusoidal_embedding(seq_len + 2, dim, pad_token)) + + def forward(self, tokens): + mask = tokens.ne(self.pad_token).long() + indices = torch.cumsum(mask, dim=1) * mask + self.pad_token + pos_embeds = self.weight.index_select(0, indices.view(-1)).view( + *tokens.shape, -1) + return pos_embeds + + +class GELU(nn.Module): + + def forward(self, x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask=None): + r"""x: [B, L, C]. + mask: [B, *, L, L] or None. + """ + b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, l, n, c) + k = self.k(x).view(b, l, n, c) + v = self.v(x).view(b, l, n, c) + + # compute attention + attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k) + if mask is not None: + attn = attn.masked_fill(mask == 0, float('-inf')) + attn = F.softmax(attn, dim=-1) + attn = self.dropout(attn) + + # gather context + x = torch.einsum('bnij,bjnc->binc', attn, v) + x = x.reshape(b, l, -1) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, ffn_dim, ffn_act, num_heads, dropout=0.1): + assert ffn_act in ['gelu', 'relu'] + super(AttentionBlock, self).__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.ffn_act = ffn_act + self.num_heads = num_heads + + # layers + self.norm1 = nn.LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads, dropout) + self.norm2 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + GELU() if ffn_act == 'gelu' else nn.ReLU(inplace=True), + nn.Linear(ffn_dim, dim), nn.Dropout(dropout)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XGLM(nn.Module): + r"""A multilingual GPT model with an embedding head. + """ + + def __init__(self, + vocab_size=256008, + max_seq_len=2048, + dim=1024, + ffn_dim=4096, + ffn_act='gelu', + embed_dim=768, + num_heads=16, + num_layers=24, + pad_token=1, + dropout=0.1): + super(XGLM, self).__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.dim = dim + self.ffn_dim = ffn_dim + self.ffn_act = ffn_act + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pad_token = pad_token + self.scale = math.sqrt(dim) # rescale token embedings + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim, pad_token) + self.pos_embedding = SinusoidalEmbedding(max_seq_len, dim, pad_token) + self.eos_embedding = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + AttentionBlock(dim, ffn_dim, ffn_act, num_heads, dropout) + for _ in range(num_layers) + ]) + self.norm = nn.LayerNorm(dim) + self.head = nn.Linear(dim, embed_dim, bias=False) + + # causal attention mask + self.register_buffer( + 'attn_mask', + torch.tril(torch.ones(1, 1, 1 + max_seq_len, 1 + max_seq_len))) + + # init weights + self.apply(self.init_weights) + + def forward(self, tokens, mask=None): + r"""tokens: [B, L]. + mask: [B, L]. + """ + b, seq_len = tokens.size(0), 1 + tokens.size(1) + + # embeddings + x = self.scale * self.token_embedding(tokens) + x = torch.cat([x, self.eos_embedding.repeat(b, 1, 1)], dim=1) + # x = x + self.pos_embedding(tokens) + x = self.dropout(x) + + # attention mask + if mask is None: + mask = self.attn_mask[:, :, :seq_len, :seq_len].repeat(b, 1, 1, 1) + else: + mask = self.attn_mask[:, :, :seq_len, :seq_len] * torch.cat( + [mask, torch.zeros_like(mask[:, :1])], dim=1).view( + b, 1, 1, seq_len) + + # transformer + for block in self.blocks: + x = block(x, mask) + x = self.norm(x) + + # head + logits = self.head(x[:, -1]) + return logits + + def init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + if m.padding_idx is not None: + nn.init.zeros_(m.weight[m.padding_idx]) diff --git a/modelscope/models/multi_modal/ofa/__init__.py b/modelscope/models/multi_modal/ofa/__init__.py new file mode 100644 index 00000000..3e8e59f4 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .modeling_ofa import OFADecoder, OFAEncoder, OFAModel, OFAPreTrainedModel +from .tokenization_ofa import OFATokenizer, OFATokenizerZH +from .tokenization_ofa_fast import OFATokenizerFast, OFATokenizerZHFast diff --git a/modelscope/models/multi_modal/ofa/adaptor/__init__.py b/modelscope/models/multi_modal/ofa/adaptor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/ofa/configuration_ofa.py b/modelscope/models/multi_modal/ofa/configuration_ofa.py new file mode 100644 index 00000000..2edc651e --- /dev/null +++ b/modelscope/models/multi_modal/ofa/configuration_ofa.py @@ -0,0 +1,211 @@ +# Copyright 2022 Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" OFA model configuration""" +import warnings + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +OFA_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'ofa-medium': 'https://huggingface.co/ofa-base/resolve/main/config.json', + # OFA models are implemeted to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa +} + + +class OFAConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~OFAModel`]. It is used to instantiate an OFA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the OFA [ofa-base](https://huggingface.co/ofa-base) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the OFA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~OFAModel`] or [`~TFOFAModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimension of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = 'ofa' + keys_to_ignore_at_inference = ['past_key_values'] + + attribute_map = { + 'num_attention_heads': 'encoder_attention_heads', + 'hidden_size': 'd_model' + } + + def __init__(self, + vocab_size=59457, + max_position_embeddings=1024, + encoder_layers=4, + encoder_ffn_dim=512 * 4, + encoder_attention_heads=8, + decoder_layers=4, + decoder_ffn_dim=512 * 4, + decoder_attention_heads=8, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function='gelu', + d_model=512, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + pad_token_id=1, + bos_token_id=0, + decoder_start_token_id=0, + eos_token_id=2, + forced_eos_token_id=2, + encoder_normalize_before=True, + decoder_normalize_before=True, + normformer=True, + encoder_drop_path_rate=0.0, + decoder_drop_path_rate=0.0, + layernorm_embedding=True, + patch_layernorm_embedding=True, + resnet_type='resnet101', + resnet_model_path=None, + resnet_drop_path_rate=0.0, + token_bucket_size=256, + image_bucket_size=42, + add_type_embedding=True, + share_decoder_input_output_embed=True, + attn_scale_factor=2., + code_layernorm_embedding=True, + code_image_size=128, + entangle_position_embedding=False, + interpolate_position=False, + orig_patch_image_size=224, + share_attn_bias=False, + use_image_feature=True, + disable_entangle=False, + use_ofasys=False, + vit_type='vit_base', + vit_drop_path_rate=0.0, + **kwargs): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.encoder_normalize_before = encoder_normalize_before + self.decoder_normalize_before = decoder_normalize_before + self.normformer = normformer + self.encoder_drop_path_rate = encoder_drop_path_rate + self.decoder_drop_path_rate = decoder_drop_path_rate + self.layernorm_embedding = layernorm_embedding + self.patch_layernorm_embedding = patch_layernorm_embedding + self.resnet_type = resnet_type + self.resnet_model_path = resnet_model_path + self.resnet_drop_path_rate = resnet_drop_path_rate + self.token_bucket_size = token_bucket_size + self.image_bucket_size = image_bucket_size + self.add_type_embedding = add_type_embedding + self.share_decoder_input_output_embed = share_decoder_input_output_embed + self.attn_scale_factor = attn_scale_factor + self.code_layernorm_embedding = code_layernorm_embedding + self.code_image_size = code_image_size + self.entangle_position_embedding = entangle_position_embedding + self.interpolate_position = interpolate_position + self.orig_patch_image_size = orig_patch_image_size + + self.share_attn_bias = share_attn_bias + self.use_image_feature = use_image_feature + self.disable_entangle = disable_entangle + self.use_ofasys = use_ofasys + self.vit_type = vit_type + self.vit_drop_path_rate = vit_drop_path_rate + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + # ensure backward compatibility for BART CNN models + if self.forced_bos_token_id is None and kwargs.get( + 'force_bos_token_to_be_generated', False): + self.forced_bos_token_id = self.bos_token_id + warnings.warn( + f'Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. ' + 'The config can simply be saved and uploaded again to be fixed.' + ) diff --git a/modelscope/models/multi_modal/ofa/generate/__init__.py b/modelscope/models/multi_modal/ofa/generate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/ofa/generate/incremental_decoding_utils.py b/modelscope/models/multi_modal/ofa/generate/incremental_decoding_utils.py new file mode 100644 index 00000000..db0df9b2 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/incremental_decoding_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE + +import uuid +from typing import Dict, Optional + +from torch import Tensor + + +class FairseqIncrementalState(object): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return '{}.{}'.format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState, ) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState) + return cls diff --git a/modelscope/models/multi_modal/ofa/generate/multihead_attention.py b/modelscope/models/multi_modal/ofa/generate/multihead_attention.py new file mode 100644 index 00000000..9101d52d --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/multihead_attention.py @@ -0,0 +1,510 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE + +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import Tensor, nn +from torch.nn import Parameter + + +@with_incremental_state +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__) + + self.head_dim = embed_dim // num_heads + assert (self.head_dim * num_heads == self.embed_dim + ), 'embed_dim must be divisible by num_heads' + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + 'Self-attention requires query, key and ' + 'value to be of the same size') + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, + Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == 'xla' + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim, f'query dim {embed_dim} != {self.embed_dim}' + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if (not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting()): + assert key is not None and value is not None + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat( + (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and 'prev_key' in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, + attn_mask.new_zeros(attn_mask.size(0), 1)], + dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros( + key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous().view(tgt_len, bsz * self.num_heads, + self.head_dim).transpose(0, 1)) + if k is not None: + k = ( + k.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1)) + if v is not None: + v = ( + v.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1)) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if 'prev_key' in saved_state: + _prev_key = saved_state['prev_key'] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, + self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if 'prev_value' in saved_state: + _prev_value = saved_state['prev_value'] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, + self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if 'prev_key_padding_mask' in saved_state: + prev_key_padding_mask = saved_state['prev_key_padding_mask'] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, + self.head_dim) + saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, + self.head_dim) + saved_state['prev_key_padding_mask'] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, + saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], + dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], + dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, + attn_mask.new_zeros(attn_mask.size(0), 1)], + dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), + 1).type_as(key_padding_mask), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, + bsz) + + assert list( + attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float('-inf'), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill( + key_padding_mask, float('-inf')) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils.softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list( + attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, + 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, + tgt_len, + src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), + key_padding_mask.float()], + dim=1) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), + filler.float()], dim=1) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention and input_buffer_k.size( + 0) == new_order.size(0): + break + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, + input_buffer) + return incremental_state + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, + Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, 'attn_state') + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, 'attn_state', + buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, + bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + '.' if name != '' else '' + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + 'in_proj_weight'): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + 'q_proj.weight'] = state_dict[k][:dim] + items_to_add[prefix + 'k_proj.weight'] = state_dict[k][dim:2 + * dim] + items_to_add[prefix + 'v_proj.weight'] = state_dict[k][2 + * dim:] + + keys_to_remove.append(k) + + k_bias = prefix + 'in_proj_bias' + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + + 'q_proj.bias'] = state_dict[k_bias][:dim] + items_to_add[prefix + + 'k_proj.bias'] = state_dict[k_bias][dim:2 + * dim] + items_to_add[prefix + + 'v_proj.bias'] = state_dict[k_bias][2 + * dim:] + + keys_to_remove.append(prefix + 'in_proj_bias') + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value diff --git a/modelscope/models/multi_modal/ofa/generate/ngram_repeat_block.py b/modelscope/models/multi_modal/ofa/generate/ngram_repeat_block.py new file mode 100644 index 00000000..4bccfa76 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/ngram_repeat_block.py @@ -0,0 +1,155 @@ +# Originally from Microsoft Corporation. +# Licensed under the MIT License. +""" Wrapper for ngram_repeat_block cuda extension """ +import math +import warnings +from typing import Dict, List + +import torch +from torch import nn + +try: + from fairseq import ngram_repeat_block_cuda + + EXTENSION_BUILT = True +except ImportError: + EXTENSION_BUILT = False + + +def is_cuda_extension_usable() -> bool: + """Check whether ngram_repeat_block_cuda is built properly""" + if not EXTENSION_BUILT or not torch.cuda.is_available(): + return False + bsz = 2 + tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], + dtype=torch.long, + device='cuda') + lprobs = torch.rand((8, 12), device='cuda') + try: + outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) + outputs = outputs + 4 # This line breaks if the extension is built incorrectly. + return True + except RuntimeError: + warnings.warn( + 'NGramRepeatBlock extension must be rebuilt.' + 'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' + ) + return False + + +class NGramRepeatBlock(nn.Module): + """ Wrapper class for calling ngram_repeat_block cuda extension """ + + def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): + super().__init__() + self.use_extension = is_cuda_extension_usable( + ) if use_extension else False + self.no_repeat_ngram_size = no_repeat_ngram_size + + def reset_parameters(self): + pass + + @torch.jit.unused + def call_cuda_extension( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + return ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, step, + beam_size, + self.no_repeat_ngram_size) + + def forward( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + """ + Args: + tokens(Tensor): Input tokens(Bsz*beam, seq_len) + lprobs(Tensor): likelihood probability, + Expected to be updated in place.(Bsz*beam, vocab_size) + bsz(int): batch size + step(int): current step + beam_size(int): beam size + no_repeat_ngram_size(int): Ngram size + """ + msg = f'expected {bsz * beam_size} got' + assert tokens.size(0) == bsz * beam_size, f'{msg} {tokens.size(0)}' + assert lprobs.size(0) == bsz * beam_size, f'{msg} {lprobs.size(0)}' + if self.use_extension: + return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, + step) + + else: + return self._no_repeat_ngram( + tokens, + lprobs, + bsz, + beam_size, + step, + ) + + def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, + step: int): + """For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" + gen_ngrams: List[Dict[str, List[int]]] = [ + torch.jit.annotate(Dict[str, List[int]], {}) + for bbsz_idx in range(bsz * beam_size) + ] + cpu_tokens = tokens.cpu() + for bbsz_idx in range(bsz * beam_size): + gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() + for ngram in self.transpose_list([ + gen_tokens[i:] for i in range(self.no_repeat_ngram_size) + ]): # noqa + key = ','.join([str(x) for x in ngram[:-1]]) + gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( + key, torch.jit.annotate(List[int], [])) + [ngram[-1]] + if step + 2 - self.no_repeat_ngram_size >= 0: + # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + banned_tokens = [ + self.calculate_banned_tokens(tokens, step, gen_ngrams, + self.no_repeat_ngram_size, + bbsz_idx) + for bbsz_idx in range(bsz * beam_size) + ] + else: + banned_tokens = [ + torch.jit.annotate(List[int], []) + for bbsz_idx in range(bsz * beam_size) + ] + for bbsz_idx in range(bsz * beam_size): + lprobs[bbsz_idx][torch.tensor( + banned_tokens[bbsz_idx], + dtype=torch.int64)] = torch.tensor(-math.inf).to(lprobs) + return lprobs + + @staticmethod + def calculate_banned_tokens( + tokens, + step: int, + gen_ngrams: List[Dict[str, List[int]]], + no_repeat_ngram_size: int, + bbsz_idx: int, + ): + tokens_list: List[int] = tokens[bbsz_idx, + step + 2 - no_repeat_ngram_size:step + + 1].tolist() # noqa + # before decoding the next token, prevent decoding of ngrams that have already appeared + ngram_index = ','.join([str(x) for x in tokens_list]) + return gen_ngrams[bbsz_idx].get(ngram_index, + torch.jit.annotate(List[int], [])) + + @staticmethod + def transpose_list(l: List[List[int]]): # noqa + # GeneratorExp aren't supported in TS so ignoring the lint + min_len = min([len(x) for x in l]) # noqa + l2 = [[row[i] for row in l] for i in range(min_len)] + return l2 diff --git a/modelscope/models/multi_modal/ofa/generate/search.py b/modelscope/models/multi_modal/ofa/generate/search.py new file mode 100644 index 00000000..0dcaf6b3 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/search.py @@ -0,0 +1,848 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE + +import math +from typing import List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from .token_generation_constraints import (ConstraintState, + OrderedConstraintState, + UnorderedConstraintState) + + +class Search(nn.Module): + + def __init__(self, tokenizer): + super().__init__() + self.pad = tokenizer.pad_token_id + self.unk = tokenizer.unk_token_id + self.eos = tokenizer.eos_token_id + tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} + added = { + value: key + for key, value in tokenizer.get_added_vocab().items() + } + tgt_dict.update(added) + self.vocab_size = len(tgt_dict) + self.src_lengths = torch.tensor(-1) + self.supports_constraints = False + self.stop_on_max_len = False + + def step(self, + step, + lprobs, + scores, + prev_output_tokens=None, + original_batch_idxs=None): + """Take a single search step. + + Args: + step: the current search step, starting at 0 + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + scores: (bsz x input_beam_size x step) + the historical model scores of each hypothesis up to this point + prev_output_tokens: (bsz x step) + the previously generated oputput tokens + original_batch_idxs: (bsz) + the tensor with the batch indices, in the range [0, bsz) + this is useful in case there has been applied a re-ordering + and we need to know the orignal indices + + Return: A tuple of (scores, indices, beams) where: + scores: (bsz x output_beam_size) + the scores of the chosen elements; output_beam_size can be + larger than input_beam_size, e.g., we may return + 2*input_beam_size to account for EOS + indices: (bsz x output_beam_size) + the indices of the chosen elements + beams: (bsz x output_beam_size) + the hypothesis ids of the chosen elements, in the range [0, input_beam_size) + """ + raise NotImplementedError + + @torch.jit.export + def set_src_lengths(self, src_lengths): + self.src_lengths = src_lengths + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], + beam_size: int): + """Initialize constraint states for constrained decoding (if supported). + + Args: + batch_constraints: (torch.Tensor, optional) + the list of constraints, in packed form + beam_size: (int) + the beam size + Returns: + *encoder_out* rearranged according to *new_order* + """ + pass + + def prune_sentences(self, batch_idxs: Tensor): + """ + Removes constraint states for completed sentences (if supported). + This is called from sequence_generator._generate() when sentences are + deleted from the batch. + + Args: + batch_idxs: Indices of *sentences* whose constraint state should be *kept*. + """ + pass + + def update_constraints(self, active_hypos: Tensor): + """ + Updates the constraint states by selecting the beam items that are retained. + This is called at each time step of sequence_generator._generate() when + the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. + + Args: + active_hypos: (batch size, beam size) + list of integers denoting, for each sentence, which beam candidate items + should be kept. + """ + pass + + +class BeamSearch(Search): + + def __init__(self, tgt_dict): + super().__init__(tgt_dict) + self.constraint_states = None + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + # Project back into relative indices and beams + beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='floor') + indices_buf = indices_buf.fmod(vocab_size) + + # At this point, beams_buf and indices_buf are single-dim and contain relative indices + return scores_buf, indices_buf, beams_buf + + +class PrefixConstrainedBeamSearch(Search): + + def __init__(self, tgt_dict, prefix_allowed_tokens_fn): + super().__init__(tgt_dict) + self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self.stop_on_max_len = True + + @torch.jit.export + def apply_mask(self, x, prev_output_tokens, original_batch_idxs): + beam_size = x.shape[0] // original_batch_idxs.shape[0] + original_batch_idxs = ( + original_batch_idxs.unsqueeze(-1).repeat( + (1, beam_size)).flatten().tolist()) + + mask = torch.full_like(x, -math.inf) + for sent_i, (sent, batch_i) in enumerate( + zip(prev_output_tokens, original_batch_idxs)): + mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 + + return mask + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Tensor, + prev_output_tokens: Tensor, + original_batch_idxs: Tensor, + ): + bsz, beam_size, vocab_size = lprobs.size() + + lprobs += self.apply_mask( + lprobs.view(bsz * beam_size, 1, vocab_size), + prev_output_tokens, + original_batch_idxs, + ).view(bsz, beam_size, vocab_size) + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + return scores_buf, indices_buf, beams_buf + + +class LexicallyConstrainedBeamSearch(Search): + """Implements lexically constrained beam search as described in + + Fast Lexically Constrained Decoding with Dynamic Beam + Allocation for Neural Machine Translation. Post & Vilar, + NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ + + and + + Improved Lexically Constrained Decoding for Translation and + Monolingual Rewriting. Hu et al, NAACL + 2019. https://www.aclweb.org/anthology/N19-1090/ + + This is accomplished by maintaining, for each beam hypothesis, a + ConstraintState object (see constraints.py) that tracks which + constraints have been generated and using this information to + shape the beam for each input sentence. + """ + + def __init__(self, tokenizer, representation): + super().__init__(tokenizer) + self.representation = representation + tgt_dict = {value: key for key, value in tokenizer.get_vocab().items()} + added = { + value: key + for key, value in tokenizer.get_added_vocab().items() + } + tgt_dict.update(added) + self.vocab_size = len(tgt_dict) + self.num_cands = 0 + self.supports_constraints = True + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], + beam_size: int): + self.constraint_states = [] + for constraint_tensor in batch_constraints: + if self.representation == 'ordered': + constraint_state = OrderedConstraintState.create( + constraint_tensor) + elif self.representation == 'unordered': + constraint_state = UnorderedConstraintState.create( + constraint_tensor) + + self.constraint_states.append( + [constraint_state for i in range(beam_size)]) + + @torch.jit.export + def prune_sentences(self, batch_idxs: Tensor): + self.constraint_states = [ + self.constraint_states[i] for i in batch_idxs.tolist() + ] + + @torch.jit.export + def update_constraints(self, active_hypos: Tensor): + if self.constraint_states: + batch_size = active_hypos.size(0) + for sentid in range(batch_size): + self.constraint_states[sentid] = [ + self.constraint_states[sentid][i] + for i in active_hypos[sentid] + ] + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + """ + A constrained step builds a large candidates list from the following: + - the top 2 * {beam_size} items over the whole beam + - for each item in the beam + - the top {each_k} (default 1) + - all next constraints + We then compute the constrained state of each beam item, and assign + stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so + on. We then sort by (stripe, score), and truncate the list at + 2 * beam size. + + Args: + step: the decoder step + lprobs: (batch size, beam size, target vocab) + the target-vocab distributions for each item in the beam. + Retrun: A tuple of (scores, indices, beams, constraints) where: + scores: (batch, output beam size) + the scores of the chosen elements + indices: (batch, output beam size) + the target vocab indices of the chosen elements + beams: (batch, output beam size) + the 0-indexed hypothesis ids of the chosen elements + constraints: (batch, output beam size) + the new constraint states + """ + each_k = 1 + device = lprobs.device + + batch_size, beam_size, vocab_size = lprobs.size() + + self.num_cands = min( + # Just take the k-best. We'll get another k from the 1-best from each + # row, plus more from the constraints + beam_size * 2, + lprobs.view(batch_size, -1).size(1) + - 1, # -1 so we never select pad + ) + + # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items + constraint_states = self.constraint_states + if constraint_states and step > 0: + not_finished_indices = [] + for sentno, sent_constraints in enumerate(constraint_states): + for beamno, state in enumerate(sent_constraints): + index = sentno * beam_size + beamno + if not state.finished: + not_finished_indices.append(index) + not_finished_indices = torch.tensor(not_finished_indices) + if not_finished_indices.numel() > 0: + lprobs.view(batch_size * beam_size, -1)[not_finished_indices, + self.eos] = -math.inf + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam entry for each batch item + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(batch_size, -1), + self.num_cands, + ) + scores_buf, indices_buf = top_prediction + # Project back into relative indices and beams + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + + # Short circuit if there are no constraints in this batch + if not constraint_states: + return scores_buf, indices_buf, beams_buf + + # STEP 1: get top-1 from each hypothesis across all sentences in the batch + if step > 0: + top_scores, top_indices = torch.topk( + lprobs.view(batch_size * beam_size, -1), + k=each_k, + dim=1, + ) + top_scores = top_scores.view(batch_size, -1) + top_indices = top_indices.view(batch_size, -1) + scores_buf = torch.cat((scores_buf, top_scores), dim=1) + indices_buf = torch.cat((indices_buf, top_indices), dim=1) + new_beams = torch.arange( + 0, beam_size, device=device).repeat(batch_size, 1) + beams_buf = torch.cat((beams_buf, new_beams), dim=1) + + # Now, process sentences in the batch one by one. + new_scores_buf = torch.zeros((batch_size, 2 * beam_size), + device=device) + new_indices_buf = torch.zeros((batch_size, 2 * beam_size), + device=device).long() + new_beams_buf = torch.zeros((batch_size, 2 * beam_size), + device=device).long() + for sentno, states in enumerate(constraint_states): + scores, indices, beams, new_states = self.step_sentence( + step, + sentno, + lprobs[sentno], + constraint_states[sentno], + beams_buf[sentno].clone(), + indices_buf[sentno].clone(), + scores_buf[sentno].clone(), + ) + new_scores_buf[sentno] = scores + new_indices_buf[sentno] = indices + new_beams_buf[sentno] = beams + self.constraint_states[sentno] = new_states + + return new_scores_buf, new_indices_buf, new_beams_buf + + @torch.jit.export + def step_sentence( + self, + step: int, + sentno: int, + lprobs: Tensor, + constraint_states: List[List[ConstraintState]], + beams_buf: Tensor, + indices_buf: Tensor, + scores_buf: Tensor, + ): + """Does per-sentence processing. Adds all constraints for each + hypothesis to the list of candidates; then removes duplicates, + sorts, and dynamically stripes across the banks. All tensor inputs + are collapsed to those pertaining to a single input sentence. + """ + device = lprobs.device + + # STEP 2: Add all constraints for each beam item + for beamno, state in enumerate(constraint_states): + next_tokens = torch.tensor( + list(state.next_tokens()), device=device).long() + if next_tokens.numel() != 0: + indices_buf = torch.cat((indices_buf, next_tokens)) + next_beams = ( + torch.tensor(beamno, device=device).repeat( + next_tokens.size(0)).long()) + beams_buf = torch.cat((beams_buf, next_beams)) + next_values = lprobs[beamno].take(next_tokens.view(-1)) + scores_buf = torch.cat((scores_buf, next_values)) + + # At the 0th time step, there is just one beam item + if step == 0: + break + + # STEP 3: Compute the "bank" for each candidate. This is the + # number of constraints it's generated. We need this so that + # we can do round-robin allocation of the beam across these + # banks. If C is the number of constraints, we select the best + # item in bank C, then the best in bank C-1, etc, followed by + # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so + # on, until the maximum beam size. We accomplish this by + # creating a sort key and striping across the banks. + + # Compute the new states for all candidates + cands_size = indices_buf.size(0) + constraint_states = [ + constraint_states[beams_buf[i]].advance(indices_buf[i]) + for i in range(cands_size) + ] + + banks = torch.tensor([state.bank for state in constraint_states], + device=device) + + # STEP 4: Sort + num_constraint_tokens = len(state.tokens) + + # Sort by keys (bank, score) (i.e., sort banks together, and scores + # within banks). AFAIK pytorch doesn't support either stable sort or + # multi-key sorting, so we have to hack this. + MAX_SCORE = -100 + sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf + sort_values, sort_indices = sort_key.sort(dim=0, descending=True) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + banks = banks[sort_indices] + + # Sort the constraints to follow suit + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 5: Remove duplicates. The topk calls (overall and + # per-row) plus the per-row generation of constraints will + # produce duplicates. Here we remove them. + + def roll(t): + """Rolls a 1d tensor left by 1. + + [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] + """ + return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) + + # We map candidates (beam, token_id) to a single dimension. + # This is then shifted by 1. We can then easily identify + # duplicates and create a mask that identifies unique + # extensions. + uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf + uniques_mask = roll(uniques_mask) != uniques_mask + + # Use the mask to pare down the data structures + scores_buf = torch.masked_select(scores_buf, uniques_mask) + indices_buf = torch.masked_select(indices_buf, uniques_mask) + beams_buf = torch.masked_select(beams_buf, uniques_mask) + banks = torch.masked_select(banks, uniques_mask) + i = 1 + for mask in uniques_mask[1:]: + if not mask: + constraint_states.pop(i) + i += mask + + # STEP 6: Assign IDs round-robin across banks, sort, and + # truncate. Now that the candidates are sorted by (bank, + # score) and uniqed, we dynamically allocate the {beam_size} + # beam by striping across the candidates. These stripes will + # be used as sort keys to do round-robin selection. This is + # accomplished in a single pass with offsets. Sorting by + # highest-banks (furthest-along hypotheses) first ensures + # progress through the constraints. + # + # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 + # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 + # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 + # = 0 5 10 1 6 11 13 2 7 12 3 8 + # + # Sorting by this then gives the following banks: + # + # 3 2 1 0 3 2 1 0 3 2 1 2 + # + # We'll take the top {beam_size} of these. + stripe_offsets = [ + offset * (len(banks) + 1) for offset in range(len(banks) + 1) + ] + stripes = torch.zeros_like(banks) + cur_bank_count = -1 + cur_bank = banks[0] + for i, bank in enumerate(banks): + if bank != cur_bank: + cur_bank_count = 0 + cur_bank = bank + else: + cur_bank_count += 1 + stripes[i] = num_constraint_tokens - bank + stripe_offsets[ + cur_bank_count] + + # STEP 7: Sort by the stripes values + sort_values, sort_indices = stripes.sort(dim=0) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 8: Truncate to the candidates size! + scores_buf = scores_buf[:self.num_cands] + indices_buf = indices_buf[:self.num_cands] + beams_buf = beams_buf[:self.num_cands] + + return scores_buf, indices_buf, beams_buf, constraint_states + + +class LengthConstrainedBeamSearch(Search): + + def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): + super().__init__(tgt_dict) + self.min_len_a = min_len_a + self.min_len_b = min_len_b + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.beam = BeamSearch(tgt_dict) + self.needs_src_lengths = True + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + min_lens = self.min_len_a * self.src_lengths + self.min_len_b + max_lens = self.max_len_a * self.src_lengths + self.max_len_b + lprobs[step < min_lens, :, self.eos] = -math.inf + lprobs[step >= max_lens, :, self.eos] = 0 + return self.beam.step(step, lprobs, scores) + + +class DiverseBeamSearch(Search): + """Diverse Beam Search. + + See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence + Models" for details. + + We only implement the Hamming Diversity penalty here, which performed best + in the original paper. + """ + + def __init__(self, tgt_dict, num_groups, diversity_strength): + super().__init__(tgt_dict) + self.num_groups = num_groups + self.diversity_strength = -diversity_strength + self.beam = BeamSearch(tgt_dict) + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + if beam_size % self.num_groups != 0: + raise ValueError( + 'DiverseBeamSearch requires --beam to be divisible by the number of groups' + ) + + # initialize diversity penalty + diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs) + + scores_G, indices_G, beams_G = [], [], [] + for g in range(self.num_groups): + lprobs_g = lprobs[:, g::self.num_groups, :] + scores_g = scores[:, g::self.num_groups, :] if step > 0 else None + + # apply diversity penalty + if g > 0: + lprobs_g = torch.add( + lprobs_g, + other=diversity_buf.unsqueeze(1), + alpha=self.diversity_strength, + ) + else: + lprobs_g = lprobs_g.contiguous() + + scores_buf, indices_buf, beams_buf = self.beam.step( + step, lprobs_g, scores_g) + beams_buf.mul_(self.num_groups).add_(g) + + scores_G.append(scores_buf.clone()) + indices_G.append(indices_buf.clone()) + beams_G.append(beams_buf.clone()) + + # update diversity penalty + diversity_buf.scatter_add_( + 1, indices_buf, + torch.ones(indices_buf.size()).to(diversity_buf)) + + # interleave results from different groups + scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1) + indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1) + beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1) + return scores_buf, indices_buf, beams_buf + + +class Sampling(Search): + sampling_topk: int + sampling_topp: float + + def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0): + super().__init__(tgt_dict) + self.sampling_topk = sampling_topk + self.sampling_topp = sampling_topp + + def _sample_topp(self, lprobs): + """Sample among the smallest set of elements whose cumulative probability mass exceeds p. + + See `"The Curious Case of Neural Text Degeneration" + (Holtzman et al., 2019) `_. + + Args: + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + + Return: A tuple of (trimed_probs, truncated_indices) where: + trimed_probs: (bsz x input_beam_size x ?) + the model's probabilities over the elements selected to sample from. The + width of the third dimension is determined by top-P. + truncated_indices: (bsz x input_beam_size x ?) + the indices of the chosen elements. + """ + probs = lprobs.exp_() + + # sort the last dimension (vocab dimension) in descending order + sorted_probs, sorted_indices = probs.sort(descending=True) + + # compute a mask to indicate the words to be included in the top-P set. + cumsum_probs = sorted_probs.cumsum(dim=2) + mask = cumsum_probs.lt(self.sampling_topp) + + # note that mask was computed by 'lt'. One more word needs to be included + # so that the cumulative probability mass can exceed p. + cumsum_mask = mask.cumsum(dim=2) + last_included = cumsum_mask[:, :, -1:] + last_included.clamp_(0, mask.size()[2] - 1) + mask = mask.scatter_(2, last_included, 1) + + # truncate unnecessary dims. + max_dim = last_included.max() + truncated_mask = mask[:, :, :max_dim + 1] + truncated_probs = sorted_probs[:, :, :max_dim + 1] + truncated_indices = sorted_indices[:, :, :max_dim + 1] + + # trim the words that are not in top-P by setting their probabilities + # to 0, so that they would not be sampled later. + trim_mask = ~truncated_mask + trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) + return trimed_probs, truncated_indices + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + + if self.sampling_topp > 0: + # only sample from the smallest set of words whose cumulative probability mass exceeds p + probs, top_indices = self._sample_topp(lprobs) + elif self.sampling_topk > 0: + # only sample from top-k candidates + lprobs, top_indices = lprobs.topk(self.sampling_topk) + probs = lprobs.exp_() + else: + probs = lprobs.exp_() + + # dummy data to be consistent with true branch for type check + top_indices = torch.empty(0).to(probs) + # sample + if step == 0: + indices_buf = torch.multinomial( + probs.view(bsz, -1), + beam_size, + replacement=True, + ).view(bsz, beam_size) + else: + indices_buf = torch.multinomial( + probs.view(bsz * beam_size, -1), + 1, + replacement=True, + ).view(bsz, beam_size) + + if step == 0: + # expand to beam size + probs = probs.expand(bsz, beam_size, -1) + + # gather scores + scores_buf = torch.gather( + probs, dim=2, index=indices_buf.unsqueeze(-1)) + scores_buf = scores_buf.log_().view(bsz, -1) + + # remap indices if using top-k or top-P sampling + if self.sampling_topk > 0 or self.sampling_topp > 0: + indices_buf = torch.gather( + top_indices.expand(bsz, beam_size, -1), + dim=2, + index=indices_buf.unsqueeze(-1), + ).squeeze(2) + + if step == 0: + beams_buf = indices_buf.new_zeros(bsz, beam_size) + else: + beams_buf = torch.arange(0, + beam_size).to(indices_buf).repeat(bsz, 1) + # make scores cumulative + scores_buf.add_( + torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf)) + + return scores_buf, indices_buf, beams_buf + + +class DiverseSiblingsSearch(Search): + """ + Beam search with diverse siblings. + + See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details. + https://arxiv.org/abs/1611.08562 + + 1/ Calculate hypotheses for each beam + 2/ Intra-sibling ordering + 3/ Rewrite scores + 4/ Choose top K hypotheses + + if diversity_rate == 0 is equivalent to BeamSearch + """ + + def __init__(self, tgt_dict, diversity_rate): + super().__init__(tgt_dict) + self.diversity_rate = diversity_rate + self.beam = BeamSearch(tgt_dict) + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + k = min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ) + s_list: List[Tensor] + i_list: List[Tensor] + s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)] + i_list = [ + torch.LongTensor().to(device=lprobs.device) + for i in range(beam_size) + ] + sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate + + if step == 0: + return self.beam.step(step, lprobs, scores) + lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) + + # 1/ Calculate hypotheses for each beam + for i in range(beam_size): + torch.topk( + lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i])) + i_list[i].fmod_(vocab_size) + + # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores + s_list[i].sub_(sibling_score) + + # 4/ Choose top K hypotheses + indices = torch.stack(i_list, dim=1).view(bsz, -1) + + final_scores = torch.empty(0).to(lprobs) + final_indices = torch.LongTensor().to(device=lprobs.device) + final_beams = torch.LongTensor().to(device=lprobs.device) + (final_scores, final_indices) = torch.topk( + torch.stack(s_list, dim=1).view(bsz, -1), + k, + ) + + final_beams = final_indices // k + + for i in range(bsz): + final_indices[i] = indices[i][final_indices[i]] + + return final_scores, final_indices, final_beams diff --git a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py new file mode 100644 index 00000000..e42d3c8e --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py @@ -0,0 +1,972 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 + +import math +import sys +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +from modelscope.models.multi_modal.ofa.generate import search +from modelscope.models.multi_modal.ofa.generate.ngram_repeat_block import \ + NGramRepeatBlock + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + r""" + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + return expanded_mask.masked_fill(expanded_mask.bool(), + torch.finfo(dtype).min) + + +class SequenceGenerator(nn.Module): + + def __init__(self, + tokenizer, + beam_size=1, + max_len_a=0, + max_len_b=200, + max_len=0, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + constraint_trie=None, + constraint_range=None, + gen_code=False, + gen_box=False, + ignore_eos=False, + zero_shot=False): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + self.gen_code = gen_code + self.gen_box = gen_box + self.ignore_eos = ignore_eos + self.tokenizer = tokenizer + self.tgt_dict = { + value: key + for key, value in tokenizer.get_vocab().items() + } + added = { + value: key + for key, value in tokenizer.get_added_vocab().items() + } + self.tgt_dict.update(added) + self.pad = tokenizer.pad_token_id + self.unk = tokenizer.unk_token_id + self.bos = tokenizer.bos_token_id + self.eos = tokenizer.eos_token_id + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) if + symbols_to_strip_from_output is not None else {self.bos, self.eos}) + self.vocab_size = len(self.tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.max_len = max_len + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.temperature = temperature + self.match_source_len = match_source_len + self.zero_shot = zero_shot + + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None + + assert temperature > 0, '--temperature must be greater than 0' + + self.search = ( + search.BeamSearch(self.tokenizer) + if search_strategy is None else search_strategy) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, 'needs_src_lengths') + and self.search.needs_src_lengths) + + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + + self.constraint_trie = constraint_trie + + self.constraint_start = None + self.constraint_end = None + if constraint_range is not None: + constraint_start, constraint_end = constraint_range.split(',') + self.constraint_start = int(constraint_start) + self.constraint_end = int(constraint_end) + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + @torch.no_grad() + def generate(self, models, sample: Dict[str, Dict[str, Tensor]], + **kwargs) -> List[List[Dict[str, Tensor]]]: + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(models, sample, **kwargs) + + def _generate( + self, + models, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + model = EnsembleModel(models) + incremental_states = torch.jit.annotate( + List[Tuple[Tuple[torch.Tensor]]], + [ + torch.jit.annotate(Tuple[Tuple[torch.Tensor]], {}) + for i in range(model.models_size) + ], + ) + net_input = sample['net_input'] + + if 'src_tokens' in net_input: + src_tokens = net_input['src_tokens'] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ((src_tokens.ne(self.eos) + & src_tokens.ne(self.pad)).long().sum(dim=1)) + elif 'input_ids' in net_input: + src_tokens = net_input['input_ids'] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ((src_tokens.ne(self.eos) + & src_tokens.ne(self.pad)).long().sum(dim=1)) + elif 'source' in net_input: + src_tokens = net_input['source'] + src_lengths = ( + net_input['padding_mask'].size(-1) + - net_input['padding_mask'].sum(-1) + if net_input['padding_mask'] is not None else torch.tensor( + src_tokens.size(-1)).to(src_tokens)) + elif 'features' in net_input: + src_tokens = net_input['features'] + src_lengths = ( + net_input['padding_mask'].size(-1) + - net_input['padding_mask'].sum(-1) + if net_input['padding_mask'] is not None else torch.tensor( + src_tokens.size(-1)).to(src_tokens)) + else: + raise Exception( + 'expected src_tokens or source in net input. input keys: ' + + str(net_input.keys())) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = int(self.max_len_a * src_len + self.max_len_b) + assert ( + self.min_len <= max_len + ), 'min_len cannot be larger than max_len, please adjust these!' + # compute the encoder output for each beam + with torch.autograd.profiler.record_function( + 'EnsembleModel: forward_encoder'): + encoder_outs = model.forward_encoder(net_input) + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + + # initialize buffers + scores = (torch.zeros(bsz * beam_size, + max_len + 1).to(src_tokens).float() + ) # +1 for eos; pad is never chosen for scoring + tokens = (torch.zeros(bsz * beam_size, + max_len + 2).to(src_tokens).long().fill_( + self.pad)) # +2 for eos and pad + tokens[:, 0] = self.bos + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = (torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [ + torch.jit.annotate(List[Dict[str, Tensor]], []) + for i in range(bsz) + ], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + # a boolean array indicating if the sentence at the index is finished or not + finished = [False for i in range(bsz)] + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = ((torch.arange(0, bsz) + * beam_size).unsqueeze(1).type_as(tokens).to( + src_tokens.device)) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to( + src_tokens.device) + + reorder_state: Optional[Tensor] = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if 'id' in sample and isinstance(sample['id'], Tensor): + original_batch_idxs = sample['id'] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange( + batch_idxs.numel()).type_as(batch_idxs) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size) + original_batch_idxs = original_batch_idxs[batch_idxs] + model.reorder_incremental_state(incremental_states, + reorder_state) + encoder_outs = model.reorder_encoder_out( + encoder_outs, reorder_state) + + with torch.autograd.profiler.record_function( + 'EnsembleModel: forward_decoder'): + lprobs, avg_attn_scores = model.forward_decoder( + tokens[:, :step + 1], + encoder_outs, + incremental_states, + self.temperature, + constraint_trie=self.constraint_trie, + constraint_start=self.constraint_start, + constraint_end=self.constraint_end, + gen_code=self.gen_code, + zero_shot=self.zero_shot, + prefix_tokens=prefix_tokens) + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, :step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + # handle prefix tokens (possibly with different lengths) + if (prefix_tokens is not None and step < prefix_tokens.size(1) + and step < max_len): + lprobs, tokens, scores = self._prefix_tokens( + step, lprobs, scores, tokens, prefix_tokens, beam_size) + elif step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs[:, self.eos] = -math.inf + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + if (self.gen_code or self.gen_box) and step < max_len: + lprobs[:, :4] = -math.inf + if self.gen_box: + lprobs[:, -1] = -math.inf + if (step + 1) % 5 == 0: + lprobs[:, self.constraint_start:59457] = -math.inf + else: + lprobs[:, 59457:] = -math.inf + + # handle max length constraint + if step >= max_len: + lprobs[:, :self.eos] = -math.inf + lprobs[:, self.eos + 1:] = -math.inf + if self.ignore_eos: + lprobs[:, self.eos] = 1 + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty(bsz * beam_size, + avg_attn_scores.size(1), + max_len + 2).to(scores) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(lprobs) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.repeat_ngram_blocker is not None: + # process prefix_tokens + p_toks_len = prefix_tokens.ne(self.pad).sum( + dim=1) if prefix_tokens is not None else None + if p_toks_len is not None: + p_toks_len_beam = p_toks_len.unsqueeze(-1).repeat( + 1, beam_size).view(-1) + no_repeat_ngram_size = self.repeat_ngram_blocker.no_repeat_ngram_size + out_prefix = p_toks_len_beam < ( + step + no_repeat_ngram_size - 1) + else: + out_prefix = torch.ones(bsz * beam_size).bool() + ngram_blocker_tokens = tokens[out_prefix] + ngram_blocker_lprobs = lprobs[out_prefix] + ngram_blocker_bsz = torch.div( + out_prefix.sum(), beam_size, rounding_mode='trunc') + + lprobs[out_prefix] = self.repeat_ngram_blocker( + tokens=ngram_blocker_tokens, + lprobs=ngram_blocker_lprobs, + bsz=ngram_blocker_bsz, + beam_size=beam_size, + step=step) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, :step + 1], + original_batch_idxs, + ) + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to( + eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]) + + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len, f'{step} < {max_len}' + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device) + batch_mask[finalized_sents] = False + batch_idxs = torch.arange( + bsz, device=cand_indices.device).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, -1) + tokens = tokens.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, -1) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~( # noqa + (~cands_to_ignore) & (~eos_mask[:, :beam_size])) # noqa + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[:eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather( + cand_bbsz_idx, dim=1, index=active_hypos) + active_scores = torch.gather( + cand_scores, dim=1, index=active_hypos) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, :step + 1] = torch.index_select( + tokens[:, :step + 1], dim=0, index=active_bbsz_idx) + # Select the next token for each of them + tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( + cand_indices, dim=1, index=active_hypos) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx) + scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( + cand_scores, dim=1, index=active_hypos) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, :step + 2] = torch.index_select( + attn[:, :, :step + 2], dim=0, index=active_bbsz_idx) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem['score'].item()) for elem in finalized[sent]]) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [ + finalized[sent][ssi] for ssi in sorted_scores_indices + ] + finalized[sent] = torch.jit.annotate(List[Dict[str, Tensor]], + finalized[sent]) + return finalized + + def _prefix_tokens(self, step: int, lprobs, scores, tokens, prefix_tokens, + beam_size: int): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat( + 1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + if self.constraint_trie is None: + lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 + else: + lprobs[prefix_mask] = -math.inf + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), + prefix_lprobs[prefix_mask]) + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, + tokens.size(-1))[:, 0, + 1:step + 1] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, + beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, + beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, + beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select( + 0, bbsz_idx)[:, 1:step + 2] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1:step + + 2] if attn is not None else None) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1)**self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) + + unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode='floor') + sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) + + # Create a set of "{sent}{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # For every finished beam item + # sentence index in the current (possibly reduced) batch + seen = (sent << 32) + unfin_idx + unique_seen: List[int] = torch.unique(seen).tolist() + + if self.match_source_len: + condition = step > torch.index_select(src_lengths, 0, unfin_idx) + eos_scores = torch.where(condition, torch.tensor(-math.inf), + eos_scores) + sent_list: List[int] = sent.tolist() + for i in range(bbsz_idx.size()[0]): + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent_list[i]]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent_list[i]].append({ + 'tokens': + tokens_clone[i], + 'score': + eos_scores[i], + 'attention': + hypo_attn, # src_len x tgt_len + 'alignment': + torch.empty(0), + 'positional_scores': + pos_scores[i], + }) + + newly_finished: List[int] = [] + for unique_s in unique_seen: + # check termination conditions for this sentence + unique_sent: int = unique_s >> 32 + unique_unfin_idx: int = unique_s - (unique_sent << 32) + + if not finished[unique_sent] and self.is_finished( + step, unique_unfin_idx, max_len, len( + finalized[unique_sent]), beam_size): + finished[unique_sent] = True + newly_finished.append(unique_unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + +class EnsembleModel(nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + # self.has_incremental: bool = False + # if all( + # hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + # for m in models + # ): + # self.has_incremental = True + + self.has_incremental = True + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, 'encoder') + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min([ + m.max_decoder_positions() + for m in self.models if hasattr(m, 'max_decoder_positions') + ] + [sys.maxsize]) # + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + encoder_input = { + k: v + for k, v in net_input.items() if k != 'decoder_input_ids' + } + encoder_input['output_hidden_states'] = True + return [ + model.encoder.forward(**encoder_input) for model in self.models + ] + + @torch.jit.export + def forward_decoder(self, + tokens, + encoder_outs: List[Dict[str, List[Tensor]]], + incremental_states: List[Optional[torch.Tensor]], + temperature: float = 1.0, + constraint_trie=None, + constraint_start=None, + constraint_end=None, + gen_code=False, + zero_shot=False, + prefix_tokens=None): + log_probs = [] + avg_attn: Optional[Tensor] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None + code_mask = (tokens.new_ones(tokens.size(0)) * gen_code).bool() + + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + encoder_hidden_states = encoder_out.last_hidden_state + encoder_attention_mask = _expand_mask( + encoder_out.padding_mask, encoder_hidden_states.dtype, + tokens.shape[-1]) + src_pos_embed = encoder_out.position_embedding + + # if tokens.eq(self.single_model.config.pad_token_id).any(): + attention_mask = tokens.eq(self.single_model.padding_idx) + + # decode each model + if self.has_incremental_states(): + decoder_out = model.decoder.forward( + input_ids=tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_mask, + src_pos_embed=src_pos_embed, + past_key_values=incremental_states[i], + use_cache=True, + output_attentions=True) + else: + if hasattr(model, 'decoder'): + # decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out) + decoder_out = model.decoder.forward( + input_ids=tokens, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_mask, + src_pos_embed=src_pos_embed) + else: + decoder_out = model.forward(tokens) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + + if 'cross_attentions' in decoder_out: + attn = decoder_out['cross_attentions'][-1].transpose(1, 0) + attn = attn.mean(dim=0) # (B, tgt_len, src_len) + if attn is not None: + attn = attn[:, -1, :] + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else attn, + ) + + beam_size = decoder_out_tuple[0].size(0) // prefix_tokens.size( + 0) if prefix_tokens is not None else 0 + if constraint_trie is not None and not zero_shot: + assert constraint_start is None and constraint_end is None + constraint_masks = decoder_out_tuple[0].new_zeros( + decoder_out_tuple[0].size()).bool() + constraint_prefix_tokens = tokens.tolist() + for token_index, constraint_prefix_token in enumerate( + constraint_prefix_tokens): + prefix_len = prefix_tokens[token_index // beam_size].ne( + 1).sum().item() if prefix_tokens is not None else 0 + if len(constraint_prefix_token) > prefix_len: + constraint_prefix_token = [ + 0 + ] + constraint_prefix_token[prefix_len + 1:] + constraint_nodes = constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_masks[token_index][:, + constraint_nodes] = True + else: + constraint_masks[token_index] = True + decoder_out_tuple[0].masked_fill_(~constraint_masks, -math.inf) + if constraint_start is not None and constraint_end is not None and not zero_shot: + assert constraint_trie is None + decoder_out_tuple[0][:, :, 4:constraint_start] = -math.inf + decoder_out_tuple[0][:, :, constraint_end:] = -math.inf + + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None) + if constraint_trie is not None and zero_shot: + assert constraint_start is None and constraint_end is None + constraint_masks = decoder_out_tuple[0].new_zeros( + decoder_out_tuple[0].size()).bool() + constraint_prefix_tokens = tokens.tolist() + for token_index, constraint_prefix_token in enumerate( + constraint_prefix_tokens): + constraint_nodes = constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_masks[token_index][:, constraint_nodes] = True + probs.masked_fill_(~constraint_masks, -math.inf) + if constraint_start is not None and constraint_end is not None and zero_shot: + assert constraint_trie is None + probs[:, :, 4:constraint_start] = -math.inf + probs[:, :, constraint_end:] = -math.inf + probs = probs[:, -1, :] + if self.models_size == 1: + return probs, attn + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = torch.logsumexp( + torch.stack(log_probs, dim=0), dim=0) - math.log(self.models_size) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out(self, + encoder_outs: Optional[List[Dict[str, + List[Tensor]]]], + new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[Dict[str, List[Tensor]]] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order)) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Optional[torch.Tensor]], + new_order, + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + model.decoder.reorder_incremental_state_scripting( + incremental_states[i], new_order) diff --git a/modelscope/models/multi_modal/ofa/generate/token_generation_constraints.py b/modelscope/models/multi_modal/ofa/generate/token_generation_constraints.py new file mode 100644 index 00000000..13fb3fcf --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/token_generation_constraints.py @@ -0,0 +1,512 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE +"""Implements tracking of constraints for a beam item. + +A list of constraints is given as a list of one or more token +sequences, each of length at least one token. For example, for an input sentence + +> Die maschinelle Übersetzung ist schwer zu kontrollieren. + +We could have the constraints: +* to influence +* hard + +There are two implementations: +* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. +* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. + +The difference is that in the first, the constraints are assumed to be +in order; the algorithm will permit zero or more tokens between them. +In the second, the constraints are not ordered, so many orderings will +be explored. + +The same sequence can be present any number of times, and will appear +that many times in the output. +""" + +from collections import Counter +from typing import List, Set + +import torch + + +class ConstraintState: + + def __init__(self): + pass + + +def pack_constraints( + batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: + """Takes a list of list of constraints in tensor form (a list of + tensor constraints for each sentence) and transforms it into a + packed Tensor. For example, here is a batch of size 3 with 3, 0, + and 1 constraints: + + [ [ [3 1 2], [3], [4 5 6 7], ] + [], + [ [1 8 9 10 1 4 11 12], ] + ] + + Its corresponding packed structure is: + + [ [ 3 3 1 2 0 3 0 4 5 6 7 0], + [ 0 0 0 0 0 0 0 0 0 0 0 0], + [ 1 1 8 9 10 1 4 11 12 0 0 0] ] + + The packed tensor has shape (batch size, maxlen), where + maxlen is defined below. Each row contains concatenated + constraint tokens for that sentence, with 0 appended after + each constraint. The first item in each row is the number + of constraints for that sentence. So maxlen is the maximum + of + + (number of constraints) + (sum length of constraints) + 1. + + across all sentences in the batch. + """ + # The maximum word length of concatenated constraints for any sentence + max_constraints_len = 1 + for sentence_constraints in batch_constraints: + if len(sentence_constraints): + # number of constraints, plus sum of constrain lens, plus a zero after each + constraints_len = (1 + + sum([c.size(0) for c in sentence_constraints]) + + len(sentence_constraints)) + max_constraints_len = max(max_constraints_len, constraints_len) + + batch_size = len(batch_constraints) + constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() + for i, sentence_constraints in enumerate(batch_constraints): + constraints_tensor[i, 0] = len(sentence_constraints) + offset = 1 + for j, constraint in enumerate(sentence_constraints): + this_len = constraint.size(0) + constraints_tensor[i, offset:offset + this_len] = constraint + offset += this_len + 1 + + return constraints_tensor.long() + + +def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Transforms *one row* of a packed constraint tensor (e.g., for one + sentence in the batch) into a list of constraint tensors. + """ + constraint_list = [] + num_constraints = constraint_tensor[0] + constraints = constraint_tensor.tolist() + offset = 1 + for i in range(num_constraints): + where = constraints.index(0, offset) + constraint_list.append(constraint_tensor[offset:where]) + offset = where + 1 + + return constraint_list + + +class ConstraintNode: + """ + Represents a node in a trie managing unordered constraints. + """ + + def __init__(self, token: int = None, parent=None): + # The token associate with this node (None for the root) + self.token = int(token) if token is not None else None + # The parent (None at the root) + self.parent = parent + # Whether this node is a completed constraint + self.terminal = 0 + # List of child nodes + self.children = {} + + # The cumulative number of constraints from this point in the + # trie forward + self.num_constraints = 0 + + @property + def id(self): + return self.token + + def __str__(self): + term = self.terminal != 0 + return f'[{self.token}].{term}#{self.num_constraints}' + + def __getitem__(self, key: int): + return self.children.get(key, None) + + def next_tokens(self) -> Set[int]: + """The set of child labels.""" + return set(self.children.keys()) + + @staticmethod + def create(constraints: List[List[int]]): + root = ConstraintNode() + for sequence in constraints: + root.add_sequence(sequence) + + return root + + @staticmethod + def print_graph(node: 'ConstraintNode'): + if len(node.children) == 0: + return str(node) + else: + s = f'({node}' + for child in node.children.values(): + s += ' ' + ConstraintNode.print_graph(child) + s += ')' + return s + + def token_counts(self) -> Counter: + """Returns a counter of the number of times each token is used + in a constraint. + """ + token_counts = Counter() + kids = list(self.children.values()) + while len(kids) > 0: + kid = kids.pop() + token_counts[kid.id] += kid.num_constraints + kids += list(kid.children.values()) + + return token_counts + + def tokens(self) -> Set[int]: + """Returns the set of tokens in constraints.""" + return set(self.token_counts().keys()) + + def add_sequence(self, sequence: List[int]): + """Adds a constraint, represented as a list of integers, to + the trie.""" + assert len(sequence) > 0 + + token = int(sequence[0]) + if token not in self.children: + self.children[token] = ConstraintNode(token, parent=self) + + node = self.children[token] + if len(sequence) == 1: + node.terminal += 1 + node.num_constraints += 1 + parent = node.parent + while parent is not None: + parent.num_constraints += 1 + parent = parent.parent + else: + node.add_sequence(sequence[1:]) + + +class UnorderedConstraintState(ConstraintState): + """ + Records progress through the set of constraints for each item in the beam + using a trie. + """ + + def __init__(self, + node: ConstraintNode, + copy_from: 'ConstraintState' = None): + self.node = node + + if copy_from is None: + # The root node + self.root = node + # The set of states in the graph that have been completed + self.completed = Counter() + # The... + self.generated = Counter() + # The list of tokens we need to generate + self.needed_tokens = self.root.tokens() + else: + self.completed = Counter(copy_from.completed) + self.generated = Counter(copy_from.generated) + self.root = copy_from.root + + # Mark the node as generated + if self.node != self.root: + self.generated[node] += 1 + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + constraint_trie_root = ConstraintNode.create(constraint_list) + return UnorderedConstraintState(constraint_trie_root) + + def __str__(self): + gen_str = ','.join([str(node) for node in self.generated]) + return f'{self.name}/{self.bank}({gen_str})x{self.num_completed}' + + def __copy__(self): + copied_state = UnorderedConstraintState(self.node, copy_from=self) + return copied_state + + def copy(self): + return self.__copy__() + + @property + def name(self): + if self.node.id is None: + return 'ROOT' + else: + return str(self.node.id) + + @property + def is_root(self): + return self.node == self.root + + @property + def bank(self): + return sum(self.generated.values()) + + @property + def num_completed(self): + """The number of constraints (not constraint tokens) that are completed. + In addition to the already-completed states, we need to account for the + current state, which might get marked as completed when another token + is generated. + """ + in_final = self.node.terminal and self.completed[ + self.node] < self.node.terminal + return sum(self.completed.values()) + in_final + + @property + def finished(self): + return self.root.num_constraints - self.num_completed == 0 + + @property + def token_counts(self): + return self.root.token_counts() + + @property + def tokens(self): + return self.root.tokens() + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + if self.node != self.root: + return self.root.next_tokens().union(self.node.next_tokens()) + else: + return self.root.next_tokens() + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + + next_state = None + child = self.node[token] + if child is not None and self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + + def rewind(): + """If we're mid-trie and an "illegal" token is chosen next, we need + to reset our state to the root state. However, along the way, we need + to check whether a prefix of the current trie state represents a state + we could mark as completed. + """ + node = self.node + while node != self.root: + if node.terminal and self.completed[node] < node.terminal: + next_state.completed[node] += 1 + return + + next_state.generated[node] -= 1 + node = node.parent + + # Fall off the graph, check the root + if next_state is None and token in self.root.next_tokens(): + child = self.root[token] + # We can only traverse this edge if it's not saturated + if self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + else: + next_state = UnorderedConstraintState( + self.root, copy_from=self) + + # Rewind + rewind() + + elif next_state is None: + next_state = UnorderedConstraintState(self.root, copy_from=self) + # Rewind + rewind() + + return next_state + + +class ConstraintSequence: + + def __init__(self, sequences: List[List[int]]): + """Represents a set of possibly multitoken constraints by + concatenating them and internally recording the end points. + """ + self.sequences = [] + self.endpoints = [] + self.num_tokens = 0 + self.tokens = set() + for sequence in sequences: + for token in sequence: + self.tokens.add(token) + self.num_tokens += len(sequence) + self.endpoints += [False + for x in range(len(sequence) - 1)] + [True] + self.sequences += sequence + + def __getitem__(self, key: int): + return self.sequences[key] + + def __len__(self): + return len(self.sequences) + + def __str__(self): + return str(self.sequences) + + +class OrderedConstraintState(ConstraintState): + """ + Records progress through the set of linear nonbranching constraints with gaps. + """ + + def __init__(self, sequence: ConstraintSequence, state: int = -1): + self.sequence = sequence + self.state = state + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + return OrderedConstraintState(ConstraintSequence(constraint_list), -1) + + def __str__(self): + return f'{self.state}/{self.bank}x{self.num_completed}' + + def __copy__(self): + return OrderedConstraintState(self.sequence, self.state) + + def copy(self): + return self.__copy__() + + @property + def num_completed(self): + if self.state == -1: + return 0 + count = len( + list( + filter(lambda x: x, + self.sequence.endpoints[0:self.state + 1]))) + return count + + @property + def is_root(self): + return self.state == -1 + + @property + def name(self): + if self.state == -1: + return 'ROOT' + else: + return str(self.sequence[self.state]) + + @property + def bank(self) -> int: + return self.state + 1 + + @property + def finished(self): + return self.state + 1 == len(self.sequence) + + @property + def token_counts(self): + return self.sequence.token_counts() + + @property + def tokens(self): + return self.sequence.tokens + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + tokens = set() + if self.state > 0: + tokens.add(self.sequence[0]) + if not self.finished: + tokens.add(self.sequence[self.state + 1]) + return tokens + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") + + if self.finished: + # Accept anything + next_state = self.copy() + + elif self.sequence[self.state + 1] == token: + # Advance to the next token + next_state = OrderedConstraintState(self.sequence, self.state + 1) + + elif self.sequence.endpoints[self.state]: + # Accept anything between constraints (*) + next_state = self.copy() + + elif token == self.sequence[0]: + # Start over having generated the first token + next_state = OrderedConstraintState(self.sequence, 0) + else: + # Start over from the root + next_state = OrderedConstraintState(self.sequence, -1) + + return next_state diff --git a/modelscope/models/multi_modal/ofa/generate/utils.py b/modelscope/models/multi_modal/ofa/generate/utils.py new file mode 100644 index 00000000..8c8abf99 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/generate/utils.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license which can be found at +# https://github.com/facebookresearch/fairseq/blob/main/LICENSE + +import collections +from collections import abc +from itertools import accumulate + +import torch +import torch.nn.functional as F + +try: + from amp_C import multi_tensor_l2norm + + multi_tensor_l2norm_available = True +except ImportError: + multi_tensor_l2norm_available = False + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + +MANIFOLD_PATH_SEP = '|' + + +def apply_to_sample(f, sample): + if hasattr(sample, '__len__') and len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, collections.OrderedDict): + # OrderedDict has attributes that needs to be preserved + od = collections.OrderedDict( + (key, _apply(value)) for key, value in x.items()) + od.__dict__ = x.__dict__ + return od + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return _apply(sample) + + +def move_to_device(batch, device): + r"""Puts each data field to the device""" + if isinstance(batch, torch.Tensor): + return batch.to(device) + elif isinstance(batch, (list, tuple)): + return tuple(move_to_device(item, device) for item in batch) + elif isinstance(batch, abc.Mapping): + return { + key: move_to_device(value, device) + for key, value in batch.items() + } + else: + return batch + + +def strip_pad(tensor, pad): + return tensor[tensor.ne(pad)] + + +def get_token_to_word_mapping(tokens, exclude_list): + n = len(tokens) + word_start = [int(token not in exclude_list) for token in tokens] + word_idx = list(accumulate(word_start)) + token_to_word = {i: word_idx[i] for i in range(n)} + return token_to_word + + +def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = (((tgt_sent != pad) & # noqa + (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)) + src_invalid = (((src_sent == pad) | # noqa + (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)) + src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) + tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) + alignment = [] + if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): + attn_valid = attn[tgt_valid] + attn_valid[:, src_invalid] = float('-inf') + _, src_indices = attn_valid.max(dim=1) + for tgt_idx, src_idx in zip(tgt_valid, src_indices): + alignment.append(( + src_token_to_word[src_idx.item()] - 1, + tgt_token_to_word[tgt_idx.item()] - 1, + )) + return alignment + + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def log_softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.log_softmax(x.float(), dim=dim) + else: + return F.log_softmax(x, dim=dim, dtype=torch.float32) + + +def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = (tgt_sent != pad).nonzero(as_tuple=False) + src_valid = (src_sent != pad).nonzero(as_tuple=False).squeeze(dim=-1) + alignment = [] + if len(tgt_valid) != 0 and len(src_valid) != 0: + attn_valid = attn[tgt_valid, src_valid] + alignment = [['{:.6f}'.format(p) for p in src_probs.tolist()] + for src_probs in attn_valid] + return alignment diff --git a/modelscope/models/multi_modal/ofa/modeling_ofa.py b/modelscope/models/multi_modal/ofa/modeling_ofa.py new file mode 100644 index 00000000..69005ef0 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/modeling_ofa.py @@ -0,0 +1,2261 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OFA model.""" + +import math +import random +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +from packaging import version +from torch import Tensor, nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.file_utils import (ModelOutput, add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, + Seq2SeqModelOutput) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_ofa import OFAConfig +from .generate import utils +from .resnet import ResNet +from .utils.utils import DropPath +from .vit import vit_base, vit_huge, vit_large, vit_large_336 + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = 'ofa-base' +_CONFIG_FOR_DOC = 'OFAConfig' +_TOKENIZER_FOR_DOC = 'OFATokenizer' +TORCH_VERSION = version.parse(torch.__version__) +TORCH_MESH_GRID_WARNING_VERSION = version.parse('1.9.1') + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + +OFA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'ofa-tiny', + 'ofa-medium', + 'ofa-base', + 'ofa-large', + 'ofa-huge', +] + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + +except ImportError: + has_fused_layernorm = False + + +def LayerNorm(normalized_shape, + eps=1e-5, + elementwise_affine=True, + export=False): + r""" + Layer normalization. + If apex is available, use `FusedLayerNorm` instead. + """ + if torch.jit.is_scripting(): + export = True + if not export and torch.cuda.is_available() and has_fused_layernorm: + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +def make_token_bucket_position(bucket_size, + max_position=DEFAULT_MAX_SOURCE_POSITIONS): + r""" + Make relative position indices for the text. + """ + context_pos = torch.arange(max_position, dtype=torch.long)[:, None] + memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] + relative_pos = context_pos - memory_pos + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), + mid - 1, torch.abs(relative_pos)) + log_pos = torch.ceil( # noqa + torch.log(abs_pos / mid) / math.log((max_position - 1) / mid) * # noqa + (mid - 1)) + mid # noqa + log_pos = log_pos.int() + bucket_pos = torch.where(abs_pos.le(mid), relative_pos, + log_pos * sign).long() + return bucket_pos + bucket_size - 1 + + +def make_image_bucket_position(bucket_size, num_relative_distance): + r""" + Make relative position indices for the image. + """ + coords_h = torch.arange(bucket_size) + coords_w = torch.arange(bucket_size) + if TORCH_VERSION > TORCH_MESH_GRID_WARNING_VERSION: + coords = torch.stack( + torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + else: + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - \ + coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += bucket_size - 1 + relative_coords[:, :, 0] *= 2 * bucket_size - 1 + relative_position_index = torch.zeros( + size=(bucket_size * bucket_size + 1, ) * 2, + dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return relative_position_index + + +def new_arange(x, *size): + r""" + Return a Tensor of `size` filled with a range function on the device of x. + If size is empty, using the size of the variable x. + """ + if len(size) == 0: + size = x.size() + return torch.arange(size[-1], device=x.device).expand(*size).contiguous() + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, + decoder_start_token_id: int): + r""" + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, 'self.model.config.pad_token_id has to be defined.' + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + past_key_values_length: int = 0): + r""" + Make causal mask used for uni-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float('-inf')) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + r""" + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + return expanded_mask.masked_fill(expanded_mask.bool(), + torch.finfo(dtype).min) + + +def Embedding(num_embeddings, + embedding_dim, + padding_idx=None, + zero_init=False): + r""" + Embedding for tokens + """ + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + if zero_init: + nn.init.constant_(m.weight, 0) + return m + + +def Linear(in_features, out_features, bias=True): + r""" + Implementation of linear projection with xavier initialization + """ + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +class LayerDropModuleList(nn.ModuleList): + r""" + A LayerDrop implementation based on :class:`torch.nn.ModuleList`. + + Args: + p (float): probability of dropping out each layer + modules (iterable, optional): an iterable of modules to add + """ + + def __init__(self, p, modules=None): + super().__init__(modules) + self.p = p + + def __iter__(self): + dropout_probs = torch.empty(len(self)).uniform_() + for i, m in enumerate(super().__iter__()): + if not self.training or (dropout_probs[i] > self.p): + yield m + + +class OFAAttention(nn.Module): + r""" + Multi-headed attention, with additional implementation for NormFormer. + + Args: + embed_dim (`int`): embedding dimension. + num_heads (`int`): the number of attention heads. + dropout (`float32`): the ratio for dropout. + is_decoder (`bool`): whether or not decoder attention. + bias (`bool`): whether to add bias. + scale_heads (`bool`): whether to learn scaling heads, only for Normformer. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + scale_heads: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f'embed_dim must be divisible by num_heads ' \ + f'(got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads}).' + # 1. difference + scale_factor = 2 + self.scaling = float(self.head_dim * scale_factor)**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.attn_dropout = nn.Dropout(p=dropout) + self.c_attn = nn.Parameter( + torch.ones((self.num_heads, )), + requires_grad=True) if scale_heads else None + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + r""" + Reshape tensors for multi-head attention. + """ + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + attn_bias: Optional[torch.Tensor] = None, + ): + r""" + Args: + hidden_states (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`)`: input states. + key_value_states (`torch.FloatTensor` of shape (bsz, tgt_len, embed_dim), *optional*): key value states. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + cached past key value states for fast inference. + attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, seq_len)`, *optional*): attention mask. + output_attentions (`bool`, *optional*): whether to output attention weights of all layers. + attn_bias (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`, *optional*): + the attention bias for positional information. + + Returns: + attn_output (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`): attention outputs. + attn_weights_reshaped (`torch.FloatTensor`, *optional*): attention weights of all layers. + past_key_value (`torch.FloatTensor`, *optional*): cached key value states for fast inference. + """ + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, + bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError(f'Attention weights should be of size ' + f'{(bsz * self.num_heads, tgt_len, src_len)}, ' + f'but is {attn_weights.size()}') + + # Add attention bias for positional information + if attn_bias is not None: + attn_weights += attn_bias + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, + src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, + src_len) + + attn_weights = F.softmax(attn_weights, dim=-1) + + if output_attentions: + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, + tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, + tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = self.attn_dropout(attn_weights) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, + self.head_dim): + raise ValueError( + f'`attn_output` should be of size ' + f'{(bsz, self.num_heads, tgt_len, self.head_dim)}, ' + f'but is {attn_output.size()}') + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, + self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + if self.c_attn is not None: + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, + self.head_dim) + attn_output = torch.einsum('bthd,h->bthd', attn_output, + self.c_attn) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OFAEncoderLayer(nn.Module): + r""" + OFA encoder layer implementation. + + Args: + config: configuration for OFA. + drop_path_rate: the ratio for drop path. + """ + + def __init__(self, config: OFAConfig, drop_path_rate=0.0): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = OFAAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.self_attn_mid_layer_norm = LayerNorm( + self.embed_dim) if config.normformer else None + self.dropout = nn.Dropout(config.dropout) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = nn.Dropout(config.activation_dropout) + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.ffn_layer_norm = LayerNorm( + config.encoder_ffn_dim) if config.normformer else None + self.final_layer_norm = LayerNorm(self.embed_dim) + self.normalize_before = config.encoder_normalize_before + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def residual_connection(self, x, residual): + r""" + Residual connection with drop path. + """ + return residual + self.drop_path(x) + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool = False, + attn_bias: Optional[torch.Tensor] = None): + r""" + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape *(bsz, src_len, embed_dim)* + attention_mask (`torch.FloatTensor`): attention mask of size + *(bsz, 1, src_len, src_len)* where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + whether to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + attn_bias (`torch.FloatTensor`): bias for positional information. + + Returns: + outputs (`tuple(torch.FloatTensor)`): + output hidden states of size (bsz, src_len, embed_dim), optionally with attention weights. + """ + + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + attn_bias=attn_bias, + ) + if self.self_attn_mid_layer_norm: + hidden_states = self.self_attn_mid_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states) + if self.ffn_layer_norm: + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (attn_weights, ) + + return outputs + + +class OFADecoderLayer(nn.Module): + r""" + OFA decoder layer implementation. + + Args: + config: configuration for OFA. + drop_path_rate: the ratio for drop path. + """ + + def __init__(self, config: OFAConfig, drop_path_rate=0.0): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = OFAAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = nn.Dropout(p=config.dropout) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = nn.Dropout(p=config.activation_dropout) + + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.self_attn_mid_layer_norm = LayerNorm( + self.embed_dim) if config.normformer else None + self.cross_attn = OFAAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.cross_attn_layer_norm = LayerNorm(self.embed_dim) + self.cross_attn_mid_layer_norm = LayerNorm( + self.embed_dim) if config.normformer else None + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.ffn_layer_norm = LayerNorm( + config.decoder_ffn_dim) if config.normformer else None + self.final_layer_norm = LayerNorm(self.embed_dim) + self.normalize_before = config.decoder_normalize_before + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def residual_connection(self, x, residual): + r""" + Residual connection with drop path. + """ + return residual + self.drop_path(x) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + self_attn_bias: Optional[torch.Tensor] = None, + cross_attn_bias: Optional[torch.Tensor] = None, + ): + r""" + Args: + hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): input to the layer. + attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`): + attention mask where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`): + cross attention input to the layer. + encoder_attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`): + encoder attention mask where padding elements are indicated by very large negative values. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): whether to return the attentions tensors of all attention layers. + use_cache (`bool`, *optional*): whether to use cache + self_attn_bias (`torch.FloatTensor`): self attention bias for positional information. + cross_attn_bias (`torch.FloatTensor`): cross attention bias for positional information. + """ + + # Self attention with intermediate layernorm + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + # add present self-attn cache to position 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + attn_bias=self_attn_bias, + ) + if self.self_attn_mid_layer_norm: + hidden_states = self.self_attn_mid_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross attention with intermediate layernorm + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + if self.normalize_before: + hidden_states = self.cross_attn_layer_norm(hidden_states) + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + attn_bias=cross_attn_bias, + ) + if self.cross_attn_mid_layer_norm: + hidden_states = self.cross_attn_mid_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.cross_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # FFN with intermediate layernorm + residual = hidden_states + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states) + if self.ffn_layer_norm: + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.residual_connection(hidden_states, residual) + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +class OFAPreTrainedModel(PreTrainedModel): + r""" + Base class OFA + """ + + config_class = OFAConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + + def _init_weights(self, module): + r""" + Weight initialization which follows BERT. + """ + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + r""" + Turn on the switch of gradient checkpointing. + """ + if isinstance(module, (OFADecoder, OFAEncoder)): + module.gradient_checkpointing = value + + +@dataclass +class OFAEncoderOutput(ModelOutput): + r""" + Base class for OFA's outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): + Sequence of hidden-states at the output of the last layer of the model. + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed + or when `config.output_hidden_states=True`): + + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(bsz, seq_len, hidden)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed + or when `config.output_attentions=True`): + + Tuple of `torch.FloatTensor` (one for each layer) of shape `(bsz, num_heads, seq_len, seq_len)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): + postional embeddings of the inputs. + """ + + last_hidden_state: torch.FloatTensor = None + padding_mask: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + position_embedding: Optional[torch.FloatTensor] = None + + +OFA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`~OFAConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OFA_GENERATION_EXAMPLE = r""" + Image captioning example: + + ```python + >>> from PIL import Image + >>> from torchvision import transforms + >>> from transformers import OFATokenizer, OFAForConditionalGeneration + + >>> mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + >>> resolution = 256 + >>> patch_resize_transform = transforms.Compose([ + lambda image: image.convert("RGB"), + transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + + >>> model = OFAForConditionalGeneration.from_pretrained(ckpt_dir) + >>> tokenizer = OFATokenizer.from_pretrained(ckpt_dir) + + >>> txt = " what is the description of the image?" + >>> inputs = tokenizer([txt], max_length=1024, return_tensors="pt")["input_ids"] + >>> img = Image.open(path_to_image) + >>> patch_img = patch_resize_transform(img).unsqueeze(0) + + >>> gen = model.generate(inputs, patch_img=patch_img, num_beams=4) + >>> print(tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + +OFA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): + indices of input sequence tokens in the vocabular, and padding will be ignored by default; + + indices can be obtained using [`~OFATokenizer`]. + + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the resized image, which are transformed by the default operations. + patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the second (if it exists) image. + patch_masks (`torch.BoolTensor`): the patches to be masked. + token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. + sample_patch_num (`int`): the number of patches to sample. + decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. + code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. + attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding. + encoder_outputs (`OFAEncoderOutput`): + encoder outputs with hidden states, positional embeddings, and padding masks. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of + shape `(bsz, num_heads, src_len, head_size)`. + use_cache (`bool`): whether to use cache for faster inference. + output_attentions (`bool`): whether to output attention weights. + output_hidden_states (`bool`): whether to output hidden states. + return_dict (`bool`): unused. Keep it for generation only. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. +""" + + +class OFAEncoder(OFAPreTrainedModel): + r""" + OFA encoder consisting of layers of [`OFAEncoderLayer`]. + + Args: + config: OFAConfig + embed_tokens (`nn.Embedding`, *optional*): output embedding + """ + + def __init__(self, + config: OFAConfig, + embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = nn.Dropout(config.dropout) + self.encoder_layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt( + embed_dim) if config.scale_embedding else 1.0 + self.num_attention_heads = config.encoder_attention_heads + + if getattr(config, 'layernorm_embedding', False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, + self.padding_idx) + + if config.add_type_embedding: + if config.use_image_feature: + self.type_embedding = Embedding(2, embed_dim, padding_idx=None) + else: + self.type_embedding = Embedding(1, embed_dim, padding_idx=None) + else: + self.type_embedding = None + + if config.use_image_feature: + if config.use_ofasys: + vit_backbone = { + 'vit_base': vit_base, + 'vit_large': vit_large, + 'vit_large_336': vit_large_336, + 'vit_huge': vit_huge, + }[config.vit_type] + self.embed_images = vit_backbone(config.vit_drop_path_rate) + + self.image_proj = Linear(self.embed_images.width, embed_dim) + + else: + if config.resnet_type == 'resnet18': + self.embed_images = ResNet( + [2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet34': + self.embed_images = ResNet( + [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet50': + self.embed_images = ResNet( + [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet101': + self.embed_images = ResNet( + [3, 4, 23], + drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet152': + self.embed_images = ResNet( + [3, 8, 36], + drop_path_rate=config.resnet_drop_path_rate) + else: + raise NotImplementedError + + self.image_proj = Linear(1024, embed_dim) + + if not config.use_ofasys and config.resnet_model_path: + print('load resnet {}'.format(config.resnet_model_path)) + resnet_state_dict = torch.load(config.resnet_model_path) + self.embed_images.load_state_dict(resnet_state_dict) + if config.patch_layernorm_embedding: + self.patch_layernorm_embedding = LayerNorm(embed_dim) + else: + self.patch_layernorm_embedding = None + + self.embed_positions = Embedding(self.max_source_positions + 2, + embed_dim) + + if config.use_image_feature: + self.embed_image_positions = Embedding( + config.image_bucket_size**2 + 1, embed_dim) + if not config.use_ofasys: + self.pos_ln = LayerNorm(embed_dim) + + if config.use_image_feature: + self.image_pos_ln = LayerNorm(embed_dim) + self.pos_scaling = float(embed_dim / self.num_attention_heads + * config.attn_scale_factor)**-0.5 + + if not (config.use_ofasys and config.entangle_position_embedding): + self.pos_q_linear = nn.Linear(embed_dim, embed_dim) + self.pos_k_linear = nn.Linear(embed_dim, embed_dim) + + if self.encoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.encoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + + dpr = [ + x.item() for x in torch.linspace(0, config.encoder_drop_path_rate, + config.encoder_layers) + ] + self.layers.extend([ + OFAEncoderLayer(config, drop_path_rate=dpr[i]) + for i in range(config.encoder_layers) + ]) + self.num_layers = len(self.layers) + + if config.encoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + self.token_bucket_size = config.token_bucket_size + token_num_rel_dis = 2 * config.token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(config.token_bucket_size) + self.share_attn_bias = config.share_attn_bias + num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers + self.token_rel_pos_table_list = nn.ModuleList([ + Embedding( + token_num_rel_dis, self.num_attention_heads, zero_init=True) + for _ in range(num_rel_pos_tables) + ]) + + if config.use_image_feature: + self.image_bucket_size = config.image_bucket_size + image_num_rel_dis = (2 * config.image_bucket_size + - 1) * (2 * config.image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position( + config.image_bucket_size, image_num_rel_dis) + self.image_rel_pos_table_list = nn.ModuleList([ + Embedding( + image_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range(num_rel_pos_tables) + ]) + + self.register_buffer('image_rp_bucket', image_rp_bucket) + + if config.layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.entangle_position_embedding = config.entangle_position_embedding + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + self.use_ofasys = config.use_ofasys + + def get_input_embeddings(self): + r""" + Get the embedding weight. + """ + return self.embed_tokens + + def set_input_embeddings(self, value): + r""" + Set the weight of embedding with the given tensor. + """ + self.embed_tokens = value + + def get_rel_pos_bias(self, x, idx): + r""" + Get the relative positional bias of the text, for attention. + """ + + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) + values = values.permute([0, 3, 1, 2]) + return values.contiguous() + + def get_image_rel_pos_bias(self, image_position_ids, idx): + r""" + Get the relative positional bias of the image, for attention. + """ + + bsz, seq_len = image_position_ids.shape + rp_bucket_size = self.image_rp_bucket.size(1) + + rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( + bsz, rp_bucket_size, rp_bucket_size).gather( + 1, image_position_ids[:, :, None].expand( + bsz, seq_len, rp_bucket_size)).gather( + 2, image_position_ids[:, None, :].expand( + bsz, seq_len, seq_len)) + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(0, 3, 1, 2) + return values + + def get_patch_images_info(self, patch_images, sample_patch_num, device): + r""" + Get the basic information of the resized image. + + Args: + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the resized image. + sample_patch_num (`int`): + the number of patches to sample. If it is equal to -1, no sampling will be performed. + device: GPU device. + + Returns: + image_embed (`torch.FloatTensor` of shape `(bsz, h * w, hidden)`): the output of the visual encoder. + image_num_patches (`int`, equal to `h * w`): the number of patches. + image_padding_mask (`torch.BooleanTensor` of shape `(bsz, h*w)`): image padding mask. + image_position_ids (`torch.LongTensor` of shape `(bsz, h*w)`): image position ids. + image_pos_embed (`torch.FloatTensor` of shape (bsz, h*w, hidden)): the positional embedding. + """ + + image_embed = self.embed_images(patch_images) + h, w = image_embed.shape[-2:] + image_num_patches = h * w + image_padding_mask = patch_images.new_zeros( + (patch_images.size(0), image_num_patches)).bool() + image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w)\ + + torch.arange(h).unsqueeze(1) * self.image_bucket_size + 1 + image_position_idx = image_position_idx.view(-1).to(device) + image_position_ids = image_position_idx[None, :].expand( + patch_images.size(0), image_num_patches) + + image_embed = image_embed.flatten(2).transpose(1, 2) + if sample_patch_num is not None: + patch_orders = [ + random.sample(range(image_num_patches), k=sample_patch_num) + for _ in range(patch_images.size(0)) + ] + patch_orders = torch.LongTensor(patch_orders).to(device) + image_embed = image_embed.gather( + 1, + patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))) + image_num_patches = sample_patch_num + image_padding_mask = image_padding_mask.gather(1, patch_orders) + image_position_ids = image_position_ids.gather(1, patch_orders) + orig_num_patches = (self.config.orig_patch_image_size // 16)**2 + orig_hw = self.config.orig_patch_image_size // 16 + if self.config.interpolate_position and image_num_patches > orig_num_patches: + old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \ + torch.arange(orig_hw).unsqueeze(1) * \ + self.config.image_bucket_size + 1 # noqa + old_image_position_ids = old_image_position_ids.to(device) + old_image_pos_embed = self.embed_image_positions( + old_image_position_ids) + old_image_pos_embed = old_image_pos_embed.reshape( + 1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2) + image_pos_embed = F.interpolate( + old_image_pos_embed, size=(h, w), mode='bilinear') + image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape( + 1, image_num_patches, -1) + image_pos_embed = image_pos_embed.expand( + patch_images.size(0), -1, -1) + else: + image_pos_embed = self.embed_image_positions(image_position_ids) + + return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed + + def forward_embedding(self, + input_ids, + image_embed: Optional[torch.Tensor] = None, + image_embed_2: Optional[torch.Tensor] = None, + token_embedding: Optional[torch.Tensor] = None, + pos_embed: Optional[torch.Tensor] = None, + image_pos_embed: Optional[torch.Tensor] = None, + image_pos_embed_2: Optional[torch.Tensor] = None): + r""" + Generate embeddings of both the image and the text. + Actually since OFA unifies both unimodal and multimodal data, + image inputs are optional. + + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the tokens in the vocabulary. + image_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): image embeddings. + image_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): + image embeddings of the second image (if it exists). + token_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`, *optional*): + input token embeddings to replace the embeddings of input ids. + image_pos_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): + positional embeddings of the image. + image_pos_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): + positional embeddings of the second image. + + Returns: + x (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`): embeddings of the input. + embed (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`): + embeddings without adding positional and type embeddings. + """ + + # embed tokens and positions + if token_embedding is None: + token_embedding = self.embed_tokens(input_ids) + x = embed = self.embed_scale * token_embedding + if self.entangle_position_embedding and pos_embed is not None: + x += pos_embed + if self.type_embedding is not None: + x += self.type_embedding(input_ids.new_zeros(x.size()[:2])) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout(x) + + # embed raw images + if image_embed is not None: + image_embed = self.image_proj(image_embed) + image_x = image_embed = self.embed_scale * image_embed + if self.entangle_position_embedding and image_pos_embed is not None: + image_x += image_pos_embed + if self.type_embedding is not None: + image_x += self.type_embedding( + input_ids.new_ones(image_x.size()[:2])) + if self.patch_layernorm_embedding is not None: + image_x = self.patch_layernorm_embedding(image_x) + image_x = self.dropout(image_x) + x = torch.cat([image_x, x], dim=1) + embed = torch.cat([image_embed, embed], dim=1) + + if image_embed_2 is not None: + assert self.type_embedding is not None + image_embed_2 = self.image_proj(image_embed_2) + image_x_2 = image_embed_2 = self.embed_scale * image_embed_2 + if self.entangle_position_embedding and image_pos_embed_2 is not None: + image_x_2 += image_pos_embed_2 + if self.type_embedding is not None: + image_x_2 += self.type_embedding( + input_ids.new_full(image_x_2.size()[:2], fill_value=2)) + if self.patch_layernorm_embedding is not None: + image_x_2 = self.patch_layernorm_embedding(image_x_2) + image_x_2 = self.dropout(image_x_2) + if self.quant_noise is not None: + image_x_2 = self.quant_noise(image_x_2) + x = torch.cat([image_x_2, x], dim=1) + embed = torch.cat([image_embed_2, embed], dim=1) + + return x, embed + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + # if encoder_out["last_hidden_state"] is None: + if 'last_hidden_state' not in encoder_out: + new_encoder_out = None + else: + new_encoder_out = encoder_out['last_hidden_state'].index_select( + 0, new_order) + # if encoder_out["padding_mask"] is None: + if 'padding_mask' not in encoder_out: + new_encoder_padding_mask = None + else: + new_encoder_padding_mask = encoder_out[ + 'padding_mask'].index_select(0, new_order) + + # if encoder_out["position_embedding"] is None: + if 'position_embedding' not in encoder_out: + new_position_embeddings = None + else: + new_position_embeddings = encoder_out[ + 'position_embedding'].index_select(0, new_order) + + if 'hidden_states' not in encoder_out: + new_encoer_states = None + else: + encoder_states = encoder_out['hidden_states'] + new_encoer_states = () + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + new_encoer_states += (state.index_select(0, new_order), ) + + if 'attentions' not in encoder_out: + attentions = None + else: + attentions = encoder_out['attentions'] + + return OFAEncoderOutput( + last_hidden_state=new_encoder_out, # B x T x C + padding_mask=new_encoder_padding_mask, # B x T + hidden_states=new_encoer_states, # List[T x B x C] + attentions=attentions, + position_embedding=new_position_embeddings # B x T x C + ) + + def forward( + self, + input_ids=None, + patch_images: Optional[torch.Tensor] = None, + patch_images_2: Optional[torch.Tensor] = None, + patch_masks: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + sample_patch_num: Optional[int] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): + indices of input sequence tokens in the vocabular, and padding will be ignored by default; + + indices can be obtained using [`~OFATokenizer`]. + + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the resized image, which are transformed by the default operations. + patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the second (if it exists) image. + patch_masks (`torch.BoolTensor`): the patches to be masked. + output_attentions (`bool`): whether to return all attention weights, + output_hidden_states (`bool`): whether to return all hidden states. + token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. + sample_patch_num (`int`): the number of patches to sample. + + Returns: + [`OFAEncoderOutput`]: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the states of the last layer. + padding_mask (`torch.BoolTensor` of shape `(bsz, seq_len)`): + the padding mask of the source context. + hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the states of all layers including the embeddings. + attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`): + the attention weights of all layers. + position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + positional embeddings of the input image and tokens. + """ + image_embed = None + image_embed_2 = None + image_pos_embed = None + image_pos_embed_2 = None + if patch_images is not None: + image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ + self.get_patch_images_info(patch_images, sample_patch_num, input_ids.device) + image_padding_mask[~patch_masks] = True + if patch_images_2 is not None: + image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \ + self.get_patch_images_info(patch_images_2, sample_patch_num, input_ids.device) + image_padding_mask_2[~patch_masks] = True + + encoder_padding_mask = input_ids.eq(self.padding_idx) + if patch_images is not None: + encoder_padding_mask = torch.cat( + [image_padding_mask, encoder_padding_mask], dim=1) + if patch_images_2 is not None: + encoder_padding_mask = torch.cat( + [image_padding_mask_2, encoder_padding_mask], dim=1) + has_pads = encoder_padding_mask.any() + + pos_embed = self.embed_positions(new_arange(input_ids)) + x, encoder_embedding = self.forward_embedding( + input_ids, image_embed, image_embed_2, token_embeddings, pos_embed, + image_pos_embed, image_pos_embed_2) + + # account for padding while computing the representation + if has_pads: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + + if self.use_ofasys: + if patch_images is not None: + pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) + if patch_images_2 is not None: + pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) + else: + pos_embed = self.pos_ln(pos_embed) + if patch_images is not None: + image_pos_embed = self.image_pos_ln(image_pos_embed) + pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) + if patch_images_2 is not None: + image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) + pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) + + def build_abs_pos_bias(pos_embed): + batch_size, seq_length = pos_embed.size(0), pos_embed.size(1) + if not (self.use_ofasys and self.entangle_position_embedding): + pos_q = self.pos_q_linear(pos_embed).view( + batch_size, seq_length, self.num_attention_heads, + -1).transpose(1, 2) * self.pos_scaling + pos_k = self.pos_k_linear(pos_embed).view( + batch_size, seq_length, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + else: + abs_pos_bias = torch.zeros( + batch_size, + self.num_attention_heads, + seq_length, + seq_length, + dtype=pos_embed.dtype, + device=pos_embed.device) + return abs_pos_bias + + abs_pos_bias = build_abs_pos_bias(pos_embed) + + # expand attention_mask + if has_pads: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(encoder_padding_mask, dtype=x.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # encoder layers + for idx, layer in enumerate(self.layers): + if output_hidden_states: + encoder_states += (x, ) + self_attn_bias = abs_pos_bias.clone() + + real_idx = 0 if self.share_attn_bias else idx + + self_attn_bias[:, :, -input_ids.size(1):, + -input_ids.size(1):] += self.get_rel_pos_bias( + input_ids, real_idx) + if patch_images_2 is not None: + self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ + self.get_image_rel_pos_bias(image_position_ids_2, real_idx) + self_attn_bias[:, :, + image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa + image_num_patches_2:image_num_patches_2 + image_num_patches] += \ + self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa + elif patch_images is not None: + self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \ + self.get_image_rel_pos_bias(image_position_ids, real_idx) + self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1)) + + hidden_outputs = layer( + x, + attention_mask if has_pads else None, + attn_bias=self_attn_bias, + output_attentions=output_attentions) + x = hidden_outputs[0] + + if output_attentions: + attention = hidden_outputs[1] + all_attentions = all_attentions + (attention, ) + + if output_hidden_states: + encoder_states += (x, ) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + return OFAEncoderOutput( + last_hidden_state=x, + padding_mask=encoder_padding_mask, + hidden_states=encoder_states, + attentions=all_attentions, + position_embedding=pos_embed) + + +class OFADecoder(OFAPreTrainedModel): + r""" + OFA decoder consisting of layers of [`OFADecoderLayer`] + + Args: + config: OFAConfig + embed_tokens (`nn.Embedding`, *optional*): output embedding + """ + + def __init__(self, + config: OFAConfig, + embed_tokens: Optional[nn.Embedding] = None, + output_projection=None): + super().__init__(config) + self.dropout = nn.Dropout(config.dropout) + self.decoder_layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self._future_mask = torch.empty(0) + self.share_input_output_embed = config.share_decoder_input_output_embed + self.num_attention_heads = config.decoder_attention_heads + self.use_ofasys = config.use_ofasys + self.disable_entangle = config.disable_entangle + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, + self.padding_idx) + + self.embed_dim = config.d_model + self.output_embed_dim = config.d_model + + self.layers = nn.ModuleList( + [OFADecoderLayer(config) for _ in range(config.decoder_layers)]) + if config.layernorm_embedding: + self.layernorm_embedding = LayerNorm(self.embed_dim) + else: + self.layernorm_embedding = None + + if config.use_ofasys: + if config.add_type_embedding: + self.type_embedding = Embedding( + 1, self.embed_dim, padding_idx=None) + else: + self.type_embedding = None + + self.window_size = config.code_image_size // 8 + + self.embed_positions = Embedding(self.max_target_positions + 2, + self.embed_dim) + + if not config.use_ofasys: + self.embed_image_positions = Embedding( + config.image_bucket_size**2 + 1, self.embed_dim) + if not config.use_ofasys: + self.pos_ln = LayerNorm(self.embed_dim) + self.image_pos_ln = LayerNorm(self.embed_dim) + self.pos_scaling = float(self.embed_dim / self.num_attention_heads + * config.attn_scale_factor)**-0.5 + + if not (config.use_ofasys and config.entangle_position_embedding): + self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) + self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) + + self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) + self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) + + if config.code_layernorm_embedding: + self.code_layernorm_embedding = LayerNorm(self.embed_dim) + else: + self.code_layernorm_embedding = None + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + + dpr = [ + x.item() for x in torch.linspace(0, config.decoder_drop_path_rate, + config.decoder_layers) + ] + self.layers.extend([ + OFADecoderLayer(config, drop_path_rate=dpr[i]) + for i in range(config.decoder_layers) + ]) + self.num_layers = len(self.layers) + + if config.decoder_normalize_before: + self.layer_norm = LayerNorm(self.embed_dim) + else: + self.layer_norm = None + + self.adaptive_softmax = None + self.output_projection = output_projection + if self.output_projection is None: + self.build_output_projection(config) + + self.token_bucket_size = config.token_bucket_size + token_num_rel_dis = 2 * config.token_bucket_size - 1 + token_rp_bucket = make_token_bucket_position(config.token_bucket_size) + + self.share_attn_bias = config.share_attn_bias + num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers + self.token_rel_pos_table_list = nn.ModuleList([ + Embedding( + token_num_rel_dis, self.num_attention_heads, zero_init=True) + for _ in range(num_rel_pos_tables) + ]) + + if config.use_image_feature: + if not config.use_ofasys: + self.image_bucket_size = config.image_bucket_size + image_num_rel_dis = (2 * config.image_bucket_size - 1) * ( + 2 * config.image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position( + config.image_bucket_size, image_num_rel_dis) + image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ + torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa + image_position_idx = torch.cat( + [torch.tensor([0]), + image_position_idx.view(-1)]) + image_position_idx = torch.cat( + [image_position_idx, + torch.tensor([1024] * 768)]) + self.register_buffer('image_position_idx', image_position_idx) + + self.image_rel_pos_table_list = nn.ModuleList([ + Embedding( + image_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range(num_rel_pos_tables) + ]) + self.register_buffer('image_rp_bucket', image_rp_bucket) + + self.register_buffer('token_rp_bucket', token_rp_bucket) + self.entangle_position_embedding = config.entangle_position_embedding + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def build_output_projection(self, config): + if self.share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + else: + self.output_projection = nn.Linear( + self.output_embed_dim, config.vocab_size, bias=False) + nn.init.normal_( + self.output_projection.weight, + mean=0, + std=self.output_embed_dim**-0.5) + + def get_rel_pos_bias(self, x, idx): + r""" + Get the relative positional bias of the text, for attention. + """ + + seq_len = x.size(1) + rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] + values = F.embedding(rp_bucket, + self.token_rel_pos_table_list[idx].weight) + values = values.permute([2, 0, 1]) + return values.contiguous() + + def get_image_rel_pos_bias(self, x, idx): + r""" + Get the relative positional bias of the image, for attention. + """ + + seq_len = x.size(1) + image_position_idx = self.image_position_idx[:seq_len] + rp_bucket = self.image_rp_bucket[ + image_position_idx][:, image_position_idx] + values = F.embedding(rp_bucket, + self.image_rel_pos_table_list[idx].weight) + values = values.permute(2, 0, 1) + return values + + def get_pos_info(self, tgt_pos_embed, src_pos_embed=None, use_image=False): + r""" + Get the positional information. + + Args: + tgt_pos_embed (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`): + the target-side positional embeddings. + src_pos_embed (`torch.FloatTensor` of shape `(bsz, src_len, embed_dim)`, *optional*): + the source-side positional embeddings. + use_image (`bool`): whether to use image. + + Returns: + abs_pos_bias (`torch.FloatTensor` of shape `(bsz, src_len, tgt_len, src_len)`): + absolute positional bias for attention. + """ + + batch_size = tgt_pos_embed.size(0) + tgt_len = tgt_pos_embed.size(1) + if not self.use_ofasys: + tgt_pos_embed = self.image_pos_ln( + tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) + + if src_pos_embed is not None: + src_len = src_pos_embed.size(1) + if not (self.entangle_position_embedding and self.use_ofasys): + pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, + -1).transpose(1, 2) * self.pos_scaling + pos_k = self.cross_pos_k_linear(src_pos_embed).view( + batch_size, src_len, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + else: + abs_pos_bias = torch.zeros( + batch_size, + self.num_attention_heads, + tgt_len, + src_len, + dtype=tgt_pos_embed.dtype, + device=tgt_pos_embed.device) + else: + # batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1) + if not (self.entangle_position_embedding and self.use_ofasys): + pos_q = self.self_pos_q_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, + -1).transpose(1, 2) * self.pos_scaling + pos_k = self.self_pos_k_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + else: + abs_pos_bias = torch.zeros( + batch_size, + self.num_attention_heads, + tgt_len, + tgt_len, + dtype=tgt_pos_embed.dtype, + device=tgt_pos_embed.device) + + return abs_pos_bias + + def get_input_embeddings(self): + r""" + Get the input embeddings + """ + return self.embed_tokens + + def set_input_embeddings(self, value): + r""" + Set the weights of the embeddings with the given tensor. + """ + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + dtype, past_key_values_length): + r""" + Create causal mask for unidirectional decoding. + [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + """ + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + dtype, + past_key_values_length=past_key_values_length).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else + expanded_attn_mask + combined_attention_mask) + + return combined_attention_mask + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return self.max_target_positions + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, + sample) + + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, + 'adaptive_softmax') and self.adaptive_softmax is not None: + if sample is not None: + assert 'target' in sample + target = sample['target'] + else: + target = None + out = self.adaptive_softmax.get_log_prob( + net_output[0], target=target) + return out.exp_() if not log_probs else out + + logits = net_output[0] + if log_probs: + return utils.log_softmax(logits, dim=-1) + else: + return utils.softmax(logits, dim=-1) + + def reorder_incremental_state_scripting( + self, + past_key_values: Optional[torch.Tensor], + new_order: Tensor, + ): + """Main entry point for reordering the incremental state. + + Due to limitations in TorchScript, we call this function in + :class:`fairseq.sequence_generator.SequenceGenerator` instead of + calling :func:`reorder_incremental_state` directly. + """ + input_buffer = past_key_values + new_past_key_values = [] + if input_buffer is not None: + for input_buffer_k in input_buffer: + new_input_buffer_k = [] + for input in input_buffer_k: + if input is None: + input = None + else: + input = input.index_select(0, new_order) + new_input_buffer_k.append(input) + new_past_key_values.append(new_input_buffer_k) + return new_past_key_values + + def forward( + self, + input_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + code_masks: Optional[torch.Tensor] = None, + src_pos_embed: torch.Tensor = None, + past_key_values: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. + attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): mask to avoid attention on padding tokens. + encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden state of the encoder. + encoder_attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): the padding mask of the source side. + code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. + src_pos_embed (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the positional embeddings of the source side. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of + shape `(bsz, num_heads, src_len, head_size)`. + use_cache (`bool`): whether to use cache for faster inference. + output_attentions (`bool`): whether to output attention weights. + output_hidden_states (`bool`): whether to output hidden states. + + Returns: + BaseModelOutputWithPastAndCrossAttentions or a plain tuple: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden states. + past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference. + hidden_states (`tuple(torch.FloatTensor)`): hidden states of all layers. + attentions (`tuple(torch.FloatTensor)): self attention weights of all layers. + cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers. + """ # noqa + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if past_key_values is not None and len(past_key_values) > 0: + size = past_key_values[0][0].size() + bsz, tgt_len = size[0], size[-2] + 1 + token_position_idx = torch.arange( + tgt_len, + device=input_ids.device).expand([bsz, tgt_len]).contiguous() + else: + bsz, tgt_len = input_ids.shape + token_position_idx = new_arange(input_ids) + tgt_pos_embed = self.embed_positions(token_position_idx) + if code_masks is not None and torch.any(code_masks): + image_position_idx = self.image_position_idx[:input_ids.size( + 1)].unsqueeze(0).expand(bsz, tgt_len) + tgt_pos_embed[code_masks] = self.embed_image_positions( + image_position_idx)[code_masks] + + # self attn position bias + self_abs_pos_bias = self.get_pos_info(tgt_pos_embed, use_image=False) + if code_masks is not None and torch.any(code_masks): + self_image_abs_pos_bias = self.get_pos_info( + tgt_pos_embed, use_image=True) + self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] + # cross attn position bias + cross_abs_pos_bias = self.get_pos_info( + tgt_pos_embed, src_pos_embed=src_pos_embed) + if code_masks is not None and torch.any(code_masks): + cross_image_abs_pos_bias = self.get_pos_info( + tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True) + cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[ + code_masks] + cross_abs_pos_bias = cross_abs_pos_bias.reshape( + -1, + *cross_abs_pos_bias.size()[-2:]) + + all_prev_output_tokens = input_ids.clone() + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :] + tgt_pos_embed = tgt_pos_embed[:, -1:, :] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(input_ids) + + if self.entangle_position_embedding and not self.disable_entangle: + x += tgt_pos_embed + + if self.layernorm_embedding is not None: + if code_masks is None or not code_masks.any( + ) or not self.code_layernorm_embedding: + x = self.layernorm_embedding(x) + elif code_masks is not None and code_masks.all(): + x = self.code_layernorm_embedding(x) + else: + x[~code_masks] = self.layernorm_embedding(x[~code_masks]) + x[code_masks] = self.code_layernorm_embedding(x[code_masks]) + + hidden_states = self.dropout(x) + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None and len( + past_key_values) > 0 else 0 + + shape, dtype = input_ids.shape, hidden_states.dtype + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, shape, dtype, past_key_values_length) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if ( + output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # decoder layers + for idx, layer in enumerate(self.layers): + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + past_key_value = past_key_values[ + idx] if past_key_values is not None and len( + past_key_values) > 0 else None + + self_attn_bias = self_abs_pos_bias.clone() + real_idx = 0 if self.share_attn_bias else idx + if code_masks is None or not code_masks.any(): + self_attn_bias += self.get_rel_pos_bias( + all_prev_output_tokens, real_idx).unsqueeze(0) + elif code_masks is not None and code_masks.all(): + self_attn_bias += self.get_image_rel_pos_bias( + all_prev_output_tokens, real_idx).unsqueeze(0) + else: + self_attn_bias[~code_masks] += self.get_rel_pos_bias( + all_prev_output_tokens, real_idx).unsqueeze(0) + self_attn_bias[code_masks] += self.get_image_rel_pos_bias( + all_prev_output_tokens, real_idx).unsqueeze(0) + self_attn_bias = self_attn_bias.reshape( + -1, + *self_attn_bias.size()[-2:]) + if past_key_value is not None and len(past_key_values) > 0: + self_attn_bias = self_attn_bias[:, -1:, :] + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + self_attn_bias=self_attn_bias, + cross_attn_bias=cross_abs_pos_bias, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += ( + layer_outputs[3 if output_attentions else 1], ) + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2], ) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + + if self.layer_norm is not None: + hidden_states = self.layer_norm(hidden_states) + + if self.output_projection is not None: + hidden_states = self.output_projection(hidden_states) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, # (bz, + past_key_values=next_cache, # (bz, n_heads, seq_len, head_dim) + hidden_states=all_hidden_states, + attentions=all_self_attns, # (bz, n_heads, tgt_len, src_len) + cross_attentions= # noqa + all_cross_attentions # (bz, n_heads, tgt_len, src_len) # noqa + ) + + +@add_start_docstrings( + 'The bare OFA Model outputting raw hidden-states without any specific head on top.', + OFA_START_DOCSTRING, +) +class OFAModel(OFAPreTrainedModel): + r""" + The OFA model built with an encoder and a decoder only, without any classification head. + + Args: + config (OFAConfig): OFA configuration. + """ + + def __init__(self, config: OFAConfig, **kwargs): + super().__init__(config) + self.disable_entangle = getattr(kwargs, 'disable_entangle', False) + + self.padding_idx, vocab_size = config.pad_token_id, config.vocab_size + shared = nn.Embedding(vocab_size, config.d_model, self.padding_idx) + + self.encoder = OFAEncoder(config, shared) + self.decoder = OFADecoder(config, shared) + self.use_ofasys = config.use_ofasys + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + r""" + Retrieve input embeddings. + """ + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, value): + r""" + Set values for input embeddings + """ + shared = value + self.encoder.embed_tokens = shared + self.decoder.embed_tokens = shared + + def get_encoder(self): + r""" + Retrieve the encoder + """ + return self.encoder + + def get_decoder(self): + r""" + Retrieve the decoder + """ + return self.decoder + + @add_start_docstrings_to_model_forward(OFA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + # an adaptor for fairseq generator + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, + sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" + if hasattr(self, 'decoder'): + return self.decoder.get_normalized_probs(net_output, log_probs, + sample) + elif torch.is_tensor(net_output): + # syntactic sugar for simple models which don't have a decoder + # (e.g., the classification tutorial) + logits = net_output.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def forward(self, + input_ids=None, + patch_images=None, + patch_images_2=None, + patch_masks=None, + token_embeddings=None, + sample_patch_num=None, + decoder_input_ids=None, + code_masks=None, + attention_mask=None, + encoder_outputs=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=False): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): + indices of input sequence tokens in the vocabular, and padding will be ignored by default; + + indices can be obtained using [`~OFATokenizer`]. + + patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the resized image, which are transformed by the default operations. + patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): + the second (if it exists) image. + patch_masks (`torch.BoolTensor`): the patches to be masked. + token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings. + sample_patch_num (`int`): the number of patches to sample. + decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary. + code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation. + attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding. + encoder_outputs (`OFAEncoderOutput`): + encoder outputs with hidden states, positional embeddings, and padding masks. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of + shape `(bsz, num_heads, src_len, head_size)`. + use_cache (`bool`): whether to use cache for faster inference. + output_attentions (`bool`): whether to output attention weights. + output_hidden_states (`bool`): whether to output hidden states. + return_dict (`bool`): unused. Keep it for generation only. + + Returns: + Seq2SeqModelOutput: + last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last decoder hidden states. + past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference. + decoder_hidden_states (`tuple(torch.FloatTensor)`): the decoder hidden states of all layers. + decoder_attentions (`tuple(torch.FloatTensor)): the decoder self attention weights of all layers. + cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the encoder last hidden state. + encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): + the encoder states of all layers including the embeddings. + encoder_attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`): + the encoder attention weights of all layers. + """ # noqa + + output_attentions = output_attentions if output_attentions else self.config.output_attentions + output_hidden_states = ( + output_hidden_states + if output_hidden_states else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + patch_images=patch_images, + patch_images_2=patch_images_2, + patch_masks=patch_masks, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + token_embeddings=token_embeddings, + sample_patch_num=sample_patch_num, + ) + + if decoder_input_ids.eq(self.config.pad_token_id).any(): + attention_mask = decoder_input_ids.eq(self.padding_idx) + + encoder_hidden_states = encoder_outputs.last_hidden_state + encoder_attention_mask = _expand_mask(encoder_outputs.padding_mask, + encoder_hidden_states.dtype, + decoder_input_ids.shape[-1]) + src_pos_embed = encoder_outputs.position_embedding + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + code_masks=code_masks, + src_pos_embed=src_pos_embed, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return Seq2SeqLMOutput( + logits=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + decoder_input_ids=None, + past=None, + attention_mask=None, + code_masks=None, + use_cache=False, + encoder_outputs=None, + **kwargs): + # if attention_mask is None: + attention_mask = decoder_input_ids.new_zeros(decoder_input_ids.shape) + + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + 'input_ids': None, + 'patch_images': None, + 'patch_images_2': None, + 'patch_masks': None, + 'token_embeddings': None, + 'sample_patch_num': None, + 'attention_mask': attention_mask, + 'encoder_outputs': encoder_outputs, + 'past_key_values': past, + 'decoder_input_ids': decoder_input_ids, + 'code_masks': code_masks, + 'use_cache': use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, + self.config.decoder_start_token_id) + + def _prepare_encoder_decoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str] = None): + # 1. get encoder + encoder = self.get_encoder() + + # 2. prepare encoder args and encoder kwargs from model kwargs + irrelevant_prefix = [ + 'decoder_', 'cross_attn', 'use_cache', 'attention_mask' + ] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + + if encoder_kwargs.get('patch_masks') is None: + encoder_kwargs['patch_masks'] = torch.tensor([True]) + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.main_input_name + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs['encoder_outputs']: ModelOutput = encoder( + **encoder_kwargs) + model_kwargs['attention_mask'] = None + + return model_kwargs + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + **model_kwargs, + ): + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat( + 1, expand_size).view(-1).to(input_ids.device)) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if 'token_type_ids' in model_kwargs: + token_type_ids = model_kwargs['token_type_ids'] + model_kwargs['token_type_ids'] = token_type_ids.index_select( + 0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs['attention_mask'] = attention_mask.index_select( + 0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError( + 'If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.' + ) + encoder_outputs[ + 'last_hidden_state'] = encoder_outputs.last_hidden_state.index_select( + 0, + expanded_return_idx.to( + encoder_outputs.last_hidden_state.device)) + encoder_outputs[ + 'position_embedding'] = encoder_outputs.position_embedding.index_select( + 0, + expanded_return_idx.to( + encoder_outputs.position_embedding.device)) + encoder_outputs[ + 'padding_mask'] = encoder_outputs.padding_mask.index_select( + 0, + expanded_return_idx.to( + encoder_outputs.padding_mask.device)) + model_kwargs['encoder_outputs'] = encoder_outputs + return input_ids, model_kwargs diff --git a/modelscope/models/multi_modal/ofa/resnet.py b/modelscope/models/multi_modal/ofa/resnet.py new file mode 100644 index 00000000..aad0f002 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/resnet.py @@ -0,0 +1,297 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + assert False + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + drop_path_rate=0.0): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = identity + self.drop_path(out) + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, + layers, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + drop_path_rate=0.0): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate) + self.layer2 = self._make_layer( + Bottleneck, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0], + drop_path_rate=drop_path_rate) + self.layer3 = self._make_layer( + Bottleneck, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + drop_path_rate=drop_path_rate) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, + (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilate=False, + drop_path_rate=0.0): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)] + for i in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + drop_path_rate=dpr[i])) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + return x + + def forward(self, x): + return self._forward_impl(x) diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa.py b/modelscope/models/multi_modal/ofa/tokenization_ofa.py new file mode 100644 index 00000000..fd50505c --- /dev/null +++ b/modelscope/models/multi_modal/ofa/tokenization_ofa.py @@ -0,0 +1,372 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OFA.""" +import collections +import os +from typing import List, Optional, Tuple + +from transformers import PreTrainedTokenizer +from transformers.models.bart.tokenization_bart import BartTokenizer +from transformers.models.bert.tokenization_bert import (BasicTokenizer, + WordpieceTokenizer) +from transformers.utils import logging + +from modelscope.utils.constant import ModelFile + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': 'vocab.json', 'merges_file': 'merges.txt'} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': { + 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', + }, + 'merges_file': { + 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', + }, + # OFA models are implemented to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'ofa-base': 1024, +} + +VOCAB_FILES_NAMES_ZH = {'vocab_file': ModelFile.VOCAB_FILE} + +PRETRAINED_VOCAB_FILES_MAP_ZH = { + 'vocab_file': { + 'bert-base-chinese': + 'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt', + } + # OFA models are implemented to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = { + 'ofa-base': 1024, +} + +PRETRAINED_INIT_CONFIGURATION_ZH = { + 'bert-base-chinese': { + 'do_lower_case': True + }, +} + + +class OFATokenizer(BartTokenizer): + """ + Construct a OFA tokenizer. + + [`~OFATokenizer`] is identical to [`BartTokenizer`] and runs end-to-end tokenization: punctuation splitting and + wordpiece. + + Refer to superclass [`BartTokenizer`] for usage examples and documentation concerning parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, 'r', encoding='utf-8') as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip('\n') + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class OFATokenizerZH(PreTrainedTokenizer): + r""" + Construct a OFA tokenizer. Based on WordPiece. + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES_ZH + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH + + def __init__(self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + bos_token='', + eos_token='', + sep_token='', + cls_token='', + unk_token='', + pad_token='', + mask_token='', + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs): + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " + 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([ + (ids, tok) for tok, ids in self.vocab.items() + ]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token=self.unk_token) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens): + + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = ' '.join(tokens).replace(' ##', '').strip() + return out_string + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ( + [0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, + (filename_prefix + '-' if filename_prefix else '') + + VOCAB_FILES_NAMES['vocab_file']) + else: + vocab_file = (filename_prefix + + '-' if filename_prefix else '') + save_directory + with open(vocab_file, 'w', encoding='utf-8') as writer: + for token, token_index in sorted( + self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f'Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive.' + ' Please check that the vocabulary is not corrupted!') + index = token_index + writer.write(token + '\n') + index += 1 + return (vocab_file, ) diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py b/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py new file mode 100644 index 00000000..db11370d --- /dev/null +++ b/modelscope/models/multi_modal/ofa/tokenization_ofa_fast.py @@ -0,0 +1,215 @@ +# Copyright 2022 OFA-Sys Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OFA.""" +from typing import List, Optional, Tuple + +import json +from tokenizers import normalizers +from transformers import PreTrainedTokenizerFast +from transformers.models.bart.tokenization_bart_fast import BartTokenizerFast +from transformers.utils import logging + +from modelscope.utils.constant import ModelFile +from .tokenization_ofa import OFATokenizer, OFATokenizerZH + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + 'vocab_file': 'vocab.json', + 'merges_file': 'merges.txt', + 'tokenizer_file': 'tokenizer.json' +} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': { + 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/vocab.json', + }, + 'merges_file': { + 'ofa-base': 'https://huggingface.co/ofa-base/resolve/main/merges.txt', + }, + 'tokenizer_file': { + 'ofa-base': + 'https://huggingface.co/ofa-base/resolve/main/tokenizer.json', + }, + # OFA models are implemented to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'ofa-base': 1024, +} + +VOCAB_FILES_NAMES_ZH = {'vocab_file': ModelFile.VOCAB_FILE} + +PRETRAINED_VOCAB_FILES_MAP_ZH = { + 'vocab_file': { + 'bert-base-chinese': + 'https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt', + } + # OFA models are implemeted to be compatible with both huggingface + # and modelscope frameworks. For all OFA models available on huggingface, + # please refer to https://huggingface.co/models?filter=ofa +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH = { + 'ofa-base': 1024, +} + +PRETRAINED_INIT_CONFIGURATION_ZH = { + 'bert-base-chinese': { + 'do_lower_case': True + }, +} + + +class OFATokenizerFast(BartTokenizerFast): + r""" + Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library). + + [`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting + and wordpiece. + + Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = OFATokenizer + + +class OFATokenizerZHFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" OFA tokenizer (backed by HuggingFace's *tokenizers* library). + + [`~OFATokenizerFast`] is identical to [`BartTokenizerFast`] and runs end-to-end tokenization: punctuation splitting + and wordpiece. + + Refer to superclass [`BartTokenizerFast`] for usage examples and documentation concerning parameters. + """ + vocab_files_names = VOCAB_FILES_NAMES_ZH + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_ZH + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION_ZH + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES_ZH + slow_tokenizer_class = OFATokenizerZH + + def __init__(self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + bos_token='', + eos_token='', + sep_token='', + cls_token='', + unk_token='', + pad_token='', + mask_token='', + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads( + self.backend_tokenizer.normalizer.__getstate__()) + if (normalizer_state.get('lowercase', do_lower_case) != do_lower_case + or normalizer_state.get('strip_accents', strip_accents) + != strip_accents or normalizer_state.get( + 'handle_chinese_chars', + tokenize_chinese_chars) != tokenize_chinese_chars): + normalizer_class = getattr(normalizers, + normalizer_state.pop('type')) + normalizer_state['lowercase'] = do_lower_case + normalizer_state['strip_accents'] = strip_accents + normalizer_state['handle_chinese_chars'] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class( + **normalizer_state) + + self.do_lower_case = do_lower_case + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save( + save_directory, name=filename_prefix) + return tuple(files) diff --git a/modelscope/models/multi_modal/ofa/utils/__init__.py b/modelscope/models/multi_modal/ofa/utils/__init__.py new file mode 100644 index 00000000..b937315b --- /dev/null +++ b/modelscope/models/multi_modal/ofa/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/modelscope/models/multi_modal/ofa/utils/constant.py b/modelscope/models/multi_modal/ofa/utils/constant.py new file mode 100644 index 00000000..b3776f8f --- /dev/null +++ b/modelscope/models/multi_modal/ofa/utils/constant.py @@ -0,0 +1,14 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + +OFA_TASK_KEY_MAPPING = { + Tasks.ocr_recognition: OutputKeys.TEXT, + Tasks.image_captioning: OutputKeys.CAPTION, + Tasks.text_summarization: OutputKeys.TEXT, + Tasks.visual_question_answering: OutputKeys.TEXT, + Tasks.visual_grounding: OutputKeys.BOXES, + Tasks.text_classification: OutputKeys.LABELS, + Tasks.image_classification: OutputKeys.LABELS, + Tasks.visual_entailment: OutputKeys.LABELS, +} diff --git a/modelscope/models/multi_modal/ofa/utils/utils.py b/modelscope/models/multi_modal/ofa/utils/utils.py new file mode 100644 index 00000000..c5aa8483 --- /dev/null +++ b/modelscope/models/multi_modal/ofa/utils/utils.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Optional + +import torch +import torch.nn as nn + + +def expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + r""" + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + return expanded_mask.masked_fill(expanded_mask.bool(), + torch.finfo(dtype).min) + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + r""" + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Args: + x (`nn.Modules`): input nn layers. + drop_prob (`float`): drop path ratio. + training (`bool`): whether is training or inference. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (1, x.shape[1], 1) + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + r""" + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Args: + drop_prob: drop path ratio. + """ + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) diff --git a/modelscope/models/multi_modal/ofa/vit.py b/modelscope/models/multi_modal/ofa/vit.py new file mode 100644 index 00000000..b6bba7ee --- /dev/null +++ b/modelscope/models/multi_modal/ofa/vit.py @@ -0,0 +1,155 @@ +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from fairseq.modules import LayerNorm +from torch import nn + +from .utils.utils import DropPath + +__all__ = [ + 'vit_base', + 'vit_large', + 'vit_large_336', + 'vit_huge', +] + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None, + drop_path_rate=0.0): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([ + ('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model)), + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + self.drop_path = DropPath(drop_path_rate) + + def attention(self, x: torch.Tensor): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None else None) + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.drop_path(self.attention(self.ln_1(x))) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + + def __init__( + self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + + def __init__( + self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.input_resolution = input_resolution + self.patch_size = patch_size + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.width = width + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer( + width, layers, heads, drop_path_rate=drop_path_rate) + + def forward(self, x: torch.Tensor): + resolution = x.shape[-2] + height, width = x.shape[-2] // self.patch_size, x.shape[ + -1] // self.patch_size + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + if resolution != self.input_resolution: + old_pe = self.positional_embedding[1:] + patch_num = self.input_resolution // self.patch_size + old_pe = old_pe.reshape(1, patch_num, patch_num, + -1).permute(0, 3, 1, 2) + new_pe = F.interpolate( + old_pe, size=(height, width), mode='bilinear') + new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1) + x = x + new_pe.to(x.dtype) + else: + x = x + self.positional_embedding[1:].to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + bz, seq, hidden = x.shape + x = x.transpose(1, 2).reshape(bz, hidden, height, width) + + return x + + +def vit_base(drop_path_rate: float = 0.0): + return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate) + + +def vit_large(drop_path_rate: float = 0.0): + return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate) + + +def vit_large_336(drop_path_rate: float = 0.0): + return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate) + + +def vit_huge(drop_path_rate: float = 0.0): + return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate) diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py new file mode 100644 index 00000000..fc578b25 --- /dev/null +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -0,0 +1,332 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os +import re +import string +from functools import partial +from os import path as osp +from typing import Any, Callable, Dict, List, Optional, Union + +import json +import torch.cuda +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.ofa.utils.collate import collate_tokens +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.trie import Trie +from .ofa import OFAModel, OFATokenizer, OFATokenizerZH +from .ofa.generate import sequence_generator as sg +from .ofa.generate.utils import move_to_device +from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks +from .ofa.utils.utils import expand_mask + +__all__ = ['OfaForAllTasks'] + + +@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) +@MODELS.register_module(Tasks.ocr_recognition, module_name=Models.ofa) +@MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) +@MODELS.register_module( + Tasks.visual_question_answering, module_name=Models.ofa) +@MODELS.register_module(Tasks.visual_entailment, module_name=Models.ofa) +@MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) +@MODELS.register_module(Tasks.text_summarization, module_name=Models.ofa) +@MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) +class OfaForAllTasks(TorchModel): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir=model_dir, *args, **kwargs) + model = OFAModel.from_pretrained(model_dir) + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + self.model = model.module if hasattr(model, 'module') else model + self.language = self.cfg.model.get('language', 'en') + if self.language == 'en': + self.tokenizer = OFATokenizer.from_pretrained(model_dir) + elif self.language in ['zh', 'cn']: + self.tokenizer = OFATokenizerZH.from_pretrained(model_dir) + else: + raise NotImplementedError + # there is some diff between here and our ofa code, + # there will be no need to use param: use_bpe + if not model.use_ofasys: + self.tokenizer.add_tokens( + [''.format(i) for i in range(8192)]) + self.tokenizer.add_tokens( + [''.format(i) for i in range(1000)]) + self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) + self.batch_size = self.cfg.model.get('batch_size', 1) + self.patch_image_size = self.cfg.model.get('patch_image_size', 480) + self.max_image_size = self.cfg.model.get('max_image_size', 512) + self.val_batch_size = self.cfg.model.get('valid_batch_size', + self.batch_size) + self.transtab = str.maketrans( + {key: None + for key in string.punctuation}) + self.gen_type = self.cfg.model.get('gen_type', 'generation') + assert self.gen_type in ['generation', 'traverse'], \ + 'model.gen_type must be in ["generation", "traverse"]' + self.bos_item = torch.LongTensor([self.tokenizer.bos_token_id]) + self.pad_item = torch.LongTensor([self.tokenizer.pad_token_id]) + self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id]) + self.index2ans = {} + self.ans2label_dict = {} + self.load_ans2label() + # Initialize generator + sg_args = { + 'tokenizer': self.tokenizer, + 'beam_size': 5, + 'max_len_b': 16, + 'min_len': 1, + 'no_repeat_ngram_size': 3, + 'constraint_range': None + } + if hasattr(self.cfg.model, 'beam_search'): + sg_args.update(self.cfg.model.beam_search) + if len(self.ans2label_dict) > 0: + self.constraint_trie = Trie(self.tokenizer.eos_token_id) + self.val_ans_l = [] + self.val_masks_l = [] + self.build_trie() + sg_args['constraint_trie'] = self.constraint_trie + else: + self.constraint_trie = None + self.generator = sg.SequenceGenerator(**sg_args) + inference_d = { + 'generation': self._text_gen_inference, + 'traverse': self._traverse_inference, + } + self.task_inference_mapping = { + Tasks.ocr_recognition: self._text_gen_inference, + Tasks.image_captioning: self._text_gen_inference, + Tasks.text_summarization: self._text_gen_inference, + Tasks.visual_grounding: self._visual_grounding_inference, + Tasks.visual_entailment: inference_d[self.gen_type], + Tasks.visual_question_answering: inference_d[self.gen_type], + Tasks.text_classification: inference_d[self.gen_type], + Tasks.image_classification: inference_d[self.gen_type], + } + pattern_str = '((?<=[^ a-zA-Z0-9.,:!?]) +| +(?=[^ a-zA-Z0-9.,:!?]))' + self.pattern = re.compile(pattern_str) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + input = move_to_device(input, self.model.device) + if self.model.training: + return self.model(**input['net_input']) + else: + return self.inference(input) + + def inference(self, input: Dict[str, Any]) -> Dict[str, Any]: + ret = self.task_inference_mapping[self.cfg.task](input) + if 'samples' in input: + ret['samples'] = input['samples'] + for key in [ + OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, + OutputKeys.LABELS, OutputKeys.SCORES + ]: + if key not in ret: + ret[key] = None + return ret + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + if not self.model.training and self.cfg.task == Tasks.image_captioning: + caption = input[OutputKeys.CAPTION] + result_l = list() + for cap in caption: + if self.language == 'en': + result_l.append(cap.translate(self.transtab).strip()) + else: + result_l.append(cap) + input[OutputKeys.CAPTION] = result_l + if self.gen_type == 'generation' and self.language in [ + 'zh', 'cn' + ] and self.cfg.task != Tasks.visual_grounding: + ret_l = list() + for text in input[OFA_TASK_KEY_MAPPING[self.cfg.task]]: + ret_l.append(self.detokenizer(text)) + input[OFA_TASK_KEY_MAPPING[self.cfg.task]] = ret_l + return input + + def _text_gen_inference(self, input): + gen_outputs = self.generator.generate([self.model], + input, + prefix_tokens=input.get( + 'prefix_tokens', None)) + gen_l = list() + for idx, gen_out in enumerate(gen_outputs): + if len(gen_out) > 0: + decode_tokens = gen_out[0]['tokens'] + if 'prefix_tokens' in input: + prefix_len = input['prefix_tokens'][idx].ne( + self.pad_item.to(self.model.device)).sum() + decode_tokens = decode_tokens[prefix_len:] + gen_l.append(decode_tokens) + else: + gen_l.append('') + result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True) + result = [item.strip() for item in result] + # text generation tasks have no score + ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result} + if self.cfg.task.endswith('classification'): + ret[OutputKeys.SCORES] = [1.0] * len(result) + return ret + + def _visual_grounding_inference(self, input): + gen_output = self.generator.generate([self.model], input) + tokens = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] + region_coord_l = list() + for i in range(len(tokens)): + region_coord_l.append(tokens[i][:-1] + - len(self.tokenizer.get_vocab().items()) + + self.cfg.num_bins) + region_tensor = torch.stack(region_coord_l, dim=0) + region_tensor = region_tensor / (self.cfg.num_bins + - 1) * self.max_image_size + region_tensor[:, ::2] /= input['w_resize_ratios'] + region_tensor[:, 1::2] /= input['h_resize_ratios'] + return { + OutputKeys.BOXES: + move_to_device(region_tensor, torch.device('cpu')).tolist(), + OutputKeys.SCORES: [1.0] * region_tensor.shape[0] + } + + def _traverse_inference(self, input): + encoder_input = dict() + for key in input['net_input'].keys(): + encoder_input[key] = input['net_input'][key] + encoder_out = self.model.encoder(**encoder_input) + valid_result = [] + for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l): + valid_size = len(val_ans) + valid_tgt_items = [ + torch.cat([ + torch.tensor(decoder_prompt[1:]).to('cpu'), valid_answer, + self.eos_item + ]) for decoder_prompt in input['decoder_prompts'] + for valid_answer in val_ans + ] + valid_prev_items = [ + torch.cat( + [torch.tensor(decoder_prompt).to('cpu'), valid_answer]) + for decoder_prompt in input['decoder_prompts'] + for valid_answer in val_ans + ] + valid_constraint_mask_items = [ + torch.cat([ + torch.zeros( + len(decoder_prompt) - 1, + valid_constraint_mask.size(1)).bool(), + valid_constraint_mask], dim=0) # yapf: disable + for decoder_prompt in input['decoder_prompts'] # yapf: disable + for valid_constraint_mask in val_masks] # yapf: disable + valid_tgt = collate_tokens( + valid_tgt_items, + pad_idx=self.tokenizer.pad_token_id).to(self.model.device) + valid_prev_output = collate_tokens( + valid_prev_items, + pad_idx=self.tokenizer.pad_token_id).to(self.model.device) + val_masks = collate_tokens( + valid_constraint_mask_items, + pad_idx=self.tokenizer.pad_token_id).to(self.model.device) + new_encoder_out = { + 'last_hidden_state': + encoder_out['last_hidden_state'].repeat_interleave( + valid_size, dim=0), + 'padding_mask': + encoder_out['padding_mask'].repeat_interleave( + valid_size, dim=0), + 'position_embedding': + encoder_out['position_embedding'].repeat_interleave( + valid_size, dim=0) + } + encoder_attention_mask = expand_mask( + new_encoder_out['padding_mask'], + new_encoder_out['last_hidden_state'].dtype, + valid_prev_output.shape[-1]) + + decoder_out = self.model.decoder( + valid_prev_output, + encoder_hidden_states=new_encoder_out['last_hidden_state'], + encoder_attention_mask=encoder_attention_mask, + src_pos_embed=new_encoder_out['position_embedding']) + + decoder_out[0].masked_fill_(~val_masks, -math.inf) + lprobs = self.model.get_normalized_probs( + decoder_out, log_probs=True) + scores = lprobs.gather( + dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1) + scores = scores.masked_fill( + valid_tgt.eq(self.tokenizer.pad_token_id), 0) + scores = scores.masked_fill((~val_masks).all(2), 0) + scores = scores.sum(1) + scores = scores.view(-1, valid_size) + valid_result.append(scores) + valid_result = torch.cat(valid_result, dim=-1) + predicts = valid_result.argmax(1).tolist() + probs = F.softmax(valid_result, dim=-1) + hyps = [self.index2ans[predict_index] for predict_index in predicts] + scores = [ + float(prob[idx].cpu().detach().numpy()) + for prob, idx in zip(probs, predicts) + ] + return {OutputKeys.LABELS: hyps, OutputKeys.SCORES: scores} + + def build_trie(self): + answer_item_list = [] + + for i, answer in enumerate(self.ans2label_dict.keys()): + answer_item = self.tokenizer( + ' ' + answer, return_tensors='pt', + add_special_tokens=False).input_ids.squeeze(0) + answer_item_list.append(answer_item) + self.index2ans[i] = answer + self.constraint_trie.insert([self.tokenizer.bos_token_id] + + answer_item.tolist() + + [self.tokenizer.eos_token_id]) + + constraint_mask_list = [] + for answer_item in answer_item_list: + constraint_mask = torch.zeros( + (len(answer_item) + 1, + len(self.tokenizer.get_vocab()))).bool() + for i in range(len(answer_item) + 1): + constraint_prefix_token = [self.tokenizer.bos_token_id + ] + answer_item[:i].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + constraint_mask_list.append(constraint_mask) + + for i in range(0, len(answer_item_list), self.val_batch_size): + self.val_ans_l += [answer_item_list[i:i + self.val_batch_size]] + self.val_masks_l += [ + constraint_mask_list[i:i + self.val_batch_size] + ] + + def load_ans2label(self): + if self.cfg.model.get('answer2label', None): + ans2label_file = osp.join(self.model_dir, + self.cfg.model.answer2label) + with open(ans2label_file, 'r') as reader: + self.ans2label_dict = json.load(reader) + + def save_pretrained(self, + target_folder: Union[str, os.PathLike], + save_checkpoint_names: Union[str, List[str]] = None, + save_function: Callable = None, + config: Optional[dict] = None, + **kwargs): + super(OfaForAllTasks, self). \ + save_pretrained(target_folder=target_folder, + save_checkpoint_names=save_checkpoint_names, + save_function=partial(save_function, with_meta=False), + config=config, + **kwargs) + + def detokenizer(self, text): + return self.pattern.sub('', text) diff --git a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py new file mode 100644 index 00000000..655d36d2 --- /dev/null +++ b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py @@ -0,0 +1,246 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from os import path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.cuda +from PIL import Image +from pkg_resources import packaging +from taming.models.vqgan import GumbelVQ, VQModel +from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, + ToTensor) + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.mmr.models.module_clip import CLIP +from modelscope.models.multi_modal.mmr.models.tokenization_clip import \ + SimpleTokenizer as ClipTokenizer +from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer +from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg +from modelscope.models.multi_modal.ofa.generate.search import Sampling +from modelscope.models.multi_modal.ofa.generate.utils import move_to_device +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + +__all__ = ['OfaForTextToImageSynthesis'] + + +def custom_to_pil(x): + x = x.detach().cpu() + x = torch.clamp(x, -1., 1.) + x = (x + 1.) / 2. + x = x.permute(1, 2, 0).numpy() + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == 'RGB': + x = x.convert('RGB') + return x + + +def load_vqgan(config, ckpt_path=None, is_gumbel=False): + if is_gumbel: + model = GumbelVQ(**config['model']['params']) + else: + model = VQModel(**config['model']['params']) + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location='cpu')['state_dict'] + missing, unexpected = model.load_state_dict(sd, strict=False) + return model.eval() + + +def build_clip_model(model_path): + state_dict = torch.load(model_path, map_location='cpu').state_dict() + vit = 'visual.proj' in state_dict + if vit: + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len( + set( + k.split('.')[2] for k in state_dict + if k.startswith(f'visual.layer{b}'))) + for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0] + output_width = round( + (state_dict['visual.attnpool.positional_embedding'].shape[0] + - 1)**0.5) + vision_patch_size = None + assert output_width**2 + 1 == state_dict[ + 'visual.attnpool.positional_embedding'].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) + + model = CLIP(embed_dim, image_resolution, vision_layers, vision_width, + vision_patch_size, context_length, vocab_size, + transformer_width, transformer_heads, transformer_layers) + + for key in ['input_resolution', 'context_length', 'vocab_size']: + if key in state_dict: + del state_dict[key] + + model.load_state_dict(state_dict) + return model.eval() + + +def _convert_image_to_rgb(image): + return image.convert('RGB') + + +def build_clip_transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + +@MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa) +class OfaForTextToImageSynthesis(Model): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir=model_dir, *args, **kwargs) + # Initialize ofa + model = OFAModel.from_pretrained(model_dir) + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + self.model = model.module if hasattr(model, 'module') else model + self.tokenizer = OFATokenizer.from_pretrained(model_dir) + self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) + self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) + self._device = torch.device('cuda') if torch.cuda.is_available() \ + else torch.device('cpu') + self.model.to(self._device) + + # Initialize vqgan + vqgan_config = json.load( + open(os.path.join(model_dir, 'vqgan_config.json'))) + self.vqgan_model = load_vqgan( + vqgan_config, + ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'), + is_gumbel=True).to(self._device) + + # Initialize OpenAI clip + + self.clip_tokenizer = ClipTokenizer(model_dir) + self.clip_model = build_clip_model( + os.path.join(model_dir, 'ViT-B-16.pt')) + self.clip_preprocess = build_clip_transform( + self.clip_model.visual.input_resolution) + + self.clip_model.to(self._device) + self.clip_model.eval() + + # Initialize generator + sampling = Sampling(self.tokenizer, sampling_topp=0.9) + sg_args = { + 'tokenizer': self.tokenizer, + 'beam_size': 2, + 'max_len_b': 1024, + 'min_len': 1024, + 'search_strategy': sampling, + 'gen_code': True, + 'constraint_range': '50265,58457' + } + if hasattr(self.cfg.model, 'beam_search'): + sg_args.update(self.cfg.model.beam_search) + self.generator = sg.SequenceGenerator(**sg_args) + + def clip_tokenize(self, texts, context_length=77, truncate=False): + + if isinstance(texts, str): + texts = [texts] + + sot_token = self.clip_tokenizer.encoder['<|startoftext|>'] + eot_token = self.clip_tokenizer.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + self.clip_tokenizer.encode(text) + + [eot_token] for text in texts] + if packaging.version.parse( + torch.__version__) < packaging.version.parse('1.8.0'): + result = torch.zeros( + len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros( + len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f'Input {texts[i]} is too long for context length {context_length}' + ) + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + def forward(self, input: Dict[str, Any]): + + text = input['samples'][0]['text'] + input = move_to_device(input, self._device) + clip_text_input = self.clip_tokenize([text]).to(self._device) + + gen_output = self.generator.generate([self.model], input) + gen_tokens = torch.stack( + [item['tokens'][:-1] for item in gen_output[0]], dim=0) + codes = gen_tokens.view(-1, 32, 32) - 50265 + + quant_b = self.vqgan_model.quantize.get_codebook_entry( + codes.view(-1), + list(codes.size()) + [self.vqgan_model.quantize.embedding_dim]) + imgs = self.vqgan_model.decode(quant_b) + + sample_num = imgs.size()[0] + pil_imgs = [custom_to_pil(imgs[i]) for i in range(sample_num)] + + clip_image_input = torch.stack( + [self.clip_preprocess(img) for img in pil_imgs], + dim=0).to(self._device) + + with torch.no_grad(): + hyp_image_features = self.clip_model.encode_image(clip_image_input) + hyp_image_features /= hyp_image_features.norm(dim=-1, keepdim=True) + text_features = self.clip_model.encode_text(clip_text_input) + text_features /= text_features.norm(dim=-1, keepdim=True) + ti_similarity = hyp_image_features @ text_features.T + + sorted_score, ti_indices = torch.sort( + ti_similarity.view(-1), descending=True) + + pil_imgs_orderby_ti = [pil_imgs[index] for index in ti_indices] + return pil_imgs_orderby_ti[0] diff --git a/modelscope/models/multi_modal/team/__init__.py b/modelscope/models/multi_modal/team/__init__.py new file mode 100644 index 00000000..58bbdca5 --- /dev/null +++ b/modelscope/models/multi_modal/team/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from .team_model import TEAMForMultiModalSimilarity diff --git a/modelscope/models/multi_modal/team/team_model.py b/modelscope/models/multi_modal/team/team_model.py new file mode 100644 index 00000000..8c0e288a --- /dev/null +++ b/modelscope/models/multi_modal/team/team_model.py @@ -0,0 +1,127 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from tokenizers import BertWordPieceTokenizer +from torchvision.transforms import Compose, Normalize, Resize, ToTensor + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .utils import TEAM, BertWrapper, CLIPVisionWrapper, CrossLayer + +logger = get_logger() + +__all__ = ['TEAMForMultiModalSimilarity'] + + +@MODELS.register_module(Tasks.multi_modal_similarity, module_name=Models.team) +class TEAMForMultiModalSimilarity(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + text_model = BertWrapper( + config_json='{}/text_config.json'.format(model_dir), + feat_dim=768, + token_dim=1024) + text_model.bert.cls = None + image_model = CLIPVisionWrapper() + + self.model = TEAM( + text_model, + image_model, + pretrained='{}/{}'.format(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE)) + self.model.eval() + + self.device_id = device_id + if self.device_id >= 0 and torch.cuda.is_available(): + self.model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device_id = -1 + logger.info('Use CPU for inference') + + self.text_tokenizer = BertWordPieceTokenizer( + '{}/{}'.format(model_dir, ModelFile.VOCAB_FILE), lowercase=False) + self.text_tokenizer.enable_truncation(max_length=30) + + norm_op = Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)) + self.img_preprocessor = Compose([ + Resize((224, 224), interpolation=Image.BICUBIC), + ToTensor(), norm_op + ]) + + def tokenize_text(self, text_str): + tokens = self.text_tokenizer.encode(text_str) + max_tokens = 30 + text_ids_tensor = torch.zeros((1, max_tokens)).long() + text_mask_tensor = torch.zeros((1, max_tokens)) + text_ids, text_mask = tokens.ids, tokens.attention_mask + text_ids_tensor[0, 0:len(text_ids)] = torch.tensor(text_ids) + text_mask_tensor[0, 0:len(text_mask)] = torch.tensor(text_mask) + return text_ids_tensor, text_mask_tensor + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with torch.no_grad(): + if 'img' in input and input['img'] is not None: + input_img = input['img'] + input_img = LoadImage.convert_to_img(input_img) + img_tensor = self.img_preprocessor(input_img)[None, ...] + + if self.device_id >= 0: + img_tensor = img_tensor.to('cuda:{}'.format( + self.device_id)) + _, _, image_feature, image_tensors = self.model.get_feature( + None, None, img_tensor) + image_feature = image_feature.cpu().numpy() + else: + image_feature, image_tensors = None, None + + if 'text' in input and input['text'] is not None: + text_str = input['text'] + if isinstance(text_str, str): + text_ids_tensor, text_mask_tensor = self.tokenize_text( + text_str) + else: + raise TypeError( + f'text should be str, but got {type(text_str)}') + + if self.device_id >= 0: + text_ids_tensor = text_ids_tensor.to('cuda:{}'.format( + self.device_id)) + text_mask_tensor = text_mask_tensor.to('cuda:{}'.format( + self.device_id)) + text_feature, text_tensors, _, _ = self.model.get_feature( + text_ids_tensor, text_mask_tensor, None) + text_feature = text_feature.cpu().numpy() + else: + text_tensors, text_mask_tensor = None, None + + if text_tensors is not None and text_mask_tensor is not None and image_tensors is not None: + score = self.model.get_cross_score(text_tensors, + text_mask_tensor, + image_tensors)[0].item() + else: + score = None + output = { + OutputKeys.IMG_EMBEDDING: image_feature, + OutputKeys.TEXT_EMBEDDING: text_feature, + OutputKeys.SCORES: score + } + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/models/multi_modal/team/utils.py b/modelscope/models/multi_modal/team/utils.py new file mode 100644 index 00000000..73919179 --- /dev/null +++ b/modelscope/models/multi_modal/team/utils.py @@ -0,0 +1,329 @@ +# Copyright 2021 The OpenAI Team Authors. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +# +# The implementation here is modified based on OpenAI CLIP, +# originally MIT License, Copyright (c) 2021 OpenAI, +# and publicly available at https://github.com/openai/CLIP/. + +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch import nn +from transformers import BertConfig, BertForMaskedLM + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + use_gc=False): + super().__init__() + self.use_gc = use_gc + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + if self.use_gc: + for each_block in self.resblocks: + x = checkpoint.checkpoint(each_block, x) + return x + else: + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + use_gc=False): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads, use_gc=use_gc) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + class_embedding = self.class_embedding.to(x.dtype) + \ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([class_embedding, x], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIPVisionWrapper(nn.Module): + + def __init__(self, ): + super().__init__() + self.vision_transformer = VisionTransformer( + input_resolution=224, + patch_size=14, + width=1024, + layers=24, + heads=16, + output_dim=768) + + def forward(self, x): + x = self.vision_transformer.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + class_embedding = self.vision_transformer.class_embedding.to(x.dtype) + \ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([class_embedding, x], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.vision_transformer.positional_embedding.to(x.dtype) + x = self.vision_transformer.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.vision_transformer.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x_tensor = x.clone() + x = self.vision_transformer.ln_post(x[:, 0, :]) + + if self.vision_transformer.proj is not None: + x = x @ self.vision_transformer.proj + + return x, x_tensor + + +class BertWrapper(nn.Module): + + def __init__(self, config_json, feat_dim, token_dim): + super(BertWrapper, self).__init__() + bert_config = BertConfig.from_json_file(config_json) + self.bert = BertForMaskedLM(bert_config).bert + + self.projector = nn.Linear(768, feat_dim, bias=False) + self.projector_token_embeds = nn.Linear(768, token_dim) + + def forward(self, input_ids, attention_mask): + trans_features = { + 'input_ids': input_ids, + 'attention_mask': attention_mask + } + output_states = self.bert(**trans_features, return_dict=False) + output_tokens = output_states[0] + + cls_tokens = output_tokens[:, 0, :] # CLS token is first token + + return self.projector(cls_tokens), self.projector_token_embeds( + output_tokens) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CrossLayer(nn.Module): + + def __init__(self, feat_dim, mlp_ratio): + super(CrossLayer, self).__init__() + self.norm1 = nn.LayerNorm(feat_dim) + self.norm2 = nn.LayerNorm(feat_dim) + self.norm3 = nn.LayerNorm(feat_dim) + + self.self_attn = nn.MultiheadAttention( + embed_dim=feat_dim, num_heads=16) + self.cross_attn = nn.MultiheadAttention( + embed_dim=feat_dim, num_heads=16) + self.ffn = Mlp( + in_features=feat_dim, + hidden_features=feat_dim * mlp_ratio, + drop=0.1) + + self.dropout1 = nn.Dropout(0.1) + self.dropout2 = nn.Dropout(0.1) + self.dropout3 = nn.Dropout(0.1) + + def forward(self, text_tensors, text_masks, image_tensors, + retrieved_tensors): + retrieved_tensors_res = self.norm1(retrieved_tensors) + retrieved_tensors_res = self.self_attn( + (text_tensors + retrieved_tensors_res).permute(1, 0, 2), + (text_tensors + retrieved_tensors_res).permute(1, 0, 2), + retrieved_tensors_res.permute(1, 0, 2), + key_padding_mask=(text_masks == 0), + )[0].permute(1, 0, 2) + retrieved_tensors = retrieved_tensors + self.dropout1( + retrieved_tensors_res) + + retrieved_tensors_res = self.norm2(retrieved_tensors) + retrieved_tensors_res = self.cross_attn( + (text_tensors + retrieved_tensors_res).permute(1, 0, 2), + image_tensors.permute(1, 0, 2), + image_tensors.permute(1, 0, 2))[0].permute(1, 0, 2) + retrieved_tensors = retrieved_tensors + self.dropout2( + retrieved_tensors_res) + + retrieved_tensors_res = self.norm3(retrieved_tensors) + retrieved_tensors = retrieved_tensors + self.dropout3( + self.ffn(retrieved_tensors_res)) + + return retrieved_tensors + + +class TEAM(nn.Module): + + def __init__(self, text_model, image_model, pretrained): + super(TEAM, self).__init__() + self.text_model = text_model + self.image_model = image_model + + self.cross_model = nn.ModuleList( + [CrossLayer(feat_dim=1024, mlp_ratio=2)]) + + self.image_tensor_fc = nn.Linear(1024, 768) + self.text_tensor_fc = nn.Linear(1024, 768) + + params = torch.load(pretrained, 'cpu') + self.load_state_dict(params, strict=True) + + def get_feature(self, text_data=None, text_mask=None, img_tensor=None): + if text_data is not None: + text_feature, text_tensors = self.text_model(text_data, text_mask) + text_feature = F.normalize(text_feature, p=2.0, dim=1) + else: + text_feature, text_tensors = None, None + + if img_tensor is not None: + image_feature, image_tensors = self.image_model(img_tensor) + image_feature = F.normalize(image_feature, p=2.0, dim=1) + else: + image_feature, image_tensors = None, None + + return text_feature, text_tensors, image_feature, image_tensors + + def get_cross_score(self, text_tensors, text_mask, image_tensors): + retrieved_tensors = torch.zeros_like(text_tensors) + pair_score_list = [] + text_tensors_proj = self.text_tensor_fc(text_tensors) + text_mask_float = text_mask.type(text_tensors_proj.dtype) + for each_cross_model in self.cross_model: + retrieved_tensors = each_cross_model(text_tensors, text_mask, + image_tensors, + retrieved_tensors) + retrieved_tensors_proj = self.image_tensor_fc(retrieved_tensors) + + pair_score = torch.sum( + F.normalize(retrieved_tensors_proj, p=2.0, dim=2) + * F.normalize(text_tensors_proj, p=2.0, dim=2), + dim=2) + pair_score_reduced = torch.sum( + pair_score * text_mask_float, dim=1) / torch.clamp( + torch.sum(text_mask_float, dim=1), min=1.0) + pair_score_list.append(pair_score_reduced) + return pair_score_list diff --git a/modelscope/models/nlp/T5/__init__.py b/modelscope/models/nlp/T5/__init__.py new file mode 100644 index 00000000..cb0921c6 --- /dev/null +++ b/modelscope/models/nlp/T5/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .backbone import T5Model + from .text2text_generation import T5ForConditionalGeneration + +else: + _import_structure = { + 'backbone': ['T5Model'], + 'text2text_generation': ['T5ForConditionalGeneration'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/T5/backbone.py b/modelscope/models/nlp/T5/backbone.py new file mode 100644 index 00000000..9a46d980 --- /dev/null +++ b/modelscope/models/nlp/T5/backbone.py @@ -0,0 +1,1531 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" + +import copy +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_utils import (PreTrainedModel, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.utils import (DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, replace_return_docstrings) +from transformers.utils.model_parallel_utils import (assert_device_map, + get_device_map) + +from modelscope.metainfo import Models +from modelscope.models.base import Model, Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import (BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqModelOutput) +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from .configuration import T5Config + +logger = get_logger(__name__) + + +################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' + 'https://www.tensorflow.org/install/ for installation instructions.' + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f'Converting TensorFlow checkpoint from {tf_path}') + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f'Loading TF weight {name} with shape {shape}') + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in [ + 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', + 'AdamWeightDecayOptimizer_1', 'global_step' + ] for n in name): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if '_slot_' in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + scope_names = re.split(r'_(\d+)', m_name) + else: + scope_names = [m_name] + if scope_names[0] in ['kernel', 'scale', 'embedding']: + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'self_attention': + pointer = getattr(pointer, 'layer') + pointer = pointer[0] + elif scope_names[0] == 'enc_dec_attention': + pointer = getattr(pointer, 'layer') + pointer = pointer[1] + elif scope_names[0] == 'dense_relu_dense': + pointer = getattr(pointer, 'layer') + pointer = pointer[2] + elif scope_names[0] == 'rms_norm': + if hasattr(pointer, 'layer_norm'): + pointer = getattr(pointer, 'layer_norm') + elif hasattr(pointer, 'final_layer_norm'): + pointer = getattr(pointer, 'final_layer_norm') + elif scope_names[0] == 'scale': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif scope_names[0] == 'squad': + pointer = getattr(pointer, 'classifier') + elif scope_names[0] == 'decoder' and name[1] == 'logits': + continue + elif scope_names[0] == 'logits': + pointer = getattr(pointer, 'lm_head') + elif scope_names[0] == 'wi' and len( + scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f'wi_{scope_names[1]}') + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ['kernel', 'scale', 'embedding']: + pointer = getattr(pointer, 'weight') + if scope_names[0] != 'embedding': + logger.info( + f'Transposing numpy weight of shape {array.shape} for {name}') + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f'Initialize PyTorch weight {name}') + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info( + f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}." + ) + return model + + +class T5LayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean( + -1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseReluDense(nn.Module): + + def __init__(self, config: T5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = nn.functional.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedGeluDense(nn.Module): + + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu_act = ACT2FN['gelu_new'] + + def forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + + def __init__(self, config: T5Config): + super().__init__() + if config.feed_forward_proj == 'relu': + self.DenseReluDense = T5DenseReluDense(config) + elif config.feed_forward_proj == 'gated-gelu': + self.DenseReluDense = T5DenseGatedGeluDense(config) + else: + raise ValueError( + f'{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`' + ) + + self.layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to( + torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in + # positions up to max_distance + relateive_pos_log = torch.log(relative_position.float() / max_exact) + max_dis_log = math.log(max_distance / max_exact) + origin_relative_position = relateive_pos_log / max_dis_log * ( + num_buckets - max_exact) + relative_postion_if_large = max_exact + origin_relative_position.to( + torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, + torch.full_like(relative_postion_if_large, num_buckets - 1)) + + relative_buckets += torch.where(is_small, relative_position, + relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange( + query_length, + dtype=torch.long, + device=self.relative_attention_bias.weight.device)[:, None] + memory_position = torch.arange( + key_length, + dtype=torch.long, + device=self.relative_attention_bias.weight.device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f'past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states' + real_seq_length += past_key_value[0].shape[ + 2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[ + 1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, + self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view( + batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, + past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], + dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q( + hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, + past_key_value[0] if past_key_value is not None else None) + value_states = project( + hidden_states, self.v, key_value_states, + past_key_value[1] if past_key_value is not None else None) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1):, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + attn_weights = nn.functional.softmax( + scores.float(), dim=-1).type_as( + scores) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul( + attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, + value_states) if (self.is_decoder + and use_cache) else None + outputs = (attn_output, ) + (present_key_value_state, ) + ( + position_bias, ) + + if output_attentions: + outputs = outputs + (attn_weights, ) + return outputs + + +class T5LayerSelfAttention(nn.Module): + + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states, + ) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention( + config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output, + ) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + if not self.is_decoder: + logger.warning( + '`past_key_values` is passed to the encoder. Please make sure this is intended.' + ) + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f'There should be {expected_num_past_key_values} past states. ' + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f'Got {len(past_key_value)} past key / value states') + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ + 2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf( + hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf( + hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[ + 1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf( + hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states, ) + + if use_cache: + outputs = outputs + (present_key_value_state, ) + attention_outputs + else: + outputs = outputs + attention_outputs + + # hidden-states, present_key_value_states, (self-attention position + # bias), (self-attention weights), (cross-attention position bias), + # (cross-attention weights) + return outputs + + +class T5PreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface + for downloading and loading pretrained models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = 'transformer' + is_parallelizable = True + supports_gradient_checkpointing = True + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + 'decoder_input_ids': input_ids, + 'input_ids': input_ids, + 'decoder_attention_mask': input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, T5Model): + # Mesh TensorFlow embeddings initialization See + # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseReluDense): + # Mesh TensorFlow FF initialization See + # https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and + # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model)**-0.5)) + if hasattr(module.wi, 'bias') and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff)**-0.5)) + if hasattr(module.wo, 'bias') and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedGeluDense): + module.wi_0.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model)**-0.5)) + if hasattr(module.wi_0, 'bias') and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model)**-0.5)) + if hasattr(module.wi_1, 'bias') and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff)**-0.5)) + if hasattr(module.wo, 'bias') and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before + # softmax See + # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim)**-0.5)) + module.k.weight.data.normal_( + mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_( + mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * key_value_proj_dim)**-0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_( + mean=0.0, std=factor * ((d_model)**-0.5)) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (T5Attention, T5Stack)): + module.gradient_checkpointing = value + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert ( + decoder_start_token_id is not None + ), 'self.model.config.decoder_start_token_id has to be defined.' + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1, ), + decoder_start_token_id) + shifted_input_ids = torch.cat( + [shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, 'self.model.config.pad_token_id has to be defined.' + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item( + ), 'Verify that `shifted_input_ids` has only positive values' + + return shifted_input_ids + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the + label information. num_labels: An optional arg to tell the + model how many classes to initialize. + Method will call utils.parse_label_mapping + if num_labels not supplied. If num_labels is + not found, the model will use the default + setting (2 classes). + + Returns: + The loaded model, which is initialized by + transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.get('model_dir', None) + if model_dir is None: + config = T5Config(**kwargs) + model = cls(config) + else: + model_kwargs = {} + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + model.model_dir = model_dir + return model + + +class T5Stack(T5PreTrainedModel): + + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList([ + T5Block(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers) + ]) + self.final_layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + def parallelize(self, device_map=None): + r""" + This is an experimental feature and is a subject to change at a + moment's notice. + + Uses a device map to distribute attention modules of the model + across several devices. If no device map is given, it will evenly + distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note + that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric + reasons). That means that the first device should have fewer + attention modules mapped to it than other devices. For + reference, the t5 models have the following number of + attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example: + + ```python # Here is an example of a device map on a machine with 4 + GPUs # using t5-3b, which has a total of 24 attention modules: model + = T5ForConditionalGeneration.from_pretrained("t5-3b") device_map = { + 0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9], 2: [10, 11, 12, 13, 14, + 15, 16], 3: [17, 18, 19, 20, 21, 22, 23], + } model.parallelize(device_map) ``` all of the parallelize methods + in this file are the same + + """ + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) + if device_map is None else device_map) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = 'cpu' if 'cpu' in self.device_map.keys( + ) else 'cuda:' + str(min(self.device_map.keys())) + self.last_device = 'cuda:' + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = 'cuda:' + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + def deparallelize(self): + r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python # On a 4 GPU machine with t5-3b: model = + T5ForConditionalGeneration.from_pretrained("t5-3b") device_map = { + 0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9], 2: [10, 11, 12, 13, 14, + 15, 16], 3: [17, 18, 19, 20, 21, 22, 23], + } model.parallelize(device_map) # Splits the model across several + devices model.deparallelize() # Put the model back on cpu and + cleans memory by calling torch.cuda.empty_cache() ``` + + all of the deparallelize methods in this file are the same + """ + self.model_parallel = False + self.device_map = None + self.first_device = 'cpu' + self.last_device = 'cpu' + for i in range(len(self.block)): + self.block[i] = self.block[i].to('cpu') + self.embed_tokens = self.embed_tokens.to('cpu') + self.final_layer_norm = self.final_layer_norm.to('cpu') + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = 'decoder_' if self.is_decoder else '' + raise ValueError( + f'You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = 'decoder_' if self.is_decoder else '' + raise ValueError( + f'You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds' + ) + + if inputs_embeds is None: + assert self.embed_tokens is not None, 'You have to initialize the model with valid token embeddings' + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[ + 2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f'`use_cache` can only be set to `True` if {self} is used as a decoder' + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to( + inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, + encoder_seq_length, + device=inputs_embeds.device, + dtype=torch.long) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, + self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions + and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, + past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to( + hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to( + hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to( + hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( + hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return tuple( + module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: hidden-states, key-value-states, + # (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + ( + None, ) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer + # store them layer_outputs = hidden-states, key-value-states + # (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + ( + present_key_value_state, ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3], ) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + ( + layer_outputs[5], ) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and 'cuda:' + str(k) != self.last_device: + hidden_states = hidden_states.to('cuda:' + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and +`decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`, +but this feature is deprecated and will be removed in future versions. If you do +not want to use any `decoder_head_mask` now, please set `decoder_head_mask = +torch.ones(num_layers, num_heads)`. +""" + + +@MODELS.register_module(group_key=Tasks.backbone, module_name=Models.T5) +class T5Model(T5PreTrainedModel): + """The bare T5 Model transformer outputting raw hidden-states without any + specific head on top. + + The T5 model was proposed in [Exploring the Limits of Transfer Learning with + a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by + Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, + Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder + transformer pre-trained in a text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass + documentation for the generic methods the library implements for all its + model (such as downloading or saving, resizing the input embeddings, pruning + heads etc.) + + This model is also a PyTorch + [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch + documentation for all matter related to general usage and behavior. + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters + of the model. + Initializing with a config file does not load the weights associated + with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + """ + _keys_to_ignore_on_load_missing = [ + r'encoder\.embed_tokens\.weight', + r'decoder\.embed_tokens\.weight', + ] + _keys_to_ignore_on_load_unexpected = [ + r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map( + len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None else device_map) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to('cpu') + self.decoder = self.decoder.to('cpu') + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of + heads to prune in this layer} See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model + with relative position embeddings so you should be able to pad the + inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a + look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, + sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask + values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, + target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`T5Tokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for + `decoder_input_ids` generation. If `past_key_values` is used, + optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining + take a look at [T5 Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, + target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in + `decoder_input_ids`. Causal mask will also be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, + num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the + encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or + `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the + decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or + `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in + the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, + `optional`: *attentions*) `last_hidden_state` of shape `(batch_size, + sequence_length, hidden_size)` is a sequence of hidden states at the + output of the last layer of the encoder. Used in the cross-attention + of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length + `config.n_layers` with each tuple having 4 tensors of shape + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention + blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only + the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead + of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if you want + more control over how to convert `input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, + target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to + directly pass an embedded representation. If `past_key_values` is + used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more + control over how to convert `decoder_input_ids` indices into + associated vectors than the model's internal embedding lookup + matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, + `decoder_inputs_embeds` takes the value of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned + and can be used to speed up decoding (see `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain + tuple. + Returns: + + Example: + + ```python >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5Model.from_pretrained("t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + >>> ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] + if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to( + self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) diff --git a/modelscope/models/nlp/T5/configuration.py b/modelscope/models/nlp/T5/configuration.py new file mode 100644 index 00000000..1f9a965e --- /dev/null +++ b/modelscope/models/nlp/T5/configuration.py @@ -0,0 +1,175 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2020, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" T5 model configuration""" +from typing import Mapping + +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxSeq2SeqConfigWithPast + +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + + +class T5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to + instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the T5 + [t5-small](https://huggingface.co/t5-small) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model // + num_heads`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `T5Block`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the + `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + model_type = 't5' + keys_to_ignore_at_inference = ['past_key_values'] + attribute_map = { + 'hidden_size': 'd_model', + 'num_attention_heads': 'num_heads', + 'num_hidden_layers': 'num_layers' + } + + def __init__(self, + vocab_size=32128, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj='relu', + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + **kwargs): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = (num_decoder_layers if num_decoder_layers + is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split('-') + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == 'gated' + + if len(act_info) > 1 and act_info[0] != 'gated' or len(act_info) > 2: + raise ValueError( + f'`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer.' + 'Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. ' + "'gated-gelu' or 'relu'") + + # for backwards compatibility + if feed_forward_proj == 'gated-gelu': + self.dense_act_fn = 'gelu_new' + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + 'input_ids': { + 0: 'batch', + 1: 'encoder_sequence' + }, + 'attention_mask': { + 0: 'batch', + 1: 'encoder_sequence' + }, + } + if self.use_past: + common_inputs['attention_mask'][ + 1] = 'past_encoder_sequence + sequence' + common_inputs['decoder_input_ids'] = {0: 'batch'} + common_inputs['decoder_attention_mask'] = { + 0: 'batch', + 1: 'past_decoder_sequence + sequence' + } + else: + common_inputs['decoder_input_ids'] = { + 0: 'batch', + 1: 'decoder_sequence' + } + common_inputs['decoder_attention_mask'] = { + 0: 'batch', + 1: 'decoder_sequence' + } + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction='inputs') + + return common_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/modelscope/models/nlp/T5/text2text_generation.py b/modelscope/models/nlp/T5/text2text_generation.py new file mode 100644 index 00000000..c4dcdfdb --- /dev/null +++ b/modelscope/models/nlp/T5/text2text_generation.py @@ -0,0 +1,455 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.utils.model_parallel_utils import (assert_device_map, + get_device_map) + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import BaseModelOutput, Seq2SeqLMOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from .backbone import T5PreTrainedModel, T5Stack +from .configuration import T5Config + +logger = get_logger(__name__) + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and +`decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`, +but this feature is deprecated and will be removed in future versions. If you do +not want to use any `decoder_head_mask` now, please set `decoder_head_mask = +torch.ones(num_layers, num_heads)`. +""" + + +@MODELS.register_module( + group_key=Tasks.text2text_generation, + module_name=Models.T5, +) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r'encoder\.embed_tokens\.weight', + r'decoder\.embed_tokens\.weight', + r'lm_head\.weight', + ] + _keys_to_ignore_on_load_unexpected = [ + r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map( + len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None else device_map) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to('cpu') + self.decoder = self.decoder.to('cpu') + self.lm_head = self.lm_head.to('cpu') + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def forward(self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model + with relative position embeddings so you should be able to pad the + inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a + look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, + sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask + values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, + target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`T5Tokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for + `decoder_input_ids` generation. If `past_key_values` is used, + optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining + take a look at [T5 Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, + target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in + `decoder_input_ids`. Causal mask will also be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, + num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the + encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or + `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the + decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or + `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in + the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, + `optional`: *attentions*) `last_hidden_state` of shape `(batch_size, + sequence_length, hidden_size)` is a sequence of hidden states at the + output of the last layer of the encoder. Used in the cross-attention + of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length + `config.n_layers` with each tuple having 4 tensors of shape + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention + blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only + the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead + of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if you want + more control over how to convert `input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, + target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to + directly pass an embedded representation. If `past_key_values` is + used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more + control over how to convert `decoder_input_ids` indices into + associated vectors than the model's internal embedding lookup + matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, + `decoder_inputs_embeds` takes the value of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned + and can be used to speed up decoding (see `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain + tuple. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. + Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All + labels set to `-100` are ignored (masked), the loss is only computed + for labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python >>> from transformers import T5Tokenizer, + T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + >>> ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] + if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to( + self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab See + # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct( + lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss + # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs + return ((loss, ) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'decoder_input_ids': input_ids, + 'past_key_values': past, + 'encoder_outputs': encoder_outputs, + 'attention_mask': attention_mask, + 'head_mask': head_mask, + 'decoder_head_mask': decoder_head_mask, + 'cross_attn_head_mask': cross_attn_head_mask, + 'use_cache': use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning( + 'You might want to consider setting `use_cache=True` to speed up decoding' + ) + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device)), ) + + assert reordered_layer_past_states[0].shape == layer_past_states[ + 0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, ) + return reordered_decoder_past diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py new file mode 100644 index 00000000..1d71469a --- /dev/null +++ b/modelscope/models/nlp/__init__.py @@ -0,0 +1,123 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .bart import BartForTextErrorCorrection + from .bert import ( + BertForMaskedLM, + BertForTextRanking, + BertForSentenceEmbedding, + BertForSequenceClassification, + BertForTokenClassification, + BertForDocumentSegmentation, + BertModel, + BertConfig, + ) + from .csanmt import CsanmtForTranslation + from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model + from .gpt_neo import GPTNeoModel + from .gpt3 import GPT3ForTextGeneration + from .heads import SequenceClassificationHead + from .palm_v2 import PalmForTextGeneration + from .ponet import PoNetForMaskedLM, PoNetModel, PoNetConfig + from .space import SpaceForDialogIntent, SpaceForDialogModeling, SpaceForDST + from .space_T_cn import TableQuestionAnswering + from .space_T_en import StarForTextToSql + from .structbert import ( + SbertForFaqQuestionAnswering, + SbertForMaskedLM, + SbertForSequenceClassification, + SbertForTokenClassification, + SbertTokenizer, + SbertModel, + SbertTokenizerFast, + ) + from .T5 import T5ForConditionalGeneration + from .mglm import MGLMForTextSummarization + from .task_models import ( + FeatureExtractionModel, + InformationExtractionModel, + LSTMCRFForNamedEntityRecognition, + SequenceClassificationModel, + SingleBackboneTaskModelBase, + TaskModelForTextGeneration, + TokenClassificationModel, + TransformerCRFForNamedEntityRecognition, + ) + from .veco import (VecoConfig, VecoForMaskedLM, + VecoForSequenceClassification, + VecoForTokenClassification, VecoModel, VecoTokenizer, + VecoTokenizerFast) + from .bloom import BloomModel +else: + _import_structure = { + 'backbones': ['SbertModel'], + 'bart': ['BartForTextErrorCorrection'], + 'csanmt': ['CsanmtForTranslation'], + 'heads': ['SequenceClassificationHead'], + 'gpt3': ['GPT3ForTextGeneration'], + 'structbert': [ + 'SbertForFaqQuestionAnswering', + 'SbertForMaskedLM', + 'SbertForSequenceClassification', + 'SbertForTokenClassification', + 'SbertTokenizer', + 'SbertTokenizerFast', + 'SbertModel', + ], + 'veco': [ + 'VecoConfig', + 'VecoForMaskedLM', + 'VecoForSequenceClassification', + 'VecoForTokenClassification', + 'VecoModel', + 'VecoTokenizer', + 'VecoTokenizerFast', + ], + 'bert': [ + 'BertForMaskedLM', + 'BertForTextRanking', + 'BertForSentenceEmbedding', + 'BertForSequenceClassification', + 'BertForTokenClassification', + 'BertForDocumentSegmentation', + 'BertModel', + 'BertConfig', + ], + 'ponet': ['PoNetForMaskedLM', 'PoNetModel', 'PoNetConfig'], + 'palm_v2': ['PalmForTextGeneration'], + 'deberta_v2': ['DebertaV2ForMaskedLM', 'DebertaV2Model'], + 'space_T_en': ['StarForTextToSql'], + 'space_T_cn': ['TableQuestionAnswering'], + 'space': + ['SpaceForDialogIntent', 'SpaceForDialogModeling', 'SpaceForDST'], + 'task_models': [ + 'FeatureExtractionModel', + 'InformationExtractionModel', + 'LSTMCRFForNamedEntityRecognition', + 'LSTMCRFForWordSegmentation', + 'SequenceClassificationModel', + 'SingleBackboneTaskModelBase', + 'TaskModelForTextGeneration', + 'TokenClassificationModel', + 'TransformerCRFForNamedEntityRecognition', + 'TransformerCRFForWordSegmentation', + ], + 'sentence_embedding': ['SentenceEmbedding'], + 'T5': ['T5ForConditionalGeneration'], + 'mglm': ['MGLMForTextSummarization'], + 'gpt_neo': ['GPTNeoModel'], + 'bloom': ['BloomModel'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/bart/__init__.py b/modelscope/models/nlp/bart/__init__.py new file mode 100644 index 00000000..31912efc --- /dev/null +++ b/modelscope/models/nlp/bart/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .text_error_correction import BartForTextErrorCorrection diff --git a/modelscope/models/nlp/bart/text_error_correction.py b/modelscope/models/nlp/bart/text_error_correction.py new file mode 100644 index 00000000..27abedb5 --- /dev/null +++ b/modelscope/models/nlp/bart/text_error_correction.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import torch.cuda + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['BartForTextErrorCorrection'] + + +@MODELS.register_module(Tasks.text_error_correction, module_name=Models.bart) +class BartForTextErrorCorrection(TorchModel): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir=model_dir, *args, **kwargs) + """initialize the text error correction model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + ckpt_name = ModelFile.TORCH_MODEL_FILE + local_model = osp.join(model_dir, ckpt_name) + + bart_vocab_dir = model_dir + # turn on cuda if GPU is available + from fairseq import checkpoint_utils, utils + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.use_fp16 = kwargs[ + 'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\ + else False + + overrides = { + 'data': bart_vocab_dir, + 'beam': 2, + } + models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + utils.split_paths(local_model), arg_overrides=overrides) + # Move models to GPU + for model in models: + model.eval() + model.to(self._device) + if self.use_fp16: + model.half() + model.prepare_for_inference_(cfg) + self.models = models + # Initialize generator + self.generator = task.build_generator(models, 'translation') + + self.task = task + + def forward(self, input: Dict[str, Dict]) -> Dict[str, Any]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + Example: + 1 sent: + {'net_input': + {'src_tokens':tensor([2478,242,24,4]), + 'src_lengths': tensor([4])} + } + + + Returns: + Dict[str, Tensor]: results + Example: + { + 'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer + } + """ + import fairseq.utils + + if len(input['net_input']['src_tokens'].size()) == 1: + input['net_input']['src_tokens'] = input['net_input'][ + 'src_tokens'].view(1, -1) + + if torch.cuda.is_available(): + input = fairseq.utils.move_to_cuda(input, device=self._device) + + sample = input + + translations = self.task.inference_step(self.generator, self.models, + sample) + + # get 1-best List[Tensor] + preds = translations[0][0]['tokens'] + return {'predictions': preds} diff --git a/modelscope/models/nlp/bert/__init__.py b/modelscope/models/nlp/bert/__init__.py new file mode 100644 index 00000000..28a10f57 --- /dev/null +++ b/modelscope/models/nlp/bert/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .backbone import ( + BertLayer, + BertModel, + BertPreTrainedModel, + ) + from .configuration import BertConfig + from .fill_mask import BertForMaskedLM + from .text_ranking import BertForTextRanking + from .sentence_embedding import BertForSentenceEmbedding + from .text_classification import BertForSequenceClassification + from .token_classification import BertForTokenClassification + from .document_segmentation import BertForDocumentSegmentation +else: + _import_structure = { + 'backbone': [ + 'BertModel', + 'BertPreTrainedModel', + ], + 'configuration': ['BertConfig'], + 'fill_mask': ['BertForMaskedLM'], + 'text_ranking': ['BertForTextRanking'], + 'sentence_embedding': ['BertForSentenceEmbedding'], + 'text_classification': ['BertForSequenceClassification'], + 'token_classification': ['BertForTokenClassification'], + 'document_segmentation': ['BertForDocumentSegmentation'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/bert/backbone.py b/modelscope/models/nlp/bert/backbone.py new file mode 100755 index 00000000..df0aebd2 --- /dev/null +++ b/modelscope/models/nlp/bert/backbone.py @@ -0,0 +1,952 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model. """ + +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import (BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions) +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from modelscope.utils.logger import get_logger +from .configuration import BertConfig + +logger = get_logger(__name__) + +_CONFIG_FOR_DOC = 'BertConfig' + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model + # variable name and be able to load any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and + # exported when serialized + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse('1.6.0'): + self.register_buffer( + 'token_type_ids', + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, + past_key_values_length:seq_length + + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor + # where it is all zeros, which usually occurs when its auto-generated, + # registered buffer helps users when tracing the model without passing + # token_type_ids, solves issue #5664 + if token_type_ids is None: + if hasattr(self, 'token_type_ids'): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, + dtype=torch.long, + device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' + f'heads ({config.num_attention_heads})') + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, 'position_embedding_type', 'absolute') + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all + # cross attention key/value_states. Further calls to cross_attention + # layer can then reuse all cross-attention key/value_states (first + # "if" case) if uni-directional self-attention (decoder) save + # Tuple(torch.Tensor, torch.Tensor) of all previous decoder + # key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected + # key/value_states (third "elif" case) if encoder bi-directional + # self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + if self.is_decoder: + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = BertSelfAttention( + config, position_embedding_type=position_embedding_type) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError( + f'{self} should be used as a decoder model if cross attention is added' + ) + self.crossattention = BertAttention( + config, position_embedding_type='absolute') + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[ + 1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, 'crossattention'): + raise ValueError( + f'If `encoder_hidden_states` are passed, {self} has to be instantiated ' + f'with cross-attention layers by setting `config.add_cross_attention=True`' + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface + for downloading and loading pretrained models. + """ + + config_class = BertConfig + base_model_prefix = 'bert' + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels not supplied. + If num_labels is not found, the model will use the default setting (2 classes). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.get('model_dir', None) + if model_dir is None: + config = BertConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) + id2label = kwargs.get( + 'id2label', None if label2id is None else + {id: label + for label, id in label2id.items()}) + if id2label is not None and label2id is None: + label2id = {label: id for id, label in id2label.items()} + + num_labels = kwargs.get( + 'num_labels', None if label2id is None else len(label2id)) + if num_labels is not None: + model_kwargs['num_labels'] = num_labels + if label2id is not None: + model_kwargs['label2id'] = label2id + if id2label is not None: + model_kwargs['id2label'] = id2label + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + model.model_dir = model_dir + return model + + +@MODELS.register_module(group_key=Tasks.backbone, module_name=Models.bert) +class BertModel(BertPreTrainedModel): + """The Bert Model transformer outputting raw hidden-states without any + specific head on top. + + This model inherits from [`PreTrainedModel`]. Check the superclass + documentation for the generic methods the library implements for all its + model (such as downloading or saving, resizing the input embeddings, pruning + heads etc.) + + This model is also a PyTorch + [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch + documentation for all matter related to general usage and behavior. + + Parameters: + config ([`BertConfig`]): Model configuration class with all the + parameters of the model. + Initializing with a config file does not load the weights associated + with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + The model can behave as an encoder (with only self-attention) as well as a + decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam + Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. To be used in a + Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` + is then expected as an input to the forward pass. + + + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def _instantiate(cls, model_dir=None, add_pooling_layer=True, **config): + config = BertConfig(**config) + model = cls(config, add_pooling_layer) + return model + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs): + r""" + Args: + input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`BertTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `((batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask + values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `((batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the + inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position + embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, + num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask + values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length, hidden_size)`, + *optional*): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if you want + more control over how to convert `input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a + plain tuple. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the + encoder. Used in the cross-attention if the model is configured as a + decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, + sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of + the encoder input. This mask is used in the cross-attention if the + model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length + `config.n_layers` with each tuple having 4 tensors of shape + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention + blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only + the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead + of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned + and can be used to speed up decoding (see `past_key_values`). + Others (**kwargs) + some additional parameters might passed in from upstream pipeline, + which not influence the results. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, 'token_type_ids'): + buffered_token_type_ids = self.embeddings.token_type_ids[:, : + seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def extract_sequence_outputs(self, outputs): + return outputs['last_hidden_state'] + + def extract_pooled_outputs(self, outputs): + return outputs['pooler_output'] diff --git a/modelscope/models/nlp/bert/configuration.py b/modelscope/models/nlp/bert/configuration.py new file mode 100644 index 00000000..1e2cef95 --- /dev/null +++ b/modelscope/models/nlp/bert/configuration.py @@ -0,0 +1,163 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BERT model configuration """ +from collections import OrderedDict +from typing import Mapping + +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfig + +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + + +class BertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`BertModel`] or a [`TFBertModel`]. It is used to instantiate a BERT model + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the BERT + [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different + tokens that can be represented by the `inputs_ids` passed when + calling [`BertModel`] or [`TFBertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the + Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) + layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, `"gelu"`, `"relu"`, `"silu"` and + `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the + embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or + 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling + [`BertModel`] or [`TFBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, + `"relative_key"`, `"relative_key_query"`. For positional embeddings + use `"absolute"`. For more information on `"relative_key"`, please + refer to [Self-Attention with Relative Position Representations + (Shaw et al.)](https://arxiv.org/abs/1803.02155). For more + information on `"relative_key_query"`, please refer to *Method 4* in + [Improve Transformer Models with Better Relative Position Embeddings + (Huang et al.)](https://arxiv.org/abs/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). Only relevant if + `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python >>> from transformers import BertModel, BertConfig + + >>> # Initializing a BERT bert-base-uncased style configuration + >>> configuration = BertConfig() + + >>> # Initializing a model from the bert-base-uncased style configuration + >>> model = BertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = 'bert' + + def __init__(self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type='absolute', + use_cache=True, + classifier_dropout=None, + **kwargs): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class BertOnnxConfig(OnnxConfig): + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([ + ('input_ids', { + 0: 'batch', + 1: 'sequence' + }), + ('attention_mask', { + 0: 'batch', + 1: 'sequence' + }), + ('token_type_ids', { + 0: 'batch', + 1: 'sequence' + }), + ]) diff --git a/modelscope/models/nlp/bert/document_segmentation.py b/modelscope/models/nlp/bert/document_segmentation.py new file mode 100644 index 00000000..b46c77e4 --- /dev/null +++ b/modelscope/models/nlp/bert/document_segmentation.py @@ -0,0 +1,109 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import TokenClassifierOutput +from transformers.models.bert.modeling_bert import (BertModel, + BertPreTrainedModel) + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + +__all__ = ['BertForDocumentSegmentation'] + + +@MODELS.register_module( + Tasks.document_segmentation, module_name=Models.bert_for_ds) +class BertForDocumentSegmentation(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + + def build_with_config(self, config): + self.bert_model = BertForDocumentSegmentationBase.from_pretrained( + self.model_dir, from_tf=False, config=config) + return self.bert_model + + def forward(self, input: Dict[str, Dict]) -> Dict[str, Any]: + pass + + +class BertForDocumentSegmentationBase(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.sentence_pooler_type = None + self.bert = BertModel(config, add_pooling_layer=False) + + classifier_dropout = config.hidden_dropout_prob + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.class_weights = None + self.init_weights() + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + sentence_attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None): + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + if self.sentence_pooler_type is not None: + raise NotImplementedError + else: + sequence_output = self.dropout(sequence_output) + + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(weight=self.class_weights) + if sentence_attention_mask is not None: + active_loss = sentence_attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), + torch.tensor(loss_fct.ignore_index).type_as(labels)) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/modelscope/models/nlp/bert/fill_mask.py b/modelscope/models/nlp/bert/fill_mask.py new file mode 100644 index 00000000..4f81f62d --- /dev/null +++ b/modelscope/models/nlp/bert/fill_mask.py @@ -0,0 +1,299 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import BertModel, BertPreTrainedModel +from .configuration import BertConfig + +logger = logging.get_logger(__name__) + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.bert) +class BertForMaskedLM(BertPreTrainedModel): + r"""Bert Model with a `language modeling` head on top. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of Structbert, the preprocessor of this model + is `modelscope.preprocessors.NLPPreprocessor`. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config: BertConfig, **kwargs): + super().__init__(config) + + if config.is_decoder: + logger.warning( + 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for ' + 'bi-directional self-attention.') + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, + *optional*): + Labels for computing the masked language modeling loss. Indices + should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` + docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., + config.vocab_size]` + + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_bert_backbone_base_std') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_bert_backbone_base_std') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return AttentionFillMaskModelOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + input_ids=input_ids, + ) + + def prepare_inputs_for_generation(self, + input_ids, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError('The PAD token should be defined for generation') + + padding_mask = attention_mask.new_zeros((attention_mask.shape[0], 1)) + attention_mask = torch.cat([attention_mask, padding_mask], dim=-1) + dummy_token = torch.full((effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {'input_ids': input_ids, 'attention_mask': attention_mask} diff --git a/modelscope/models/nlp/bert/sentence_embedding.py b/modelscope/models/nlp/bert/sentence_embedding.py new file mode 100644 index 00000000..f4c2620e --- /dev/null +++ b/modelscope/models/nlp/bert/sentence_embedding.py @@ -0,0 +1,113 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.metainfo import Models +from modelscope.models import Model +from modelscope.models.builder import MODELS +from modelscope.outputs import BackboneModelOutput +from modelscope.utils.constant import Tasks +from .backbone import BertModel, BertPreTrainedModel + + +@MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) +class BertForSentenceEmbedding(BertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.config = config + setattr(self, self.base_model_prefix, + BertModel(config, add_pooling_layer=False)) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ) -> BackboneModelOutput: + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + Returns: + Returns `modelscope.outputs.AttentionTextClassificationModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base') + >>> print(model(**preprocessor('This is a test'))) + """ + return self.base_model.forward( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + model_dir = kwargs.get('model_dir') + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + model.model_dir = model_dir + return model diff --git a/modelscope/models/nlp/bert/text_classification.py b/modelscope/models/nlp/bert/text_classification.py new file mode 100644 index 00000000..b1d18d0f --- /dev/null +++ b/modelscope/models/nlp/bert/text_classification.py @@ -0,0 +1,208 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTextClassificationModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import BertModel, BertPreTrainedModel + +logger = logging.get_logger(__name__) + + +@MODELS.register_module(Tasks.text_classification, module_name=Models.bert) +@MODELS.register_module(Tasks.nli, module_name=Models.bert) +@MODELS.register_module( + Tasks.sentiment_classification, module_name=Models.bert) +@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.bert) +@MODELS.register_module( + Tasks.zero_shot_classification, module_name=Models.bert) +class BertForSequenceClassification(BertPreTrainedModel): + r"""Bert Model transformer with a sequence classification/regression head on top + (a linear layer on top of the pooled output) e.g. for GLUE tasks. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of Bert, the preprocessor of this model + is `modelscope.preprocessors.SequenceClassificationPreprocessor`. + + Trainer: + This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, + NlpEpochBasedTrainer, or trainers from other frameworks. + The preferred trainer in ModelScope is NlpEpochBasedTrainer. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + setattr(self, self.base_model_prefix, BertModel(config)) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None + else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + Returns `modelscope.outputs.AttentionTextClassificationModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.base_model.forward( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long + or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else output + + return AttentionTextClassificationModelOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/modelscope/models/nlp/bert/text_ranking.py b/modelscope/models/nlp/bert/text_ranking.py new file mode 100644 index 00000000..b5ac8d7e --- /dev/null +++ b/modelscope/models/nlp/bert/text_ranking.py @@ -0,0 +1,93 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.utils.checkpoint + +from modelscope.metainfo import Models +from modelscope.models import Model +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTextClassificationModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import BertModel +from .text_classification import BertForSequenceClassification + +logger = logging.get_logger(__name__) + + +@MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) +class BertForTextRanking(BertForSequenceClassification): + + def __init__(self, config, *args, **kwargs): + super().__init__(config) + neg_sample = kwargs.get('neg_sample', 8) + self.neg_sample = neg_sample + setattr(self, self.base_model_prefix, + BertModel(self.config, add_pooling_layer=True)) + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + *args, + **kwargs) -> AttentionTextClassificationModelOutput: + outputs = self.base_model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + # backbone model should return pooled_output as its second output + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + if self.base_model.training: + scores = logits.view(-1, self.neg_sample + 1) + batch_size = scores.size(0) + loss_fct = torch.nn.CrossEntropyLoss() + target_label = torch.zeros( + batch_size, dtype=torch.long, device=scores.device) + loss = loss_fct(scores, target_label) + return AttentionTextClassificationModelOutput( + loss=loss, + logits=logits, + ) + return AttentionTextClassificationModelOutput(logits=logits, ) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels not supplied. + If num_labels is not found, the model will use the default setting (1 classes). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + num_labels = kwargs.get('num_labels', 1) + neg_sample = kwargs.get('neg_sample', 4) + model_args = {} if num_labels is None else {'num_labels': num_labels} + if neg_sample is not None: + model_args['neg_sample'] = neg_sample + + model_dir = kwargs.get('model_dir') + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_args) + model.model_dir = model_dir + return model diff --git a/modelscope/models/nlp/bert/token_classification.py b/modelscope/models/nlp/bert/token_classification.py new file mode 100644 index 00000000..5dc6b0ce --- /dev/null +++ b/modelscope/models/nlp/bert/token_classification.py @@ -0,0 +1,225 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import TokenClassifierOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import BertModel, BertPreTrainedModel + +logger = logging.get_logger(__name__) + + +@MODELS.register_module(Tasks.token_classification, module_name=Models.bert) +@MODELS.register_module(Tasks.part_of_speech, module_name=Models.bert) +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert) +class BertForTokenClassification(BertPreTrainedModel): + r"""Bert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks, word-segmentation. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of Bert, the preprocessor of this model + is `modelscope.preprocessors.SequenceClassificationPreprocessor`. + + Trainer: + This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, + NlpEpochBasedTrainer, or trainers from other frameworks. + The preferred trainer in ModelScope is NlpEpochBasedTrainer. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + _keys_to_ignore_on_load_unexpected = [r'pooler'] + + def __init__(self, config, **kwargs): + super().__init__(config) + self.num_labels = config.num_labels + + setattr(self, self.base_model_prefix, + BertModel(config, add_pooling_layer=False)) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None + else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + offset_mapping=None, + label_mask=None, + ): + r""" + Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, + sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using + :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and + :meth:`transformers.PreTrainedTokenizer.__call__` for details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask + values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the + inputs. Indices are selected in ``[0, 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position + embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or + :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask + values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to + directly pass an embedded representation. This is useful if you want + more control over how to convert :obj:`input_ids` indices into + associated vectors than the model's internal embedding lookup + matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention + layers. See ``attentions`` under returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See + ``hidden_states`` under returned tensors for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` + instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, + `optional`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If + :obj:`config.num_labels == 1` a regression loss is computed + (Mean-Square loss), If :obj:`config.num_labels > 1` a classification + loss is computed (Cross-Entropy). + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + label_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask + values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + Returns `modelscope.outputs.TokenClassifierOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_bert_word-segmentation_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_bert_word-segmentation_chinese-base') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), + torch.tensor(loss_fct.ignore_index).type_as(labels)) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, ) + outputs[2:] + return ((loss, ) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + offset_mapping=offset_mapping, + ) diff --git a/modelscope/models/nlp/bloom/__init__.py b/modelscope/models/nlp/bloom/__init__.py new file mode 100644 index 00000000..ad93252f --- /dev/null +++ b/modelscope/models/nlp/bloom/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .backbone import BloomModel +else: + _import_structure = { + 'backbone': ['BloomModel'], + } + import sys + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/bloom/backbone.py b/modelscope/models/nlp/bloom/backbone.py new file mode 100644 index 00000000..f8ea7b2f --- /dev/null +++ b/modelscope/models/nlp/bloom/backbone.py @@ -0,0 +1,15 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from transformers import BloomConfig +from transformers import BloomModel as BloomModelTransform + +from modelscope.metainfo import Models +from modelscope.models.builder import BACKBONES +from modelscope.utils.constant import Tasks + + +@BACKBONES.register_module(group_key=Tasks.backbone, module_name=Models.bloom) +class BloomModel(BloomModelTransform): + + def __init__(self, **kwargs): + config = BloomConfig(**kwargs) + super().__init__(config) diff --git a/modelscope/models/nlp/csanmt/__init__.py b/modelscope/models/nlp/csanmt/__init__.py new file mode 100644 index 00000000..85531617 --- /dev/null +++ b/modelscope/models/nlp/csanmt/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .translation import CsanmtForTranslation diff --git a/modelscope/models/nlp/csanmt/translation.py b/modelscope/models/nlp/csanmt/translation.py new file mode 100644 index 00000000..4bac8e6d --- /dev/null +++ b/modelscope/models/nlp/csanmt/translation.py @@ -0,0 +1,1495 @@ +# Part of the implementation is borrowed and modified from THUMT, +# publicly available at https://github.com/THUNLP-MT/THUMT +# Copyright 2017-2022 The Alibaba MT Team Authors. All rights reserved. +import math +from collections import namedtuple +from typing import Dict + +import tensorflow as tf + +from modelscope.metainfo import Models +from modelscope.models.base import Model, Tensor +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + +__all__ = ['CsanmtForTranslation'] + + +@MODELS.register_module(Tasks.translation, module_name=Models.translation) +class CsanmtForTranslation(Model): + + def __init__(self, model_dir, *args, **kwargs): + """ + Args: + params (dict): the model configuration. + """ + super().__init__(model_dir, *args, **kwargs) + self.params = kwargs + + def __call__(self, + input: Dict[str, Tensor], + label: Dict[str, Tensor] = None) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input: the preprocessed data + + Returns: + output_seqs: output sequence of target ids + """ + if label is None: + with tf.compat.v1.variable_scope('NmtModel'): + output_seqs, output_scores = self.beam_search( + input, self.params) + return { + 'output_seqs': output_seqs, + 'output_scores': output_scores, + } + else: + train_op, loss = self.transformer_model_train_fn(input, label) + return { + 'train_op': train_op, + 'loss': loss, + } + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Run the forward pass for a model. + + Args: + input (Dict[str, Tensor]): the dict of the model inputs for the forward method + + Returns: + Dict[str, Tensor]: output from the model forward pass + """ + ... + + def encoding_graph(self, features, params): + src_vocab_size = params['src_vocab_size'] + hidden_size = params['hidden_size'] + + initializer = tf.compat.v1.random_normal_initializer( + 0.0, hidden_size**-0.5, dtype=tf.float32) + + if params['shared_source_target_embedding']: + with tf.compat.v1.variable_scope( + 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE): + src_embedding = tf.compat.v1.get_variable( + 'Weights', [src_vocab_size, hidden_size], + initializer=initializer) + else: + with tf.compat.v1.variable_scope('Source_Embedding'): + src_embedding = tf.compat.v1.get_variable( + 'Weights', [src_vocab_size, hidden_size], + initializer=initializer) + src_bias = tf.compat.v1.get_variable('encoder_input_bias', + [hidden_size]) + + eos_padding = tf.zeros([tf.shape(input=features)[0], 1], tf.int64) + src_seq = tf.concat([features, eos_padding], 1) + src_mask = tf.cast(tf.not_equal(src_seq, 0), dtype=tf.float32) + shift_src_mask = src_mask[:, :-1] + shift_src_mask = tf.pad( + tensor=shift_src_mask, + paddings=[[0, 0], [1, 0]], + constant_values=1) + + encoder_input = tf.gather(src_embedding, tf.cast(src_seq, tf.int32)) + encoder_input = encoder_input * (hidden_size**0.5) + if params['position_info_type'] == 'absolute': + encoder_input = add_timing_signal(encoder_input) + encoder_input = tf.multiply(encoder_input, + tf.expand_dims(shift_src_mask, 2)) + + encoder_input = tf.nn.bias_add(encoder_input, src_bias) + encoder_self_attention_bias = attention_bias(shift_src_mask, 'masking') + + if params['residual_dropout'] > 0.0: + encoder_input = tf.nn.dropout( + encoder_input, rate=params['residual_dropout']) + + # encode + encoder_output = transformer_encoder(encoder_input, + encoder_self_attention_bias, + shift_src_mask, params) + return encoder_output, encoder_self_attention_bias + + def semantic_encoding_graph(self, features, params, name=None): + hidden_size = params['hidden_size'] + initializer = tf.compat.v1.random_normal_initializer( + 0.0, hidden_size**-0.5, dtype=tf.float32) + scope = None + if params['shared_source_target_embedding']: + vocab_size = params['src_vocab_size'] + scope = 'Shared_Semantic_Embedding' + elif name == 'source': + vocab_size = params['src_vocab_size'] + scope = 'Source_Semantic_Embedding' + elif name == 'target': + vocab_size = params['trg_vocab_size'] + scope = 'Target_Semantic_Embedding' + else: + raise ValueError('error: no right name specified.') + + with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): + embedding_mat = tf.compat.v1.get_variable( + 'Weights', [vocab_size, hidden_size], initializer=initializer) + + eos_padding = tf.zeros([tf.shape(input=features)[0], 1], tf.int64) + input_seq = tf.concat([features, eos_padding], 1) + input_mask = tf.cast(tf.not_equal(input_seq, 0), dtype=tf.float32) + shift_input_mask = input_mask[:, :-1] + shift_input_mask = tf.pad( + tensor=shift_input_mask, + paddings=[[0, 0], [1, 0]], + constant_values=1) + + encoder_input = tf.gather(embedding_mat, tf.cast(input_seq, tf.int32)) + encoder_input = encoder_input * (hidden_size**0.5) + encoder_input = tf.multiply(encoder_input, + tf.expand_dims(shift_input_mask, 2)) + + encoder_self_attention_bias = attention_bias(shift_input_mask, + 'masking') + + if params['residual_dropout'] > 0.0: + encoder_input = tf.nn.dropout( + encoder_input, rate=params['residual_dropout']) + + # encode + encoder_output = transformer_semantic_encoder( + encoder_input, encoder_self_attention_bias, shift_input_mask, + params) + return encoder_output + + def build_contrastive_training_graph(self, features, labels, params): + # representations + source_name = 'source' + target_name = 'target' + if params['shared_source_target_embedding']: + source_name = None + target_name = None + feature_output = self.semantic_encoding_graph( + features, params, name=source_name) + label_output = self.semantic_encoding_graph( + labels, params, name=target_name) + + return feature_output, label_output + + def MGMC_sampling(self, x_embedding, y_embedding, params, epsilon=1e-12): + K = params['num_of_samples'] + eta = params['eta'] + assert K % 2 == 0 + + def get_samples(x_vector, y_vector): + bias_vector = y_vector - x_vector + w_r = tf.math.divide( + tf.abs(bias_vector) - tf.reduce_min( + input_tensor=tf.abs(bias_vector), axis=2, keepdims=True) + + epsilon, + tf.reduce_max( + input_tensor=tf.abs(bias_vector), axis=2, keepdims=True) + - tf.reduce_min( + input_tensor=tf.abs(bias_vector), axis=2, keepdims=True) + + 2 * epsilon) + + R = [] + for i in range(K // 2): + omega = eta * tf.random.normal(tf.shape(input=bias_vector), 0.0, w_r) + \ + (1.0 - eta) * tf.random.normal(tf.shape(input=bias_vector), 0.0, 1.0) + sample = x_vector + omega * bias_vector + R.append(sample) + return R + + ALL_SAMPLES = [] + ALL_SAMPLES = get_samples(x_embedding, y_embedding) + ALL_SAMPLES.extend(get_samples(y_embedding, x_embedding)) + + assert len(ALL_SAMPLES) == K + + return tf.concat(ALL_SAMPLES, axis=0) + + def decoding_graph(self, + encoder_output, + encoder_self_attention_bias, + labels, + params={}, + embedding_augmentation=None): + trg_vocab_size = params['trg_vocab_size'] + hidden_size = params['hidden_size'] + + initializer = tf.compat.v1.random_normal_initializer( + 0.0, hidden_size**-0.5, dtype=tf.float32) + + if params['shared_source_target_embedding']: + with tf.compat.v1.variable_scope( + 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE): + trg_embedding = tf.compat.v1.get_variable( + 'Weights', [trg_vocab_size, hidden_size], + initializer=initializer) + else: + with tf.compat.v1.variable_scope('Target_Embedding'): + trg_embedding = tf.compat.v1.get_variable( + 'Weights', [trg_vocab_size, hidden_size], + initializer=initializer) + + eos_padding = tf.zeros([tf.shape(input=labels)[0], 1], tf.int64) + trg_seq = tf.concat([labels, eos_padding], 1) + trg_mask = tf.cast(tf.not_equal(trg_seq, 0), dtype=tf.float32) + shift_trg_mask = trg_mask[:, :-1] + shift_trg_mask = tf.pad( + tensor=shift_trg_mask, + paddings=[[0, 0], [1, 0]], + constant_values=1) + + decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32)) + + decoder_input *= hidden_size**0.5 + decoder_self_attention_bias = attention_bias( + tf.shape(input=decoder_input)[1], 'causal') + decoder_input = tf.pad( + tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :] + if params['position_info_type'] == 'absolute': + decoder_input = add_timing_signal(decoder_input) + + decoder_input = tf.nn.dropout( + decoder_input, rate=1 - (1.0 - params['residual_dropout'])) + + # training + decoder_output, attention_weights = transformer_decoder( + decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_self_attention_bias, + states_key=None, + states_val=None, + embedding_augmentation=embedding_augmentation, + params=params) + + logits = self.prediction(decoder_output, params) + + on_value = params['confidence'] + off_value = (1.0 - params['confidence']) / tf.cast( + trg_vocab_size - 1, dtype=tf.float32) + soft_targets = tf.one_hot( + tf.cast(trg_seq, tf.int32), + depth=trg_vocab_size, + on_value=on_value, + off_value=off_value) + mask = tf.cast(shift_trg_mask, logits.dtype) + xentropy = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=tf.stop_gradient(soft_targets)) * mask + loss = tf.reduce_sum(input_tensor=xentropy) / tf.reduce_sum( + input_tensor=mask) + + return loss + + def build_training_graph(self, + features, + labels, + params, + feature_embedding=None, + label_embedding=None): + # encode + encoder_output, encoder_self_attention_bias = self.encoding_graph( + features, params) + embedding_augmentation = None + if feature_embedding is not None and label_embedding is not None: + embedding_augmentation = self.MGMC_sampling( + feature_embedding, label_embedding, params) + + encoder_output = tf.tile(encoder_output, + [params['num_of_samples'], 1, 1]) + encoder_self_attention_bias = tf.tile( + encoder_self_attention_bias, + [params['num_of_samples'], 1, 1, 1]) + labels = tf.tile(labels, [params['num_of_samples'], 1]) + + # decode + loss = self.decoding_graph( + encoder_output, + encoder_self_attention_bias, + labels, + params, + embedding_augmentation=embedding_augmentation) + + return loss + + def transformer_model_train_fn(self, features, labels): + initializer = get_initializer(self.params) + with tf.compat.v1.variable_scope('NmtModel', initializer=initializer): + num_gpus = self.params['num_gpus'] + gradient_clip_norm = self.params['gradient_clip_norm'] + global_step = tf.compat.v1.train.get_global_step() + print(global_step) + + # learning rate + learning_rate = get_learning_rate_decay( + self.params['learning_rate'], global_step, self.params) + learning_rate = tf.convert_to_tensor( + value=learning_rate, dtype=tf.float32) + + # optimizer + if self.params['optimizer'] == 'sgd': + optimizer = tf.compat.v1.train.GradientDescentOptimizer( + learning_rate) + elif self.params['optimizer'] == 'adam': + optimizer = tf.compat.v1.train.AdamOptimizer( + learning_rate=learning_rate, + beta1=self.params['adam_beta1'], + beta2=self.params['adam_beta2'], + epsilon=self.params['adam_epsilon']) + else: + tf.compat.v1.logging.info('optimizer not supported') + sys.exit() + opt = MultiStepOptimizer(optimizer, self.params['update_cycle']) + + def fill_gpus(inputs, num_gpus): + outputs = inputs + for i in range(num_gpus): + outputs = tf.concat([outputs, inputs], axis=0) + outputs = outputs[:num_gpus, ] + return outputs + + features = tf.cond( + pred=tf.shape(input=features)[0] < num_gpus, + true_fn=lambda: fill_gpus(features, num_gpus), + false_fn=lambda: features) + labels = tf.cond( + pred=tf.shape(input=labels)[0] < num_gpus, + true_fn=lambda: fill_gpus(labels, num_gpus), + false_fn=lambda: labels) + + if num_gpus > 0: + feature_shards = shard_features(features, num_gpus) + label_shards = shard_features(labels, num_gpus) + else: + feature_shards = [features] + label_shards = [labels] + + if num_gpus > 0: + devices = ['gpu:%d' % d for d in range(num_gpus)] + else: + devices = ['cpu:0'] + multi_grads = [] + sharded_losses = [] + + for i, device in enumerate(devices): + with tf.device(device), tf.compat.v1.variable_scope( + tf.compat.v1.get_variable_scope(), + reuse=True if i > 0 else None): + with tf.name_scope('%s_%d' % ('GPU', i)): + feature_output, label_output = self.build_contrastive_training_graph( + feature_shards[i], label_shards[i], self.params) + mle_loss = self.build_training_graph( + feature_shards[i], label_shards[i], self.params, + feature_output, label_output) + sharded_losses.append(mle_loss) + tf.compat.v1.summary.scalar('mle_loss_{}'.format(i), + mle_loss) + + # Optimization + trainable_vars_list = [ + v for v in tf.compat.v1.trainable_variables() + if 'Shared_Semantic_Embedding' not in v.name + and 'mini_xlm_encoder' not in v.name + ] + grads_and_vars = opt.compute_gradients( + mle_loss, + var_list=trainable_vars_list, + colocate_gradients_with_ops=True) + multi_grads.append(grads_and_vars) + + total_loss = tf.add_n(sharded_losses) / len(sharded_losses) + + # Average gradients + grads_and_vars = average_gradients(multi_grads) + + if gradient_clip_norm > 0.0: + grads, var_list = list(zip(*grads_and_vars)) + grads, _ = tf.clip_by_global_norm(grads, gradient_clip_norm) + grads_and_vars = zip(grads, var_list) + + train_op = opt.apply_gradients( + grads_and_vars, + global_step=tf.compat.v1.train.get_global_step()) + + return train_op, total_loss + + def prediction(self, decoder_output, params): + hidden_size = params['hidden_size'] + trg_vocab_size = params['trg_vocab_size'] + + if params['shared_embedding_and_softmax_weights']: + embedding_scope = 'Shared_Embedding' if params[ + 'shared_source_target_embedding'] else 'Target_Embedding' + with tf.compat.v1.variable_scope(embedding_scope, reuse=True): + weights = tf.compat.v1.get_variable('Weights') + else: + weights = tf.compat.v1.get_variable('Softmax', + [tgt_vocab_size, hidden_size]) + shape = tf.shape(input=decoder_output)[:-1] + decoder_output = tf.reshape(decoder_output, [-1, hidden_size]) + logits = tf.matmul(decoder_output, weights, transpose_b=True) + logits = tf.reshape(logits, tf.concat([shape, [trg_vocab_size]], 0)) + return logits + + def inference_func(self, + encoder_output, + feature_output, + encoder_self_attention_bias, + trg_seq, + states_key, + states_val, + params={}): + trg_vocab_size = params['trg_vocab_size'] + hidden_size = params['hidden_size'] + + initializer = tf.compat.v1.random_normal_initializer( + 0.0, hidden_size**-0.5, dtype=tf.float32) + + if params['shared_source_target_embedding']: + with tf.compat.v1.variable_scope( + 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE): + trg_embedding = tf.compat.v1.get_variable( + 'Weights', [trg_vocab_size, hidden_size], + initializer=initializer) + else: + with tf.compat.v1.variable_scope('Target_Embedding'): + trg_embedding = tf.compat.v1.get_variable( + 'Weights', [trg_vocab_size, hidden_size], + initializer=initializer) + + decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32)) + decoder_input *= hidden_size**0.5 + decoder_self_attention_bias = attention_bias( + tf.shape(input=decoder_input)[1], 'causal') + decoder_input = tf.pad( + tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :] + if params['position_info_type'] == 'absolute': + decoder_input = add_timing_signal(decoder_input) + + decoder_input = decoder_input[:, -1:, :] + decoder_self_attention_bias = decoder_self_attention_bias[:, :, -1:, :] + decoder_output, attention_weights = transformer_decoder( + decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_self_attention_bias, + states_key=states_key, + states_val=states_val, + embedding_augmentation=feature_output, + params=params) + decoder_output_last = decoder_output[:, -1, :] + attention_weights_last = attention_weights[:, -1, :] + + if params['shared_embedding_and_softmax_weights']: + embedding_scope = \ + 'Shared_Embedding' if params['shared_source_target_embedding'] else 'Target_Embedding' + with tf.compat.v1.variable_scope(embedding_scope, reuse=True): + weights = tf.compat.v1.get_variable('Weights') + else: + weights = tf.compat.v1.get_variable('Softmax', + [trg_vocab_size, hidden_size]) + logits = tf.matmul(decoder_output_last, weights, transpose_b=True) + log_prob = tf.nn.log_softmax(logits) + return log_prob, attention_weights_last, states_key, states_val + + def beam_search(self, features, params): + beam_size = params['beam_size'] + trg_vocab_size = params['trg_vocab_size'] + hidden_size = params['hidden_size'] + num_decoder_layers = params['num_decoder_layers'] + lp_rate = params['lp_rate'] + max_decoded_trg_len = params['max_decoded_trg_len'] + batch_size = tf.shape(input=features)[0] + + features = tile_to_beam_size(features, beam_size) + features = merge_first_two_dims(features) + + encoder_output, encoder_self_attention_bias = self.encoding_graph( + features, params) + source_name = 'source' + if params['shared_source_target_embedding']: + source_name = None + feature_output = self.semantic_encoding_graph( + features, params, name=source_name) + + init_seqs = tf.fill([batch_size, beam_size, 1], 0) + init_log_probs = \ + tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)]) + init_log_probs = tf.tile(init_log_probs, [batch_size, 1]) + init_scores = tf.zeros_like(init_log_probs) + fin_seqs = tf.zeros([batch_size, beam_size, 1], tf.int32) + fin_scores = tf.fill([batch_size, beam_size], tf.float32.min) + fin_flags = tf.zeros([batch_size, beam_size], tf.bool) + + states_key = [ + tf.zeros([batch_size, 0, hidden_size]) + for layer in range(num_decoder_layers) + ] + states_val = [ + tf.zeros([batch_size, 0, hidden_size]) + for layer in range(num_decoder_layers) + ] + for layer in range(num_decoder_layers): + states_key[layer].set_shape( + tf.TensorShape([None, None, hidden_size])) + states_val[layer].set_shape( + tf.TensorShape([None, None, hidden_size])) + states_key = [ + tile_to_beam_size(states_key[layer], beam_size) + for layer in range(num_decoder_layers) + ] + states_val = [ + tile_to_beam_size(states_val[layer], beam_size) + for layer in range(num_decoder_layers) + ] + + state = BeamSearchState( + inputs=(init_seqs, init_log_probs, init_scores), + state=(states_key, states_val), + finish=(fin_flags, fin_seqs, fin_scores), + ) + + def _beam_search_step(time, state): + seqs, log_probs = state.inputs[:2] + states_key, states_val = state.state + + flat_seqs = merge_first_two_dims(seqs) + flat_states_key = [ + merge_first_two_dims(states_key[layer]) + for layer in range(num_decoder_layers) + ] + flat_states_val = [ + merge_first_two_dims(states_val[layer]) + for layer in range(num_decoder_layers) + ] + + step_log_probs, step_attn_weights, step_states_key, step_states_val = self.inference_func( + encoder_output, + feature_output, + encoder_self_attention_bias, + flat_seqs, + flat_states_key, + flat_states_val, + params=params) + + step_log_probs = split_first_two_dims(step_log_probs, batch_size, + beam_size) + curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs + + next_states_key = [ + split_first_two_dims(step_states_key[layer], batch_size, + beam_size) + for layer in range(num_decoder_layers) + ] + next_states_val = [ + split_first_two_dims(step_states_val[layer], batch_size, + beam_size) + for layer in range(num_decoder_layers) + ] + + # Apply length penalty + length_penalty = tf.pow( + (5.0 + tf.cast(time + 1, dtype=tf.float32)) / 6.0, lp_rate) + curr_scores = curr_log_probs / length_penalty + + # Select top-k candidates + # [batch_size, beam_size * vocab_size] + curr_scores = tf.reshape(curr_scores, + [-1, beam_size * trg_vocab_size]) + # [batch_size, 2 * beam_size] + top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size) + # Shape: [batch_size, 2 * beam_size] + beam_indices = top_indices // trg_vocab_size + symbol_indices = top_indices % trg_vocab_size + # Expand sequences + # [batch_size, 2 * beam_size, time] + candidate_seqs = gather_2d(seqs, beam_indices) + candidate_seqs = tf.concat( + [candidate_seqs[:, :, :-1], + tf.expand_dims(symbol_indices, 2)], + axis=2) + pad_seqs = tf.fill([batch_size, 2 * beam_size, 1], + tf.constant(0, tf.int32)) + candidate_seqs = tf.concat([candidate_seqs, pad_seqs], axis=2) + + # Expand sequences + # Suppress finished sequences + flags = tf.equal(symbol_indices, 0) + # [batch, 2 * beam_size] + alive_scores = top_scores + tf.cast( + flags, dtype=tf.float32) * tf.float32.min + # [batch, beam_size] + alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size) + alive_symbols = gather_2d(symbol_indices, alive_indices) + alive_indices = gather_2d(beam_indices, alive_indices) + alive_seqs = gather_2d(seqs, alive_indices) + alive_seqs = tf.concat( + [alive_seqs[:, :, :-1], + tf.expand_dims(alive_symbols, 2)], + axis=2) + pad_seqs = tf.fill([batch_size, beam_size, 1], + tf.constant(0, tf.int32)) + alive_seqs = tf.concat([alive_seqs, pad_seqs], axis=2) + alive_states_key = [ + gather_2d(next_states_key[layer], alive_indices) + for layer in range(num_decoder_layers) + ] + alive_states_val = [ + gather_2d(next_states_val[layer], alive_indices) + for layer in range(num_decoder_layers) + ] + alive_log_probs = alive_scores * length_penalty + + # Select finished sequences + prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish + # [batch, 2 * beam_size] + step_fin_scores = top_scores + ( + 1.0 - tf.cast(flags, dtype=tf.float32)) * tf.float32.min + # [batch, 3 * beam_size] + fin_flags = tf.concat([prev_fin_flags, flags], axis=1) + fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1) + # [batch, beam_size] + fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size) + fin_flags = gather_2d(fin_flags, fin_indices) + pad_seqs = tf.fill([batch_size, beam_size, 1], + tf.constant(0, tf.int32)) + prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2) + fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1) + fin_seqs = gather_2d(fin_seqs, fin_indices) + + new_state = BeamSearchState( + inputs=(alive_seqs, alive_log_probs, alive_scores), + state=(alive_states_key, alive_states_val), + finish=(fin_flags, fin_seqs, fin_scores), + ) + + return time + 1, new_state + + def _is_finished(t, s): + log_probs = s.inputs[1] + finished_flags = s.finish[0] + finished_scores = s.finish[2] + max_lp = tf.pow( + ((5.0 + tf.cast(max_decoded_trg_len, dtype=tf.float32)) / 6.0), + lp_rate) + best_alive_score = log_probs[:, 0] / max_lp + worst_finished_score = tf.reduce_min( + input_tensor=finished_scores + * tf.cast(finished_flags, dtype=tf.float32), + axis=1) + add_mask = 1.0 - tf.cast( + tf.reduce_any(input_tensor=finished_flags, axis=1), + dtype=tf.float32) + worst_finished_score += tf.float32.min * add_mask + bound_is_met = tf.reduce_all( + input_tensor=tf.greater(worst_finished_score, + best_alive_score)) + + cond = tf.logical_and( + tf.less(t, max_decoded_trg_len), tf.logical_not(bound_is_met)) + + return cond + + def _loop_fn(t, s): + outs = _beam_search_step(t, s) + return outs + + time = tf.constant(0, name='time') + shape_invariants = BeamSearchState( + inputs=(tf.TensorShape([None, None, None]), + tf.TensorShape([None, None]), tf.TensorShape([None, + None])), + state=([ + tf.TensorShape([None, None, None, hidden_size]) + for layer in range(num_decoder_layers) + ], [ + tf.TensorShape([None, None, None, hidden_size]) + for layer in range(num_decoder_layers) + ]), + finish=(tf.TensorShape([None, + None]), tf.TensorShape([None, None, None]), + tf.TensorShape([None, None]))) + outputs = tf.while_loop( + cond=_is_finished, + body=_loop_fn, + loop_vars=[time, state], + shape_invariants=[tf.TensorShape([]), shape_invariants], + parallel_iterations=1, + back_prop=False) + + final_state = outputs[1] + alive_seqs = final_state.inputs[0] + alive_scores = final_state.inputs[2] + final_flags = final_state.finish[0] + final_seqs = final_state.finish[1] + final_scores = final_state.finish[2] + + alive_seqs.set_shape([None, beam_size, None]) + final_seqs.set_shape([None, beam_size, None]) + + final_seqs = tf.compat.v1.where( + tf.reduce_any(input_tensor=final_flags, axis=1), final_seqs, + alive_seqs) + final_scores = tf.compat.v1.where( + tf.reduce_any(input_tensor=final_flags, axis=1), final_scores, + alive_scores) + + final_seqs = final_seqs[:, :, :-1] + return final_seqs, final_scores + + +class BeamSearchState( + namedtuple('BeamSearchState', ('inputs', 'state', 'finish'))): + pass + + +def tile_to_beam_size(tensor, beam_size): + """Tiles a given tensor by beam_size. """ + tensor = tf.expand_dims(tensor, axis=1) + tile_dims = [1] * tensor.shape.ndims + tile_dims[1] = beam_size + + return tf.tile(tensor, tile_dims) + + +def infer_shape(x): + x = tf.convert_to_tensor(x) + + if x.shape.dims is None: + return tf.shape(x) + + static_shape = x.shape.as_list() + dynamic_shape = tf.shape(x) + + ret = [] + for i in range(len(static_shape)): + dim = static_shape[i] + if dim is None: + dim = dynamic_shape[i] + ret.append(dim) + + return ret + + +def split_first_two_dims(tensor, dim_0, dim_1): + shape = infer_shape(tensor) + new_shape = [dim_0] + [dim_1] + shape[1:] + return tf.reshape(tensor, new_shape) + + +def merge_first_two_dims(tensor): + shape = infer_shape(tensor) + shape[0] *= shape[1] + shape.pop(1) + return tf.reshape(tensor, shape) + + +def gather_2d(params, indices, name=None): + """ Gather the 2nd dimension given indices + :param params: A tensor with shape [batch_size, M, ...] + :param indices: A tensor with shape [batch_size, N] + :param name: An optional string + :return: A tensor with shape [batch_size, N, ...] + """ + batch_size = tf.shape(params)[0] + range_size = tf.shape(indices)[1] + batch_pos = tf.range(batch_size * range_size) // range_size + batch_pos = tf.reshape(batch_pos, [batch_size, range_size]) + indices = tf.stack([batch_pos, indices], axis=-1) + output = tf.gather_nd(params, indices, name=name) + + return output + + +def linear(inputs, output_size, bias, concat=True, dtype=None, scope=None): + with tf.compat.v1.variable_scope( + scope, default_name='linear', values=[inputs], dtype=dtype): + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + input_size = [item.get_shape()[-1] for item in inputs] + + if len(inputs) != len(input_size): + raise RuntimeError('inputs and input_size unmatched!') + + output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]], + axis=0) + # Flatten to 2D + inputs = [tf.reshape(inp, [-1, inp.shape[-1]]) for inp in inputs] + + results = [] + if concat: + input_size = sum(input_size) + inputs = tf.concat(inputs, 1) + + shape = [input_size, output_size] + matrix = tf.compat.v1.get_variable('matrix', shape) + results.append(tf.matmul(inputs, matrix)) + else: + for i in range(len(input_size)): + shape = [input_size[i], output_size] + name = 'matrix_%d' % i + matrix = tf.compat.v1.get_variable(name, shape) + results.append(tf.matmul(inputs[i], matrix)) + + output = tf.add_n(results) + + if bias: + shape = [output_size] + bias = tf.compat.v1.get_variable('bias', shape) + output = tf.nn.bias_add(output, bias) + + output = tf.reshape(output, output_shape) + + return output + + +def layer_norm(inputs, epsilon=1e-6, name=None, reuse=None): + with tf.compat.v1.variable_scope( + name, default_name='layer_norm', values=[inputs], reuse=reuse): + channel_size = inputs.get_shape().as_list()[-1] + + scale = tf.compat.v1.get_variable( + 'layer_norm_scale', [channel_size], + initializer=tf.ones_initializer()) + + offset = tf.compat.v1.get_variable( + 'layer_norm_offset', [channel_size], + initializer=tf.zeros_initializer()) + + mean = tf.reduce_mean(inputs, -1, True) + variance = tf.reduce_mean(tf.square(inputs - mean), -1, True) + + norm_inputs = (inputs - mean) * tf.compat.v1.rsqrt(variance + epsilon) + + return norm_inputs * scale + offset + + +def _layer_process(x, mode): + if not mode or mode == 'none': + return x + elif mode == 'layer_norm': + return layer_norm(x) + else: + raise ValueError('Unknown mode %s' % mode) + + +def _residual_fn(x, y, keep_prob=None): + if keep_prob and keep_prob < 1.0: + y = tf.nn.dropout(y, rate=1 - (keep_prob)) + return x + y + + +def embedding_augmentation_layer(x, embedding_augmentation, params, name=None): + hidden_size = params['hidden_size'] + keep_prob = 1.0 - params['relu_dropout'] + with tf.compat.v1.variable_scope( + name, + default_name='embedding_augmentation_layer', + values=[x, embedding_augmentation]): + with tf.compat.v1.variable_scope('input_layer'): + hidden = linear(embedding_augmentation, hidden_size, True, True) + hidden = tf.nn.relu(hidden) + + if keep_prob and keep_prob < 1.0: + hidden = tf.nn.dropout(hidden, rate=1 - (keep_prob)) + + with tf.compat.v1.variable_scope('output_layer'): + output = linear(hidden, hidden_size, True, True) + + return x + output + + +def transformer_ffn_layer(x, params, name=None): + filter_size = params['filter_size'] + hidden_size = params['hidden_size'] + keep_prob = 1.0 - params['relu_dropout'] + with tf.compat.v1.variable_scope( + name, default_name='ffn_layer', values=[x]): + with tf.compat.v1.variable_scope('input_layer'): + hidden = linear(x, filter_size, True, True) + hidden = tf.nn.relu(hidden) + + if keep_prob and keep_prob < 1.0: + hidden = tf.nn.dropout(hidden, rate=1 - (keep_prob)) + + with tf.compat.v1.variable_scope('output_layer'): + output = linear(hidden, hidden_size, True, True) + + return output + + +def transformer_encoder(encoder_input, + encoder_self_attention_bias, + mask, + params={}, + name='encoder'): + num_encoder_layers = params['num_encoder_layers'] + hidden_size = params['hidden_size'] + num_heads = params['num_heads'] + residual_dropout = params['residual_dropout'] + attention_dropout = params['attention_dropout'] + layer_preproc = params['layer_preproc'] + layer_postproc = params['layer_postproc'] + x = encoder_input + mask = tf.expand_dims(mask, 2) + with tf.compat.v1.variable_scope(name): + for layer in range(num_encoder_layers): + with tf.compat.v1.variable_scope('layer_%d' % layer): + max_relative_dis = params['max_relative_dis'] \ + if params['position_info_type'] == 'relative' else None + o, w = multihead_attention( + _layer_process(x, layer_preproc), + None, + encoder_self_attention_bias, + hidden_size, + hidden_size, + hidden_size, + num_heads, + attention_dropout, + max_relative_dis=max_relative_dis, + name='encoder_self_attention') + x = _residual_fn(x, o, 1.0 - residual_dropout) + x = _layer_process(x, layer_postproc) + + o = transformer_ffn_layer( + _layer_process(x, layer_preproc), params) + x = _residual_fn(x, o, 1.0 - residual_dropout) + x = _layer_process(x, layer_postproc) + + x = tf.multiply(x, mask) + + return _layer_process(x, layer_preproc) + + +def transformer_semantic_encoder(encoder_input, + encoder_self_attention_bias, + mask, + params={}, + name='mini_xlm_encoder'): + num_encoder_layers = params['num_semantic_encoder_layers'] + hidden_size = params['hidden_size'] + num_heads = params['num_heads'] + residual_dropout = params['residual_dropout'] + attention_dropout = params['attention_dropout'] + layer_preproc = params['layer_preproc'] + layer_postproc = params['layer_postproc'] + x = encoder_input + mask = tf.expand_dims(mask, 2) + with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): + for layer in range(num_encoder_layers): + with tf.compat.v1.variable_scope('layer_%d' % layer): + max_relative_dis = params['max_relative_dis'] + o, w = multihead_attention( + _layer_process(x, layer_preproc), + None, + encoder_self_attention_bias, + hidden_size, + hidden_size, + hidden_size, + num_heads, + attention_dropout, + max_relative_dis=max_relative_dis, + name='encoder_self_attention') + x = _residual_fn(x, o, 1.0 - residual_dropout) + x = _layer_process(x, layer_postproc) + + o = transformer_ffn_layer( + _layer_process(x, layer_preproc), params) + x = _residual_fn(x, o, 1.0 - residual_dropout) + x = _layer_process(x, layer_postproc) + + x = tf.multiply(x, mask) + + with tf.compat.v1.variable_scope( + 'pooling_layer', reuse=tf.compat.v1.AUTO_REUSE): + output = tf.reduce_sum( + input_tensor=x, axis=1) / tf.reduce_sum( + input_tensor=mask, axis=1) + output = linear( + tf.expand_dims(output, axis=1), hidden_size, True, True) + + return _layer_process(output, layer_preproc) + + +def transformer_decoder(decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + states_key=None, + states_val=None, + embedding_augmentation=None, + params={}, + name='decoder'): + num_decoder_layers = params['num_decoder_layers'] + hidden_size = params['hidden_size'] + num_heads = params['num_heads'] + residual_dropout = params['residual_dropout'] + attention_dropout = params['attention_dropout'] + layer_preproc = params['layer_preproc'] + layer_postproc = params['layer_postproc'] + x = decoder_input + with tf.compat.v1.variable_scope(name): + for layer in range(num_decoder_layers): + with tf.compat.v1.variable_scope('layer_%d' % layer): + max_relative_dis = params['max_relative_dis'] \ + if params['position_info_type'] == 'relative' else None + # continuous semantic augmentation + if embedding_augmentation is not None: + x = embedding_augmentation_layer( + x, _layer_process(embedding_augmentation, + layer_preproc), params) + x = _layer_process(x, layer_postproc) + o, w = multihead_attention( + _layer_process(x, layer_preproc), + None, + decoder_self_attention_bias, + hidden_size, + hidden_size, + hidden_size, + num_heads, + attention_dropout, + states_key=states_key, + states_val=states_val, + layer=layer, + max_relative_dis=max_relative_dis, + name='decoder_self_attention') + x = _residual_fn(x, o, 1.0 - residual_dropout) + x = _layer_process(x, layer_postproc) + + o, w = multihead_attention( + _layer_process(x, layer_preproc), + encoder_output, + encoder_decoder_attention_bias, + hidden_size, + hidden_size, + hidden_size, + num_heads, + attention_dropout, + max_relative_dis=max_relative_dis, + name='encdec_attention') + x = _residual_fn(x, o, 1.0 - residual_dropout) + x = _layer_process(x, layer_postproc) + + o = transformer_ffn_layer( + _layer_process(x, layer_preproc), params) + x = _residual_fn(x, o, 1.0 - residual_dropout) + x = _layer_process(x, layer_postproc) + + return _layer_process(x, layer_preproc), w + + +def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4): + length = tf.shape(x)[1] + channels = tf.shape(x)[2] + position = tf.cast(tf.range(length), tf.float32) + num_timescales = channels // 2 + + log_timescale_increment = \ + (math.log(float(max_timescale) / float(min_timescale)) / (tf.cast(num_timescales, tf.float32) - 1)) + inv_timescales = min_timescale * tf.exp( + tf.cast(tf.range(num_timescales), tf.float32) + * -log_timescale_increment) + + scaled_time = \ + tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) + signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) + signal = tf.pad(signal, [[0, 0], [0, tf.compat.v1.mod(channels, 2)]]) + signal = tf.reshape(signal, [1, length, channels]) + + return x + tf.cast(signal, x.dtype) + + +def attention_bias(inputs, mode, inf=-1e9, dtype=None): + if dtype is None: + dtype = tf.float32 + + if dtype != tf.float32: + inf = dtype.min + + if mode == 'masking': + mask = inputs + ret = (1.0 - mask) * inf + ret = tf.expand_dims(tf.expand_dims(ret, 1), 1) + + elif mode == 'causal': + length = inputs + lower_triangle = tf.linalg.band_part(tf.ones([length, length]), -1, 0) + ret = inf * (1.0 - lower_triangle) + ret = tf.reshape(ret, [1, 1, length, length]) + else: + raise ValueError('Unknown mode %s' % mode) + + return tf.cast(ret, dtype) + + +def split_heads(x, num_heads): + n = num_heads + old_shape = x.get_shape().dims + ndims = x.shape.ndims + + last = old_shape[-1] + new_shape = old_shape[:-1] + [n] + [last // n if last else None] + ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0)) + ret.set_shape(new_shape) + perm = [0, ndims - 1] + [i for i in range(1, ndims - 1)] + [ndims] + return tf.transpose(ret, perm) + + +def dot_product_attention(q, + k, + v, + bias, + dropout_rate=0.0, + name=None, + rpr=None): + with tf.compat.v1.variable_scope( + name, default_name='dot_product_attention', values=[q, k, v]): + q_shape = tf.shape(q) + bs, hd, lq, dk = q_shape[0], q_shape[1], q_shape[2], q_shape[3] + lk = tf.shape(k)[2] + dv = tf.shape(v)[3] + + if rpr is not None: + rpr_k, rpr_v = rpr['rpr_k'], rpr[ + 'rpr_v'] # (lq, lk, dk), (lq, lk, dv) + + if rpr is None: + logits = tf.matmul(q, k, transpose_b=True) + else: # self-attention with relative position representaion + logits_part1 = tf.matmul(q, k, transpose_b=True) # bs, hd, lq, lk + + q = tf.reshape(tf.transpose(q, [2, 0, 1, 3]), + [lq, bs * hd, dk]) # lq, bs*hd, dk + logits_part2 = tf.matmul(q, + tf.transpose(rpr_k, + [0, 2, 1])) # lq, bs*hd, lk + logits_part2 = tf.reshape( + tf.transpose(logits_part2, [1, 0, 2]), [bs, hd, lq, lk]) + + logits = logits_part1 + logits_part2 # bs, hd, lq, lk + + if bias is not None: + logits += bias + + weights = tf.nn.softmax(logits, name='attention_weights') + + if dropout_rate > 0.0: + weights = tf.nn.dropout(weights, 1.0 - dropout_rate) + + if rpr is None: + return tf.matmul(weights, v), weights + else: + outputs_part1 = tf.matmul(weights, v) # bs, hd, lq, dv + + weights = tf.reshape( + tf.transpose(weights, [2, 0, 1, 3]), + [lq, bs * hd, lk]) # lq, bs*hd, lk + outputs_part2 = tf.matmul(weights, rpr_v) # lq, bs*hd, dv + outputs_part2 = tf.reshape( + tf.transpose(outputs_part2, [1, 0, 2]), [bs, hd, lq, dv]) + + outputs = outputs_part1 + outputs_part2 # bs, hd, lq, dv + weights = tf.reshape( + tf.transpose(weights, [1, 0, 2]), + [bs, hd, lq, lk]) # bs, hd, lq, lk + + return outputs, weights + + +def combine_heads(x): + x = tf.transpose(x, [0, 2, 1, 3]) + old_shape = x.get_shape().dims + a, b = old_shape[-2:] + new_shape = old_shape[:-2] + [a * b if a and b else None] + x = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0)) + x.set_shape(new_shape) + + return x + + +def create_rpr(orginal_var, + length_q, + length_kv, + max_relative_dis, + name='create_rpr'): + with tf.name_scope(name): + idxs = tf.reshape(tf.range(length_kv), [-1, 1]) # only self-attention + idys = tf.reshape(tf.range(length_kv), [1, -1]) + ids = idxs - idys + ids = ids + max_relative_dis + ids = tf.maximum(ids, 0) + ids = tf.minimum(ids, 2 * max_relative_dis) + ids = ids[-length_q:, :] + rpr = tf.gather(orginal_var, ids) + return rpr + + +def multihead_attention(queries, + memories, + bias, + key_depth, + value_depth, + output_depth, + num_heads, + dropout_rate, + states_key=None, + states_val=None, + layer=0, + max_relative_dis=None, + name=None): + if key_depth % num_heads != 0: + raise ValueError( + 'Key size (%d) must be divisible by the number of attention heads (%d).' + % (key_size, num_heads)) + + if value_depth % num_heads != 0: + raise ValueError( + 'Value size (%d) must be divisible by the number of attention heads (%d).' + % (value_size, num_heads)) + + with tf.compat.v1.variable_scope( + name, default_name='multihead_attention', + values=[queries, memories]): + if memories is None: + # self attention + combined = linear( + queries, + key_depth * 2 + value_depth, + True, + True, + scope='qkv_transform') + q, k, v = tf.split( + combined, [key_depth, key_depth, value_depth], axis=2) + else: + q = linear(queries, key_depth, True, True, scope='q_transform') + combined = linear( + memories, + key_depth + value_depth, + True, + True, + scope='kv_transform') + k, v = tf.split(combined, [key_depth, value_depth], axis=2) + + if states_key is not None: + k = states_key[layer] = tf.concat([states_key[layer], k], axis=1) + if states_val is not None: + v = states_val[layer] = tf.concat([states_val[layer], v], axis=1) + + q = split_heads(q, num_heads) + k = split_heads(k, num_heads) + v = split_heads(v, num_heads) + + key_depth_per_head = key_depth // num_heads + q *= key_depth_per_head**-0.5 + + length_q = tf.shape(q)[2] + length_kv = tf.shape(k)[2] + + # relative position representation (only in self-attention) + if memories is None and max_relative_dis is not None: + rpr_k = tf.compat.v1.get_variable( + 'rpr_k', [2 * max_relative_dis + 1, key_depth // num_heads]) + rpr_v = tf.compat.v1.get_variable( + 'rpr_v', [2 * max_relative_dis + 1, value_depth // num_heads]) + rpr_k = create_rpr(rpr_k, length_q, length_kv, max_relative_dis) + rpr_v = create_rpr(rpr_v, length_q, length_kv, max_relative_dis) + rpr = {'rpr_k': rpr_k, 'rpr_v': rpr_v} + x, w = dot_product_attention(q, k, v, bias, dropout_rate, rpr=rpr) + else: + x, w = dot_product_attention(q, k, v, bias, dropout_rate) + x = combine_heads(x) + w = tf.reduce_mean(w, 1) + x = linear(x, output_depth, True, True, scope='output_transform') + return x, w + + +def get_initializer(params): + if params['initializer'] == 'uniform': + max_val = params['initializer_scale'] + return tf.compat.v1.random_uniform_initializer(-max_val, max_val) + elif params['initializer'] == 'normal': + return tf.compat.v1.random_normal_initializer( + 0.0, params['initializer_scale']) + elif params['initializer'] == 'normal_unit_scaling': + return tf.compat.v1.variance_scaling_initializer( + params['initializer_scale'], mode='fan_avg', distribution='normal') + elif params['initializer'] == 'uniform_unit_scaling': + return tf.compat.v1.variance_scaling_initializer( + params['initializer_scale'], + mode='fan_avg', + distribution='uniform') + else: + raise ValueError('Unrecognized initializer: %s' + % params['initializer']) + + +def get_learning_rate_decay(learning_rate, global_step, params): + if params['learning_rate_decay'] in ['linear_warmup_rsqrt_decay', 'noam']: + step = tf.cast(global_step, dtype=tf.float32) + warmup_steps = tf.cast(params['warmup_steps'], dtype=tf.float32) + multiplier = params['hidden_size']**-0.5 + decay = multiplier * tf.minimum((step + 1) * (warmup_steps**-1.5), + (step + 1)**-0.5) + return learning_rate * decay + elif params['learning_rate_decay'] == 'piecewise_constant': + return tf.compat.v1.train.piecewise_constant( + tf.cast(global_step, dtype=tf.int32), + params['learning_rate_boundaries'], params['learning_rate_values']) + elif params['learning_rate_decay'] == 'none': + return learning_rate + else: + raise ValueError('Unknown learning_rate_decay') + + +def average_gradients(tower_grads): + average_grads = [] + for grad_and_vars in zip(*tower_grads): + grads = [] + for g, _ in grad_and_vars: + expanded_g = tf.expand_dims(g, 0) + grads.append(expanded_g) + grad = tf.concat(axis=0, values=grads) + grad = tf.reduce_mean(grad, 0) + v = grad_and_vars[0][1] + grad_and_var = (grad, v) + average_grads.append(grad_and_var) + return average_grads + + +_ENGINE = None + + +def all_reduce(tensor): + if _ENGINE is None: + return tensor + + return _ENGINE.allreduce(tensor, compression=_ENGINE.Compression.fp16) + + +class MultiStepOptimizer(tf.compat.v1.train.Optimizer): + + def __init__(self, + optimizer, + step=1, + use_locking=False, + name='MultiStepOptimizer'): + super(MultiStepOptimizer, self).__init__(use_locking, name) + self._optimizer = optimizer + self._step = step + self._step_t = tf.convert_to_tensor(step, name='step') + + def _all_reduce(self, tensor): + with tf.name_scope(self._name + '_Allreduce'): + if tensor is None: + return tensor + + if isinstance(tensor, tf.IndexedSlices): + tensor = tf.convert_to_tensor(tensor) + + return all_reduce(tensor) + + def compute_gradients(self, + loss, + var_list=None, + gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None): + grads_and_vars = self._optimizer.compute_gradients( + loss, var_list, gate_gradients, aggregation_method, + colocate_gradients_with_ops, grad_loss) + + grads, var_list = list(zip(*grads_and_vars)) + + # Do not create extra variables when step is 1 + if self._step == 1: + grads = [self._all_reduce(t) for t in grads] + return list(zip(grads, var_list)) + + first_var = min(var_list, key=lambda x: x.name) + iter_var = self._create_non_slot_variable( + initial_value=0 if self._step == 1 else 1, + name='iter', + colocate_with=first_var) + + new_grads = [] + + for grad, var in zip(grads, var_list): + grad_acc = self._zeros_slot(var, 'grad_acc', self._name) + + if isinstance(grad, tf.IndexedSlices): + grad_acc = tf.scatter_add( + grad_acc, + grad.indices, + grad.values, + use_locking=self._use_locking) + else: + grad_acc = tf.assign_add( + grad_acc, grad, use_locking=self._use_locking) + + def _acc_grad(): + return grad_acc + + def _avg_grad(): + return self._all_reduce(grad_acc / self._step) + + grad = tf.cond(tf.equal(iter_var, 0), _avg_grad, _acc_grad) + new_grads.append(grad) + + return list(zip(new_grads, var_list)) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + if self._step == 1: + return self._optimizer.apply_gradients( + grads_and_vars, global_step, name=name) + + grads, var_list = list(zip(*grads_and_vars)) + + def _pass_gradients(): + return tf.group(*grads) + + def _apply_gradients(): + op = self._optimizer.apply_gradients( + zip(grads, var_list), global_step, name) + with tf.control_dependencies([op]): + zero_ops = [] + for var in var_list: + grad_acc = self.get_slot(var, 'grad_acc') + zero_ops.append( + grad_acc.assign( + tf.zeros_like(grad_acc), + use_locking=self._use_locking)) + zero_op = tf.group(*zero_ops) + return tf.group(*[op, zero_op]) + + iter_var = self._get_non_slot_variable('iter', tf.get_default_graph()) + update_op = tf.cond( + tf.equal(iter_var, 0), _apply_gradients, _pass_gradients) + + with tf.control_dependencies([update_op]): + iter_op = iter_var.assign( + tf.mod(iter_var + 1, self._step_t), + use_locking=self._use_locking) + + return tf.group(*[update_op, iter_op]) + + +def shard_features(x, num_datashards): + x = tf.convert_to_tensor(x) + batch_size = tf.shape(x)[0] + size_splits = [] + + with tf.device('/cpu:0'): + for i in range(num_datashards): + size_splits.append( + tf.cond( + tf.greater( + tf.compat.v1.mod(batch_size, num_datashards), + i), lambda: batch_size // num_datashards + 1, + lambda: batch_size // num_datashards)) + + return tf.split(x, size_splits, axis=0) diff --git a/modelscope/models/nlp/deberta_v2/__init__.py b/modelscope/models/nlp/deberta_v2/__init__.py new file mode 100644 index 00000000..08b184e5 --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/__init__.py @@ -0,0 +1,53 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import DebertaV2Config + from .tokenization import DebertaV2Tokenizer + from .tokenization_fast import DebertaV2TokenizerFast + from .backbone import ( + DebertaV2Model, + DebertaV2PreTrainedModel, + ) + from .fill_mask import DebertaV2ForMaskedLM + +else: + _import_structure = { + 'configuration': ['DebertaV2Config'], + 'tokenization': ['DebertaV2Tokenizer'], + 'tokenization_fast': ['DebertaV2TokenizerFast'], + 'backbone': [ + 'DebertaV2Model', + 'DebertaV2PreTrainedModel', + ], + 'fill_mask': [ + 'DebertaV2ForMaskedLM', + ] + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__) diff --git a/modelscope/models/nlp/deberta_v2/backbone.py b/modelscope/models/nlp/deberta_v2/backbone.py new file mode 100644 index 00000000..cca38133 --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/backbone.py @@ -0,0 +1,1224 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa-v2 model.""" + +from collections.abc import Sequence +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import LayerNorm +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import softmax_backward_data + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionBackboneModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .configuration import DebertaV2Config + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.deberta.modeling_deberta.ContextPooler +class ContextPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, + config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`torch.tensor`): The input tensor that will apply softmax. + mask (`torch.IntTensor`): + The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example: + + ```python + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4, 20, 100]) + + >>> # Create a mask + >>> mask = (x > 0).int() + + >>> # Specify the dimension to apply softmax + >>> dim = -1 + + >>> y = XSoftmax.apply(x, mask, dim) + ```""" + + @staticmethod + def forward(self, input, mask, dim): + self.dim = dim + rmask = ~(mask.to(torch.bool)) + + output = input.masked_fill(rmask, + torch.tensor(torch.finfo(input.dtype).min)) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + (output, ) = self.saved_tensors + inputGrad = softmax_backward_data(self, grad_output, output, self.dim, + output) + return inputGrad, None, None + + @staticmethod + def symbolic(g, self, mask, dim): + import torch.onnx.symbolic_helper as sym_help + from torch.onnx.symbolic_opset9 import masked_fill, softmax + + mask_cast_value = g.op( + 'Cast', mask, to_i=sym_help.cast_pytorch_to_onnx['Long']) + r_mask = g.op( + 'Cast', + g.op('Sub', + g.op('Constant', value_t=torch.tensor(1, dtype=torch.int64)), + mask_cast_value), + to_i=sym_help.cast_pytorch_to_onnx['Byte'], + ) + output = masked_fill( + g, self, r_mask, + g.op( + 'Constant', + value_t=torch.tensor(torch.finfo(self.type().dtype()).min))) + output = softmax(g, output, dim) + return masked_fill( + g, output, r_mask, + g.op('Constant', value_t=torch.tensor(0, dtype=torch.uint8))) + + +# Copied from transformers.models.deberta.modeling_deberta.DropoutContext +class DropoutContext(object): + + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +# Copied from transformers.models.deberta.modeling_deberta.get_mask +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to( + torch.bool) + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.deberta.modeling_deberta.XDropout +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask, ) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + @staticmethod + def symbolic(g: torch._C.Graph, input: torch._C.Value, + local_ctx: Union[float, DropoutContext]) -> torch._C.Value: + from torch.onnx import symbolic_opset12 + + dropout_p = local_ctx + if isinstance(local_ctx, DropoutContext): + dropout_p = local_ctx.dropout + # StableDropout only calls this function when training. + train = True + # TODO: We should check if the opset_version being used to export + # is > 12 here, but there's no good way to do that. As-is, if the + # opset_version < 12, export will fail with a CheckerError. + # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like: + # if opset_version < 12: + # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train) + return symbolic_opset12.dropout(g, input, dropout_p, train) + + +# Copied from transformers.models.deberta.modeling_deberta.StableDropout +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm +class DebertaV2SelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 +class DebertaV2Attention(nn.Module): + + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaV2SelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 +class DebertaV2Intermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm +class DebertaV2Output(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 +class DebertaV2Layer(nn.Module): + + def __init__(self, config): + super().__init__() + self.attention = DebertaV2Attention(config) + self.intermediate = DebertaV2Intermediate(config) + self.output = DebertaV2Output(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions=False, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if output_attentions: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, 'conv_kernel_size', 3) + groups = getattr(config, 'conv_groups', 1) + self.conv_act = getattr(config, 'conv_act', 'tanh') + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size, + padding=(kernel_size - 1) // 2, + groups=groups) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute( + 0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +class DebertaV2Encoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList( + [DebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, 'relative_attention', False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, + 'max_relative_positions', -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, 'position_buckets', -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + self.rel_embeddings = nn.Embedding(pos_ebd_size, + config.hidden_size) + + self.norm_rel_ebd = [ + x.strip() + for x in getattr(config, 'norm_rel_ebd', 'none').lower().split('|') + ] + + if 'layer_norm' in self.norm_rel_ebd: + self.LayerNorm = LayerNorm( + config.hidden_size, + config.layer_norm_eps, + elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, 'conv_kernel_size', + 0) > 0 else None + self.gradient_checkpointing = False + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + if rel_embeddings is not None and ('layer_norm' in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze( + -2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size( + -2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position( + q, + hidden_states.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, + relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states, ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + output_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + next_kv, + attention_mask, + query_states, + relative_pos, + rel_embeddings, + ) + else: + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, + input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len( + self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m, ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states, ) + + if not return_dict: + return tuple( + v for v in [output_states, all_hidden_states, all_attentions] + if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, + hidden_states=all_hidden_states, + attentions=all_attentions) + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where( + (relative_pos < mid) & (relative_pos > -mid), + torch.tensor(mid - 1).type_as(relative_pos), + torch.abs(relative_pos), + ) + log_pos = ( + torch.ceil( + torch.log(abs_pos / mid) + / torch.log(torch.tensor( + (max_position - 1) / mid)) * (mid - 1)) + mid) + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), + log_pos * sign) + return bucket_pos + + +def build_relative_position(query_size, + key_size, + bucket_size=-1, + max_position=-1): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = torch.arange(0, query_size) + k_ids = torch.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - k_ids[None, :] + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, + max_position) + rel_pos_ids = rel_pos_ids.to(torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([ + query_layer.size(0), + query_layer.size(1), + query_layer.size(2), + relative_pos.size(-1) + ]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([ + query_layer.size(0), + query_layer.size(1), + key_layer.size(-2), + key_layer.size(-2) + ]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' + f'heads ({config.num_attention_heads})') + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, 'attention_head_size', + _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear( + config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear( + config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear( + config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, 'share_att_key', False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, 'relative_attention', False) + + if self.relative_attention: + self.position_buckets = getattr(config, 'position_buckets', -1) + self.max_relative_positions = getattr(config, + 'max_relative_positions', -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if not self.share_att_key: + if 'c2p' in self.pos_att_type: + self.pos_key_proj = nn.Linear( + config.hidden_size, self.all_head_size, bias=True) + if 'p2c' in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, + self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), + x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.ByteTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, optional): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, optional): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores( + self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores( + self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores( + self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if 'c2p' in self.pos_att_type: + scale_factor += 1 + if 'p2c' in self.pos_att_type: + scale_factor += 1 + scale = torch.sqrt( + torch.tensor(query_layer.size(-1), dtype=torch.float) + * scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose( + -1, -2)) / torch.tensor( + scale, dtype=query_layer.dtype) + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias(query_layer, key_layer, + relative_pos, + rel_embeddings, + scale_factor) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view(-1, self.num_attention_heads, + attention_scores.size(-2), + attention_scores.size(-1)) + + # bsz x height x length x dimension + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), + attention_probs.size(-1)), value_layer) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, + context_layer.size(-2), + context_layer.size(-1)).permute(0, 2, 1, + 3).contiguous()) + new_context_layer_shape = context_layer.size()[:-2] + (-1, ) + context_layer = context_layer.view(new_context_layer_shape) + if output_attentions: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, + rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position( + q, + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError( + f'Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}' + ) + + att_span = self.pos_ebd_size + relative_pos = relative_pos.long().to(query_layer.device) + + rel_embeddings = rel_embeddings[0:att_span * 2, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores( + self.key_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1) + else: + if 'c2p' in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, + 1) # .split(self.all_head_size, dim=-1) + if 'p2c' in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, + 1) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if 'c2p' in self.pos_att_type: + scale = torch.sqrt( + torch.tensor(pos_key_layer.size(-1), dtype=torch.float) + * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([ + query_layer.size(0), + query_layer.size(1), + relative_pos.size(-1) + ]), + ) + score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype) + + # position->content + if 'p2c' in self.pos_att_type: + scale = torch.sqrt( + torch.tensor(pos_query_layer.size(-1), dtype=torch.float) + * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([ + query_layer.size(0), + key_layer.size(-2), + key_layer.size(-2) + ]), + ).transpose(-1, -2) + score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype) + + return score + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm +class DebertaV2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, 'pad_token_id', 0) + self.embedding_size = getattr(config, 'embedding_size', + config.hidden_size) + self.word_embeddings = nn.Embedding( + config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, 'position_biased_input', + True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear( + self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + mask=None, + inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 +class DebertaV2PreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaV2Config + base_model_prefix = 'deberta' + _keys_to_ignore_on_load_missing = ['position_ids'] + _keys_to_ignore_on_load_unexpected = ['position_embeddings'] + supports_gradient_checkpointing = True + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DebertaV2Encoder): + module.gradient_checkpointing = value + + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + ponet_config = DebertaV2Config(**kwargs) + model = cls(ponet_config) + else: + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + return model + + +@MODELS.register_module(Tasks.backbone, module_name=Models.deberta_v2) +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + """The bare DeBERTa_v2 Model transformer outputting raw hidden-states without any specific head on top. + + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config (`DebertaV2Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. + """ + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.embeddings = DebertaV2Embeddings(config) + self.encoder = DebertaV2Encoder(config) + self.z_steps = 0 + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError( + 'The prune function is not implemented in DeBERTa model.') + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, AttentionBackboneModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`): + Indices of input sequence tokens in the vocabulary. + + attention_mask (`torch.FloatTensor` of shape `('batch_size, sequence_length')`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + position_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + inputs_embeds (`torch.FloatTensor` of shape `('batch_size, sequence_length', hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a dataclass instead of a plain tuple. + + Returns: + Returns `modelscope.outputs.AttentionBackboneModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite', task='backbone') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite') + >>> print(model(**preprocessor('这是个测试'))) + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + output_attentions=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output, ) + encoder_outputs[ + (1 if output_hidden_states else 2):] + + return AttentionBackboneModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states + if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) diff --git a/modelscope/models/nlp/deberta_v2/configuration.py b/modelscope/models/nlp/deberta_v2/configuration.py new file mode 100644 index 00000000..7921ca2f --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/configuration.py @@ -0,0 +1,128 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020, Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DeBERTa-v2 model configuration, mainly copied from :class:`~transformers.DeBERTaV2Config""" + +from transformers import PretrainedConfig + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class DebertaV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DebertaV2Model`]. It is used to instantiate a + DeBERTa-v2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the DeBERTa + [microsoft/deberta-v2-xlarge](https://huggingface.co/microsoft/deberta-v2-xlarge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 128100): + Vocabulary size of the DeBERTa-v2 model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DebertaV2Model`]. + hidden_size (`int`, *optional*, defaults to 1536): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 24): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"`, `"gelu"`, `"tanh"`, `"gelu_fast"`, `"mish"`, `"linear"`, `"sigmoid"` and `"gelu_new"` + are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 0): + The vocabulary size of the `token_type_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-7): + The epsilon used by the layer normalization layers. + relative_attention (`bool`, *optional*, defaults to `True`): + Whether use relative position encoding. + max_relative_positions (`int`, *optional*, defaults to -1): + The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value + as `max_position_embeddings`. + pad_token_id (`int`, *optional*, defaults to 0): + The value used to pad input_ids. + position_biased_input (`bool`, *optional*, defaults to `False`): + Whether add absolute position embedding to content embedding. + pos_att_type (`List[str]`, *optional*): + The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`, + `["p2c", "c2p"]`, `["p2c", "c2p"]`. + layer_norm_eps (`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + """ + model_type = 'deberta_v2' + + def __init__(self, + vocab_size=128100, + hidden_size=1536, + num_hidden_layers=24, + num_attention_heads=24, + intermediate_size=6144, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=0, + initializer_range=0.02, + layer_norm_eps=1e-7, + relative_attention=False, + max_relative_positions=-1, + pad_token_id=0, + position_biased_input=True, + pos_att_type=None, + pooler_dropout=0, + pooler_hidden_act='gelu', + **kwargs): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.relative_attention = relative_attention + self.max_relative_positions = max_relative_positions + self.pad_token_id = pad_token_id + self.position_biased_input = position_biased_input + + # Backwards compatibility + if type(pos_att_type) == str: + pos_att_type = [x.strip() for x in pos_att_type.lower().split('|')] + + self.pos_att_type = pos_att_type + self.vocab_size = vocab_size + self.layer_norm_eps = layer_norm_eps + + self.pooler_hidden_size = kwargs.get('pooler_hidden_size', hidden_size) + self.pooler_dropout = pooler_dropout + self.pooler_hidden_act = pooler_hidden_act diff --git a/modelscope/models/nlp/deberta_v2/fill_mask.py b/modelscope/models/nlp/deberta_v2/fill_mask.py new file mode 100644 index 00000000..ed127d4c --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/fill_mask.py @@ -0,0 +1,230 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils.constant import Tasks +from .backbone import DebertaV2Model, DebertaV2PreTrainedModel + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 +@MODELS.register_module(Tasks.fill_mask, module_name=Models.deberta_v2) +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + r"""DeBERTa_v2 Model with a `language modeling` head on top. + + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Preprocessor: + This is the fill_mask model of Deberta_v2, the preprocessor of this model + is `modelscope.preprocessors.NLPPreprocessor`. + + Parameters: + config (`DebertaV2Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config, **kwargs): + super().__init__(config) + + self.deberta = DebertaV2Model(config) + self.cls = DebertaV2OnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, AttentionFillMaskModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`): + Indices of input sequence tokens in the vocabulary. + + attention_mask (`torch.FloatTensor` of shape `('batch_size, sequence_length')`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + position_ids (`torch.LongTensor` of shape `('batch_size, sequence_length')`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + inputs_embeds (`torch.FloatTensor` of shape `('batch_size, sequence_length', hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a dataclass instead of a plain tuple. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-lite') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。'))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。')) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if not return_dict: + output = (prediction_scores, ) + outputs[1:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return AttentionFillMaskModelOutput( + loss=masked_lm_loss, + logits=prediction_scores, + input_ids=input_ids, + attentions=outputs.attentions, + hidden_states=outputs.hidden_states) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaV2PredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaV2LMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = DebertaV2PredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaV2OnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = DebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores diff --git a/modelscope/models/nlp/deberta_v2/tokenization.py b/modelscope/models/nlp/deberta_v2/tokenization.py new file mode 100644 index 00000000..adb60288 --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/tokenization.py @@ -0,0 +1,546 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DeBERTa. mainly copied from :module:`~transformers.tokenization_deberta`""" + +import os +import unicodedata +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as sp +from transformers.tokenization_utils import PreTrainedTokenizer + +PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} + +PRETRAINED_INIT_CONFIGURATION = {} + +VOCAB_FILES_NAMES = {'vocab_file': 'spm.model'} + + +class DebertaV2Tokenizer(PreTrainedTokenizer): + r""" + Constructs a DeBERTa-v2 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece) + and [jieba](https://github.com/fxsjy/jieba). + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the input when tokenizing. + bos_token (`string`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + eos_token (`string`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. When building a sequence using special tokens, this is not the token that is + used for the end of sequence. The token used is the `sep_token`. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__(self, + vocab_file, + do_lower_case=False, + split_by_punct=False, + split_chinese=True, + bos_token='[CLS]', + eos_token='[SEP]', + unk_token='[UNK]', + sep_token='[SEP]', + pad_token='[PAD]', + cls_token='[CLS]', + mask_token='[MASK]', + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + super().__init__( + do_lower_case=do_lower_case, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + split_by_punct=split_by_punct, + split_chinese=split_chinese, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + ' model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + ) + self.do_lower_case = do_lower_case + self.split_by_punct = split_by_punct + self.split_chinese = split_chinese + self.vocab_file = vocab_file + self._tokenizer = SPMTokenizer( + vocab_file, + split_by_punct=split_by_punct, + sp_model_kwargs=self.sp_model_kwargs) + self.jieba = None + if self.split_chinese: + try: + import jieba + except ImportError: + raise ImportError( + 'You need to install jieba to split chinese and use DebertaV2Tokenizer. ' + 'See https://pypi.org/project/jieba/ for installation.') + self.jieba = jieba + + @property + def vocab_size(self): + return len(self.vocab) + + @property + def vocab(self): + return self._tokenizer.vocab + + def get_vocab(self): + vocab = self.vocab.copy() + vocab.update(self.get_added_vocab()) + return vocab + + def _tokenize(self, text: str) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + if self.do_lower_case: + text = text.lower() + if self.split_chinese: + seg_list = [x for x in self.jieba.cut(text)] + text = ' '.join(seg_list) + return self._tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self._tokenizer.spm.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self._tokenizer.spm.IdToPiece( + index) if index < self.vocab_size else self.unk_token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + return self._tokenizer.decode(tokens) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A DeBERTa sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask(self, + token_ids_0, + token_ids_1=None, + already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ( + [0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences(self, + token_ids_0, + token_ids_1=None): + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def prepare_for_tokenization(self, + text, + is_split_into_words=False, + **kwargs): + add_prefix_space = kwargs.pop('add_prefix_space', False) + if is_split_into_words or add_prefix_space: + text = ' ' + text + return (text, kwargs) + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + return self._tokenizer.save_pretrained( + save_directory, filename_prefix=filename_prefix) + + +class SPMTokenizer: + r""" + Constructs a tokenizer based on [SentencePiece](https://github.com/google/sentencepiece). + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + def __init__(self, + vocab_file, + split_by_punct=False, + sp_model_kwargs: Optional[Dict[str, Any]] = None): + self.split_by_punct = split_by_punct + self.vocab_file = vocab_file + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + spm = sp.SentencePieceProcessor(**self.sp_model_kwargs) + if not os.path.exists(vocab_file): + raise FileNotFoundError(f'{vocab_file} does not exist!') + spm.load(vocab_file) + bpe_vocab_size = spm.GetPieceSize() + # Token map + # 0+1 + # 1+1 + # 2+1 + self.vocab = {spm.IdToPiece(i): i for i in range(bpe_vocab_size)} + self.ids_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)] + # self.vocab['[PAD]'] = 0 + # self.vocab['[CLS]'] = 1 + # self.vocab['[SEP]'] = 2 + # self.vocab['[UNK]'] = 3 + + self.spm = spm + + def __getstate__(self): + state = self.__dict__.copy() + state['spm'] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, 'sp_model_kwargs'): + self.sp_model_kwargs = {} + + self.spm = sp.SentencePieceProcessor(**self.sp_model_kwargs) + self.spm.Load(self.vocab_file) + + def tokenize(self, text): + return self._encode_as_pieces(text) + + def convert_ids_to_tokens(self, ids): + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def decode(self, tokens, start=-1, end=-1, raw_text=None): + if raw_text is None: + return self.spm.decode_pieces([t for t in tokens]) + else: + words = self.split_to_words(raw_text) + word_tokens = [self.tokenize(w) for w in words] + token2words = [0] * len(tokens) + tid = 0 + for i, w in enumerate(word_tokens): + for k, t in enumerate(w): + token2words[tid] = i + tid += 1 + word_start = token2words[start] + word_end = token2words[end] if end < len(tokens) else len(words) + text = ''.join(words[word_start:word_end]) + return text + + def add_special_token(self, token): + if token not in self.special_tokens: + self.special_tokens.append(token) + if token not in self.vocab: + self.vocab[token] = len(self.vocab) - 1 + self.ids_to_tokens.append(token) + return self.id(token) + + def part_of_whole_word(self, token, is_bos=False): + if is_bos: + return True + if (len(token) == 1 and (_is_whitespace(list(token)[0]))): + return False + if _is_control(list(token)[0]): + return False + if _is_punctuation(list(token)[0]): + return False + if token in self.add_special_token: + return False + + word_start = b'\xe2\x96\x81'.decode('utf-8') + return not token.startswith(word_start) + + def pad(self): + return '[PAD]' + + def bos(self): + return '[CLS]' + + def eos(self): + return '[SEP]' + + def unk(self): + return '[UNK]' + + def mask(self): + return '[MASK]' + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + return self.vocab[sym] if sym in self.vocab else 1 + + def _encode_as_pieces(self, text): + text = convert_to_unicode(text) + if self.split_by_punct: + words = self._run_split_on_punc(text) + pieces = [self.spm.encode(w, out_type=str) for w in words] + return [p for w in pieces for p in w] + else: + return self.spm.encode(text, out_type=str) + + def split_to_words(self, text): + pieces = self._encode_as_pieces(text) + word_start = b'\xe2\x96\x81'.decode('utf-8') + words = [] + offset = 0 + prev_end = 0 + for i, p in enumerate(pieces): + if p.startswith(word_start): + if offset > prev_end: + words.append(text[prev_end:offset]) + prev_end = offset + w = p.replace(word_start, '') + else: + w = p + try: + s = text.index(w, offset) + pn = '' + k = i + 1 + while k < len(pieces): + pn = pieces[k].replace(word_start, '') + if len(pn) > 0: + break + k += 1 + + if len(pn) > 0 and pn in text[offset:s]: + offset = offset + 1 + else: + offset = s + len(w) + except Exception: + offset = offset + 1 + + if prev_end < offset: + words.append(text[prev_end:offset]) + + return words + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize('NFD', text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == 'Mn': + continue + output.append(char) + return ''.join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return [''.join(x) for x in output] + + def save_pretrained(self, path: str, filename_prefix: str = None): + filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] + if filename_prefix is not None: + filename = filename_prefix + '-' + filename + full_path = os.path.join(path, filename) + with open(full_path, 'wb') as fs: + fs.write(self.spm.serialized_model_proto()) + return (full_path, ) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == ' ' or char == '\t' or char == '\n' or char == '\r': + return True + cat = unicodedata.category(char) + if cat == 'Zs': + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == '\t' or char == '\n' or char == '\r': + return False + cat = unicodedata.category(char) + if cat.startswith('C'): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or ( + cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith('P'): + return True + return False + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError(f'Unsupported string type: {type(text)}') diff --git a/modelscope/models/nlp/deberta_v2/tokenization_fast.py b/modelscope/models/nlp/deberta_v2/tokenization_fast.py new file mode 100644 index 00000000..913ea5bd --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/tokenization_fast.py @@ -0,0 +1,241 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization class for model DeBERTa.""" + +import os +from shutil import copyfile +from typing import Optional, Tuple + +from transformers.file_utils import is_sentencepiece_available +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from modelscope.utils import logger as logging + +if is_sentencepiece_available(): + from .tokenization import DebertaV2Tokenizer +else: + DebertaV2Tokenizer = None + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + 'vocab_file': 'spm.model', + 'tokenizer_file': 'tokenizer.json' +} + +PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} + +PRETRAINED_INIT_CONFIGURATION = {} + + +class DebertaV2TokenizerFast(PreTrainedTokenizerFast): + r""" + Constructs a DeBERTa-v2 fast tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece) + and [rjieba-py](https://github.com/messense/rjieba-py). + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the input when tokenizing. + bos_token (`string`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + eos_token (`string`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. When building a sequence using special tokens, this is not the token that is + used for the end of sequence. The token used is the `sep_token`. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = DebertaV2Tokenizer + + def __init__(self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=False, + split_by_punct=False, + split_chinese=True, + bos_token='[CLS]', + eos_token='[SEP]', + unk_token='[UNK]', + sep_token='[SEP]', + pad_token='[PAD]', + cls_token='[CLS]', + mask_token='[MASK]', + **kwargs) -> None: + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + split_by_punct=split_by_punct, + split_chinese=split_chinese, + **kwargs, + ) + + self.do_lower_case = do_lower_case + self.split_by_punct = split_by_punct + self.split_chinese = split_chinese + self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A DeBERTa sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask(self, + token_ids_0, + token_ids_1=None, + already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ( + [0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences(self, + token_ids_0, + token_ids_1=None): + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' + 'tokenizer.') + + if not os.path.isdir(save_directory): + logger.error( + f'Vocabulary path ({save_directory}) should be a directory') + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + '-' if filename_prefix else '') + + VOCAB_FILES_NAMES['vocab_file']) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file, ) diff --git a/modelscope/models/nlp/gpt3/__init__.py b/modelscope/models/nlp/gpt3/__init__.py new file mode 100644 index 00000000..051cc8f2 --- /dev/null +++ b/modelscope/models/nlp/gpt3/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import GPT3Config + from .backbone import GPT3Model + from .text_generation import GPT3ForTextGeneration + from .tokenizer import JiebaBPETokenizer +else: + _import_structure = { + 'configuration': ['GPT3Config'], + 'backbone': ['GPT3Model'], + 'text_generation': ['GPT3ForTextGeneration'], + 'tokenizer': ['JiebaBPETokenizer'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/gpt3/backbone.py b/modelscope/models/nlp/gpt3/backbone.py new file mode 100644 index 00000000..4647428e --- /dev/null +++ b/modelscope/models/nlp/gpt3/backbone.py @@ -0,0 +1,355 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from typing import Optional, Union + +import addict +import torch +from torch import nn +from torch.nn import functional as F +from transformers.modeling_utils import PreTrainedModel + +from modelscope.utils.constant import ModelFile +from .configuration import GPT3Config + + +class GPT3SelfAttention(nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + # Per attention head + self.hidden_size_per_attention_head = \ + self.hidden_size // self.num_attention_heads + + self.query_key_value = nn.Linear(self.hidden_size, + 3 * self.hidden_size) + self.softmax = nn.Softmax(dim=-1) + self.attention_dropout = nn.Dropout( + config.attention_probs_dropout_prob) + + # Output. + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout_prob) + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + ( + self.num_attention_heads, self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def _split_tensor_along_last_dim(self, + tensor, + num_partitions, + contiguous_split_chunks=False): + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward(self, hidden_states, ltor_mask, is_infer=False): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Attention heads. [b, s, hp] + tgt_len = hidden_states.size(1) + ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len]) + mixed_x_layer = self.query_key_value(hidden_states) + (mixed_query_layer, mixed_key_layer, mixed_value_layer) = \ + self._split_tensor_along_last_dim(mixed_x_layer, 3) + + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + previous_type = value_layer.type() + + # Raw attention scores. [b, np, s, s] + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.hidden_size_per_attention_head) + # Apply the left to right attention mask. + if is_infer: + src_len = key_layer.size(2) + ltor_mask = torch.tril( + torch.ones((1, tgt_len, src_len), + device=hidden_states.device)).view( + 1, 1, tgt_len, src_len).type(previous_type) + converted_mask = 10000.0 * (1.0 - ltor_mask) + attention_scores = (torch.mul(attention_scores, ltor_mask) + - converted_mask).type(previous_type) + + # Attention probabilities. [b, np, s, s] + attention_probs = self.softmax(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size, ) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = self.dense(context_layer) + output = self.output_dropout(output) + + return output + + +class GPT3MLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config): + super().__init__() + + hidden_size = config.hidden_size + # Project to 4h. + self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size) + self.activation_func = F.gelu + # Project back to h. + self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + output = self.dropout(output) + return output + + +class GPT3TransformerLayer(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config): + super().__init__() + + # Layernorm on the input data. + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + # Self attention. + self.attention = GPT3SelfAttention(config) + + # Layernorm on the attention output + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + self.mlp = GPT3MLP(config) + + def forward(self, hidden_states, ltor_mask): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Layer norm at the begining of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.attention(layernorm_output, ltor_mask) + # Residual connection. + layernorm_input = hidden_states + attention_output + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + # Second residual connection. + output = layernorm_input + mlp_output + + return output + + +class GPT3Transformer(nn.Module): + """Transformer class.""" + + def __init__(self, config): + super().__init__() + + self.input_tensor = None + + # Number of layers. + self.num_layers = config.num_hidden_layers + + self.layers = torch.nn.ModuleList( + [GPT3TransformerLayer(config) for _ in range(self.num_layers)]) + + # Final layer norm before output. + self.final_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward(self, hidden_states, attention_mask): + # hidden_states: [s, b, h] + + for index in range(self.num_layers): + layer = self._get_layer(index) + hidden_states = layer(hidden_states, attention_mask) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class GPT3TransformerLanguageModel(nn.Module): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config): + super().__init__() + + # Embeddings. + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) + + # Transformer. + self.transformer = GPT3Transformer(config) + + def forward(self, input_ids, attention_mask, position_ids): + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = words_embeddings + position_embeddings + transformer_input = self.embedding_dropout(embeddings) + transformer_output = self.transformer(transformer_input, + attention_mask) + + logits = F.linear(transformer_output, self.word_embeddings.weight) + return logits + + +class GPT3Model(PreTrainedModel): + + config_class = GPT3Config + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def __init__(self, config): + super().__init__(config) + self.language_model = GPT3TransformerLanguageModel(config) + + def forward(self, + input_ids, + attention_mask=None, + position_ids=None, + labels=None, + **kwargs): + seq_length = input_ids.size(1) + attention_mask = torch.tril( + torch.ones((1, 1, seq_length, seq_length), + dtype=torch.long, + device=input_ids.device)) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + logits = self.language_model(input_ids, attention_mask, position_ids) + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.config.vocab_size), labels.view(-1)) + return addict.Dict(loss=loss, logits=logits) + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, + os.PathLike]]): + config = cls.config_class.from_pretrained( + pretrained_model_name_or_path) + model = cls(config) + state_dict_file = os.path.join(pretrained_model_name_or_path, + ModelFile.TORCH_MODEL_BIN_FILE) + state_dict = torch.load(state_dict_file) + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + state_dict = { + k.replace('model.language_model', 'language_model'): v + for k, v in state_dict.items() + } + model.load_state_dict(state_dict) + return model + + def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): + return {'input_ids': input_ids} diff --git a/modelscope/models/nlp/gpt3/configuration.py b/modelscope/models/nlp/gpt3/configuration.py new file mode 100644 index 00000000..66e8b836 --- /dev/null +++ b/modelscope/models/nlp/gpt3/configuration.py @@ -0,0 +1,111 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class GPT3Config(PretrainedConfig): + + model_type = 'gpt3' + + def __init__( + self, + vocab_size=25600, + hidden_size=768, + ffn_hidden_size=None, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=2048, + type_vocab_size=2, + layernorm_epsilon=1e-12, + bias_gelu_fusion=True, + fp32_residual_connection=False, + sequence_parallel=False, + fp16=False, + bf16=False, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=False, + kv_channels=None, + masked_softmax_fusion=True, + attention_dropout=0.1, + bias_dropout_fusion=True, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.1, + init_method_std=0.02, + # generate + eod_id=7, + tokens_to_generate=100, + top_k=0, + top_p=0.9, + **kwargs): + super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = 4 * hidden_size \ + if ffn_hidden_size is None else ffn_hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.layernorm_epsilon = layernorm_epsilon + self.bias_gelu_fusion = bias_gelu_fusion + self.fp32_residual_connection = fp32_residual_connection + self.sequence_parallel = sequence_parallel + self.fp16 = fp16 + self.bf16 = bf16 + assert not (fp16 and bf16) + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + if kv_channels is None: + assert hidden_size % num_attention_heads == 0 + self.kv_channels = hidden_size // num_attention_heads + self.masked_softmax_fusion = masked_softmax_fusion + self.attention_dropout = attention_dropout + self.bias_dropout_fusion = bias_dropout_fusion + self.apply_residual_connection_post_layernorm = \ + apply_residual_connection_post_layernorm + self.hidden_dropout = hidden_dropout + self.init_method_std = init_method_std + self.eod_id = eod_id + self.tokens_to_generate = tokens_to_generate + self.top_k = top_k + self.top_p = top_p + + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + self.no_persist_layer_norm = \ + TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11) + + @property + def params_dtype(self): + if self.fp16: + return torch.half + elif self.bf16: + return torch.bfloat16 + else: + return torch.float diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py new file mode 100644 index 00000000..a0091259 --- /dev/null +++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py @@ -0,0 +1,1057 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from megatron import mpu +from megatron.global_vars import get_global_memory_buffer, set_global_variables +from megatron.model import (AttnMaskType, Float16Module, LayerNorm, + bias_gelu_impl) +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from torch import nn +from torch.nn import functional as F +from transformers.modeling_utils import PreTrainedModel + +from modelscope.models import TorchModel +from modelscope.models.nlp.gpt3 import GPT3Config +from modelscope.utils.nlp.distributed import initialize_distributed +from modelscope.utils.nlp.load_checkpoint import pre_load +from modelscope.utils.torch_utils import set_random_seed_mpu + + +class GPT3ParallelMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config, init_method, output_layer_init_method): + super().__init__() + + # Project to 4h. + self.dense_h_to_4h = mpu.ColumnParallelLinearV3( + config, + config.hidden_size, + config.ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True) + + self.bias_gelu_fusion = config.bias_gelu_fusion + self.activation_func = F.gelu + + # Project back to h. + self.dense_4h_to_h = mpu.RowParallelLinearV3( + config, + config.ffn_hidden_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h( + hidden_states) + + if self.bias_gelu_fusion: + intermediate_parallel = \ + bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = \ + self.activation_func(intermediate_parallel + bias_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + + +class GPT3Embedding(nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config, init_method): + super().__init__() + + self.hidden_size = config.hidden_size + self.init_method = init_method + + # Word embeddings (parallel). + self.word_embeddings = mpu.VocabParallelEmbedding( + config.vocab_size, self.hidden_size, init_method=self.init_method) + + # Position embedding (serial). + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + self.hidden_size) + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + + self.fp32_residual_connection = config.fp32_residual_connection + self.sequence_parallel = config.sequence_parallel + # Embeddings dropout + self.embedding_dropout = nn.Dropout(config.hidden_dropout) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + + def forward(self, input_ids, position_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + if self.sequence_parallel: + embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) + with mpu.get_cuda_rng_tracker().fork(): + embeddings = self.embedding_dropout(embeddings) + else: + embeddings = self.embedding_dropout(embeddings) + return embeddings + + +class NoopTransformerLayer(nn.Module): + + def __init__(self, layer_number): + super().__init__() + self.layer_number = layer_number + + def forward(self, + hidden_states, + attention_mask, + encoder_output=None, + enc_dec_attn_mask=None, + inference_params=None): + return hidden_states.clone() + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +class GPT3CoreAttention(nn.Module): + + def __init__(self, + config, + layer_number, + attn_mask_type=AttnMaskType.padding): + super().__init__() + self.fp16 = config.fp16 + self.bf16 = config.bf16 + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.sequence_parallel = config.sequence_parallel + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(projection_size, + world_size) + self.hidden_size_per_attention_head = mpu.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + config.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.fp16, self.bf16, self.attn_mask_type, + config.masked_softmax_fusion, attention_mask_func, + self.attention_softmax_in_fp32, coeff) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = get_global_memory_buffer().get_tensor( + (output_size[0] * output_size[1], output_size[2], output_size[3]), + query_layer.dtype, 'mpu') + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), + query_layer.size(0), value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class GPT3ParallelAttention(nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config, init_method, output_layer_init_method, + layer_number): + super().__init__() + self.layer_number = max(1, layer_number) + self.params_dtype = config.params_dtype + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_attention_head = mpu.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + config.num_attention_heads, world_size) + + # Strided linear layer. + self.query_key_value = mpu.ColumnParallelLinearV3( + config, + config.hidden_size, + 3 * projection_size, + gather_output=False, + init_method=init_method) + + self.core_attention = GPT3CoreAttention(config, self.layer_number) + + # Output. + self.dense = mpu.RowParallelLinearV3( + config, + projection_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + + def _allocate_memory(self, inference_max_sequence_len, batch_size): + return torch.empty( + inference_max_sequence_len, + batch_size, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device()) + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_len + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, + value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) + + # ================================== + # Adjust key and value for inference + # ================================== + + if inference_params: + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = key_layer + inference_value_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[:sequence_end, + batch_start:batch_end, ...] + value_layer = inference_value_memory[:sequence_end, + batch_start:batch_end, ...] + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + return output, bias + + +class nullcontext: + + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + + return _bias_dropout_add + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +class GPT3ParallelTransformerLayer(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config, init_method, output_layer_init_method, + layer_number): + + super().__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm \ + = config.apply_residual_connection_post_layernorm + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + # Self attention. + self.self_attention = GPT3ParallelAttention(config, init_method, + output_layer_init_method, + layer_number) + self.hidden_dropout = config.hidden_dropout + self.bias_dropout_fusion = config.bias_dropout_fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + # MLP + self.mlp = GPT3ParallelMLP(config, init_method, + output_layer_init_method) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 + and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, attention_bias = \ + self.self_attention( + layernorm_output, + attention_mask, + inference_params=inference_params) + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + with self.bias_dropout_add_exec_handler(): + layernorm_input = bias_dropout_add_func( + attention_output, attention_bias.expand_as(residual), residual, + self.hidden_dropout) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output, mlp_bias = self.mlp(layernorm_output) + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func(mlp_output, + mlp_bias.expand_as(residual), + residual, self.hidden_dropout) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = mpu.make_viewless_tensor( + inp=output, requires_grad=output.requires_grad, keep_graph=True) + + return output + + +class GPT3ParallelTransformer(nn.Module): + """Transformer class.""" + + def __init__(self, + config, + init_method, + output_layer_init_method, + post_layer_norm=True, + pre_process=True, + post_process=True): + super().__init__() + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = post_layer_norm + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + + self.sequence_parallel = config.sequence_parallel + + # Number of layers. + self.num_layers = config.num_hidden_layers + + # Transformer layers. + def build_layer(layer_number): + return GPT3ParallelTransformerLayer(config, init_method, + output_layer_init_method, + layer_number) + + if self.num_layers == 0: + self.num_layers = 1 + self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) + else: + self.layers = torch.nn.ModuleList( + [build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [s, b, h] + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = mpu.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + if self.sequence_parallel: + rng_context = mpu.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + with rng_context: + # Forward pass. + for index in range(self.num_layers): + layer = self._get_layer(index) + hidden_states = layer( + hidden_states, + attention_mask, + inference_params=inference_params) + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class GPT3TransformerLanguageModel(nn.Module): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config, init_method, output_layer_init_method): + super().__init__() + + self.hidden_size = config.hidden_size + self.init_method = init_method + self.encoder_hidden_state = None + + # Embeddings. + self.embedding = GPT3Embedding(config, self.init_method) + + # Transformer. + self.encoder = GPT3ParallelTransformer( + config, + self.init_method, + output_layer_init_method, + ) + + def forward(self, + enc_input_ids, + enc_position_ids, + enc_attn_mask, + inference_params=None, + enc_hidden_states=None): + + # Encoder embedding. + encoder_input = self.embedding(enc_input_ids, enc_position_ids) + + # Run encoder. + if enc_hidden_states is None: + if self.encoder is not None: + encoder_output = self.encoder( + encoder_input, + enc_attn_mask, + inference_params=inference_params) + else: + encoder_output = self.encoder_hidden_state + else: + encoder_output = enc_hidden_states.to(encoder_input.dtype) + + return encoder_output + + +def init_method_normal(sigma): + """Init method based on N(0, sigma).""" + + def init_(tensor): + return nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method_normal(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +class GPT3Model(PreTrainedModel): + + config_class = GPT3Config + + def __init__(self, config, parallel_output=False): + super().__init__(config) + + self.parallel_output = parallel_output + + self.language_model = GPT3TransformerLanguageModel( + config, init_method_normal(config.init_method_std), + scaled_init_method_normal(config.init_method_std, + config.num_hidden_layers)) + + def word_embeddings_weight(self): + return self.language_model.embedding.word_embeddings.weight + + @staticmethod + def build_attention_mask_and_position_ids(tokens): + seq_length = tokens.size(1) + attention_mask = torch.tril( + torch.ones((1, 1, seq_length, seq_length), + dtype=torch.long, + device=tokens.device)) + attention_mask = (attention_mask < 0.5) + + position_ids = torch.arange( + seq_length, dtype=torch.long, device=tokens.device) + position_ids = position_ids.unsqueeze(0).expand_as(tokens) + + return attention_mask, position_ids + + def forward(self, + input_ids, + attention_mask=None, + position_ids=None, + inference_params=None, + **kwargs): + if attention_mask is None and position_ids is None: + attention_mask, position_ids = \ + self.build_attention_mask_and_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + inference_params=inference_params) + + logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( + lm_output, self.word_embeddings_weight(), None, False, True, + self.config.sequence_parallel) + # Gather if needed. + + output = logits_parallel + if not self.parallel_output: + output = mpu.gather_from_model_parallel_region(logits_parallel) + return output.transpose(0, 1).contiguous() + + +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + + filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(filter_, float('-Inf')) + + +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Filteration based on the cumulative sum. + filter_ = cumulative_probs > top_p + # This shift by 1 is weird and I cannot justify it. This existed + # in the original implementation: + # https://github.com/ari-holtzman/degen/blob/master/gen.py + # and I guess it is needed so keeping it for now. + filter_[:, 1:] = filter_[:, :-1].clone() + # Make sure we at least have one token to select from. + filter_[..., 0] = 0 + + # Fill in the filtered part + filter_ = filter_.scatter(1, sorted_indices, filter_) + logits.masked_fill_(filter_, float('-Inf')) + + +def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): + """ Sample and generate a token. + Note: logits has the dimension [b, v] where b is the batch size + and v is the vocabulary size. + If vocab_size is provided, we will make sure the sample that is + generated is in [0, vocab-size). This will avoid out of vocabulary + generations due to padding. + """ + + # Check logits for consistency. + assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' + assert logits.type() == 'torch.cuda.FloatTensor', \ + 'input logits should be floats.' + + # Greedy is just simple argmax. + if top_k == 1: + assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' + samples = torch.argmax(logits, dim=-1) + + # Top-k or top-p sampling. + else: + # Clone so we do not modify the inputs, + logits = logits.clone() + # Apply temperature in place. + if temperature != 1.0: + logits.div_(temperature) + + if top_k > 1: + assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' + assert top_k <= logits.size(1), 'top-k is larger than logit size.' + if vocab_size: + assert top_k < vocab_size, 'top-k is larger than vocab size.' + modify_logits_for_top_k_filtering(logits, top_k) + + elif top_p > 0.0: + assert top_p <= 1.0, 'top-p should be in (0, 1].' + modify_logits_for_top_p_filtering(logits, top_p) + + # After filtering, we need to recalculate the distribution. + probs = logits.softmax(dim=-1) + samples = torch.multinomial(probs, num_samples=1).view(-1) + + # If vocab size is provided, make sure the samples are in + # in the range [0, vocab-size). + if vocab_size: + samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) + + return samples + + +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + def __init__(self, max_batch_size, max_sequence_len): + """Note that offsets are set to zero and we always set the + flag to allocate memory. After the first call, make sure to + set this flag to False.""" + self.max_sequence_len = max_sequence_len + self.max_batch_size = max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.key_value_memory_dict = {} + + def swap_key_value_dict(self, batch_idx): + 'swap between batches' + if len(self.key_value_memory_dict) == 0: + raise ValueError('should not swap when dict in empty') + + for layer_number in self.key_value_memory_dict.keys(): + inference_key_memory, inference_value_memory = self.key_value_memory_dict[ + layer_number] + assert len(batch_idx) == inference_key_memory.shape[ + 1] # make sure batch size is the same + new_inference_key_memory = inference_key_memory[:, batch_idx] + new_inference_value_memory = inference_value_memory[:, batch_idx] + self.key_value_memory_dict[layer_number] = ( + new_inference_key_memory, new_inference_value_memory) + + +class DistributedGPT3(TorchModel): + + def __init__(self, + model_dir, + rank, + path_load_tag='model', + *args, + **kwargs): + super().__init__(model_dir, *args, **kwargs) + initialize_distributed(rank, mpu, kwargs['world_size'], + kwargs['model_parallel_size'], + kwargs['master_ip'], kwargs['master_port']) + seed = 0 if 'seed' not in kwargs else kwargs['seed'] + set_random_seed_mpu(seed) + set_global_variables() + + self.config = GPT3Config.from_pretrained(model_dir) + # Build model. + model = GPT3Model(self.config) + + for param in model.parameters(): + mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # GPU allocation. + model.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if self.config.fp16 or self.config.bf16: + model = Float16Module(model, self.config) + + self.dist_model = model + load_model = pre_load(mpu, model_dir, tag=path_load_tag) + self.dist_model.load_state_dict(load_model) + + self.inference_params = None + + def forward_step(self, tokens, attention_mask, position_ids): + logits = self.dist_model( + tokens, + attention_mask, + position_ids, + inference_params=self.inference_params) + self.inference_params.sequence_len_offset += tokens.size(1) + return logits + + def generate(self, + tokens, + temperature=1.0, + use_eod_token_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False): + lengths = torch.tensor([tokens.size(1)], device=tokens.device) + pads = torch.ones( + 1, self.config.tokens_to_generate, + device=tokens.device).long() * self.config.eod_id + tokens = torch.cat((tokens, pads), dim=-1) + + batch_size = tokens.size(0) + min_prompt_length = lengths.min().item() + max_sequence_length = tokens.size(1) + max_sequence_length = min(max_sequence_length, + self.config.max_position_embeddings) + + # If the context is too big, this happens + if min_prompt_length >= max_sequence_length: + raise ValueError('context length + tokens_to_generate too large') + + # Initialize inference parameters. + self.inference_params = InferenceParams(batch_size, + max_sequence_length) + + # Added termination_id to support the case that we want to terminate the + # generation once that id is generated. + termination_id = self.config.eod_id + + # Whether we have reached a termination id. + is_generation_done = torch.zeros( + batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) + + # ============= + # Run infernece + # ============= + + with torch.no_grad(): + attention_mask, position_ids = \ + GPT3Model.build_attention_mask_and_position_ids(tokens) + prev_context_length = 0 + for context_length in range(min_prompt_length, + max_sequence_length): + + # Pick the slice that we need to pass through the network. + tokens2use = tokens[:, prev_context_length:context_length] + positions2use = position_ids[:, prev_context_length: + context_length] + attention_mask2use = attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # logits will be meanigful only in the last pipeline stage. + logits = self.forward_step(tokens2use, attention_mask2use, + positions2use) + + # Sample. + last_token_logits = logits[:, -1, :] + new_sample = sample( + last_token_logits, + top_k=self.config.top_k, + top_p=self.config.top_p, + temperature=temperature, + vocab_size=self.config.vocab_size) + + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = lengths <= context_length + # Update the tokens. + tokens[started, context_length] = new_sample[started] + + # Update the context length for the next token generation. + prev_context_length = context_length + + # instead tokenization should be in the inference loop so stop sequences can be used + if stop_on_double_eol: + hit_double_eol = (new_sample + == 628).byte() & started.byte() + hit_two_eols = (new_sample == 198).byte() & ( + tokens[:, context_length - 1] + == 198).byte() & started.byte() + done_token = hit_double_eol | hit_two_eols + elif stop_on_eol: + hit_double_eol = (new_sample + == 628).byte() & started.byte() + hit_eol = (new_sample == 198).byte() & started.byte() + done_token = hit_double_eol | hit_eol + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() + + is_generation_done = is_generation_done | done_token + done = torch.all(is_generation_done) + + if use_eod_token_for_early_termination and done: + break + + tokens = tokens[:, :(context_length + 1)] + return tokens diff --git a/modelscope/models/nlp/gpt3/text_generation.py b/modelscope/models/nlp/gpt3/text_generation.py new file mode 100644 index 00000000..b8b705a5 --- /dev/null +++ b/modelscope/models/nlp/gpt3/text_generation.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + +__all__ = ['GPT3ForTextGeneration'] + + +@MODELS.register_module(Tasks.text_generation, module_name=Models.gpt3) +class GPT3ForTextGeneration(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + from modelscope.models.nlp.gpt3 import GPT3Model + from transformers import BertTokenizer + + self.model = GPT3Model.from_pretrained(model_dir) + self.tokenizer = BertTokenizer.from_pretrained(model_dir) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'logits': Tensor([[0.54, 0.32...])]), # logits + } + """ + return self.model(**input) + + def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + assert 'input_ids' in input, "generate function must accept 'input_ids' key" + input_ids = input['input_ids'] + if 'attention_mask' in input: + attention_mask = input['attention_mask'] + input_ids = input_ids[0][attention_mask[0].nonzero()] \ + .squeeze().unsqueeze(0) + # remove sep token at the end of tokenizer output + input_ids = input_ids[:, :-1] + + gen_params = dict() + gen_params['inputs'] = input_ids + gen_params['do_sample'] = input.pop('do_sample', True) + gen_params['max_length'] = input.pop('max_length', 128) + gen_params['top_k'] = input.pop('top_k', 10) + gen_params['top_p'] = input.pop('top_p', None) + sample_output = self.model.generate(**gen_params) + return {'sequences': sample_output[0]} diff --git a/modelscope/models/nlp/gpt3/tokenizer.py b/modelscope/models/nlp/gpt3/tokenizer.py new file mode 100644 index 00000000..5780ddbd --- /dev/null +++ b/modelscope/models/nlp/gpt3/tokenizer.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tokenizers import Tokenizer + + +class JiebaBPETokenizer: + """SentencePiece BPE tokenizer with Jieba integration""" + + def __init__(self, tokenizer_json_file): + self.name = 'Jieba BPE Tokenizer' + + self.tokenizer = Tokenizer.from_file(tokenizer_json_file) + self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') + try: + import jieba + except ImportError: + raise ImportError( + 'You need to install rjieba to use JiebaTokenizer. ' + 'See https://pypi.org/project/rjieba/ for installation.') + self.jieba = jieba + self.new_line = self.vocab['\n'] + self.sep_token = self.vocab[''] + + @property + def vocab_size(self): + return self.tokenizer.get_vocab_size(with_added_tokens=True) + + @property + def vocab(self): + return self.tokenizer.get_vocab(with_added_tokens=True) + + @property + def inv_vocab(self): + vocab = self.vocab + inv_vocab = dict() + for key, val in vocab.items(): + inv_vocab[val] = key + return inv_vocab + + def tokenize(self, text, is_code=False): + """ + """ + if not is_code: + seg_list = [x for x in self.jieba.cut(text)] + return self.tokenizer.encode( + seg_list, is_pretokenized=True, add_special_tokens=True).ids + else: + return self.tokenizer.encode( + text, is_pretokenized=False, add_special_tokens=True).ids + + def detokenize(self, token_ids): + text = self.tokenizer.decode(token_ids, skip_special_tokens=False) + return text + + @property + def eod(self): + return self.eod_id diff --git a/modelscope/models/nlp/gpt_neo/__init__.py b/modelscope/models/nlp/gpt_neo/__init__.py new file mode 100644 index 00000000..ef5fdee5 --- /dev/null +++ b/modelscope/models/nlp/gpt_neo/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .backbone import GPTNeoModel +else: + _import_structure = { + 'backbone': ['GPTNeoModel'], + } + import sys + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/gpt_neo/backbone.py b/modelscope/models/nlp/gpt_neo/backbone.py new file mode 100644 index 00000000..a809bcde --- /dev/null +++ b/modelscope/models/nlp/gpt_neo/backbone.py @@ -0,0 +1,16 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from transformers import GPTNeoConfig +from transformers import GPTNeoModel as GPTNeoModelTransform + +from modelscope.metainfo import Models +from modelscope.models.builder import BACKBONES +from modelscope.utils.constant import Tasks + + +@BACKBONES.register_module( + group_key=Tasks.backbone, module_name=Models.gpt_neo) +class GPTNeoModel(GPTNeoModelTransform): + + def __init__(self, **kwargs): + config = GPTNeoConfig(**kwargs) + super().__init__(config) diff --git a/modelscope/models/nlp/heads/__init__.py b/modelscope/models/nlp/heads/__init__.py new file mode 100644 index 00000000..19194d3a --- /dev/null +++ b/modelscope/models/nlp/heads/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .sequence_classification_head import SequenceClassificationHead + from .torch_pretrain_head import BertMLMHead, RobertaMLMHead +else: + _import_structure = { + 'sequence_classification_head': ['SequenceClassificationHead'], + 'torch_pretrain_head': ['BertMLMHead', 'RobertaMLMHead'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/heads/fill_mask_head.py b/modelscope/models/nlp/heads/fill_mask_head.py new file mode 100644 index 00000000..6b0c5e05 --- /dev/null +++ b/modelscope/models/nlp/heads/fill_mask_head.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN + +from modelscope.metainfo import Heads +from modelscope.models.base import TorchHead +from modelscope.models.builder import HEADS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + + +@HEADS.register_module(Tasks.fill_mask, module_name=Heads.bert_mlm) +class BertFillMaskHead(TorchHead): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.cls = BertOnlyMLMHead(self.config) + + def forward(self, sequence_output): + prediction_scores = self.cls(sequence_output) + return {OutputKeys.LOGITS: prediction_scores} + + def compute_loss(self, outputs: Dict[str, torch.Tensor], + labels) -> Dict[str, torch.Tensor]: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + outputs.view(-1, self.config.vocab_size), labels.view(-1)) + return {OutputKeys.LOSS: masked_lm_loss} + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores diff --git a/modelscope/models/nlp/heads/infromation_extraction_head.py b/modelscope/models/nlp/heads/infromation_extraction_head.py new file mode 100644 index 00000000..626f1b59 --- /dev/null +++ b/modelscope/models/nlp/heads/infromation_extraction_head.py @@ -0,0 +1,105 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +from torch import nn + +from modelscope.metainfo import Heads +from modelscope.models.base import TorchHead +from modelscope.models.builder import HEADS +from modelscope.utils.constant import Tasks + + +@HEADS.register_module( + Tasks.information_extraction, module_name=Heads.information_extraction) +@HEADS.register_module( + Tasks.relation_extraction, module_name=Heads.information_extraction) +class InformationExtractionHead(TorchHead): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + config = self.config + assert config.get('labels') is not None + self.labels = config.labels + self.s_layer = nn.Linear(config.hidden_size, 2) # head, tail, bce + self.o_layer = nn.Linear(2 * config.hidden_size, 2) # head, tail, bce + self.p_layer = nn.Linear(config.hidden_size, + len(self.labels)) # label, ce + self.mha = nn.MultiheadAttention(config.hidden_size, 4) + + def forward(self, sequence_output, text, offsets, threshold=0.5): + # assert batch size == 1 + spos = [] + s_head_logits, s_tail_logits = self.s_layer(sequence_output).split( + 1, dim=-1) # (b, seq_len, 2) + s_head_logits = s_head_logits[0, :, 0].sigmoid() # (seq_len) + s_tail_logits = s_tail_logits[0, :, 0].sigmoid() # (seq_len) + s_masks, subjects = self._get_masks_and_mentions( + text, offsets, s_head_logits, s_tail_logits, None, threshold) + for s_mask, subject in zip(s_masks, subjects): + masked_sequence_output = sequence_output * s_mask.unsqueeze( + 0).unsqueeze(-1) # (b, s, h) + subjected_sequence_output = self.mha( + sequence_output.permute(1, 0, 2), + masked_sequence_output.permute(1, 0, 2), + masked_sequence_output.permute(1, 0, + 2))[0].permute(1, 0, + 2) # (b, s, h) + cat_sequence_output = torch.cat( + (sequence_output, subjected_sequence_output), dim=-1) + o_head_logits, o_tail_logits = self.o_layer( + cat_sequence_output).split( + 1, dim=-1) + o_head_logits = o_head_logits[0, :, 0].sigmoid() # (seq_len) + o_tail_logits = o_tail_logits[0, :, 0].sigmoid() # (seq_len) + so_masks, objects = self._get_masks_and_mentions( + text, offsets, o_head_logits, o_tail_logits, s_mask, threshold) + for so_mask, object in zip(so_masks, objects): + masked_sequence_output = ( + sequence_output * so_mask.unsqueeze(0).unsqueeze(-1)).sum( + 1) # (b, h) + lengths = so_mask.unsqueeze(0).sum(-1, keepdim=True) # (b, 1) + pooled_subject_object = masked_sequence_output / lengths # (b, h) + label = self.p_layer(pooled_subject_object).sigmoid().squeeze( + 0) + for i in range(label.size(-1)): + if label[i] > threshold: + predicate = self.labels[i] + spos.append((subject, predicate, object)) + return spos + + def _get_masks_and_mentions(self, + text, + offsets, + heads, + tails, + init_mask=None, + threshold=0.5): + ''' + text: str + heads: tensor (len(heads)) + tails: tensor (len(tails)) + ''' + seq_len = heads.size(-1) + potential_heads = [] + for i in range(seq_len - 1): + if heads[i] > threshold: + potential_heads.append(i) + potential_heads.append(seq_len - 1) + masks = [] + mentions = [] + for i in range(len(potential_heads) - 1): + head_index = potential_heads[i] + tail_index, max_val = None, 0 + for j in range(head_index, potential_heads[i + 1]): + if tails[j] > max_val and tails[j] > threshold: + tail_index = j + max_val = tails[j] + if tail_index is not None: + mask = torch.zeros_like( + heads) if init_mask is None else init_mask.clone() + mask[head_index:tail_index + 1] = 1 + masks.append(mask) # (seq_len) + char_head = offsets[head_index][0] + char_tail = offsets[tail_index][1] + mention = text[char_head:char_tail] + mentions.append(mention) + return masks, mentions diff --git a/modelscope/models/nlp/heads/sequence_classification_head.py b/modelscope/models/nlp/heads/sequence_classification_head.py new file mode 100644 index 00000000..fb03b7ff --- /dev/null +++ b/modelscope/models/nlp/heads/sequence_classification_head.py @@ -0,0 +1,43 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Heads +from modelscope.models.base import TorchHead +from modelscope.models.builder import HEADS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + + +@HEADS.register_module( + Tasks.text_classification, module_name=Heads.text_classification) +class SequenceClassificationHead(TorchHead): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + config = self.config + self.num_labels = config.num_labels + classifier_dropout = ( + config['classifier_dropout'] if config.get('classifier_dropout') + is not None else config['hidden_dropout_prob']) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config['hidden_size'], + config['num_labels']) + + def forward(self, inputs=None): + if isinstance(inputs, dict): + assert inputs.get('pooled_output') is not None + pooled_output = inputs.get('pooled_output') + else: + pooled_output = inputs + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return {OutputKeys.LOGITS: logits} + + def compute_loss(self, outputs: Dict[str, torch.Tensor], + labels) -> Dict[str, torch.Tensor]: + logits = outputs[OutputKeys.LOGITS] + return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} diff --git a/modelscope/models/nlp/heads/text_generation_head.py b/modelscope/models/nlp/heads/text_generation_head.py new file mode 100644 index 00000000..ecb02e22 --- /dev/null +++ b/modelscope/models/nlp/heads/text_generation_head.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Heads +from modelscope.models.base import TorchHead +from modelscope.models.builder import HEADS +from modelscope.utils.constant import Tasks + + +@HEADS.register_module( + Tasks.text_generation, module_name=Heads.text_generation) +class TextGenerationHead(TorchHead): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + config = self.config + self.linear = nn.Linear( + config['hidden_size'], config['vocab_size'], bias=False) + + def get_output_embeddings(self): + return self.linear + + def forward(self, inputs=None): + logits = self.linear(inputs) + return logits + + def compute_loss(self, logits: torch.Tensor, + labels) -> Dict[str, torch.Tensor]: + return F.cross_entropy(logits, labels) diff --git a/modelscope/models/nlp/heads/token_classification_head.py b/modelscope/models/nlp/heads/token_classification_head.py new file mode 100644 index 00000000..443f93df --- /dev/null +++ b/modelscope/models/nlp/heads/token_classification_head.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Heads +from modelscope.models.base import TorchHead +from modelscope.models.builder import HEADS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + + +@HEADS.register_module( + Tasks.token_classification, module_name=Heads.token_classification) +@HEADS.register_module( + Tasks.part_of_speech, module_name=Heads.token_classification) +class TokenClassificationHead(TorchHead): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + config = self.config + self.num_labels = config.num_labels + classifier_dropout = ( + config['classifier_dropout'] if config.get('classifier_dropout') + is not None else config['hidden_dropout_prob']) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config['hidden_size'], + config['num_labels']) + + def forward(self, inputs=None): + if isinstance(inputs, dict): + assert inputs.get('sequence_output') is not None + sequence_output = inputs.get('sequence_output') + else: + sequence_output = inputs + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + return logits + + def compute_loss(self, outputs: Dict[str, torch.Tensor], + labels) -> Dict[str, torch.Tensor]: + logits = outputs[OutputKeys.LOGITS] + return F.cross_entropy(logits, labels) diff --git a/modelscope/models/nlp/heads/torch_pretrain_head.py b/modelscope/models/nlp/heads/torch_pretrain_head.py new file mode 100644 index 00000000..e477533f --- /dev/null +++ b/modelscope/models/nlp/heads/torch_pretrain_head.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict + +import torch +from transformers.models.bert.modeling_bert import BertOnlyMLMHead +from transformers.models.roberta.modeling_roberta import RobertaLMHead + +from modelscope.metainfo import Heads +from modelscope.models.base import TorchHead +from modelscope.models.builder import HEADS +from modelscope.utils.constant import Tasks + + +# @HEADS.register_module(Tasks.fill_mask, module_name=Heads.bert_mlm) +class BertMLMHead(BertOnlyMLMHead, TorchHead): + + def compute_loss(self, outputs: Dict[str, torch.Tensor], + labels) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + +@HEADS.register_module(Tasks.fill_mask, module_name=Heads.roberta_mlm) +class RobertaMLMHead(RobertaLMHead, TorchHead): + + def compute_loss(self, outputs: Dict[str, torch.Tensor], + labels) -> Dict[str, torch.Tensor]: + raise NotImplementedError() diff --git a/modelscope/models/nlp/mglm/__init__.py b/modelscope/models/nlp/mglm/__init__.py new file mode 100644 index 00000000..26d1101b --- /dev/null +++ b/modelscope/models/nlp/mglm/__init__.py @@ -0,0 +1,22 @@ +# Modified by Zhipu.AI +# Original Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .mglm_for_text_summarization import mGlmForSummarization +else: + _import_structure = { + 'mglm_for_text_summarization': ['MGLMForTextSummarization'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/mglm/arguments.py b/modelscope/models/nlp/mglm/arguments.py new file mode 100755 index 00000000..13b3aeab --- /dev/null +++ b/modelscope/models/nlp/mglm/arguments.py @@ -0,0 +1,793 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""argparser configuration""" + +import argparse +import os + +import deepspeed +import json +import torch + +from .utils import get_hostname + + +def add_model_config_args(parser): + """Model arguments""" + + group = parser.add_argument_group('model', 'model configuration') + + group.add_argument( + '--transformer-xl', + action='store_true', + help='use transformer-xl for training') + group.add_argument( + '--pretrained-bert', + action='store_true', + help='use a pretrained bert-large-uncased model instead' + 'of initializing from scratch. See ' + '--tokenizer-model-type to specify which pretrained ' + 'BERT model to use') + group.add_argument( + '--encoder-decoder', + action='store_true', + help='use the encoder-decoder architecture for blocklm') + group.add_argument( + '--attention-dropout', + type=float, + default=0.1, + help='dropout probability for attention weights') + group.add_argument( + '--num-attention-heads', + type=int, + default=16, + help='num of transformer attention heads') + group.add_argument( + '--hidden-size', type=int, default=1024, help='tansformer hidden size') + group.add_argument( + '--intermediate-size', + type=int, + default=None, + help='transformer embedding dimension for FFN' + 'set to 4*`--hidden-size` if it is None') + group.add_argument( + '--num-layers', type=int, default=24, help='num decoder layers') + group.add_argument( + '--layernorm-epsilon', + type=float, + default=1e-5, + help='layer norm epsilon') + group.add_argument( + '--hidden-dropout', + type=float, + default=0.1, + help='dropout probability for hidden state transformer') + group.add_argument( + '--output-dropout', + type=float, + default=0.1, + help='dropout probability for pooled output') + group.add_argument( + '--max-position-embeddings', + type=int, + default=512, + help='maximum number of position embeddings to use') + group.add_argument( + '--vocab-size', + type=int, + default=250112, + help='vocab size to use for non-character-level ' + 'tokenization. This value will only be used when ' + 'creating a tokenizer') + group.add_argument( + '--deep-init', + action='store_true', + help='initialize bert model similar to gpt2 model.' + 'scales initialization of projection layers by a ' + 'factor of 1/sqrt(2N). Necessary to train bert ' + 'models larger than BERT-Large.') + group.add_argument( + '--make-vocab-size-divisible-by', + type=int, + default=128, + help='Pad the vocab size to be divisible by this value.' + 'This is added for computational efficieny reasons.') + group.add_argument( + '--cpu-optimizer', action='store_true', help='Run optimizer on CPU') + group.add_argument( + '--cpu_torch_adam', + action='store_true', + help='Use Torch Adam as optimizer on CPU.') + + return parser + + +def add_fp16_config_args(parser): + """Mixed precision arguments.""" + + group = parser.add_argument_group('fp16', 'fp16 configurations') + + group.add_argument( + '--fp16', action='store_true', help='Run model in fp16 mode') + group.add_argument( + '--fp32-embedding', action='store_true', help='embedding in fp32') + group.add_argument( + '--fp32-layernorm', action='store_true', help='layer norm in fp32') + group.add_argument( + '--fp32-tokentypes', + action='store_true', + help='embedding token types in fp32') + group.add_argument( + '--fp32-allreduce', action='store_true', help='all-reduce in fp32') + group.add_argument( + '--hysteresis', + type=int, + default=2, + help='hysteresis for dynamic loss scaling') + group.add_argument( + '--loss-scale', + type=float, + default=None, + help='Static loss scaling, positive power of 2 ' + 'values can improve fp16 convergence. If None, dynamic' + 'loss scaling is used.') + group.add_argument( + '--loss-scale-window', + type=float, + default=1000, + help='Window over which to raise/lower dynamic scale') + group.add_argument( + '--min-scale', + type=float, + default=1, + help='Minimum loss scale for dynamic loss scale') + group.add_argument('--attention-scale', type=float, default=1.0) + return parser + + +def add_training_args(parser): + """Training arguments.""" + + group = parser.add_argument_group('train', 'training configurations') + + group.add_argument( + '--experiment-name', + type=str, + default='gpt-345M', + help='The experiment name for summary and checkpoint') + group.add_argument( + '--batch-size', type=int, default=4, help='Data Loader batch size') + group.add_argument( + '--gradient-accumulation-steps', + type=int, + default=1, + help='Data Loader batch size') + group.add_argument( + '--weight-decay', + type=float, + default=0.01, + help='weight decay coefficient for L2 regularization') + group.add_argument( + '--checkpoint-activations', + action='store_true', + help='checkpoint activation to allow for training ' + 'with larger models and sequences') + group.add_argument( + '--checkpoint-num-layers', + type=int, + default=1, + help='chunk size (number of layers) for checkpointing') + group.add_argument( + '--deepspeed-activation-checkpointing', + action='store_true', + help='uses activation checkpointing from deepspeed') + group.add_argument( + '--epochs', + type=int, + default=None, + help='Number of finetunning epochs. Zero results in evaluation only.') + group.add_argument( + '--clip-grad', type=float, default=1.0, help='gradient clipping') + group.add_argument( + '--train-iters', + type=int, + default=0, + help='total number of iterations to train over all training runs') + group.add_argument('--label-smoothing', type=float, default=0.0) + group.add_argument( + '--log-interval', type=int, default=100, help='report interval') + group.add_argument( + '--summary-dir', + type=str, + default='', + help='The directory to store the summary') + group.add_argument('--seed', type=int, default=1234, help='random seed') + # Batch producer arguments + group.add_argument( + '--reset-position-ids', + action='store_true', + help='Reset posistion ids after end-of-document token.') + group.add_argument( + '--reset-attention-mask', + action='store_true', + help='Reset self attention maske after ' + 'end-of-document token.') + + # Learning rate. + group.add_argument( + '--lr-decay-iters', + type=int, + default=None, + help='number of iterations to decay LR over,' + ' If None defaults to `--train-iters`*`--epochs`') + group.add_argument( + '--lr-decay-style', + type=str, + default='linear', + choices=['constant', 'linear', 'cosine', 'exponential'], + help='learning rate decay function') + group.add_argument('--lr-decay-ratio', type=float, default=0.1) + group.add_argument( + '--lr', type=float, default=1.0e-4, help='initial learning rate') + group.add_argument( + '--warmup', + type=float, + default=0.01, + help='percentage of data to warmup on (.01 = 1% of all ' + 'training iters). Default 0.01') + group.add_argument( + '--switch-linear', + action='store_true', + help='Switch to linear decay for cosine decay') + # model checkpointing + group.add_argument( + '--save', + type=str, + default=None, + help='Output directory to save checkpoints to.') + group.add_argument('--new-save-directory', action='store_true') + group.add_argument( + '--save-epoch', + type=int, + default=1, + help='number of epochs between saves') + group.add_argument( + '--save-interval', + type=int, + default=5000, + help='number of iterations between saves') + group.add_argument( + '--no-save-optim', + action='store_true', + help='Do not save current optimizer.') + group.add_argument( + '--no-save-rng', + action='store_true', + help='Do not save current rng state.') + group.add_argument( + '--load', + type=str, + default=None, + help='Path to a directory containing a model checkpoint.') + group.add_argument( + '--no-load-optim', + action='store_true', + help='Do not load optimizer when loading checkpoint.') + group.add_argument( + '--no-load-rng', + action='store_true', + help='Do not load rng state when loading checkpoint.') + group.add_argument( + '--no-load-lr-scheduler', + action='store_true', + help='Do not load lr scheduler when loading checkpoint.') + group.add_argument( + '--no-deepspeed-load', + action='store_true', + help='Not use deepspeed when loading checkpoint') + group.add_argument( + '--finetune', + action='store_true', + help='Load model for finetuning. Do not load optimizer ' + 'or rng state from checkpoint and set iteration to 0. ' + 'Assumed when loading a release checkpoint.') + group.add_argument( + '--resume-dataloader', + action='store_true', + help='Resume the dataloader when resuming training. ' + 'Does not apply to tfrecords dataloader, try resuming' + 'with a different seed in this case.') + # distributed training args + group.add_argument( + '--distributed-backend', + default='nccl', + help= + 'which backend to use for distributed training. One of [gloo, nccl]', + choices=['nccl', 'gloo']) + group.add_argument( + '--DDP-impl', + default='torch', + choices=['local', 'torch', 'none'], + help='which DistributedDataParallel implementation to use.') + + group.add_argument( + '--local_rank', + type=int, + default=None, + help='local rank passed from distributed launcher') + # BlockLM training args + group.add_argument( + '--block-lm', + action='store_true', + help='whether use the BlockLM pre-training') + group.add_argument( + '--masked-lm', + action='store_true', + help='whether to use the mlm objective') + group.add_argument('--bert-prob', type=float, default=0.5) + group.add_argument('--gpt-infill-prob', type=float, default=0.5) + group.add_argument('--gpt-min-ratio', type=float, default=0.5) + group.add_argument('--gap-sentence-prob', type=float, default=0.0) + group.add_argument('--gap-sentence-ratio', type=float, default=0.15) + group.add_argument('--avg-block-length', type=int, default=3) + group.add_argument('--short-seq-prob', type=float, default=0.0) + group.add_argument('--single-span-prob', type=float, default=0.0) + group.add_argument( + '--task-mask', + action='store_true', + help='Use different mask for generation and blank filling') + group.add_argument( + '--no-shuffle-block', + action='store_true', + help='not shuffle the blocks when filling the blank') + group.add_argument( + '--no-block-position', + action='store_true', + help='Use (rough) absolute positions instead of block positions') + group.add_argument( + '--sentinel-token', + action='store_true', + help='Use sentinel (mask) tokens to replace 2d position encoding') + group.add_argument('--block-mask-prob', type=float, default=0.0) + group.add_argument('--context-mask-ratio', type=float, default=0.0) + group.add_argument( + '--random-position', + action='store_true', + help='Use random start position to cover all the position embeddings') + return parser + + +def add_evaluation_args(parser): + """Evaluation arguments.""" + + group = parser.add_argument_group('validation', + 'validation configurations') + + group.add_argument( + '--eval-batch-size', + type=int, + default=None, + help='Data Loader batch size for evaluation datasets.' + 'Defaults to `--batch-size`') + group.add_argument( + '--eval-iters', + type=int, + default=100, + help='number of iterations to run for evaluation' + 'validation/test for') + group.add_argument( + '--eval-interval', + type=int, + default=1000, + help='interval between running evaluation on validation set') + group.add_argument( + '--eval-epoch', + type=int, + default=1, + help='epoch between running evaluation on validation set') + group.add_argument( + '--eval-seq-length', + type=int, + default=None, + help='Maximum sequence length to process for ' + 'evaluation. Defaults to `--seq-length`') + group.add_argument( + '--eval-max-preds-per-seq', + type=int, + default=None, + help='Maximum number of predictions to use for ' + 'evaluation. Defaults to ' + 'math.ceil(`--eval-seq-length`*.15/10)*10') + group.add_argument('--overlapping-eval', type=int, default=32) + + return parser + + +def add_text_generate_args(parser): + """Text generate arguments.""" + + group = parser.add_argument_group('Text generation', 'configurations') + group.add_argument('--temperature', type=float, default=1.0) + group.add_argument('--top_p', type=float, default=0.0) + group.add_argument('--top_k', type=int, default=0) + group.add_argument('--out-seq-length', type=int, default=256) + group.add_argument('--num-beams', type=int, default=1) + group.add_argument('--length-penalty', type=float, default=0.0) + group.add_argument('--no-repeat-ngram-size', type=int, default=0) + group.add_argument('--min-tgt-length', type=int, default=0) + group.add_argument('--select-topk', action='store_true') + group.add_argument('--blank-maskratio', type=float, default=0.1) + return parser + + +def add_data_args(parser): + """Train/valid/test data arguments.""" + + group = parser.add_argument_group('data', 'data configurations') + + group.add_argument( + '--model-parallel-size', + type=int, + default=1, + help='size of the model parallel.') + group.add_argument( + '--shuffle', + action='store_true', + help='Shuffle data. Shuffling is deterministic ' + 'based on seed and current epoch.') + group.add_argument('--filter-english', action='store_true') + group.add_argument( + '--train-data', + nargs='+', + default=None, + help='Whitespace separated filenames or corpora names ' + 'for training.') + group.add_argument( + '--valid-data', + nargs='*', + default=None, + help="""Filename for validation data.""") + group.add_argument( + '--test-data', + nargs='*', + default=None, + help="""Filename for testing""") + group.add_argument( + '--data-dir', + type=str, + default=None, + help='The data path to all the data files') + group.add_argument( + '--input-data-sizes-file', + type=str, + default='sizes.txt', + help='the filename containing all the shards sizes') + + group.add_argument( + '--delim', default=',', help='delimiter used to parse csv data files') + group.add_argument( + '--text-key', + default='sentence', + help='key to use to extract text from json/csv') + group.add_argument( + '--eval-text-key', + default=None, + help='key to use to extract text from ' + 'json/csv evaluation datasets') + group.add_argument( + '--split', + default='1000,1,1', + help='comma-separated list of proportions for training,' + ' validation, and test split') + + group.add_argument( + '--no-lazy-loader', + action='store_true', + help='whether to lazy read the data set') + group.add_argument('--half-lazy-loader', action='store_true') + group.add_argument( + '--loader-scatter', + type=int, + default=None, + help='Number of scatters to use for dataloaders') + group.add_argument( + '--loose-json', + action='store_true', + help='Use loose json (one json-formatted string per ' + 'newline), instead of tight json (data file is one ' + 'json string)') + group.add_argument( + '--presplit-sentences', + action='store_true', + help='Dataset content consists of documents where ' + 'each document consists of newline separated sentences') + group.add_argument( + '--num-workers', + type=int, + default=2, + help="""Number of workers to use for dataloading""") + group.add_argument( + '--tokenizer-model-type', + type=str, + default=None, + help="Model type to use for sentencepiece tokenization \ + (one of ['bpe', 'char', 'unigram', 'word']) or \ + bert vocab to use for BertWordPieceTokenizer (one of \ + ['bert-large-uncased', 'bert-large-cased', etc.])") + group.add_argument( + '--tokenizer-path', + type=str, + default='tokenizer.model', + help='path used to save/load sentencepiece tokenization ' + 'models') + group.add_argument( + '--tokenizer-type', + type=str, + default='BertWordPieceTokenizer', + choices=[ + 'CharacterLevelTokenizer', 'SentencePieceTokenizer', + 'BertWordPieceTokenizer', 'GPT2BPETokenizer', 'ChineseSPTokenizer' + ], + help='what type of tokenizer to use') + group.add_argument('--no-pre-tokenize', action='store_true') + group.add_argument( + '--cache-dir', + default=None, + type=str, + help='Where to store pre-trained BERT downloads') + group.add_argument( + '--use-tfrecords', + action='store_true', + help='load `--train-data`, `--valid-data`, ' + '`--test-data` from BERT tf records instead of ' + 'normal data pipeline') + group.add_argument( + '--seq-length', + type=int, + default=512, + help='Maximum sequence length to process') + group.add_argument( + '--mem-length', + type=int, + default=0, + help='The memory length to preserve') + group.add_argument( + '--max-preds-per-seq', + type=int, + default=None, + help='Maximum number of predictions to use per sequence.' + 'Defaults to math.ceil(`--seq-length`*.15/10)*10.' + 'MUST BE SPECIFIED IF `--use-tfrecords` is True.') + group.add_argument('--non-sentence-start', type=float, default=0.0) + group.add_argument( + '--sample-one-document', + action='store_true', + help='only sample one document in one sample') + group.add_argument( + '--load-splits', + type=str, + default=None, + help='The path to load split indices from') + group.add_argument( + '--save-splits', + type=str, + default=None, + help='The path to save split indices to') + group.add_argument( + '--save-test-data', + type=str, + default=None, + help='The path to save the test data') + group.add_argument( + '--multi-task-data', + nargs='*', + default=None, + help='Downsteam task names for multi-task pre-training') + group.add_argument( + '--multi-task-ratio', + type=float, + default=0.0, + help='Ratio for multi-task pre-training') + group.add_argument('--multi-seq-length', type=int, default=None) + group.add_argument('--multi-batch-size', type=int, default=None) + return parser + + +def add_finetune_config_args(parser): + group = parser.add_argument_group('finetune', 'finetune configurations') + group.add_argument('--task', type=str, help='Task name.') + group.add_argument( + '--load-pretrained', + type=str, + help='Load pretrained model', + default=None) + group.add_argument( + '--pool-token', + type=str, + choices=['start', 'pad', 'cls'], + help='The token to pool the sequence representation', + default='cls') + group.add_argument( + '--cloze-eval', + action='store_true', + help='Evaluation dataset with cloze task') + group.add_argument( + '--multi-token', + action='store_true', + help='Use multi token for cloze evaluation') + group.add_argument( + '--segment-length', + type=int, + default=0, + help='The maximum segment length for cloze evaluation') + group.add_argument( + '--loss-func', + type=str, + choices=['cross_entropy', 'hinge', 'generative', 'mix'], + default='cross_entropy') + group.add_argument('--block-lm-ratio', type=float, default=0.0) + group.add_argument( + '--adapet', + action='store_true', + help='Use the decoupled cross entropy loss in AdaPET') + group.add_argument('--pattern-id', type=int, default=0) + group.add_argument( + '--fast-decode', + action='store_true', + help= + 'Fast decode for multi-token cloze. Can only be used without checkpoint activation.' + ) + group.add_argument('--few-superglue', action='store_true') + group.add_argument( + '--eval-valid', + action='store_true', + help='Whether evaluate on the valid set') + group.add_argument('--validation-metric', type=str, default=None) + group.add_argument( + '--unidirectional', + action='store_true', + help='Use the left to right language model') + group.add_argument('--src-seq-length', type=int, default=None) + group.add_argument('--tgt-seq-length', type=int, default=None) + group.add_argument('--adam-beta1', type=float, default=0.9) + group.add_argument('--adam-beta2', type=float, default=0.999) + group.add_argument('--adam-eps', type=float, default=1e-8) + group.add_argument( + '--optimizer', type=str, choices=['adam', 'adafactor'], default='adam') + group.add_argument('--wsc-negative', action='store_true') + group.add_argument('--overwrite', action='store_true') + group.add_argument('--no-validation', action='store_true') + # Continuous prompt arguments + group.add_argument( + '--continuous-prompt', + action='store_true', + help='Use continuous prompt for PET') + group.add_argument('--num-prompt-tokens', type=int, default=0) + group.add_argument( + '--prompt-func', default='lstm', choices=['lstm', 'mlp', 'none']) + group.add_argument( + '--freeze-transformer', action='store_true', default=False) + group.add_argument('--tune-prefix-layers', type=int, default=None) + group.add_argument('--prefix-prompt', type=int, default=0) + group.add_argument('--prompt-init', action='store_true', default=False) + return parser + + +def get_args(): + """Parse all the args.""" + + parser = argparse.ArgumentParser(description='PyTorch BERT Model') + parser = add_model_config_args(parser) + parser = add_fp16_config_args(parser) + parser = add_training_args(parser) + parser = add_evaluation_args(parser) + parser = add_text_generate_args(parser) + parser = add_data_args(parser) + parser = add_finetune_config_args(parser) + + # Include DeepSpeed configuration arguments + parser = deepspeed.add_config_arguments(parser) + + args = parser.parse_args(args=[]) + if not args.train_data and not args.data_dir: + print('WARNING: No training data specified') + + args.cuda = torch.cuda.is_available() + + args.rank = int(os.getenv('RANK', '0')) + args.world_size = int(os.getenv('WORLD_SIZE', '1')) + if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: + mpi_define_env(args) + elif os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'): + # We are using (OpenMPI) mpirun for launching distributed data parallel processes + local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')) + local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE')) + + # Possibly running with Slurm + num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1')) + nodeid = int(os.getenv('SLURM_NODEID', '0')) + + args.local_rank = local_rank + args.rank = nodeid * local_size + local_rank + args.world_size = num_nodes * local_size + + args.model_parallel_size = min(args.model_parallel_size, args.world_size) + if args.rank == 0: + print('using world size: {} and model-parallel size: {} '.format( + args.world_size, args.model_parallel_size)) + + args.dynamic_loss_scale = False + if args.loss_scale is None: + args.dynamic_loss_scale = True + if args.rank == 0: + print(' > using dynamic loss scaling') + + # The args fp32_* or fp16_* meant to be active when the + # args fp16 is set. So the default behaviour should all + # be false. + if not args.fp16: + args.fp32_embedding = False + args.fp32_tokentypes = False + args.fp32_layernorm = False + + if hasattr(args, 'deepspeed' + ) and args.deepspeed and args.deepspeed_config is not None: + with open(args.deepspeed_config) as file: + deepspeed_config = json.load(file) + if 'train_micro_batch_size_per_gpu' in deepspeed_config: + args.batch_size = deepspeed_config[ + 'train_micro_batch_size_per_gpu'] + if 'gradient_accumulation_steps' in deepspeed_config: + args.gradient_accumulation_steps = deepspeed_config[ + 'gradient_accumulation_steps'] + else: + args.gradient_accumulation_steps = 1 + if 'optimizer' in deepspeed_config: + optimizer_params_config = deepspeed_config['optimizer'].get( + 'params', {}) + args.lr = optimizer_params_config.get('lr', args.lr) + args.weight_decay = optimizer_params_config.get( + 'weight_decay', args.weight_decay) + return args + + +def mpi_define_env(args): + from mpi4py import MPI + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + + master_addr = None + if rank == 0: + master_addr = get_hostname() + master_addr = comm.bcast(master_addr, root=0) + + # Determine local rank by assuming hostnames are unique + proc_name = MPI.Get_processor_name() + all_procs = comm.allgather(proc_name) + local_rank = sum([i == proc_name for i in all_procs[:rank]]) + + os.environ['RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + args.local_rank = local_rank + args.world_size = world_size + args.rank = rank + os.environ['MASTER_ADDR'] = master_addr + os.environ[ + 'MASTER_PORT'] = '29500' # TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 + + print( + 'Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}' + .format(os.environ['RANK'], args.local_rank, os.environ['WORLD_SIZE'], + os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])) diff --git a/modelscope/models/nlp/mglm/blocklm_utils.py b/modelscope/models/nlp/mglm/blocklm_utils.py new file mode 100644 index 00000000..9af83f67 --- /dev/null +++ b/modelscope/models/nlp/mglm/blocklm_utils.py @@ -0,0 +1,625 @@ +# Copyright (c) 2022 Zhipu.AI + +import copy +import math +import random + +import numpy as np +import torch +import torch.utils.data +from scipy.stats import poisson + +from . import mpu +from .utils import print_rank_0 + + +def rindex(lst, val, start=None): + if start is None: + start = len(lst) - 1 + for i in range(start, -1, -1): + if lst[i] == val: + return i + return -1 + + +def index_in_list(lst, val, start=None): + if start is None: + start = 0 + for i in range(start, len(lst)): + if lst[i] == val: + return i + return -1 + + +class ConstructBlockStrategy: + + def __init__(self, + args, + tokenizer, + max_seq_length, + bert_prob=1.0, + gap_sentence_prob=0.0, + gpt_infill_prob=0.5, + gpt_min_ratio=0.5, + bert_ratio=0.15, + gap_sentence_ratio=0.15, + average_block_length=3, + max_block_length=40, + block_mask_prob=0.0, + context_mask_ratio=0.0, + context_mask_range=3, + short_seq_prob=0.0, + single_span_prob=0.0, + block_position_encoding=True, + encoder_decoder=False, + shuffle_blocks=True, + sentinel_token=False, + task_mask=False, + random_position=False, + masked_lm=False): + self.eod_token = args.eod_token + self.tokenizer = tokenizer + self.count = 0 + self.max_seq_length = max_seq_length + self.rank = mpu.get_data_parallel_rank() + self.world_size = mpu.get_data_parallel_world_size() + # self.rank = 0 + # self.world_size = 1 + assert 0.0 <= bert_prob <= 1.0 + self.bert_prob = bert_prob + self.gap_sentence_prob = gap_sentence_prob + self.gpt_prob = 1 - bert_prob - gap_sentence_prob + assert self.gpt_prob >= -1e-10 + self.infill_prob = gpt_infill_prob + self.gpt_min_ratio = gpt_min_ratio + self.bert_ratio = bert_ratio + self.gap_sentence_ratio = gap_sentence_ratio + self.block_length_distribution = [ + poisson.pmf(i, average_block_length) + for i in range(1, max_block_length) + ] + self.block_mask_prob = block_mask_prob + self.context_mask_ratio = context_mask_ratio + self.context_mask_range = context_mask_range + self.short_seq_prob = short_seq_prob + self.single_span_prob = single_span_prob + self.block_position_encoding = block_position_encoding + self.encoder_decoder = encoder_decoder + self.shuffle_blocks = shuffle_blocks + self.sentinel_token = sentinel_token + self.generation_mask = 'gMASK' if task_mask else 'MASK' + self.generation_mask = self.tokenizer.get_command( + self.generation_mask).Id + self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK' + self.gap_sentence_mask = self.tokenizer.get_command( + self.gap_sentence_mask).Id + self.random_position = random_position + self.masked_lm = masked_lm + print_rank_0( + f'BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}' # noqa + ) + print_rank_0( + f'generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}' # noqa + ) + print_rank_0( + f'block length distribution {self.block_length_distribution}') + print_rank_0( + f'block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}' + ) + + def contains_sentence_end(self, tok): + tok = self.tokenizer.IdToToken(tok) + if '.' in tok: + return True + if '?' in tok: + return True + if '!' in tok: + return True + if ';' in tok: + return True + if ':' in tok: + return True + if '。' in tok: + return True + if '?' in tok: + return True + if '!' in tok: + return True + if ';' in tok: + return True + if '…' in tok: + return True + if '\n' in tok: + return True + return False + + @staticmethod + def sample_spans(span_lengths, total_length, rng, offset=0): + blank_length = total_length - sum(span_lengths) + m = blank_length - len(span_lengths) + 1 + places = [rng.randrange(m + 1) for _ in range(len(span_lengths))] + places.sort() + spans = [] + for place, span_length in zip(places, span_lengths): + start = offset + place + end = offset + place + span_length + spans.append((start, end)) + offset += span_length + 1 + return spans + + def sample_span_in_document(self, tokens, masked_lengths, rng): + rng.shuffle(masked_lengths) + mask_spans = [] + mask_index = 0 + indices = [-1] + np.where(tokens == self.eod_token)[0].tolist() + last_index = len(tokens) + documents = [] + for index in reversed(indices): + start_index = index + if start_index + 1 < len(tokens) and tokens[ + start_index + 1] == self.tokenizer.get_command('ENC').Id: + start_index += 1 + length = last_index - start_index - 1 + if last_index == len(tokens) and length > 0: + length -= 1 + documents.append((start_index + 1, length)) + last_index = index + documents.sort(key=lambda x: x[1]) + for i, (offset, length) in enumerate(documents): + if i == len(documents) - 1: + current_masked_length, current_count = 0, 0 + while mask_index + current_count < len( + masked_lengths + ) and masked_lengths[ + mask_index + # noqa + current_count] + current_masked_length + current_count <= length: + current_masked_length += masked_lengths[mask_index + + current_count] + current_count += 1 + if current_count > 0: + spans = self.sample_spans( + masked_lengths[mask_index:mask_index + current_count], + length, + rng, + offset=offset) + mask_spans += spans + if mask_index + current_count < len(masked_lengths) - 1: + print(length, masked_lengths[mask_index:], + masked_lengths[:mask_index], indices) + else: + current_masked_total = int(length * self.bert_ratio) + current_masked_length, current_count = 0, 0 + while mask_index + current_count < len( + masked_lengths + ) and masked_lengths[ + mask_index + # noqa + current_count] + current_masked_length <= current_masked_total: + current_masked_length += masked_lengths[mask_index + + current_count] + current_count += 1 + if current_count > 0: + spans = self.sample_spans( + masked_lengths[mask_index:mask_index + current_count], + length, + rng, + offset=offset) + mask_spans += spans + mask_index += current_count + return mask_spans + + def make_masked_data(self, + tokens, + loss_masks, + attention_mask, + block_spans, + rng, + task='bert'): + position_ids = np.arange(len(tokens), dtype=np.long) + targets = copy.deepcopy(tokens) + mask_id = self.tokenizer.get_command('MASK').Id + mlm_masks = np.zeros(len(tokens), dtype=np.long) + for start, end in block_spans: + for idx in range(start, end): + tokens[idx] = mask_id + mlm_masks[start:end] = 1 + loss_masks = loss_masks * mlm_masks + return tokens, targets, loss_masks, position_ids + + def make_block_data(self, + tokens, + loss_masks, + attention_mask, + block_spans, + rng, + task='bert'): + text_length = len(tokens) + position_ids = np.ones(len(tokens), dtype=np.long) + for start, end in block_spans: + position_ids[start + 1:end] = 0 + position_ids = np.cumsum(position_ids) - 1 + if self.random_position and position_ids[-1] < self.max_seq_length - 1: + position_bias = self.max_seq_length - position_ids[-1] + position_bias = rng.randrange(0, position_bias) + position_ids = position_ids + position_bias + if self.encoder_decoder or not self.shuffle_blocks: + block_spans.sort(key=lambda x: x[0]) + else: + rng.shuffle(block_spans) + if self.sentinel_token: + block_spans = [(start, end, idx) + for idx, (start, end) in enumerate(block_spans)] + else: + block_spans = [(start, end, 0) for start, end in block_spans] + target_tokens, target_position_ids, target_block_position_ids, targets = [], [], [], [] + for start, end, idx in block_spans: + sop_token = 'sop' if idx == 0 else f'sop{idx}' + target_tokens.append([self.tokenizer.get_command(sop_token).Id]) + span_tokens = copy.deepcopy(tokens[start:end]) + if self.block_mask_prob > 0.0 and task == 'bert': + for sub_idx in range(len(span_tokens)): + if random.random() < self.block_mask_prob: + span_tokens[sub_idx] = self.tokenizer.get_command( + 'dBLOCK').Id + target_tokens.append(span_tokens) + targets.append(tokens[start:end]) + targets.append([self.tokenizer.get_command('eop').Id]) + if not self.sentinel_token: + target_position_id = position_ids[start:end] + target_position_ids.append(target_position_id) + target_position_ids.append([target_position_id[0]]) + else: + target_position_ids.append([self.max_seq_length] * # noqa + (end - start + 1)) + if self.block_position_encoding: + target_block_position_ids.append( + np.arange(1, end - start + 2, dtype=np.long)) + else: + target_block_position_ids.append([1] * (end - start + 1)) + block_spans.sort(key=lambda x: x[0]) + source_tokens, source_position_ids, local_spans = [], [], [] + last, current_length = 0, 0 + for start, end, idx in block_spans: + if task == 'generation': + mask_id = self.generation_mask + elif task == 'gap_sentence': + mask_id = self.gap_sentence_mask + else: + mask_token = 'MASK' if idx == 0 else f'MASK{idx}' + mask_id = self.tokenizer.get_command(mask_token).Id + local_spans.append((current_length, current_length + start - last)) + source_tokens.append(tokens[last:start]) + source_tokens.append([mask_id]) + source_position_ids.append(position_ids[last:start]) + source_position_ids.append([position_ids[start]]) + current_length += start - last + 1 + last = end + if last < len(tokens): + local_spans.append( + (current_length, current_length + len(tokens) - last)) + source_tokens.append(tokens[last:]) + source_position_ids.append(position_ids[last:]) + source_length = sum(map(len, source_tokens)) + if attention_mask is not None: + assert source_length == attention_mask + if target_tokens and self.eod_token in np.concatenate( + target_tokens).tolist(): + print('Found EOS in target', self.tokenizer.DecodeIds(tokens)) + raise RuntimeError + if self.encoder_decoder: + target_tokens = target_tokens + [ + self.tokenizer.get_command('eop').Id + ] + loss_masks = np.ones(len(target_tokens), dtype=np.long) + return source_tokens, target_tokens, loss_masks + else: + tokens = np.concatenate(source_tokens + target_tokens) + if task == 'bert' and self.context_mask_ratio > 0: + mask_candidates = set() + for start, end in local_spans: + if start != 0: + local_end = min(end, start + self.context_mask_range) + mask_candidates.update(range(start, local_end)) + if end != 0: + local_start = max(start, end - self.context_mask_range) + mask_candidates.update(range(local_start, end)) + mask_pos = rng.sample( + mask_candidates, + int(self.context_mask_ratio * text_length)) + for pos in mask_pos: + tokens[pos] = self.tokenizer.get_command('dBLOCK').Id + targets = np.concatenate(source_tokens + targets) + loss_masks = np.ones(len(tokens), dtype=np.long) + loss_masks[:source_length] = 0 + position_ids = np.concatenate(source_position_ids + + target_position_ids) + block_position_ids = np.concatenate( + [np.zeros(source_length, dtype=np.long)] + + target_block_position_ids) + position_ids = np.stack([position_ids, block_position_ids], axis=0) + if attention_mask is not None: + return tokens, targets, loss_masks, position_ids + else: + return tokens, targets, loss_masks, position_ids, source_length + + def generate_blank_data(self, + sample, + masked_lengths, + attention_mask, + rng, + task='bert'): + rng.shuffle(masked_lengths) + tokens, loss_masks = sample['text'], sample['loss_mask'] + assert tokens[0] == self.tokenizer.get_command('ENC').Id + block_spans = self.sample_span_in_document(tokens, masked_lengths, rng) + if len(block_spans) < len(masked_lengths): + return None + if self.masked_lm: + data = self.make_masked_data(tokens, loss_masks, attention_mask, + block_spans, rng) + else: + data = self.make_block_data( + tokens, + loss_masks, + attention_mask, + block_spans, + rng, + task=task) + return data + + def split_samples(self, samples, rng): + target_length = rng.randrange(32, self.max_seq_length - 1) + num_splits = (self.max_seq_length - 1) // target_length + new_samples = [] + cls_id = self.tokenizer.get_command('ENC').Id + eos_id = self.tokenizer.get_command('eos').Id + for sample in samples: + tokens, loss_masks = sample['text'][1:], sample['loss_mask'][1:] + for _ in range(num_splits): + if target_length >= len(tokens): + new_tokens, new_loss_masks = tokens, loss_masks + else: + random_start = rng.randrange(0, + len(tokens) - target_length) + while random_start > 0 and ( + tokens[random_start] == eos_id or # noqa + not (self.contains_sentence_end( # noqa + tokens[random_start - 1]) or # noqa + tokens[random_start - 1] == eos_id)): # noqa + random_start -= 1 + random_end = random_start + target_length + while random_end > random_start and not ( + self.contains_sentence_end(tokens[random_end - 1]) + or tokens[random_end - 1] == eos_id): + random_end -= 1 + if random_end - random_start < target_length // 2: + random_end = random_start + target_length + new_tokens, new_loss_masks = tokens[ + random_start:random_end], loss_masks[ + random_start:random_end] + new_tokens = np.concatenate(([cls_id], new_tokens)) + new_loss_masks = np.concatenate(([0], new_loss_masks)) + new_samples.append({ + 'text': new_tokens, + 'loss_mask': new_loss_masks + }) + return new_samples + + def construct_blocks(self, samples): + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + worker_id, num_workers = worker_info.id, worker_info.num_workers + else: + worker_id, num_workers = 0, 1 + rng = random.Random((self.count * num_workers + worker_id) + * self.world_size + self.rank) + self.count += 1 + token_batch, target_batch, loss_mask_batch, position_id_batch = [], [], [], [] + source_batch, target_batch = [], [] + if rng.random() < self.short_seq_prob: + samples = self.split_samples(samples, rng) + rand = rng.random() + single_span = rand < self.single_span_prob + rand = 0.0 if single_span else rng.random() + attention_mask = [] + if rand < self.bert_prob: + mode = 'bert' + for sample in samples: + if single_span: + masked_lengths = [ + rng.choices( + range(1, + len(self.block_length_distribution) + 1), + weights=self.block_length_distribution)[0] + ] + masked_count = masked_lengths[0] + else: + masked_lengths, masked_count = [], 0 + while masked_count < int( + self.bert_ratio * len(sample['text'])): + block_length = rng.choices( + range(1, + len(self.block_length_distribution) + 1), + weights=self.block_length_distribution)[0] + masked_lengths.append(block_length) + masked_count += block_length + if self.masked_lm: + sep = len(sample['text']) + else: + sep = len( + sample['text']) - masked_count + len(masked_lengths) + data = self.generate_blank_data( + sample, masked_lengths, sep, rng, task='bert') + if data is not None: + if self.encoder_decoder: + source_tokens, target_tokens, loss_masks = data + source_batch.append(source_tokens) + target_batch.append(target_tokens) + loss_mask_batch.append(loss_masks) + else: + tokens, targets, loss_masks, position_ids = data + token_batch.append(tokens) + target_batch.append(targets) + loss_mask_batch.append(loss_masks) + position_id_batch.append(position_ids) + attention_mask.append(sep) + + elif rand < self.bert_prob + self.gap_sentence_prob: + mode = 'sentence' + for sample in samples: + tokens, loss_masks = sample['text'], sample['loss_mask'] + sentence_spans = [] + last_index = 1 if tokens[0] == self.tokenizer.get_command( + 'ENC').Id else 0 + for i in range(len(tokens)): + if self.contains_sentence_end(tokens[i]): + if last_index < i + 1: + sentence_spans.append((last_index, i + 1)) + last_index = i + 1 + elif tokens[i] == self.tokenizer.get_command('eos').Id: + last_index = i + 1 + if last_index < len(tokens): + sentence_spans.append((last_index, len(tokens))) + if not sentence_spans and torch.distributed.get_rank() == 0: + try: + print(self.tokenizer.DecodeIds(tokens[1:])) + except IndexError: + print(tokens[1:]) + rng.shuffle(sentence_spans) + block_spans, block_length = [], 0 + for start, end in sentence_spans: + block_spans.append((start, end)) + block_length += end - start + if block_length >= int( + self.gap_sentence_ratio * len(tokens)): + break + data = self.make_block_data( + tokens, + loss_masks, + None, + block_spans, + rng, + task='gap_sentence') + tokens, targets, loss_masks, position_ids, sep = data + token_batch.append(tokens) + target_batch.append(targets) + loss_mask_batch.append(loss_masks) + position_id_batch.append(position_ids) + attention_mask.append(sep) + else: + # start_indices = [index_in_list(sample['loss_mask'], 1) for sample in samples] + # end_indices = [rindex(sample['loss_mask'], 1) for sample in samples] + # start_index, end_index = max(start_indices), min(end_indices) - self.min_generation_length + # if end_index < start_index + 1: + # end_index = start_index + 1 + # division = rng.randrange(start_index, end_index) + mode = 'gpt' + max_generation_length = rng.randint( + int(self.gpt_min_ratio + * min(map(lambda x: len(x['text']), samples))), + max(map(lambda x: len(x['text']), samples)) - 2) + for sample in samples: + generation_length = min(max_generation_length, + len(sample['text']) - 2) + attention_mask.append( + len(sample['text']) - generation_length + 1) + multiple_doc = index_in_list( + sample['text'], + self.tokenizer.get_command('eos').Id) not in [ + -1, len(sample['text']) - 1 + ] # noqa + if multiple_doc or rng.random() < self.infill_prob: + division = len(sample['text']) - generation_length + tokens, loss_masks = sample['text'], sample['loss_mask'] + source_tokens, target_tokens = tokens[:division], tokens[ + division:] + target_masks = loss_masks[division:] + tokens = np.concatenate((source_tokens, [ + self.generation_mask, + self.tokenizer.get_command('sop').Id + ], target_tokens[:-1])) + targets = np.concatenate( + (source_tokens, [self.generation_mask], target_tokens)) + loss_masks = np.concatenate( + (np.zeros(len(source_tokens) + 1, + dtype=np.long), target_masks)) + token_batch.append(tokens) + target_batch.append(targets) + loss_mask_batch.append(loss_masks) + position_ids = np.arange( + len(source_tokens) + len(target_tokens) + 1, + dtype=np.long) + position_ids[len(source_tokens) + 1:] = len(source_tokens) + if self.block_position_encoding: + block_position_ids = np.concatenate( + (np.zeros(len(source_tokens), dtype=np.long), + np.arange(len(target_tokens) + 1, dtype=np.long))) + else: + block_position_ids = np.concatenate( + (np.zeros(len(source_tokens) + 1, dtype=np.long), + np.ones(len(target_tokens) + 1, dtype=np.long))) + position_id_batch.append( + np.stack([position_ids, block_position_ids], axis=0)) + else: + tokens, targets, loss_masks, position_ids = self.generate_blank_data( + sample, [generation_length], + attention_mask[-1], + rng, + task='generation') + token_batch.append(tokens) + target_batch.append(targets) + loss_mask_batch.append(loss_masks) + position_id_batch.append(position_ids) + if tokens is None: + print(sample, generation_length, multiple_doc) + if self.encoder_decoder: + return { + 'text': torch.tensor(source_batch, dtype=torch.long), + 'target': torch.tensor(target_batch, dtype=torch.long), + 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long) + } + else: + token_batch, target_batch, loss_mask_batch, position_id_batch = self.pad_batch( + token_batch, target_batch, loss_mask_batch, position_id_batch) + return { + 'text': torch.tensor(token_batch, dtype=torch.long), + 'target': torch.tensor(target_batch, dtype=torch.long), + 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long), + 'position_id': + torch.tensor(position_id_batch, dtype=torch.long), + 'attention_mask': + torch.tensor(attention_mask, dtype=torch.long), + 'mode': mode + } + + @staticmethod + def pad_batch(token_batch, target_batch, loss_mask_batch, + position_id_batch): + seq_lengths = list(map(len, token_batch)) + if seq_lengths.count(seq_lengths[0]) != len(seq_lengths): + max_length = max(seq_lengths) + token_batch = [ + np.concatenate( + (tokens, np.zeros(max_length - len(tokens), + dtype=np.long))) + for tokens in token_batch + ] + target_batch = [ + np.concatenate( + (targets, + np.zeros(max_length - len(targets), dtype=np.long))) + for targets in target_batch + ] + loss_mask_batch = [ + np.concatenate( + (loss_masks, + np.zeros(max_length - len(loss_masks), dtype=np.long))) + for loss_masks in loss_mask_batch + ] + position_id_batch = [ + np.concatenate((position_ids, + np.zeros( + (2, max_length - position_ids.shape[1]), + dtype=np.long)), + axis=1) for position_ids in position_id_batch + ] + return token_batch, target_batch, loss_mask_batch, position_id_batch diff --git a/modelscope/models/nlp/mglm/configure_data.py b/modelscope/models/nlp/mglm/configure_data.py new file mode 100644 index 00000000..6921de08 --- /dev/null +++ b/modelscope/models/nlp/mglm/configure_data.py @@ -0,0 +1,513 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""parses arguments and preps data loader""" + +import copy +import os +import random +from bisect import bisect_right +from itertools import accumulate + +import numpy as np +import torch +import torch.utils.data + +from . import data_utils, mpu +from .blocklm_utils import ConstructBlockStrategy +from .data_utils.tokenization import make_tokenizer +from .utils import print_rank_0 + + +class MultiTaskDataset(torch.utils.data.Dataset): + + def __init__(self, + tasks, + datasets, + reweight=True, + temperature=0.8, + max_limit=200000): + super(MultiTaskDataset, self).__init__() + self.tasks = tasks + self.datasets = datasets + self.reweight = reweight + self.temperature = temperature + self.lens = [len(dataset) for dataset in datasets] + self.weights = np.array( + [min(length, max_limit)**temperature for length in self.lens]) + self.total_len = sum(self.lens) + self.cumulative_lens = list(accumulate(self.lens)) + if self.reweight: + print_rank_0(list(zip(self.tasks, self.lens, self.weights))) + else: + print_rank_0(list(zip(self.tasks, self.lens))) + self.weights /= self.weights.sum() + + def __len__(self): + return self.total_len * 1000 + + @staticmethod + def pet_wrapper(data): + text = data['text'] + loss_mask = data['logit_mask'] + target = data['target'] + attention_mask = data['mask'] + position_id = data['position'] + label = data['label'] + if len(text.shape) == 2: + text = text[label] + loss_mask = loss_mask[label] + target = target[label] + attention_mask = attention_mask[label] + position_id = position_id[label] + else: + target = target[label] + if not target.shape: + target = target.repeat(len(text)) + return { + 'text': text, + 'target': target, + 'loss_mask': loss_mask, + 'position_id': position_id, + 'attention_mask': attention_mask + } + + def __getitem__(self, idx): + if self.reweight: + rng = random.Random(idx) + rng = np.random.RandomState( + seed=[rng.randint(0, 2**32 - 1) for _ in range(16)]) + dataset_idx = rng.choice( + np.arange(len(self.datasets)), p=self.weights) + dataset = self.datasets[dataset_idx] + sample_idx = rng.choice(np.arange(len(dataset))) + item = self.datasets[dataset_idx][sample_idx] + else: + dataset_idx = bisect_right(self.cumulative_lens, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_lens[dataset_idx - 1] + item = self.datasets[dataset_idx][sample_idx] + item = self.pet_wrapper(item) + return item + + +class DataConfig: + + def __init__(self, defaults=None): + super(DataConfig, self).__init__() + if defaults is None: + defaults = {} + self.defaults = defaults + + def apply(self, args, tokenizer): + if torch.distributed.get_rank() == 0: + print('configuring data') + self.apply_defaults(args) + return make_loaders(args, tokenizer) + + def set_defaults(self, **kwargs): + for k, v in kwargs.items(): + self.defaults[k] = v + + def apply_defaults(self, args): + for k, v in self.defaults.items(): + k = k.replace('-', '_') + if not hasattr(args, k): + setattr(args, k, v) + + +def prepare_tokenizer(args): + add_sentinel_token = 0 + if args.sentinel_token: + add_sentinel_token = args.max_position_embeddings + tokenizer = make_tokenizer( + args.tokenizer_type, + None, + args.tokenizer_path, + args.vocab_size, + args.tokenizer_model_type, + add_block_symbols=args.block_lm, + cache_dir=args.cache_dir, + add_sentinel_token=add_sentinel_token, + add_task_mask=args.task_mask, + add_decoder_mask=args.block_mask_prob > 0.0 + or args.context_mask_ratio > 0.0) + if mpu.get_model_parallel_rank() == 0: + num_tokens = tokenizer.num_tokens + eod_token = tokenizer.get_command('eos').Id + assert eod_token == tokenizer.get_command('pad').Id + before = num_tokens + after = before + multiple = args.make_vocab_size_divisible_by + while (after % multiple) != 0: + after += 1 + print_rank_0('> padded vocab (size: {}) with {} dummy ' + 'tokens (new size: {})'.format(before, after - before, + after)) + print_rank_0('> found end-of-document token: {}'.format(eod_token)) + token_counts = torch.cuda.LongTensor([after, eod_token]) + else: + token_counts = torch.cuda.LongTensor([0, 0]) + # Broadcast num tokens. + torch.distributed.broadcast( + token_counts, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + num_tokens = token_counts[0].item() + eod_token = token_counts[1].item() + args.vocab_size, args.eod_token = num_tokens, eod_token + return tokenizer + + +def make_data_loader(dataset, + tokenizer, + batch_size, + num_iters, + args, + shuffle=False, + block_collate=False): + world_size = torch.distributed.get_world_size( + group=mpu.get_data_parallel_group()) + rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) + if args.loader_scatter is not None: + rank = rank // args.loader_scatter + world_size = world_size // args.loader_scatter + batch_size = batch_size // args.loader_scatter + distributed = world_size > 1 + if args.transformer_xl: + batch_sampler = data_utils.samplers.DistributedSequentialSampler( + len(dataset), num_iters, batch_size, rank, world_size) + else: + if shuffle: + sampler = data_utils.samplers.RandomSampler( + dataset, + replacement=True, + num_samples=batch_size * args.train_iters + * args.gradient_accumulation_steps) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + drop_last = distributed + # the GPUs in the same model parallel group receive the same data + if distributed: + batch_sampler = data_utils.samplers.DistributedBatchSampler( + sampler, + batch_size, + drop_last, + rank, + world_size, + gradient_accumulation_steps=args.gradient_accumulation_steps) + else: + batch_sampler = torch.utils.data.BatchSampler( + sampler, batch_size, drop_last) + collate_fn = None + if block_collate: + collate_fn = ConstructBlockStrategy( + args, + tokenizer, + args.seq_length, + bert_prob=args.bert_prob, + gap_sentence_prob=args.gap_sentence_prob, + gap_sentence_ratio=args.gap_sentence_ratio, + gpt_infill_prob=args.gpt_infill_prob, + average_block_length=args.avg_block_length, + gpt_min_ratio=args.gpt_min_ratio, + block_mask_prob=args.block_mask_prob, + context_mask_ratio=args.context_mask_ratio, + short_seq_prob=args.short_seq_prob, + single_span_prob=args.single_span_prob, + shuffle_blocks=not args.no_shuffle_block, + block_position_encoding=not args.no_block_position, + sentinel_token=args.sentinel_token, + encoder_decoder=args.encoder_decoder, + task_mask=args.task_mask, + random_position=args.random_position, + masked_lm=args.masked_lm).construct_blocks + data_loader = torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=args.num_workers, + pin_memory=True, + collate_fn=collate_fn) + + return data_loader + + +def make_tfrecord_loaders(args): + """Load train/val/test dataset from shuffled TFRecords""" + + import data_utils.tf_dl + data_set_args = { + 'batch_size': args.batch_size, + 'max_seq_len': args.seq_length, + 'max_preds_per_seq': args.max_preds_per_seq, + 'train': True, + 'num_workers': max(args.num_workers, 1), + 'seed': args.seed + args.rank + 1, + 'threaded_dl': args.num_workers > 0 + } + train = data_utils.tf_dl.TFRecordDataLoader(args.train_data, + **data_set_args) + data_set_args['train'] = False + if args.eval_seq_length is not None: + data_set_args['max_seq_len'] = args.eval_seq_length + if args.eval_max_preds_per_seq is not None: + data_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq + valid = None + if args.valid_data is not None: + valid = data_utils.tf_dl.TFRecordDataLoader(args.valid_data, + **data_set_args) + test = None + if args.test_data is not None: + test = data_utils.tf_dl.TFRecordDataLoader(args.test_data, + **data_set_args) + tokenizer = data_utils.make_tokenizer( + args.tokenizer_type, + train, + args.tokenizer_path, + args.vocab_size, + args.tokenizer_model_type, + cache_dir=args.cache_dir) + + return (train, valid, test), tokenizer + + +def make_loaders(args, tokenizer): + """makes training/val/test""" + + if args.use_tfrecords: + return make_tfrecord_loaders(args) + world_size = torch.distributed.get_world_size( + group=mpu.get_data_parallel_group()) + if args.loader_scatter is not None: + assert world_size % args.loader_scatter == 0 + batch_size = args.batch_size * world_size + eval_batch_size = batch_size + if args.eval_batch_size is not None: + eval_batch_size = args.eval_batch_size * world_size + seq_length = args.seq_length + if seq_length < 0: + seq_length = seq_length * world_size + eval_seq_length = args.eval_seq_length + if eval_seq_length is not None and eval_seq_length < 0: + eval_seq_length = eval_seq_length * world_size + split = get_split(args) + data_set_args = { + 'path': args.train_data, + 'seq_length': seq_length, + 'mem_length': args.mem_length, + 'delim': args.delim, + 'text_key': args.text_key, + 'label_key': 'label', + 'ds_type': args.data_set_type, + 'split': split, + 'loose': args.loose_json, + 'max_preds_per_seq': args.max_preds_per_seq, + 'presplit_sentences': args.presplit_sentences, + 'sample_one_document': args.sample_one_document, + 'filter_english': args.filter_english, + 'pre_tokenize': not args.no_pre_tokenize, + 'tokenizer': tokenizer, + 'save_splits': args.save_splits, + 'load_splits': args.load_splits, + 'save_test_data': args.save_test_data, + 'no_lazy_loader': args.no_lazy_loader, + 'loader_scatter': args.loader_scatter, + 'data_parallel_rank': mpu.get_data_parallel_rank(), + 'non_sentence_start': args.non_sentence_start, + 'half_lazy_loader': args.half_lazy_loader + } + + eval_set_args = copy.copy(data_set_args) + eval_set_args['split'] = [1.] + # if optional eval args were set then replace their + # equivalent values in the arg dict + if eval_seq_length: + eval_set_args['seq_length'] = eval_seq_length + if args.eval_max_preds_per_seq: + eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq + if args.eval_text_key is not None: + eval_set_args['text_key'] = args.eval_text_key + + # make datasets splits and tokenizer + train, valid, test = None, None, None + + if args.train_data is not None: + train = data_utils.make_dataset(**data_set_args) + if data_utils.should_split(split): + train, valid, test = train + eval_set_args['tokenizer'] = tokenizer + + # make training and val dataset if necessary + if valid is None and args.valid_data is not None: + eval_set_args['path'] = args.valid_data + valid = data_utils.make_dataset(**eval_set_args) + eval_set_args['tokenizer'] = tokenizer + if test is None and args.test_data is not None: + eval_set_args['path'] = args.test_data + test = data_utils.make_dataset(**eval_set_args) + + # wrap datasets with data loader + use_block = args.block_lm or args.encoder_decoder + + if train is not None and args.batch_size > 0: + train = make_data_loader( + train, + tokenizer, + batch_size, + args.train_iters, + args, + shuffle=args.shuffle, + block_collate=use_block) + args.do_train = True + else: + args.do_train = False + eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size + if valid is not None: + valid = make_data_loader( + valid, + tokenizer, + eval_batch_size, + args.train_iters, + args, + shuffle=args.shuffle, + block_collate=use_block) + args.do_valid = True + else: + args.do_valid = False + if test is not None: + test = make_data_loader( + test, + tokenizer, + eval_batch_size, + len(test) // eval_batch_size + 1, + args, + shuffle=args.shuffle, + block_collate=use_block) + args.do_test = True + else: + args.do_test = False + + return train, valid, test + + +def build_multi_task_dataset(args, tokenizer): + task_dirs = { + 'mnli': 'MNLI', + 'cola': 'CoLA', + 'mrpc': 'MRPC', + 'qnli': 'QNLI', + 'qqp': 'QQP', + 'sst2': 'SST-2', + 'agnews': 'Agnews', + 'yelp-polarity': 'yelp_review_polarity_csv', + 'yelp-full': 'yelp_review_full_csv', + 'yahoo': 'Yahoo', + 'squad': 'SQuAD', + 'race': 'RACE' + } + train, valid = None, None + if mpu.get_model_parallel_rank() == 0: + multi_seq_length = args.seq_length + if args.multi_seq_length is not None: + multi_seq_length = args.multi_seq_length + train_datasets, valid_datasets = [], [] + for task in args.multi_task_data: + task = task.lower() + data_dir = os.path.join(args.data_dir, task_dirs[task]) + train_datasets.append( + SuperGlueDataset( + args, + task, + data_dir, + multi_seq_length, + 'train', + tokenizer, + pattern_ensemble=True)) + valid_datasets.append( + SuperGlueDataset( + args, + task, + data_dir, + multi_seq_length, + 'dev', + tokenizer, + pattern_ensemble=True)) + train = MultiTaskDataset(args.multi_task_data, train_datasets) + valid = MultiTaskDataset(args.multi_task_data, valid_datasets) + world_size = torch.distributed.get_world_size( + group=mpu.get_data_parallel_group()) + multi_batch_size = args.batch_size * world_size + if args.multi_batch_size is not None: + multi_batch_size = args.multi_batch_size * world_size + train = make_data_loader( + train, + tokenizer, + multi_batch_size, + args.train_iters, + args, + shuffle=True) + valid = make_data_loader( + valid, + tokenizer, + multi_batch_size, + args.train_iters, + args, + shuffle=True) + return train, valid + + +def get_split(args): + """ + Get dataset splits from comma separated string list + """ + splits = [] + if args.split.find(',') != -1: + splits = [float(s) for s in args.split.split(',')] + elif args.split.find('/') != -1: + splits = [float(s) for s in args.split.split('/')] + else: + splits = [float(args.split)] + split_total = sum(splits) + if split_total < 1.: + splits.append(1 - split_total) + while len(splits) < 3: + splits.append(0.) + splits = splits[:3] + if args.valid_data is not None: + splits[1] = 0. + if args.test_data is not None: + splits[2] = 0. + final_sum = sum(splits) + return [s / final_sum for s in splits] + + +def configure_data(): + """add cmdline flags for configuring datasets""" + # These are options that are used by data_utils, but are either + # deprecated or not meant to be exposed to the command line user. + # These options are intneded to be set in code by specific scripts. + defaults = { + 'world_size': 1, + 'rank': -1, + 'persist_state': 0, + 'lazy': False, + 'transpose': False, + 'data_set_type': 'supervised', + 'seq_length': 256, + 'eval_seq_length': 256, + 'samples_per_shard': 100 + } + + return DataConfig(defaults=defaults) diff --git a/modelscope/models/nlp/mglm/data_utils/__init__.py b/modelscope/models/nlp/mglm/data_utils/__init__.py new file mode 100644 index 00000000..fa243cb4 --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/__init__.py @@ -0,0 +1,341 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""utils for creating datasets""" +import math +import os +import random +import time + +import torch + +from . import corpora +from .datasets import (BertSentencepairDataset, BlockDataset, ConcatDataset, + GPT2Dataset, ShuffleDataset, SplitDataset, XLDataset, + split_ds) +from .lazy_loader import (LazyLoader, LazyWriter, exists_lazy, exists_scatter, + get_scatter_path) +from .samplers import DistributedBatchSampler +from .tokenization import (BertWordPieceTokenizer, CharacterLevelTokenizer, + CommandToken, GPT2BPETokenizer, Tokenization, + Tokenizer, make_tokenizer) + +TRAIN_DATA = 0 +VAL_DATA = 1 +TEST_DATA = 2 + + +def should_split(split): + """ + given split proportions checks if should split + Examples: + >>> should_split([10,0,0]) + False + >>> should_split([1,.1,.2]) + True + """ + return max(split) / sum(split) != 1. + + +def get_ext(path): + """gets path extension""" + return os.path.splitext(path)[1] + + +def get_dataset(name, + tokenizer, + pre_tokenize, + data_parallel_rank, + loader_scatter=None, + no_lazy_loader=False, + half_lazy_loader=False): + """gets dataset object based on keyword args and file at `path`""" + global_rank = torch.distributed.get_rank() + if not supported_corpus(name): + raise NotImplementedError('dataset %s is not supported' % name) + dataset = corpora.NAMED_CORPORA[name] + path = dataset.PATH + if issubclass(dataset, corpora.PromptReader): + if not (exists_lazy(path, data_type='prompt') + and exists_lazy(path, data_type='text')) and not ( + loader_scatter is not None and exists_scatter( + path, data_type='prompt', scatter_num=loader_scatter) + and exists_scatter( + path, data_type='text', scatter_num=loader_scatter)): + # create cached version of dataset for lazy loading if it doesn't exist + if global_rank == 0: + print(f'Creating lazy loader for dataset {name}') + prompt_writer = LazyWriter( + path, data_type='prompt', is_array=pre_tokenize) + text_writer = LazyWriter( + path, data_type='text', is_array=pre_tokenize) + writers = {'prompt': prompt_writer, 'text': text_writer} + reader = dataset( + writers=writers, + tokenizer=tokenizer, + tokenize=pre_tokenize) + reader.process() + prompt_writer.close() + text_writer.close() + else: + while not os.path.exists( + LazyWriter.get_len_path(path, data_type='prompt')): + time.sleep(1) + map_fn = (lambda x: x.tolist()) if pre_tokenize else None + if loader_scatter is not None: + if not (exists_scatter( + path, data_type='prompt', scatter_num=loader_scatter) + and exists_scatter( + path, data_type='text', scatter_num=loader_scatter)): + if global_rank == 0: + print(f'Creating scatter loader for dataset {name}') + prompts = LazyLoader( + path, + data_type='prompt', + map_fn=map_fn, + mem_map=True, + is_array=pre_tokenize) + texts = LazyLoader( + path, + data_type='text', + map_fn=map_fn, + mem_map=True, + is_array=pre_tokenize) + indices = list(range(len(texts))) + random.shuffle(indices) + segment_length = (len(indices) - 1) // loader_scatter + 1 + for i in range(loader_scatter): + scatter_path = get_scatter_path(path, scatter_rank=i) + prompt_writer = LazyWriter( + scatter_path, + data_type='prompt', + is_array=pre_tokenize) + text_writer = LazyWriter( + scatter_path, + data_type='text', + is_array=pre_tokenize) + for idx in indices[i * segment_length:(i + 1) + * segment_length]: + prompt_writer.write(prompts[idx]) + text_writer.write(texts[idx]) + prompt_writer.close() + text_writer.close() + else: + while not (exists_scatter( + path, data_type='prompt', + scatter_num=loader_scatter) and exists_scatter( + path, + data_type='text', + scatter_num=loader_scatter)): + time.sleep(1) + scatter_path = get_scatter_path( + path, scatter_rank=data_parallel_rank % loader_scatter) + print(f'Rank {global_rank} is using scatter from {scatter_path}') + prompts = LazyLoader( + scatter_path, + data_type='prompt', + map_fn=map_fn, + mem_map=True, + is_array=pre_tokenize, + load_memory=no_lazy_loader, + half_load=half_lazy_loader) + texts = LazyLoader( + scatter_path, + data_type='text', + map_fn=map_fn, + mem_map=True, + is_array=pre_tokenize, + load_memory=no_lazy_loader, + half_load=half_lazy_loader) + else: + prompts = LazyLoader( + path, + data_type='prompt', + map_fn=map_fn, + mem_map=True, + is_array=pre_tokenize, + load_memory=no_lazy_loader, + half_load=half_lazy_loader) + texts = LazyLoader( + path, + data_type='text', + map_fn=map_fn, + mem_map=True, + is_array=pre_tokenize, + load_memory=no_lazy_loader, + half_load=half_lazy_loader) + text = corpora.PromptDataset( + prompt_loader=prompts, + text_loader=texts, + tokenizer=tokenizer, + to_tokenize=not pre_tokenize) + if loader_scatter is None: + if global_rank == 0: + print(f'Create dataset {name} with {len(text)} documents') + for i in range(10): + rand_id = i if i < 5 else random.randrange(len(text)) + sample_tokens = text[rand_id]['tokens'][:1024] + print(sample_tokens) + print(tokenizer.DecodeIds(sample_tokens).encode('utf-8')) + else: + for scatter_id in range(loader_scatter): + if data_parallel_rank % loader_scatter == scatter_id and data_parallel_rank // loader_scatter == 0: + print( + f'Create dataset {name} at scatter {scatter_id} with {len(text)} documents' + ) + for i in range(10): + sample_tokens = text[i]['tokens'][:1024] + print(sample_tokens) + print(tokenizer.DecodeIds(sample_tokens)) + torch.distributed.barrier() + return text + elif issubclass(dataset, corpora.KeyReader): + if not (exists_lazy(path, data_type='text') + and exists_lazy(path, data_type='mask')): + # create cached version of dataset for lazy loading if it doesn't exist + if global_rank == 0: + text_writer = LazyWriter( + path, data_type='text', is_array=pre_tokenize) + mask_writer = LazyWriter(path, data_type='mask', is_array=True) + writers = {'mask': mask_writer, 'text': text_writer} + dataset( + writers=writers, + tokenizer=tokenizer, + tokenize=pre_tokenize) + mask_writer.close() + text_writer.close() + else: + while not os.path.exists( + LazyWriter.get_len_path(path, data_type='mask')): + time.sleep(1) + map_fn = (lambda x: x.tolist()) if pre_tokenize else None + masks = LazyLoader( + path, data_type='mask', map_fn=map_fn, mem_map=True, is_array=True) + texts = LazyLoader( + path, + data_type='text', + map_fn=map_fn, + mem_map=True, + is_array=pre_tokenize) + text = corpora.KeyDataset( + mask_loader=masks, + text_loader=texts, + tokenizer=tokenizer, + to_tokenize=not pre_tokenize) + return text + + +def supported_corpus(corpus_name): + """checks if corpus name is defined in `corpora.py`""" + return corpus_name in corpora.NAMED_CORPORA + + +def make_dataset(path, + seq_length, + mem_length, + shuffle=True, + split=None, + tokenizer=None, + sample_one_document=False, + pre_tokenize=False, + ds_type='', + save_splits=None, + load_splits=None, + save_test_data=None, + no_lazy_loader=False, + loader_scatter=None, + data_parallel_rank=None, + filter_english=False, + non_sentence_start=0.0, + half_lazy_loader=False, + **kwargs): + """function to create datasets+tokenizers for common options""" + if split is None: + split = [1.] + + # get one or multiple datasets and concatenate + if isinstance(path, str): + ds = get_dataset( + path, + tokenizer=tokenizer, + pre_tokenize=pre_tokenize, + no_lazy_loader=no_lazy_loader, + loader_scatter=loader_scatter, + data_parallel_rank=data_parallel_rank, + half_lazy_loader=half_lazy_loader) + else: + ds = [ + get_dataset( + p, + tokenizer=tokenizer, + pre_tokenize=pre_tokenize, + no_lazy_loader=no_lazy_loader, + loader_scatter=loader_scatter, + data_parallel_rank=data_parallel_rank, + half_lazy_loader=half_lazy_loader) for p in path + ] + ds = ConcatDataset(ds) + + # Split dataset into train/val/test (and wrap bert dataset) + def wrap_dataset(dataset): + if ds_type.lower() == 'bert': + presplit_sentences = kwargs[ + 'presplit_sentences'] if 'presplit_sentences' in kwargs else False + dataset = BertSentencepairDataset( + dataset, + max_seq_len=seq_length, + presplit_sentences=presplit_sentences) + elif ds_type.lower() == 'gpt-xl': + assert pre_tokenize + dataset = XLDataset( + dataset, + tokenizer, + max_seq_len=seq_length, + mem_len=mem_length, + sample_across_doc=not sample_one_document) + elif ds_type.lower() == 'gpt2': + dataset = GPT2Dataset( + dataset, + tokenizer, + max_seq_len=seq_length, + sample_across_doc=not sample_one_document) + elif ds_type.lower() == 'block': + dataset = BlockDataset( + dataset, + tokenizer, + max_seq_len=seq_length, + sample_across_doc=not sample_one_document, + filter_english=filter_english, + non_sentence_start=non_sentence_start) + return dataset + + if should_split(split): + ds = split_ds( + ds, + split, + shuffle=shuffle, + save_splits=save_splits, + load_splits=load_splits) + if save_test_data is not None and torch.distributed.get_rank() == 0: + test_ds = ds[-1] + with open(save_test_data, 'w', encoding='utf-8') as output: + for data in test_ds: + text = data['tokens'] + text = tokenizer.DecodeIds(text) + output.write(text) + output.write('\n') + print(f'Write test data to {save_test_data}') + ds = [wrap_dataset(d) if d is not None else None for d in ds] + else: + ds = wrap_dataset(ds) + return ds diff --git a/modelscope/models/nlp/mglm/data_utils/corpora.py b/modelscope/models/nlp/mglm/data_utils/corpora.py new file mode 100755 index 00000000..7c6f58f8 --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/corpora.py @@ -0,0 +1,583 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""several datasets with preset arguments""" +import os +import random +from collections import defaultdict +from multiprocessing import Process, Queue +from queue import Empty + +import json +import tqdm +from torch.utils import data + +from modelscope.models.nlp.mglm.utils import print_rank_0 +from .datasets import csv_dataset, json_dataset +from .lazy_loader import LazyLoader + +NUM_PROCESSES = 100 + + +def punctuation_standardization(string: str): + punctuation_dict = { + '\u201c': "\"", + '\u201d': "\"", + '\u2019': "'", + '\u2018': "'", + '\u2013': '-' + } + for key, value in punctuation_dict.items(): + string = string.replace(key, value) + return string + + +class KeyDataset(data.Dataset): + + def __init__(self, text_loader, mask_loader, **kwargs): + self.texts = text_loader + self.masks = mask_loader + self.is_lazy = False + if isinstance(self.texts, LazyLoader) and isinstance( + self.masks, LazyLoader): + self.text_lens = self.texts.lens + self.is_lazy = True + + def get_text_len(self, idx): + return self.text_lens[idx] + + def __getitem__(self, index): + text = self.texts[index] + mask_length = self.masks[index] + mask = [] + for i, length in enumerate(mask_length): + if i % 2 == 0: + mask += [0] * length + else: + mask += [1] * length + assert len(text) == len(mask) + return {'tokens': text, 'loss_masks': mask} + + def __len__(self): + return len(self.texts) + + +class PromptDataset(data.Dataset): + + def __init__(self, + prompt_loader, + text_loader, + tokenizer=None, + to_tokenize=False, + **kwargs): + self.prompts = prompt_loader + self.texts = text_loader + self.tokenizer = tokenizer + self.to_tokenize = to_tokenize + if isinstance(self.prompts, LazyLoader) and isinstance( + self.texts, LazyLoader): + self.prompt_lens = self.prompts.lens + self.text_lens = self.texts.lens + self.is_lazy = True + + def get_text_len(self, idx): + return self.prompt_lens[idx] + self.text_lens[idx] + + def __getitem__(self, index): + prompt = self.prompts[index] + text = self.texts[index] + if self.to_tokenize: + prompt = self.tokenizer.EncodeAsIds(prompt).tokenization + text = self.tokenizer.EncodeAsIds(text).tokenization + return { + 'tokens': prompt + text, + 'loss_masks': [0] * len(prompt) + [1] * len(text) + } + + def __len__(self): + return len(self.prompts) + + +class DataReader: + PATH = None + assert_str = None + reserve_punct = False + split_row = True + TASK_QUEUE_LIMIT = 10000000 + DONE_QUEUE_LIMIT = 10000000 + + def tokenize_worker(self, input, output, info, tokenizer, tokenize): + raise NotImplementedError + + def print_info(self, info): + pass + + def __init__(self, writers, tokenizer=None, tokenize=False, **kwargs): + print(self.PATH) + print(self.assert_str) + assert os.path.exists(self.PATH), self.assert_str + print_rank_0(f'Creating dataset from {self.PATH}') + self.tokenizer = tokenizer + self.tokenize = tokenize + self.writers = writers + + def process(self): + if os.path.isdir(self.PATH): + paths = [ + os.path.join(top, name) for top, _, names in os.walk(self.PATH) + for name in names + ] + # paths = [entry.path for entry in os.scandir(self.PATH) if + # not entry.is_dir() and not entry.name.endswith("bz2")] + else: + paths = [self.PATH] + task_queue, done_queue, info_queue = Queue( + maxsize=self.TASK_QUEUE_LIMIT), Queue( + maxsize=self.DONE_QUEUE_LIMIT), Queue() + processes = [] + for i in range(NUM_PROCESSES): + process = Process( + target=self.tokenize_worker, + args=(task_queue, done_queue, info_queue, self.tokenizer, + self.tokenize)) + process.start() + processes.append(process) + + def read_input_to_queue(): + for path in paths: + print_rank_0(f'Start reading {path}') + with open(path) as file: + items = json.load(file) + for item in items: + task_queue.put(item) + # if self.split_row: + # for row in file: + # task_queue.put(row) + # else: + # items = json.load(file) + # for item in items["RECORDS"]: + # task_queue.put(item) + print_rank_0('Read input complete') + for i in range(len(processes)): + task_queue.put('STOP') + + process = Process(target=read_input_to_queue) + process.start() + count = len(processes) + progress_bar = tqdm.tqdm() + while True: + data = done_queue.get() + if data == 'COMPLETE': + count -= 1 + if count == 0: + break + else: + self.write_result(data, self.writers) + progress_bar.update() + progress_bar.close() + self.print_info(info_queue) + + @staticmethod + def write_result(data, writers): + raise NotImplementedError + + @staticmethod + def get_token_count(contents): + return sum(map(len, contents)) + + @classmethod + def process_sample(cls, text, tokenizer, tokenize): + if isinstance(text, str) and tokenize: + if not cls.reserve_punct: + text = punctuation_standardization(text) + text = tokenizer.EncodeAsIds(text).tokenization if text else [] + return text + + @staticmethod + def trim_field(content, max_length): + if len(content) > max_length: + content = content[:max_length] + content += '......' + return content + + def process_line(self, data, tokenizer, tokenize): + raise NotImplementedError + + +class PromptReader(DataReader): + is_json = True + + def tokenize_worker(self, input, output, info, tokenizer, tokenize): + for row in iter(input.get, 'STOP'): + if row: + if self.is_json: + row = row.rstrip() + row = json.loads(row) + prompts, texts = self.process_line(row, tokenizer, tokenize) + for prompt, text in zip(prompts, texts): + output.put((prompt, text)) + output.put('COMPLETE') + + @staticmethod + def write_result(data, writers): + prompt, text = data + writers['prompt'].write(prompt) + writers['text'].write(text) + + +class KeyReader(DataReader): + PATH = '/root/data/wikipedia/wiki-key.txt' + assert_str = 'make sure to set PATH for wikipedia data_utils/corpora.py' + + def process_line(self, data, tokenizer, tokenize): + keys, contents = data['key'], data['content'] + assert len(keys) == len(contents) + for i in range(1, len(keys)): + keys[i] = ' ' + keys[i] + contents = [' ' + content for content in contents] + keys = [tokenizer.EncodeAsIds(key).tokenization for key in keys] + contents = [ + tokenizer.EncodeAsIds(content).tokenization for content in contents + ] + summary = sum(keys, []) + summary_prefix = self.process_sample('Summary: ', tokenizer, tokenize) + summary_mask = [len(summary_prefix), len(summary)] + summary = summary_prefix + summary + text, text_mask = [], [] + for key, content in zip(keys, contents): + content = content + [tokenizer.get_command('eop').Id] + text += key + text += content + text_mask.append(len(key)) + text_mask.append(len(content)) + return (summary, summary_mask), (text, text_mask) + + def tokenize_worker(self, input, output, info, tokenizer, tokenize): + for row in iter(input.get, 'STOP'): + data = json.loads(row) + summary, content = self.process_line(data, tokenizer, tokenize) + output.put((summary, content)) + output.put('COMPLETE') + + @staticmethod + def write_result(data, writers): + summary, content = data + writers['text'].write(summary[0]) + writers['mask'].write(summary[1]) + writers['text'].write(content[0]) + writers['mask'].write(content[1]) + + +class zhihu(PromptReader): + PATH = '/dataset/fd5061f6/data/tokenize_data/zhihu.lazy' + reserve_punct = True + assert_str = 'make sure to set PATH for zhihu data_utils/corpora.py' + qtitle_prefix = '问题:' + qcontent_prefix = '问题描述:' + user_prefix = '回答用户:' + answer_prefix = ' 回答:' + + # qtitle_prefix = [] + # qcontent_prefix = [] + # user_prefix = [] + # answer_prefix = [] + + def process_line(self, data, tokenizer, tokenize): + prompts, texts = [], [] + ans_length = len(data.get('ans-content', '')) + ans_up = data.get('ans-up-num', '') + ans_up = int(ans_up) if ans_up else 0 + if ans_length > 100 or ans_up > 1000: + qtitle = data['q_title'] + qcontent = data['q-content'] + if qcontent is None: + qcontent = '' + qcontent = self.trim_field(qcontent, max_length=100) + user = data.get('user-signature', '') + prompt = self.qtitle_prefix + qtitle + self.qcontent_prefix + qcontent + self.user_prefix + user + self.answer_prefix # noqa + text = data['ans-content'] + prompt, text = self.process_sample(prompt, tokenizer, + tokenize), self.process_sample( + text, tokenizer, tokenize) + prompts.append(prompt) + texts.append(text) + # prompt = data["q_title"] + data["q-content"] + data["user-signature"] + # text = data["ans-content"] + # prompts.append(prompt) + # texts.append(text) + return prompts, texts + + +class zhidao(PromptReader): + PATH = '/root/data/zhidao/zhidao' + reserve_punct = True + assert_str = 'make sure to set PATH for zhidao data_utils/corpora.py' + qtitle_prefix = '问题:' + qcontent_prefix = '问题描述:' + answer_prefix = '回答:' + + def process_line(self, data, tokenizer, tokenize): + if 'title' not in data: + return [], [] + prompts, texts = [], [] + qtitle = data['title'] + qcontent = data.get('content', '') + qcontent = self.trim_field(qcontent, max_length=100) + prompt = self.qtitle_prefix + qtitle + self.qcontent_prefix + qcontent + self.answer_prefix + prompt = self.process_sample(prompt, tokenizer, tokenize) + if 'best_answer' in data: + text = data['best_answer']['content'] + if len(text) > 10: + text = self.process_sample(text, tokenizer, tokenize) + prompts.append(prompt) + texts.append(text) + for answer in data.get('other_answers', []): + text = answer['content'] + if len(text) > 100: + text = self.process_sample(text, tokenizer, tokenize) + prompts.append(prompt) + texts.append(text) + return prompts, texts + + +class baike(PromptReader): + PATH = '/dataset/fd5061f6/data/tokenize_data/baike.lazy' + reserve_punct = True + assert_str = 'make sure to set PATH for baike data_utils/corpora.py' + + def process_line(self, data, tokenizer, tokenize): + prompts, texts = [], [] + text = data.get('title', '') + data.get('abstract', '') + data.get( + 'content', '') + if text: + p, t = self.process_sample('', tokenizer, + tokenize), self.process_sample( + text, tokenizer, tokenize) + prompts.append(p) + texts.append(t) + return prompts, texts + + +class wikipedia(PromptReader): + """ + dataset for wikipedia with arguments configured for convenience + + command line usage: `--train-data wikipedia` + """ + # PATH = '/dataset/data/wiki.txt' + PATH = '/root/data/bert_data/wiki.txt' + assert_str = 'make sure to set PATH for wikipedia data_utils/corpora.py' + + def process_line(self, data, tokenizer, tokenize): + text = data['text'] + prompt, text = self.process_sample('', tokenizer, + tokenize), self.process_sample( + text, tokenizer, tokenize) + return [prompt], [text] + + +class TestDataset(PromptReader): + PATH = '/root/data/test.json' + assert_str = 'make sure to set PATH for wikipedia data_utils/corpora.py' + + def process_line(self, data, tokenizer, tokenize): + prompt, text = data['prompt'], data['text'] + prompt, text = self.process_sample(prompt, tokenizer, + tokenize), self.process_sample( + text, tokenizer, tokenize) + return [prompt], [text] + + +class OpenWebText(PromptReader): + PATH = '/dataset/fd5061f6/english_data/openwebtext2' + assert_str = 'make sure to set PATH for openwebtext data_utils/corpora.py' + + def __init__(self, *args, **kwargs): + import fasttext + super().__init__(*args, **kwargs) + self.model = fasttext.load_model( + '/dataset/fd5061f6/english_data/lid.176.bin') + print_rank_0('Load language detection model') + + def process_line(self, data, tokenizer, tokenize): + text = data['text'] + if len(text) > 100: + lang = self.model.predict(text.replace('\n', ''))[0][0] + if lang == '__label__en': + prompt, text = self.process_sample( + '', tokenizer, + tokenize), self.process_sample(text, tokenizer, tokenize) + return [prompt], [text] + return [], [] + + +class CCNews(PromptReader): + PATH = '/mnt/cc_news.json' + assert_str = 'make sure to set PATH for cc-news data_utils/corpora.py' + + def process_line(self, data, tokenizer, tokenize): + text = '' + title = data.get('title', None) + description = data.get('description', None) + maintext = data.get('maintext', None) + if title: + text += title.strip() + ' ' + if description and (not maintext + or not maintext.startswith(description)): + text += description.strip() + ' ' + if maintext: + text += maintext + if len(text) > 100: + prompt, text = self.process_sample('', tokenizer, + tokenize), self.process_sample( + text, tokenizer, tokenize) + return [prompt], [text] + else: + return [], [] + + +class BertData(PromptReader): + is_json = False + PATH = '/dataset/fd5061f6/english_data/wikibook' + + def process_line(self, data, tokenizer, tokenize): + if data: + prompt, text = '', data + prompt, text = self.process_sample(prompt, tokenizer, + tokenize), self.process_sample( + text, tokenizer, tokenize) + return [prompt], [text] + else: + return [], [] + + +class Pile(PromptReader): + is_json = True + PATH = '/mnt/train' + filtered_sources = [ + 'Github', 'StackExchange', 'DM Mathematics', 'Ubuntu IRC', 'EuroParl', + 'YoutubeSubtitles', 'Enron Emails' + ] + downsample_sources = {'PubMed Central': 0.3, 'ArXiv': 0.3, 'FreeLaw': 0.3} + + def print_info(self, info): + total_dict = defaultdict(int) + while True: + try: + source_dict = info.get(block=False) + for source, length in source_dict.items(): + total_dict[source] += length + except Empty: + break + print_rank_0(total_dict) + + def tokenize_worker(self, input, output, info, tokenizer, tokenize): + source_dict = defaultdict(int) + for row in iter(input.get, 'STOP'): + row = row.rstrip() + if row: + if self.is_json: + row = json.loads(row) + prompts, texts, source = self.process_line( + row, tokenizer, tokenize) + length = 0 + for prompt, text in zip(prompts, texts): + length += len(text) + output.put((prompt, text)) + if source: + source_dict[source] += length + output.put('COMPLETE') + info.put(source_dict) + + def process_line(self, data, tokenizer, tokenize): + source = data['meta'].get('pile_set_name', None) + text = data.get('text', None) + if source and text: + if source in self.filtered_sources: + return [], [], None + elif source in self.downsample_sources and random.random( + ) > self.downsample_sources[source]: + return [], [], None + else: + prompt, text = self.process_sample( + '', tokenizer, + tokenize), self.process_sample(text, tokenizer, tokenize) + return [prompt], [text], source + else: + return [], [], None + + +class Stories(PromptReader): + is_json = True + PATH = '/dataset/fd5061f6/english_data/stories_31G.jsonl' + + def process_line(self, data, tokenizer, tokenize): + text = data.get('text', None) + if text: + prompt, text = self.process_sample('', tokenizer, + tokenize), self.process_sample( + text, tokenizer, tokenize) + return [prompt], [text] + else: + return [], [] + + +class BertBaseData(BertData): + PATH = '/root/data/formatted_one_article_per_line' + + +class BertLargeData(BertData): + PATH = '/dataset/c07bd62b/cognitive/zhengxiao/formatted_one_article_per_line_large' + + +class WuDaoCorpus(PromptReader): + # PATH = "/dataset/fd5061f6/chinese_data/WuDao" + PATH = '/wudao' + is_json = False + reserve_punct = True + split_row = False + + def process_line(self, item, tokenizer, tokenize): + prompts, texts = [], [] + text = '' + title = item.get('title', None) + content = item.get('content', None) + if title: + text += title.strip() + ' ' + if content: + text += content + if len(text) > 100: + prompt, text = self.process_sample('', tokenizer, + tokenize), self.process_sample( + text, tokenizer, tokenize) + prompts.append(prompt) + texts.append(text) + return prompts, texts + + +NAMED_CORPORA = { + 'wikipedia': wikipedia, + 'wikipedia-key': KeyReader, + 'openwebtext': OpenWebText, + 'zhihu': zhihu, + 'zhidao': zhidao, + 'baike': baike, + 'test': TestDataset, + 'wikibook': BertData, + 'bert-base': BertBaseData, + 'bert-large': BertLargeData, + 'cc-news': CCNews, + 'pile': Pile, + 'stories': Stories, + 'wudao': WuDaoCorpus +} diff --git a/modelscope/models/nlp/mglm/data_utils/datasets.py b/modelscope/models/nlp/mglm/data_utils/datasets.py new file mode 100644 index 00000000..777b7d43 --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/datasets.py @@ -0,0 +1,1244 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""dataset objects for jsons, csvs, and BERT datasets""" + +import csv +import math +import os +import random +import time +from bisect import bisect_right +from itertools import accumulate +from operator import itemgetter + +import json +import nltk +import numpy as np +import pandas as pd +import torch +import tqdm +from nltk import tokenize +from torch.utils import data + +from modelscope.models.nlp.mglm.utils import print_rank_0 +from .lazy_loader import LazyLoader, exists_lazy + + +class ShuffleDataset(data.Dataset): + + def __init__(self, ds): + self.ds = ds + self.shuffle_ids = list(range(len(self.ds))) + random.shuffle(self.shuffle_ids) + self.is_lazy = hasattr(ds, 'is_lazy') and ds.is_lazy + if self.is_lazy: + self.prompt_lens = [ + self.ds.prompt_lens[idx] for idx in self.shuffle_ids + ] + self.text_lens = [ + self.ds.text_lens[idx] for idx in self.shuffle_ids + ] + + def __getitem__(self, idx): + return self.ds[self.shuffle_ids[idx]] + + def __len__(self): + return len(self.ds) + + +class ConcatDataset(data.Dataset): + """ + Dataset to concatenate multiple datasets. + Purpose: useful to assemble different existing datasets, possibly + large-scale datasets as the concatenation operation is done in an + on-the-fly manner. + Arguments: + datasets (sequence): List of datasets to be concatenated. + """ + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) # noqa + r.append(l + s) + s += l + return r + + def __init__(self, datasets, **kwargs): + super(ConcatDataset, self).__init__() + assert len(datasets) > 0, 'datasets should not be an empty iterable' + self.datasets = list(datasets) + self.is_lazy = sum([ + isinstance(ds, LazyLoader) + or (hasattr(ds, 'is_lazy') and ds.is_lazy) for ds in self.datasets + ]) == len(self.datasets) + self.cumulative_sizes = self.cumsum(self.datasets) + self._X = None + self._Y = None + self._lens = None + + def get_text_len(self, idx): + dataset_idx = bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx].get_text_len(sample_idx) + + def SetTokenizer(self, tokenizer): + for ds in self.datasets: + ds.SetTokenizer(tokenizer) + + def GetTokenizer(self): + return self.datasets[0].GetTokenizer() + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + def lens(self): + if self._lens is None: + self._lens = [] + if self.is_lazy: + for data in self.datasets: # noqa + self._lens.extend(data.lens) + else: + for data in self.datasets: # noqa + self._lens.extend([ + len(d['text']) if isinstance(d, dict) else len(d) + for d in data + ]) + return self._lens + + @property + def X(self): + if self._X is None: + self._X = [] + for data in self.datasets: # noqa + self._X.extend(data.X) + return self._X + + @property + def Y(self): + if self._Y is None: + self._Y = [] + for data in self.datasets: # noqa + self._Y.extend(list(data.Y)) + self._Y = np.array(self._Y) + return self._Y + + +class SplitDataset(data.Dataset): + """ + Dataset wrapper to access a subset of another dataset. + Purpose: useful to index into existing datasets, possibly + large-scale datasets as the subindexing operation is done in an + on-the-fly manner. + Arguments: + ds (Dataset or array-like): List of datasets to be subindexed + split_inds (1D array-like): List of indices part of subset + """ + + def __init__(self, ds, split_inds, **kwargs): + self.split_inds = list(split_inds) + self.wrapped_data = ds + self.is_lazy = isinstance(ds, LazyLoader) or (hasattr(ds, 'is_lazy') + and ds.is_lazy) + self._X = None + self._Y = None + + def __len__(self): + return len(self.split_inds) + + def get_text_len(self, idx): + return self.wrapped_data.get_text_len(self.split_inds[idx]) + + def __getitem__(self, index): + return self.wrapped_data[self.split_inds[index]] + + def SetTokenizer(self, tokenizer): + self.wrapped_data.SetTokenizer(tokenizer) + + def GetTokenizer(self): + return self.wrapped_data.GetTokenizer() + + @property + def X(self): + if self._X is None: + self._X = itemgetter(*self.split_inds)(self.wrapped_data.X) + return self._X + + @property + def Y(self): + if self._Y is None: + self._Y = np.array( + itemgetter(*self.split_inds)(self.wrapped_data.Y)) + return self._Y + + def __iter__(self): + for idx in self.split_inds: + yield self.wrapped_data[idx] + + +def split_ds(ds, split=None, shuffle=True, save_splits=None, load_splits=None): + """ + Split a dataset into subsets given proportions of how + much to allocate per split. If a split is 0% returns None for that split. + Purpose: Useful for creating train/val/test splits + Arguments: + ds (Dataset or array-like): Data to be split. + split (1D array-like): proportions to split `ds`. `sum(splits) != 0` + shuffle (boolean): Randomly split dataset. Default: True + save_splits: save split indices to file + load_splits: load split indices from file + """ + if split is None: + split = [.8, .2, .0] + split_sum = sum(split) + if split_sum == 0: + raise Exception('Split cannot sum to 0.') + split = np.array(split) + split /= split_sum + ds_len = len(ds) + inds = np.arange(ds_len) + if shuffle: + rng = np.random.RandomState(1234) + rng.shuffle(inds) + if load_splits is not None: + inds = np.load(load_splits) + assert len(inds) == ds_len + print_rank_0(f'Load split indices from {load_splits}') + elif save_splits is not None: + if torch.distributed.get_rank() == 0: + np.save(save_splits, inds) + print(f'Save split indices to {save_splits}') + start_idx = 0 + residual_idx = 0 + rtn_ds = [None] * len(split) + for i, f in enumerate(split): + if f != 0: + proportion = ds_len * split[i] + residual_idx += proportion % 1 + split_ = int(int(proportion) + residual_idx) + split_inds = inds[start_idx:start_idx + max(split_, 1)] + rtn_ds[i] = SplitDataset(ds, split_inds) + start_idx += split_ + residual_idx %= 1 + return rtn_ds + + +class csv_dataset(data.Dataset): + """ + Class for loading datasets from csv files. + Purpose: Useful for loading data for unsupervised modeling or transfer tasks + Arguments: + path (str): Path to csv file with dataset. + tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None + preprocess_fn (callable): Callable that process a string into desired format. + delim (str): delimiter for csv. Default: ',' + binarize_sent (bool): binarize label values to 0 or 1 if they\'re on a different scale. Default: False + drop_unlabeled (bool): drop rows with unlabelled values. Always fills remaining empty + columns with -1 (regardless if rows are dropped based on value) Default: False + text_key (str): key to get text from csv. Default: 'sentence' + label_key (str): key to get label from json dictionary. Default: 'label' + Attributes: + X (list): all strings from the csv file + Y (np.ndarray): labels to train with + """ + + def __init__(self, + path, + tokenizer=None, + preprocess_fn=None, + delim=',', + binarize_sent=False, + drop_unlabeled=False, + text_key='sentence', + label_key='label', + **kwargs): + self.is_lazy = False + self.preprocess_fn = preprocess_fn + self.SetTokenizer(tokenizer) + self.path = path + self.delim = delim + self.text_key = text_key + self.label_key = label_key + self.drop_unlabeled = drop_unlabeled + + if '.tsv' in self.path: + self.delim = '\t' + + self.X = [] + self.Y = [] + try: + cols = [text_key] + if isinstance(label_key, list): + cols += label_key + else: + cols += [label_key] + data = pd.read_csv( + self.path, sep=self.delim, usecols=cols, encoding='latin-1') + except: # noqa + data = pd.read_csv( + self.path, + sep=self.delim, + usecols=[text_key], + encoding='latin-1') + + data = data.dropna(axis=0) + + self.X = data[text_key].values.tolist() + try: + self.Y = data[label_key].values + except Exception as e: # noqa + self.Y = np.ones(len(self.X)) * -1 + + if binarize_sent: + self.Y = binarize_labels(self.Y, hard=binarize_sent) + + def SetTokenizer(self, tokenizer): + if tokenizer is None: + self.using_tokenizer = False + if not hasattr(self, '_tokenizer'): + self._tokenizer = tokenizer + else: + self.using_tokenizer = True + self._tokenizer = tokenizer + + def GetTokenizer(self): + return self._tokenizer + + @property + def tokenizer(self): + if self.using_tokenizer: + return self._tokenizer + return None + + def __len__(self): + return len(self.X) + + def __getitem__(self, index): + """process+tokenize string and return string,label,and stringlen""" + x = self.X[index] + if self.tokenizer is not None: + x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn) + elif self.preprocess_fn is not None: + x = self.preprocess_fn(x) + y = self.Y[index] + if isinstance(y, str): + if self.tokenizer is not None: + y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn) + elif self.preprocess_fn is not None: + y = self.preprocess_fn(y) + return {'text': x, 'length': len(x), 'label': y} + + def write(self, writer_gen=None, path=None, skip_header=False): + """ + given a generator of metrics for each of the data points X_i, + write the metrics, text, and labels to a csv file + """ + if path is None: + path = self.path + '.results' + print('generating csv at ' + path) + with open(path, 'w') as csvfile: + c = csv.writer(csvfile, delimiter=self.delim) + if writer_gen is not None: + # if first item of generator is a header of what the metrics mean then write header to csv file + if not skip_header: + header = (self.label_key, ) + tuple( + next(writer_gen)) + (self.text_key, ) + c.writerow(header) + for i, row in enumerate(writer_gen): + row = (self.Y[i], ) + tuple(row) + (self.X[i], ) + c.writerow(row) + else: + c.writerow([self.label_key, self.text_key]) + for row in zip(self.Y, self.X): + c.writerow(row) + + +class json_dataset(data.Dataset): + """ + Class for loading datasets from a json dump. + Purpose: Useful for loading data for unsupervised modeling or transfer tasks + Arguments: + path (str): path to json file with dataset. + tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None + preprocess_fn (callable): callable function that process a string into desired format. + Takes string, maxlen=None, encode=None as arguments. Default: process_str + text_key (str): key to get text from json dictionary. Default: 'sentence' + label_key (str): key to get label from json dictionary. Default: 'label' + Attributes: + all_strs (list): list of all strings from the dataset + all_labels (list): list of all labels from the dataset (if they have it) + """ + + def __init__(self, + path, + tokenizer=None, + preprocess_fn=None, + binarize_sent=False, + text_key='sentence', + label_key='label', + loose_json=False, + **kwargs): + self.is_lazy = False + self.preprocess_fn = preprocess_fn + self.path = path + self.SetTokenizer(tokenizer) + self.X = [] + self.Y = [] + self.text_key = text_key + self.label_key = label_key + self.loose_json = loose_json + + for j in self.load_json_stream(self.path): + s = j[text_key] + self.X.append(s) + self.Y.append(j[label_key]) + + if binarize_sent: + self.Y = binarize_labels(self.Y, hard=binarize_sent) + + def SetTokenizer(self, tokenizer): + if tokenizer is None: + self.using_tokenizer = False + if not hasattr(self, '_tokenizer'): + self._tokenizer = tokenizer + else: + self.using_tokenizer = True + self._tokenizer = tokenizer + + def GetTokenizer(self): + return self._tokenizer + + @property + def tokenizer(self): + if self.using_tokenizer: + return self._tokenizer + return None + + def __getitem__(self, index): + """gets the index'th string from the dataset""" + x = self.X[index] + if self.tokenizer is not None: + x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn) + elif self.preprocess_fn is not None: + x = self.preprocess_fn(x) + y = self.Y[index] + if isinstance(y, str): + if self.tokenizer is not None: + y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn) + elif self.preprocess_fn is not None: + y = self.preprocess_fn(y) + return {'text': x, 'length': len(x), 'label': y} + + def __len__(self): + return len(self.X) + + def write(self, writer_gen=None, path=None, skip_header=False): + """ + given a generator of metrics for each of the data points X_i, + write the metrics, text, and labels to a json file + """ + if path is None: + path = self.path + '.results' + + if writer_gen is not None: + # if first item of generator is a header of what the metrics mean then write header to csv file + def gen_helper(): + keys = {} + keys[0] = self.label_key + if not skip_header: + for idx, k in enumerate(tuple(next(writer_gen))): + keys[idx + 1] = k + for i, row in enumerate(writer_gen): + if i == 0 and skip_header: + for idx, _ in enumerate(row): + keys[idx + 1] = 'metric_%d' % (idx, ) + j = {} + for idx, v in enumerate((self.Y[i], ) + tuple(row)): + k = keys[idx] + j[k] = v + yield j + else: + + def gen_helper(): + for y in self.Y: + j = {} + j[self.label_key] = y + yield j + + def out_stream(): + for i, j in enumerate(gen_helper()): + j[self.text_key] = self.X[i] + yield j + + self.save_json_stream(path, out_stream()) + + def save_json_stream(self, save_path, json_stream): + if self.loose_json: + with open(save_path, 'w') as f: + for i, j in enumerate(json_stream): + write_string = '' + if i != 0: + write_string = '\n' + write_string += json.dumps(j) + f.write(write_string) + else: + jsons = [j for j in json_stream] + json.dump(jsons, open(save_path, 'w'), separators=(',', ':')) + + def load_json_stream(self, load_path): + if not self.loose_json: + jsons = json.load(open(load_path, 'r')) + generator = iter(jsons) + else: + + def gen_helper(): + with open(load_path, 'r') as f: + for row in f: + yield json.loads(row) + + generator = gen_helper() + + for j in generator: + if self.label_key not in j: + j[self.label_key] = -1 + yield j + + +class XLDataset(data.Dataset): + + def __init__(self, + ds, + tokenizer, + max_seq_len=1024, + mem_len=None, + sample_across_doc=True, + **kwargs): + self.ds = ds + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + if mem_len is None: + mem_len = max_seq_len + self.mem_len = mem_len + self.sample_across_doc = sample_across_doc + self.indices, self.num_samples = None, None + if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy: + self.is_lazy = True + self.init_indices() + + def init_indices(self): + if self.is_lazy: + lens = np.array( + [self.ds.get_text_len(idx) for idx in range(len(self.ds))]) + else: + lens = np.array([ + len(d['prompt']) + + len(d['text']) if isinstance(d, dict) else len(d) + for d in self.ds + ]) + self.indices = list(accumulate(lens)) + print_rank_0( + f'Dataset document count {len(lens)}, token count {self.indices[-1]}' + ) + self.num_samples = self.indices[-1] // self.max_seq_len + 1 + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + tokens, targets, loss_mask, attention_mask = self.getidx(idx) + tokens = self.pad_seq(tokens) + targets = self.pad_seq(targets) + loss_mask = self.pad_seq(loss_mask, pad_id=0) + return { + 'text': np.array(tokens), + 'target': np.array(targets), + 'loss_mask': np.array(loss_mask), + 'attention_mask': np.array(attention_mask) + } + + def getidx(self, idx): + tokens, targets, loss_masks = [], [], [] + attention_mask = np.concatenate( + (np.zeros((self.max_seq_len, self.mem_len), dtype=np.long), + np.ones((self.max_seq_len, self.max_seq_len), dtype=np.long)), + axis=1) + sample_idx = bisect_right(self.indices, idx * self.max_seq_len) + last_end = 0 if sample_idx == 0 else self.indices[sample_idx - 1] + token_offset = idx * self.max_seq_len - last_end + if token_offset != 0: + history = min(self.mem_len, token_offset) + attention_mask[:, + -self.max_seq_len - history:-self.max_seq_len] = 1 + count = 0 + while len(tokens) < self.max_seq_len and sample_idx < len(self.ds): + item = self.ds[sample_idx] + text, masks = item['tokens'], item['loss_masks'] + text = text + [self.tokenizer.get_command('eos').Id] + end = min( + len(text) - 1, token_offset + self.max_seq_len - len(tokens)) + masks = masks + [1] + if count > 0: + current = len(tokens) + attention_mask[current:, :current + self.mem_len] = 0 + tokens += text[token_offset:end] + targets += text[token_offset + 1:end + 1] + loss_masks += masks[token_offset + 1:end + 1] + count += 1 + sample_idx += 1 + token_offset = 0 + return tokens, targets, loss_masks, attention_mask + + def pad_seq(self, seq, pad_id=None): + total_tokens = self.max_seq_len + num_pad_tokens = max(0, total_tokens - len(seq)) + seq += [ + self.tokenizer.get_command('pad').Id if pad_id is None else pad_id + ] * ( + num_pad_tokens) + return seq + + +class BlockDataset(data.Dataset): + + def __init__(self, + ds, + tokenizer, + max_seq_len=1024, + sample_across_doc=True, + non_sentence_start=0.0, + filter_english=False, + **kwargs): + """ + sentence_start: the stripped article must start with a complete sentence + """ + self.ds = ds + self.ds_len = len(self.ds) + self.num_samples = 1000 * self.ds_len + self.max_seq_len = max_seq_len + self.tokenizer = tokenizer + self.sample_across_doc = sample_across_doc + self.non_sentence_start = non_sentence_start + self.filter_english = filter_english + self.weighting, self.total_len = None, None + self.is_lazy = False + if self.filter_english: + import fasttext + self.model = fasttext.load_model('/mnt/lid.176.bin') + print_rank_0('Load language detection model') + if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy: + self.is_lazy = True + self.init_weighting() + + def init_weighting(self): + if self.is_lazy: + lens = np.array( + [self.ds.get_text_len(idx) for idx in range(len(self.ds))]) + else: + lens = np.array([ + len(d['text']) if isinstance(d, dict) else len(d) + for d in self.ds + ]) + self.total_len = np.sum(lens) + print_rank_0( + f'Dataset document count {len(lens)}, token count {self.total_len}, non sentence start{self.non_sentence_start}' # noqa + ) + self.weighting = list(accumulate(lens)) + + def get_weighted_samples(self, np_rng): + while True: + idx = np_rng.randint(self.total_len) + data_idx = bisect_right(self.weighting, idx) + tokens, loss_mask = self.getidx(data_idx) + if self.filter_english: + text = self.tokenizer.DecodeIds(tokens[:1024]) + lang = self.model.predict(text.replace('\n', ''))[0][0] + if lang == '__label__en': + break + else: + break + return tokens, loss_mask + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + # init rng + rng = random.Random(idx) + rng = np.random.RandomState( + seed=[rng.randint(0, 2**32 - 1) for _ in range(16)]) + + # get possibly weighted random index from dataset + tokens, loss_mask = self.get_weighted_samples(rng) + # truncate or pad tokens + num_tokens = len(tokens) + tokens_to_strip = num_tokens - self.max_seq_len + 1 + + # randomly choose a position for start + if tokens_to_strip > 0: + move_count = 0 + strip_left_tokens = rng.randint(tokens_to_strip) + if rng.random() > self.non_sentence_start: + if rng.random() < 0.5: + while move_count < self.max_seq_len // 2 and strip_left_tokens > 0 and not self.contains_sentence_end( # noqa + tokens[strip_left_tokens - 1]): # noqa + strip_left_tokens -= 1 + move_count += 1 + else: + while move_count < self.max_seq_len // 2 and strip_left_tokens < len( + tokens) and not self.contains_sentence_end( + tokens[strip_left_tokens - 1]): + strip_left_tokens += 1 + move_count += 1 + tokens = [self.tokenizer.get_command('ENC').Id + ] + tokens[strip_left_tokens:] + loss_mask = [0] + loss_mask[strip_left_tokens:] + if len(tokens) == 2 and tokens[1] == self.tokenizer.get_command( + 'eos').Id: + tokens, loss_mask = [], [] + tokens, loss_mask = self.right_strip_seq(tokens, loss_mask, + self.max_seq_len) + else: + tokens = [self.tokenizer.get_command('ENC').Id] + tokens + loss_mask = [0] + loss_mask + # Sample multiple documents + if self.sample_across_doc: + while len(tokens) < self.max_seq_len: + new_tokens, new_loss_mask = self.get_weighted_samples(rng) + new_tokens = [self.tokenizer.get_command('ENC').Id + ] + new_tokens + new_loss_mask = [0] + new_loss_mask + is_last = len(new_tokens) >= self.max_seq_len - len(tokens) + new_tokens, new_loss_mask = self.right_strip_seq( + new_tokens, new_loss_mask, + self.max_seq_len - len(tokens)) + tokens += new_tokens + loss_mask += new_loss_mask + if is_last: + break + return {'text': np.array(tokens), 'loss_mask': np.array(loss_mask)} + + def right_strip_seq(self, tokens, loss_mask, seq_length): + strip_right_tokens = len(tokens) - seq_length + if strip_right_tokens > 0: + while strip_right_tokens < len( + tokens) - 1 and not self.contains_sentence_end( + tokens[-strip_right_tokens - 1]): + strip_right_tokens += 1 + if len(tokens) - strip_right_tokens < seq_length // 2: + strip_right_tokens = len(tokens) - seq_length + tokens = tokens[:-strip_right_tokens] + loss_mask = loss_mask[:-strip_right_tokens] + return tokens, loss_mask + + def getidx(self, data_idx): + data = self.ds[data_idx] + tokens, loss_masks = data['tokens'], data['loss_masks'] + tokens = tokens + [self.tokenizer.get_command('eos').Id] + loss_masks = loss_masks + [1] + return tokens, loss_masks + + def pad_seq(self, seq, pad_id=None): + total_tokens = self.max_seq_len + num_pad_tokens = max(0, total_tokens - len(seq)) + seq += [ + self.tokenizer.get_command('pad').Id if pad_id is None else pad_id + ] * ( + num_pad_tokens) + return seq + + # TODO: rewrite this function for chinese + def contains_sentence_end(self, tok): + tok = self.tokenizer.IdToToken(tok) + if '.' in tok: + return True + if '?' in tok: + return True + if '!' in tok: + return True + if ';' in tok: + return True + if ':' in tok: + return True + if '\n' in tok: + return True + return False + + +class GPT2Dataset(data.Dataset): + + def __init__(self, + ds, + tokenizer, + max_seq_len=1024, + num_samples=None, + weighted=True, + sample_across_doc=True, + random_across_doc_sampling=True, + sentence_start=False, + **kwargs): + """ + sentence_start: the stripped article must start with a complete sentence + """ + self.ds = ds + self.ds_len = len(self.ds) + self.num_samples = num_samples + if num_samples is None: + self.num_samples = 1000 * self.ds_len + self.max_seq_len = max_seq_len + self.tokenizer = tokenizer + self.weighted = weighted + self.sample_across_doc = sample_across_doc + self.random_across_doc_sampling = random_across_doc_sampling + self.sentence_start = sentence_start + self.weighting, self.total_len = None, None + self.is_lazy = False + if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy: + self.is_lazy = True + self.init_weighting() + + def init_weighting(self): + if self.weighted: + if self.is_lazy: + lens = np.array( + [self.ds.get_text_len(idx) for idx in range(len(self.ds))]) + else: + lens = np.array([ + len(d['text']) if isinstance(d, dict) else len(d) + for d in self.ds + ]) + self.total_len = np.sum(lens) + print_rank_0( + f'Dataset document count {len(lens)}, token count {self.total_len}' + ) + self.weighting = list(accumulate(lens)) + else: + self.weighting = None + + def get_weighted_samples(self, np_rng): + if self.weighting is not None: + idx = np_rng.randint(self.total_len) + return bisect_right(self.weighting, idx) + else: + return np_rng.randint(self.ds_len) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + # init rng + rng = random.Random(idx) + rng = np.random.RandomState( + seed=[rng.randint(0, 2**32 - 1) for _ in range(16)]) + + # get possibly weighted random index from dataset + data_idx = self.get_weighted_samples(rng) + # data_idx = rng.choice(self.ds_len, p=self.weighting) + tokens, loss_mask = self.getidx(data_idx) + + # truncate or pad tokens + num_tokens = len(tokens) + tokens_to_strip = num_tokens - self.max_seq_len - 1 + + # randomly choose a position for start + if tokens_to_strip > 0: + strip_left_tokens = rng.randint(tokens_to_strip + 1) + tokens = tokens[strip_left_tokens:] + loss_mask = loss_mask[strip_left_tokens:] + # if self.sentence_start: + # token_copy = list(tokens) + # not_done = True + # while (len(token_copy) > 0) and not_done: + # tok = token_copy.pop(0) + # if self.contains_sentence_end(tok): + # tokens = token_copy + # not_done = False + strip_right_rokens = len(tokens) - self.max_seq_len - 1 + if strip_right_rokens > 0: + tokens = tokens[:-strip_right_rokens] + loss_mask = loss_mask[:-strip_right_rokens] + # Sample multiple documents + if self.sample_across_doc: + while (len(tokens) < (self.max_seq_len + 1)): + if self.random_across_doc_sampling: + data_idx = self.get_weighted_samples(rng) + else: + data_idx = (data_idx + 1) % self.ds_len + new_tokens, new_loss_mask = self.getidx(data_idx) + tokens += new_tokens + loss_mask += new_loss_mask + tokens = tokens[:(self.max_seq_len + 1)] + loss_mask = loss_mask[:(self.max_seq_len + 1)] + + tokens = self.pad_seq(tokens) + loss_mask = self.pad_seq(loss_mask, pad_id=0) + return {'text': np.array(tokens), 'loss_mask': np.array(loss_mask)} + + def getidx(self, data_idx): + data = self.ds[data_idx] + tokens, loss_masks = data['tokens'], data['loss_masks'] + tokens = tokens + [self.tokenizer.get_command('eos').Id] + loss_masks = loss_masks + [1] + return tokens, loss_masks + + def pad_seq(self, seq, pad_id=None): + total_tokens = self.max_seq_len + 1 + num_pad_tokens = max(0, total_tokens - len(seq)) + seq += [ + self.tokenizer.get_command('pad').Id if pad_id is None else pad_id + ] * ( + num_pad_tokens) + return seq + + # TODO: rewrite this function for chinese + def contains_sentence_end(self, tok): + tok = self.tokenizer.IdToToken(tok) + if '.' in tok: + return True + if '?' in tok: + return True + if '!' in tok: + return True + return False + + +class BertSentencepairDataset(data.Dataset): + """ + Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair. + Arguments: + ds (Dataset or array-like): data corpus to use for training + max_seq_len (int): maximum sequence length to use for a sentence pair + mask_lm_prob (float): proportion of tokens to mask for masked LM + max_preds_per_seq (int): Maximum number of masked tokens per sentence pair. Default: math.ceil(max_seq_len*mask_lm_prob/10)*10 + short_seq_prob (float): Proportion of sentence pairs purposefully shorter than max_seq_len + dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1) + + """ # noqa + + def __init__(self, + ds, + max_seq_len=512, + mask_lm_prob=.15, + max_preds_per_seq=None, + short_seq_prob=.01, + dataset_size=None, + presplit_sentences=False, + weighted=True, + **kwargs): + self.ds = ds + self.ds_len = len(self.ds) + self.tokenizer = self.ds.GetTokenizer() + self.vocab_words = list(self.tokenizer.text_token_vocab.values()) + self.ds.SetTokenizer(None) + self.max_seq_len = max_seq_len + self.mask_lm_prob = mask_lm_prob + if max_preds_per_seq is None: + max_preds_per_seq = math.ceil(max_seq_len * mask_lm_prob / 10) * 10 + self.max_preds_per_seq = max_preds_per_seq + self.short_seq_prob = short_seq_prob + self.dataset_size = dataset_size + if self.dataset_size is None: + self.dataset_size = self.ds_len * (self.ds_len - 1) + self.presplit_sentences = presplit_sentences + if not self.presplit_sentences: + nltk.download('punkt', download_dir='./nltk') + self.weighted = weighted + self.get_weighting() + + def get_weighting(self): + if self.weighted: + if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy: + lens = np.array(self.ds.lens) + else: + lens = np.array([ + len(d['text']) if isinstance(d, dict) else len(d) + for d in self.ds + ]) + self.total_len = np.sum(lens) + self.weighting = list(accumulate(lens)) + else: + self.weighting = None + + def get_weighted_samples(self, np_rng): + if self.weighting is not None: + idx = np_rng.randint(self.total_len) + return bisect_right(self.weighting, idx) + else: + return np_rng.randint(self.ds_len) + + def __len__(self): + return self.dataset_size + + def __getitem__(self, idx): + # get rng state corresponding to index (allows deterministic random pair) + rng = random.Random(idx) + np_rng = np.random.RandomState( + seed=[rng.randint(0, 2**32 - 1) for _ in range(16)]) + # get seq length + target_seq_length = self.max_seq_len + short_seq = False # noqa + if rng.random() < self.short_seq_prob: + target_seq_length = rng.randint(2, target_seq_length) + short_seq = True # noqa + + # get sentence pair and label + is_random_next = None + lena = 0 + lenb = 0 + while (is_random_next is None) or (lena < 1) or (lenb < 1): + tokensa, tokensb, is_random_next = self.create_random_sentencepair( + target_seq_length, rng, np_rng) + lena = len(tokensa[0]) + lenb = len(tokensb[0]) + + # truncate sentence pair to max_seq_len + tokensa, tokensb = self.truncate_seq_pair(tokensa, tokensb, + self.max_seq_len, rng) + # join sentence pair, mask, and pad + tokens, mask, mask_labels, pad_mask = self.create_masked_lm_predictions( + tokensa, tokensb, self.mask_lm_prob, self.max_preds_per_seq, + self.vocab_words, rng) + sample = { + 'text': np.array(tokens[0]), + 'types': np.array(tokens[1]), + 'is_random': int(is_random_next), + 'mask': np.array(mask), + 'mask_labels': np.array(mask_labels), + 'pad_mask': np.array(pad_mask) + } + return sample + + def sentence_split(self, document): + """split document into sentences""" + lines = document.split('\n') + if self.presplit_sentences: + return [line for line in lines if line] + rtn = [] + for line in lines: + if line != '': + rtn.extend(tokenize.sent_tokenize(line)) + return rtn + + def sentence_tokenize(self, + sent, + sentence_num=0, + beginning=False, + ending=False): + """tokenize sentence and get token types""" + tokens = self.tokenizer.EncodeAsIds(sent).tokenization + str_type = 'str' + str(sentence_num) + token_types = [self.tokenizer.get_type(str_type).Id] * len(tokens) + return tokens, token_types + + def get_doc(self, idx): + """gets text of document corresponding to idx""" + rtn = self.ds[idx] + if isinstance(rtn, dict): + rtn = rtn['text'] + return rtn + + def create_random_sentencepair(self, target_seq_length, rng, np_rng): + """ + fetches a random sentencepair corresponding to rng state similar to + https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L248-L294 + """ + is_random_next = None + + curr_strs = [] + curr_str_types = [] + curr_len = 0 + + while curr_len < 1: + curr_len = 0 + doc_a = None + while doc_a is None: + if self.weighted: + # doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting) + doc_a_idx = self.get_weighted_samples(np_rng) + else: + doc_a_idx = rng.randint(0, self.ds_len - 1) + doc_a = self.sentence_split(self.get_doc(doc_a_idx)) + if not doc_a: + doc_a = None + + random_start_a = rng.randint(0, len(doc_a) - 1) + while random_start_a < len(doc_a): + sentence = doc_a[random_start_a] + sentence, sentence_types = self.sentence_tokenize( + sentence, 0, random_start_a == 0, + random_start_a == len(doc_a)) + curr_strs.append(sentence) + curr_str_types.append(sentence_types) + curr_len += len(sentence) + if random_start_a == len( + doc_a) - 1 or curr_len >= target_seq_length: + break + random_start_a = (random_start_a + 1) + + if curr_strs: + num_a = 1 + if len(curr_strs) >= 2: + num_a = rng.randint(0, len(curr_strs)) + + tokens_a = [] + token_types_a = [] + for j in range(num_a): + tokens_a.extend(curr_strs[j]) + token_types_a.extend(curr_str_types[j]) + + tokens_b = [] + token_types_b = [] + is_random_next = False + if len(curr_strs) == 1 or rng.random() < 0.5: + is_random_next = True + target_b_length = target_seq_length - len(tokens_a) + b_len = 0 + while b_len < 1: + doc_b = None + while doc_b is None: + doc_b_idx = rng.randint(0, self.ds_len - 2) + doc_b_idx += int(doc_b_idx >= doc_a_idx) + + doc_b = self.sentence_split(self.get_doc(doc_b_idx)) + if not doc_b: + doc_b = None + + random_start_b = rng.randint(0, len(doc_b) - 1) + while random_start_b < len(doc_b): + sentence_b = doc_b[random_start_b] + new_b_tokens, new_b_types = self.sentence_tokenize( + sentence_b, 1, random_start_b == 0, + random_start_b == len(doc_b)) + b_len += len(new_b_tokens) + tokens_b.extend(new_b_tokens) + token_types_b.extend(new_b_types) + if len(tokens_b) >= target_b_length: + break + random_start_b = (random_start_b + 1) + else: + is_random_next = False + for j in range(num_a, len(curr_strs)): + tokens_b.extend(curr_strs[j]) + token_types_b.extend(curr_str_types[j]) + + return (tokens_a, token_types_a), (tokens_b, + token_types_b), is_random_next + + def truncate_seq_pair(self, a, b, max_seq_len, rng): + """ + Truncate sequence pair according to original BERT implementation: + https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L391 + """ + tokens_a, token_types_a = a + tokens_b, token_types_b = b + max_num_tokens = max_seq_len - 3 + while True: + len_a = len(tokens_a) + len_b = len(tokens_b) + total_length = len_a + len_b + if total_length <= max_num_tokens: + break + if len(tokens_a) > len(tokens_b): + trunc_tokens = tokens_a + trunc_types = token_types_a + else: + trunc_tokens = tokens_b + trunc_types = token_types_b + + assert len(trunc_tokens) >= 1 + + if rng.random() < 0.5: + trunc_tokens.pop(0) + trunc_types.pop(0) + else: + trunc_tokens.pop() + trunc_types.pop() + return (tokens_a, token_types_a), (tokens_b, token_types_b) + + def mask_token(self, idx, tokens, types, vocab_words, rng): + """ + helper function to mask `idx` token from `tokens` according to + section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf + """ + label = tokens[idx] + if rng.random() < 0.8: + new_label = self.tokenizer.get_command('MASK').Id + else: + if rng.random() < 0.5: + new_label = label + else: + new_label = rng.choice(vocab_words) + + tokens[idx] = new_label + + return label + + def pad_seq(self, seq): + """helper function to pad sequence pair""" + num_pad = max(0, self.max_seq_len - len(seq)) + pad_mask = [0] * len(seq) + [1] * num_pad + seq += [self.tokenizer.get_command('pad').Id] * num_pad + return seq, pad_mask + + def create_masked_lm_predictions(self, a, b, mask_lm_prob, + max_preds_per_seq, vocab_words, rng): + """ + Mask sequence pair for BERT training according to: + https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L338 + """ + tokens_a, token_types_a = a + tokens_b, token_types_b = b + tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [ + self.tokenizer.get_command('sep').Id + ] + tokens_b + [self.tokenizer.get_command('sep').Id] + token_types = [token_types_a[0]] + token_types_a + [ + token_types_a[0] + ] + token_types_b + [token_types_b[0]] + + len_a = len(tokens_a) + len_b = len(tokens_b) + + cand_indices = [idx + 1 for idx in range(len_a) + ] + [idx + 2 + len_a for idx in range(len_b)] + + rng.shuffle(cand_indices) + + output_tokens, pad_mask = self.pad_seq(list(tokens)) + output_types, _ = self.pad_seq(list(token_types)) + + num_to_predict = min(max_preds_per_seq, + max(1, int(round(len(tokens) * mask_lm_prob)))) + + mask = [0] * len(output_tokens) + mask_labels = [-1] * len(output_tokens) + + for idx in sorted(cand_indices[:num_to_predict]): + mask[idx] = 1 + label = self.mask_token(idx, output_tokens, output_types, + vocab_words, rng) + mask_labels[idx] = label + + return (output_tokens, output_types), mask, mask_labels, pad_mask diff --git a/modelscope/models/nlp/mglm/data_utils/extraction.py b/modelscope/models/nlp/mglm/data_utils/extraction.py new file mode 100644 index 00000000..53027e4f --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/extraction.py @@ -0,0 +1,71 @@ +# Copyright (c) 2022 Zhipu.AI + +import glob +import os + +import json +import nltk + +nltk.download('punkt') + + +class NLTKSegmenter: + + def __init(self): + pass + + @staticmethod + def segment_string(article): + return nltk.tokenize.sent_tokenize(article) + + +wiki_path = 'data/extracted' +output_path = 'formatted/wiki-key.txt' +segmenter = NLTKSegmenter() +with open(output_path, 'w') as output: + for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False): + for filename in glob.glob( + os.path.join(dirname, 'wiki_*'), recursive=True): + print(filename) + article_lines = [] + article_open = False + with open(filename, mode='r', newline='\n') as file: + for line in file: + line = line.rstrip() + if '' in line: + key_sentences, contents = [], [] + key, content = None, [] + for sentences in article_lines[1:]: + if len(sentences) > 1: + if key: + if len(content) > 0 or len(contents) == 0: + key_sentences.append(key) + contents.append(content) + else: + contents[-1].append(key) + key, content = None, [] + key_sentences.append(sentences[0]) + contents.append(sentences[1:]) + elif len(sentences) > 0: + if key: + content.append(sentences[0]) + else: + key = sentences[0] + if key: + if len(content) > 0 or len(contents) == 0: + key_sentences.append(key) + contents.append(content) + else: + contents[-1].append(key) + contents = [' '.join(content) for content in contents] + article = {'key': key_sentences, 'content': contents} + output.write(json.dumps(article)) + output.write('\n') + article_open = False + article_lines = [] + else: + if article_open and line: + sentences = segmenter.segment_string(line) + article_lines.append(sentences) diff --git a/modelscope/models/nlp/mglm/data_utils/file_utils.py b/modelscope/models/nlp/mglm/data_utils/file_utils.py new file mode 100755 index 00000000..794e127a --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/file_utils.py @@ -0,0 +1,256 @@ +# Modified by Zhipu.AI +# This file is provided as is from: +# https://github.com/huggingface/pytorch-pretrained-BERT +# Please refer to their repository for copyright. +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import logging +import os +import shutil +import sys +import tempfile +from functools import wraps +from hashlib import sha256 +from io import open +from urllib.parse import urlparse + +import boto3 +import json +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BERT_CACHE = Path( + os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + Path.home() / '.pytorch_pretrained_bert')) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( + 'PYTORCH_PRETRAINED_BERT_CACHE', + os.path.join(os.path.expanduser('~'), '.pytorch_pretrained_bert')) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError('file {} not found'.format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError('file {} not found'.format(meta_path)) + + with open(meta_path, encoding='utf-8') as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError('file {} not found'.format(url_or_filename)) + else: + # Something unknown + raise ValueError( + 'unable to parse {} as a URL or as a local path'.format( + url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError('bad s3 path {}'.format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith('/'): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response['Error']['Code']) == 404: + raise EnvironmentError('file {} not found'.format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource('s3') + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource('s3') + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit='B', total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith('s3://'): + etag = s3_etag(url) + else: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + raise IOError( + 'HEAD request failed for url {} with status code {}'.format( + url, response.status_code)) + etag = response.headers.get('ETag') + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info('%s not found in cache, downloading to %s', url, + temp_file.name) + + # GET file object + if url.startswith('s3://'): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info('copying %s to cache at %s', temp_file.name, + cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info('creating metadata file for %s', cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w', encoding='utf-8') as meta_file: + json.dump(meta, meta_file) + + logger.info('removing temp file %s', temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + ''' + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + ''' + collection = set() + with open(filename, 'r', encoding='utf-8') as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/modelscope/models/nlp/mglm/data_utils/lazy_loader.py b/modelscope/models/nlp/mglm/data_utils/lazy_loader.py new file mode 100644 index 00000000..77a77a8a --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/lazy_loader.py @@ -0,0 +1,286 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""utils for loading text from disk""" +import mmap +import os +import pickle as pkl +import time +from itertools import accumulate + +import numpy as np +import torch +from torch.multiprocessing import Lock + + +def get_lazy_path(path): + """ + Gets directory path where lazy files are stored. + """ + return os.path.splitext(path)[0] + '.lazy' + + +def exists_lazy(path, data_type='data'): + """ + Check if we've already made a lazy version of this file for the `data_type` field. + """ + if not os.path.exists(get_lazy_path(path)): + return False + contents = os.listdir(get_lazy_path(path)) + if data_type not in contents: + return False + if data_type + '.len.pkl' not in contents: + return False + return True + + +def get_scatter_path(path, scatter_rank): + path = os.path.splitext(path)[0] + '.scatter' + scatter_path = os.path.join(path, str(scatter_rank)) + return scatter_path + + +def exists_scatter(path, scatter_num=64, data_type='data'): + for i in range(scatter_num): + scatter_path = get_scatter_path(path, scatter_rank=i) + if not exists_lazy(scatter_path, data_type=data_type): + return False + return True + + +class LazyWriter: + + def __init__(self, + path, + data_type, + is_array=False, + array_data_type=np.int32): + lazypath = get_lazy_path(path) + if not os.path.exists(lazypath): + os.makedirs(lazypath) + self.datapath = os.path.join(lazypath, data_type) + self.lenpath = os.path.join(lazypath, data_type + '.len.pkl') + self.array_data_type = array_data_type + self.output = open(self.datapath, 'wb') + self.lengths = [] + self.is_array = is_array + + @staticmethod + def get_len_path(path, data_type): + lazypath = get_lazy_path(path) + return os.path.join(lazypath, data_type + '.len.pkl') + + def write(self, s): + if isinstance(s, dict): + s = s['text'] + if self.is_array: + encoded = np.array( + s, dtype=self.array_data_type).tobytes(order='C') + self.output.write(encoded) + self.lengths.append(len(s)) + else: + encoded = s.encode('utf-8') + self.output.write(encoded) + self.lengths.append(len(encoded)) + + def close(self): + self.output.close() + with open(self.lenpath, 'wb') as f: + pkl.dump(self.lengths, f) + + +def split_strings(strings, start, chr_lens): + """ + Split strings based on string lengths and given start. + """ + return [ + strings[i - start:j - start] + for i, j in zip([start] + chr_lens[:-1], chr_lens) + ] + + +class ProcessorTokenizer: + """ + callable class that runs a preprocessing, as well as tokenization step, + on input text. + """ + + def __init__(self, tokenizer, process_fn=None): + self.tokenizer = tokenizer + self.process_fn = process_fn + + def __call__(self, string): + if self.tokenizer is not None: + string = self.tokenizer(string, process_fn=self.process_fn) + elif self.process_fn is not None: + string = self.process_fn(string) + return string + + +class LazyLoader(object): + """ + Arguments: + path: path to directory where array entries are concatenated into one big string file + and the .len file are located + data_type (str): Some datsets have multiple fields that are stored in different paths. + `data_type` specifies which of these fields to load in this class + mem_map (boolean): Specifies whether to memory map file `path` + map_fn (callable): Fetched strings are passed through map_fn before being returned. + + Example of lazy loader directory structure: + file.json + file.lazy/ + data_type1 + data_type1.len.pkl + data_type2 + data_type2.len.pkl + """ + + def __init__(self, + path, + data_type='data', + mem_map=False, + map_fn=None, + is_array=False, + array_data_type=np.int32, + load_memory=False, + half_load=False): + lazypath = get_lazy_path(path) + datapath = os.path.join(lazypath, data_type) + # get file where array entries are concatenated into one big string + self._file = open(datapath, 'rb') + self.file = self._file + self.is_array = is_array + self.array_data_type = array_data_type + # memory map file if necessary + lenpath = os.path.join(lazypath, data_type + '.len.pkl') + self.lens = pkl.load(open(lenpath, 'rb')) + if half_load: + self.lens = self.lens[:2 * len(self.lens) // 3] + self.ends = list(accumulate(self.lens)) + self.dumb_ends = list(self.ends) + self.mem_map = mem_map + self.load_memory = load_memory + if self.load_memory: + data_type_size = np.dtype(self.array_data_type).itemsize + if half_load: + self.file = self.file.read(sum(self.lens) * data_type_size) + else: + self.file = self.file.read() + self.file = np.ndarray( + shape=(len(self.file) // data_type_size, ), + dtype=array_data_type, + buffer=self.file, + order='C') + elif self.mem_map: + if is_array: + if self.ends[-1] == 0: + self.file = np.array([], dtype=array_data_type) + else: + self.file = np.memmap( + self.file, dtype=array_data_type, mode='r', order='C') + else: + if self.ends[-1] == 0: + self.file = bytearray() + else: + self.file = mmap.mmap( + self.file.fileno(), 0, prot=mmap.PROT_READ) + self.read_lock = Lock() + self.process_fn = map_fn + self.map_fn = map_fn + self._tokenizer = None + self.is_lazy = True + + def SetTokenizer(self, tokenizer): + """ + logic to set and remove (set to None) tokenizer. + combines preprocessing/tokenization into one callable. + """ + if tokenizer is None: + if not hasattr(self, '_tokenizer'): + self._tokenizer = tokenizer + else: + self._tokenizer = tokenizer + self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn) + + def GetTokenizer(self): + return self._tokenizer + + def __getitem__(self, index): + """ + read file and splice strings based on string ending array `self.ends` + """ + if not isinstance(index, slice): + if index == 0: + start = 0 + else: + start = self.ends[index - 1] + end = self.ends[index] + rtn = self.file_read(start, end) + if self.map_fn is not None: + rtn = self.map_fn(rtn) + else: + # if slice, fetch strings with 1 diskread and then splice in memory + chr_lens = self.ends[index] + if index.start == 0 or index.start is None: + start = 0 + else: + start = self.ends[index.start - 1] + stop = chr_lens[-1] + strings = self.file_read(start, stop) + rtn = split_strings(strings, start, chr_lens) + if self.map_fn is not None: + rtn = [self.map_fn(s) for s in rtn] + return rtn + + def __len__(self): + return len(self.ends) + + def file_read(self, start=0, end=None): + """read specified portion of file""" + data_type_size = np.dtype(self.array_data_type).itemsize + # atomic reads to avoid race conditions with multiprocess dataloader + self.read_lock.acquire() + if not self.mem_map and not self.load_memory: + # seek to start of file read + if self.is_array: + start = start * data_type_size + end = end * data_type_size if end is not None else None + self.file.seek(start) + # read to end of file if no end point provided + if end is None: + rtn = self.file.read() + # else read amount needed to reach end point + else: + rtn = self.file.read(end - start) + if self.is_array: + rtn = np.ndarray( + shape=(len(rtn) // data_type_size, ), + dtype=self.array_data_type, + buffer=rtn, + order='C') + else: + rtn = rtn.decode('utf-8', 'ignore') + else: + rtn = self.file[start:end] + if self.is_array: + rtn = rtn.copy() + else: + rtn = rtn.decode('utf-8', 'strict') + self.read_lock.release() + # TODO: @raulp figure out mem map byte string bug + # if mem map'd need to decode byte string to string + # # rtn = str(rtn) + # if self.mem_map: + # rtn = rtn.decode('unicode_escape') + return rtn diff --git a/modelscope/models/nlp/mglm/data_utils/samplers.py b/modelscope/models/nlp/mglm/data_utils/samplers.py new file mode 100644 index 00000000..c0f6e1ab --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/samplers.py @@ -0,0 +1,190 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""batch samplers that work with either random or sequential data samplers""" +import math +import os +import sys + +import numpy as np +import torch +from torch.utils import data + + +class RandomSampler(data.sampler.Sampler): + r""" + Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, + but this class lets the user set an epoch like DistributedSampler + Samples elements randomly. If without replacement, then sample from a shuffled dataset. + If with replacement, then user can specify ``num_samples`` to draw. + Arguments: + data_source (Dataset): dataset to sample from + num_samples (int): number of samples to draw, default=len(dataset) + replacement (bool): samples are drawn with replacement if ``True``, default=False + """ + + def __init__(self, data_source, replacement=False, num_samples=None): + super(RandomSampler, self).__init__(data_source) + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.epoch = -1 + + if self._num_samples is not None and replacement is False: + raise ValueError( + 'With replacement=False, num_samples should not be specified, ' + 'since a random permute will be performed.') + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError('num_samples should be a positive integer ' + 'value, but got num_samples={}'.format( + self.num_samples)) + if not isinstance(self.replacement, bool): + raise ValueError('replacement should be a boolean value, but got ' + 'replacement={}'.format(self.replacement)) + + @property + def num_samples(self): + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self): + n = len(self.data_source) + g = torch.Generator() + if self.epoch >= 0: + g.manual_seed(self.epoch) + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint( + high=n, size=(32, ), dtype=torch.int64, + generator=g).tolist() + yield from torch.randint( + high=n, + size=(self.num_samples % 32, ), + dtype=torch.int64, + generator=g).tolist() + else: + yield from torch.randperm(n, generator=self.generator).tolist() + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +class DistributedSequentialSampler(data.sampler.Sampler): + + def __init__(self, + num_samples, + train_iters, + batch_size, + rank=-1, + world_size=2): + super().__init__(num_samples) + if rank == -1: + rank = 0 + world_size = 1 + self.num_samples = num_samples + self.rank = rank + self.world_size = world_size + self.start_iter = 0 + self.train_iters = train_iters + self.batch_size = batch_size + self.batch_bias = [ + i * (num_samples // batch_size) for i in range(batch_size) + ] + + def __iter__(self): + for idx in range(self.start_iter, self.train_iters * 10): + batch = [(idx + bias) % self.num_samples + for bias in self.batch_bias] + tbatch = self._batch(batch) + yield tbatch + + def __len__(self): + return self.train_iters + + def _batch(self, batch): + """extracts samples only pertaining to this worker's batch""" + start = self.rank * self.batch_size // self.world_size + end = (self.rank + 1) * self.batch_size // self.world_size + return batch[start:end] + + +class DistributedBatchSampler(data.sampler.BatchSampler): + """ + similar to normal implementation of distributed sampler, except implementation is at the + batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary + data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. + """ + + def __init__(self, + sampler, + batch_size, + drop_last, + rank=-1, + world_size=2, + wrap_last=False, + gradient_accumulation_steps=None): + super(DistributedBatchSampler, self).__init__(sampler, batch_size, + drop_last) + if rank == -1: + assert False, 'should not be here' + self.rank = rank + self.world_size = world_size + self.sampler.wrap_around = 0 + self.wrap_around = 0 + self.wrap_last = wrap_last + self.start_iter = 0 + self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps # noqa + + def __iter__(self): + batch = [] + i = 0 + for idx in self.data_iterator(self.sampler, wrap_around=False): + batch.append(idx) + if len(batch) == self.batch_size: + tbatch = self._batch(batch) + if i >= self.start_iter * self.effective_batch_size: + yield tbatch + self.start_iter = 0 + i += len(batch) + batch = [] + batch_len = len(batch) + if batch_len > 0 and not self.drop_last: + if self.wrap_last: + self.sampler.wrap_around -= (self.batch_size) + self.wrap_around += (len(batch)) + self.wrap_around %= self.batch_size + yield self._batch(batch) + if self.wrap_last: + self.sampler.wrap_around += self.batch_size + + def data_iterator(self, _iter, wrap_around=False): + """iterates through data and handles wrap around""" + for i, idx in enumerate(_iter): + if i < self.wrap_around % self.batch_size: + continue + if wrap_around: + self.wrap_around += 1 + self.wrap_around %= self.batch_size + yield idx + + def _batch(self, batch): + """extracts samples only pertaining to this worker's batch""" + start = self.rank * self.batch_size // self.world_size + end = (self.rank + 1) * self.batch_size // self.world_size + return batch[start:end] diff --git a/modelscope/models/nlp/mglm/data_utils/sp_tokenizer.py b/modelscope/models/nlp/mglm/data_utils/sp_tokenizer.py new file mode 100644 index 00000000..b4d1afe3 --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/sp_tokenizer.py @@ -0,0 +1,158 @@ +# Modified by Zhipu.AI +""" +from https://github.com/openai/gpt-2/, changed for chinese +""" +import os # yapf: disable + + +""" +SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation +systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements +subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the +extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end +system that does not depend on language-specific pre/postprocessing. +https://github.com/google/sentencepiece + +pip install sentencepiece + +or git clone https://github.com/google/sentencepiece.git +python setup.py install + +""" + + +def get_pairs(word): + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + + def __init__(self, encoder, bpe_merges): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.max_len = 0 + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + if not pairs: + return token + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: # noqa + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + return [self.encoder.get(token, 1) for token in self.tokenize(text)] + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + return text + + def tokenize(self, text): + bpe_tokens = [] + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + return [self.encoder.get(token, 1) for token in tokens] + + +class Encoder_SP: + + def __init__(self, model_path): + import sentencepiece as spm + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + + def encode(self, text): + """ + text="...." + """ + return self.sp.EncodeAsIds(text) + + def decode(self, tokens): + """ + tokens=[x1,x2,...] + """ + text = [int(token) for token in tokens] + # print(text) + return self.sp.DecodeIds(text) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + +def get_encoder(encoder_file, bpe_file): + import json + filepath, filename = os.path.split(encoder_file) + shotname, extension = os.path.splitext(filename) + + if ('.model' == extension) and (bpe_file == ''): + return Encoder_SP(encoder_file) + else: + with open(encoder_file, 'r', encoding='utf-8') as f: + encoder = json.load(f) + with open(bpe_file, 'r', encoding='utf-8') as f: + bpe_data = f.read() + bpe_merges = [ + tuple(merge_str.split()) + for merge_str in bpe_data.split('\n')[1:-1] + ] + return Encoder( + encoder=encoder, + bpe_merges=bpe_merges, + ) + + +def from_pretrained(model_path): + return get_encoder(model_path + '/tokenizer/mglm250k/mglm250k-uni.model', + '') diff --git a/modelscope/models/nlp/mglm/data_utils/tokenization.py b/modelscope/models/nlp/mglm/data_utils/tokenization.py new file mode 100755 index 00000000..c89cc371 --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/tokenization.py @@ -0,0 +1,1396 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for using and training tokenizers (char, wordpiece, sentencepiece)""" +import csv +import itertools +import os +import random +from collections import namedtuple + +import nltk +import regex as re +import sentencepiece as spm +import torch +from nltk import tokenize as nltk_tokenize + +from . import sp_tokenizer +from .tokenization_gpt2 import GPT2Tokenizer +from .wordpiece import PRETRAINED_VOCAB_ARCHIVE_MAP, BertTokenizer + + +def make_tokenizer(tokenizer_type, + corpus, + model_path=None, + vocab_size=None, + model_type=None, + pad_token=0, + character_coverage=1.0, + command_tokens=None, + type_tokens=None, + **kwargs): + """ + Helper function to instantiate a tokenizer given common combinations of options. + """ + tokenizer_class = tokenizer_type + if isinstance(tokenizer_class, str): + tokenizer_class = eval(tokenizer_class) + if tokenizer_class is BertWordPieceTokenizer: + return BertWordPieceTokenizer(model_type, **kwargs) + elif tokenizer_class is GPT2BPETokenizer: + if model_type is None: + model_type = 'gpt2' + return GPT2BPETokenizer(model_type, **kwargs) + elif tokenizer_class is ChineseSPTokenizer: + return ChineseSPTokenizer(model_path, **kwargs) + text_tokenizer = tokenizer_class( + corpus=corpus, + vocab_size=vocab_size, + model_path=model_path, + model_type=model_type, + pad_token=pad_token, + character_coverage=character_coverage) + return Tokenizer(text_tokenizer, command_tokens, type_tokens) + + +class Tokenization(object): + """ + Tokenization object to hold tokenization, (processed text),and original + text. Can hold tokenization as Ids or tokens. + + It also holds command tokens (pad, unk, etc.) for the tokenization. + This allows functions to pad/operate on tokenizations without having + access to the full tokenizer, just the tokenization. + + Several standard array operations are implemented (insert, append, extend). + """ + + def __init__(self, + tokenization, + text=None, + original_text=None, + command_tokens=None, + asIds=True): + self.tokenization = tokenization + self.text = text + if self.text is None: + self.text = self.tokenization + self.original_text = original_text + if self.original_text is None: + self.original_text = self.text + self.command_tokens = command_tokens + self.asIds = asIds + self.parse_command_tokens() + + def set_command_tokens(self, command_tokens): + self.command_tokens = command_tokens + return self.parse_command_tokens() + + def parse_command_tokens(self): + if self.command_tokens is None: + return + for command_token in self.command_tokens: + if self.asIds: + setattr(self, command_token.name, command_token.Id) + else: + setattr(self, command_token.name, command_token.token) + + def __getitem__(self, index): + return self.tokenization[index] + + def __len__(self): + return len(self.tokenization) + + def insert(self, idx, other): + if isinstance(other, (CommandToken, TypeToken)): + self.tokenization.insert(idx, other.Id) + if idx == 0: + self.text = other.token + self.text + self.original_text = other.token + self.original_text + elif idx == len(self.tokenization) - 1: + self.text += other.token + self.original_text += other.token + elif isinstance(other, Tokenization): + self.tokenization = self.tokenization[: + idx] + other.tokenization + self.tokenization[ + idx:] + else: + self.tokenization = self.tokenization[: + idx] + other.tokenization + self.tokenization[ + idx:] + + def append(self, other): + if isinstance(other, (CommandToken, TypeToken)): + self.tokenization.append(other.Id) + self.text += other.token + self.original_text += other.token + elif isinstance(other, Tokenization): + self.tokenization.extend(other.tokenization) + self.text += other.text + self.original_text += other.original_text + else: + self.tokenization.append(other) + return self + + def extend(self, other): + if isinstance(other, (CommandToken, TypeToken)): + self.tokenization.append(other.Id) + self.text += other.token + self.original_text += other.token + elif isinstance(other, list) and isinstance(other[0], + (CommandToken, TypeToken)): + self.tokenization.extend([o.Id for o in other]) + self.text += [o.token for o in other] + self.original_text += [o.token for o in other] + elif isinstance(other, Tokenization): + self.tokenization.extend(other.tokenization) + self.text += other.text + self.original_text += other.original_text + else: + self.tokenization.extend(other) + return self + + +"""define some default command tokens for the tokenizer to use""" +token_format = '<{0}>' + +COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id')) + + +def prep_command_tokens(tokenlist, token_format=token_format): + return [ + CommandToken(tok[0], token_format.format(tok[0]), tok[1]) + for tok in tokenlist + ] + + +class CommandToken(object): + + def __init__(self, name, token, Id, lstrip=False, rstrip=False): + self.name = name + self.token = token + self.Id = Id + self.lstrip = lstrip + self.rstrip = rstrip + + def __str__(self): + return str(COMMAND_TUPLE(self.name, self.token, self.Id)) + + +DEFAULT_COMMAND_TOKENS = [ + ('pad', 0), + ('eos', 1), + ('bos', 2), + ('unk', 3), + ('sep', 4), + ('L2R', 5), + ('ENC', 6), + ('MASK', 7), +] +DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS) +"""define some default type tokens for bert training""" + +TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id')) + + +def prep_type_tokens(tokenlist, token_format=token_format): + return [ + TypeToken(tok[0], token_format.format(tok[0]), tok[1]) + for tok in tokenlist + ] + + +class TypeToken(object): + + def __init__(self, name, token, Id): + self.name = name + self.token = token + self.Id = Id + + def __str__(self): + return str(TYPE_TUPLE(self.name, self.token, self.Id)) + + +DEFAULT_TYPE_TOKENS = [ + ('function', 0), + ('command', 1), + ('str0', 2), + ('str1', 3), + ('str2', 4), + ('embedding0', 5), + ('embedding1', 6), + ('embedding2', 7), + ('arg0', 8), + ('arg1', 9), + ('arg2', 10), +] +DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS) + + +class Tokenizer(object): + """ + Tokenizer object that handles text tokenization, command tokens, and type tokens. + + Command tokens and text tokens are stored together in one mapping of size + `len(text_tokenizer)+len(command_tokens)`. Command tokens are stored as first + `len(command_tokens)` tokens. Token idx is stored at `idx+len(command_tokens)`. + + Token types are stored in a separate mapping of size `len(type_tokens)`. + """ + + def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None): + # set text tokenizer + self.text_tokenizer = text_tokenizer + if not hasattr(self, 'num_text_tokens'): + self.num_text_tokens = len(self.text_tokenizer) + + # set command tokens + if command_tokens is None: + command_tokens = DEFAULT_COMMAND_TOKENS + self._command_tokens = command_tokens + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = { + tok.token: tok + for tok in self._command_tokens + } + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + if not hasattr(self, 'num_command_tokens'): + self.num_command_tokens = len(self._command_tokens) + if not hasattr(self, 'num_tokens'): + self.num_tokens = self.num_command_tokens + self.num_text_tokens + + # set type tokens + if type_tokens is None: + type_tokens = DEFAULT_TYPE_TOKENS + self.type_tokens = type_tokens + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + if not hasattr(self, 'num_type_tokens'): + self.num_type_tokens = len(self.type_tokens) + + # parse tokens and vocabs from tokenizer + self._tokens = list(self.command_token_map.keys()) + list( + self.text_tokenizer.tokens) + self._vocab = {t: Id for Id, t in self.command_id_map.items()} + self._vocab.update({ + t: Id + self.num_command_tokens + for t, Id in self.text_tokenizer.vocab.items() + }) + + self._text_tokens = list(self.text_tokenizer.tokens) + self._text_token_vocab = { + t: Id + self.num_command_tokens + for t, Id in self.text_tokenizer.vocab.items() + } + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = { + t: Id + for Id, t in self.command_id_map.items() + } + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()} + + def __call__(self, text, process_fn=None): + """run preprocessing and encode text as Ids""" + return self.EncodeAsIds(text, process_fn=process_fn) + + def __len__(self): + """total number of tokens""" + return self.num_tokens + + def get_command(self, name): + """get command token corresponding to `name`""" + return self.command_name_map[name] + + def get_type(self, name): + """get type token corresponding to `name`""" + return self.type_name_map[name] + + @property + def tokens(self): + """list (or iterable) of all tokens for tokenizer""" + return self._tokens + + @property + def vocab(self): + """dictionary mapping tokens to ids for tokenizer""" + return self._vocab + + @property + def token_types(self): + """list (or iterable) of all token types for tokenizer""" + return self._token_types + + @property + def token_type_vocab(self): + """dictionary mapping token types to ids for tokenizer""" + return self._token_type_vocab + + @property + def command_tokens(self): + """list (or iterable) of all command tokens for tokenizer""" + return self._command_token_tokens + + @property + def command_token_vocab(self): + """dictionary mapping command tokens to ids for tokenizer""" + return self._command_token_vocab + + @property + def text_tokens(self): + """list (or iterable) of text tokens for text tokenizer""" + return self._text_tokens + + @property + def text_token_vocab(self): + """dictionary mapping text tokens to ids for text tokenizer""" + return self._text_token_vocab + + def EncodeAsIds(self, text, process_fn=None): + """ + encode text using text tokenizer and shift Id values for command tokens + """ + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + + def split_on_token(tok_extended: CommandToken, text): + result = [] + tok = tok_extended.token + split_text = text.split(tok) + for i, sub_text in enumerate(split_text): + # CommandToken can control whitespace stripping around them. + # We use them for GPT2 and Roberta to have different behavior depending on the special token + # Cf. https://github.com/huggingface/transformers/pull/2778 + # and https://github.com/huggingface/transformers/issues/3788 + # Strip white spaces on the right + if tok_extended.rstrip and i > 0: + # A bit counter-intuitive but we strip the left of the string + # since tok_extended.rstrip means the special token is eating all white spaces on its right + sub_text = sub_text.lstrip() + # Strip white spaces on the left + if tok_extended.lstrip and i < len(split_text) - 1: + sub_text = sub_text.rstrip() # Opposite here + + if i == 0 and not sub_text: + result.append(tok) + elif i == len(split_text) - 1: + if sub_text: + result.append(sub_text) + else: + pass + else: + if sub_text: + result.append(sub_text) + result.append(tok) + return result + + def split_on_tokens(tok_list, text): + if not text.strip(): + return [] + if not tok_list: + return self.text_tokenizer.encode(text) + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self._command_token_tokens: + tokenized_text.extend(split_on_token(tok, sub_text)) + else: + tokenized_text.append(sub_text) + text_list = tokenized_text + + return list( + itertools.chain.from_iterable( + (self._encode(token) + if token not in self._command_token_tokens else + [self.command_token_map[token].Id] + for token in tokenized_text))) + + no_split_tokens = self._command_tokens + Ids = split_on_tokens(no_split_tokens, processed_text) + tokenization = Tokenization(Ids, processed_text, text) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def _encode(self, text): + raise NotImplementedError + + def EncodeAsTokens(self, text, process_fn=None): + """ + encode text as tokens using text tokenizer + """ + tokenization = self.text_tokenizer.EncodeAsTokens( + text, process_fn=process_fn) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def IdToToken(self, Id, type_token=False): + """convert Id to token accounting for command and type tokens""" + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + if Id < self.num_command_tokens: + return self.command_id_map[Id].token + return self.text_tokenizer.IdToToken(Id - self.num_command_tokens) + + def TokenToId(self, token, type_token=False): + """convert token to Id accounting for command and type tokens""" + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + if token in self.command_token_map: + return self.command_token_map[token].Id + return self.text_tokenizer.TokenToId(token) + self.num_command_tokens + + def DecodeIds(self, Ids, type_token=False): + """ + convert Ids to tokens accounting for command and type tokens, tokens + are joined and returned as a string. + """ + if type_token: + return ' '.join( + Id.token if isinstance(Id, TypeToken) else self. + type_id_map[Id].token for Id in Ids) + rtn_strs = [] + current_str = [] + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + for Id in Ids: + if isinstance(Id, CommandToken): + rtn_strs.append(self.text_tokenizer.DecodeIds(current_str)) + current_str = [] + rtn_strs.append(Id.token) + elif Id < self.num_command_tokens: + rtn_strs.append(self.text_tokenizer.DecodeIds(current_str)) + current_str = [] + rtn_strs.append(self.command_id_map[Id].token) + else: + current_str.append(Id - self.num_command_tokens) + if current_str != []: + rtn_strs.append(self.text_tokenizer.DecodeIds(current_str)) + return ' '.join(rtn_strs) + + def DecodeTokens(self, Tokens, type_token=False): + """ + convert tokens to a string accounting for command and type tokens. + """ + if type_token: + return ' '.join( + t.token if isinstance(t, TypeToken) else t for t in Tokens) + rtn_strs = [] + current_str = [] + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + for t in Tokens: + if isinstance(t, CommandToken): + rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) + current_str = [] + rtn_strs.append(t.token) + elif t in self.command_token_map: + rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) + current_str = [] + rtn_strs.append(t) + else: + current_str.append(t) + if current_str != []: + rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) + return ' '.join(rtn_strs) + + +class TextTokenizer(object): + """ + Interface for text tokenizer + """ + + def __init__(self): + if not hasattr(self, 'num_text_tokens'): + self.num_text_tokens = 0 + if not hasattr(self, 'num_tokens'): + self.num_tokens = self.num_text_tokens + + def __call__(self, text, process_fn=None): + return self.EncodeAsIds(text, process_fn) + + def __len__(self): + return self.num_text_tokens + + @property + def tokens(self): + """list (or iterable) of text tokens for text tokenizer""" + raise NotImplementedError( + 'TextTokenizer tokens property not implemented') + + @property + def vocab(self): + """dictionary mapping tokens to ids""" + raise NotImplementedError( + 'TextTokenizer vocab property not implemented') + + @staticmethod + def exists(model_path): + """check if the filepath for a text tokenizer exists""" + raise NotImplementedError( + 'TextTokenizer exists method not implemented') + + def Train(self, corpus): + """train a tokenizer on a data corpus and save model for future use""" + raise NotImplementedError('TextTokenizer Train not implemented') + + def EncodeAsIds(self, text, process_fn=None): + """ + Preprocess text and encode as ids. Return a tokenization object with + original text, processed text, and id tokenization. + """ + raise NotImplementedError('TextTokenizer EncodeAsIds not implemented') + + def EncodeAsTokens(self, text, process_fn=None): + """ + Preprocess text and encode as tokens. Return a tokenization object with + original text, processed text, and token tokenization. + """ + raise NotImplementedError( + 'TextTokenizer EncodeAsTokens not implemented') + + def IdToToken(self, Id): + """Convert an Id to Token. Reverse lookup of self.vocab""" + raise NotImplementedError('TextTokenizer IdToToken not implemented') + + def TokenToId(self, token): + """Convert a Token to Id. Lookup of self.vocab""" + raise NotImplementedError('TextTokenizer TokenToId not implemented') + + def DecodeIds(self, Ids): + """Convert a list or tokenization object of Ids to a text string""" + raise NotImplementedError('TextTokenizer DecodeIds not implemented') + + def DecodeTokens(self, Tokens): + """Convert a list or tokenization object of tokens to a text string""" + raise NotImplementedError('TextTokenizer DecodeTokens not implemented') + + +class CharacterLevelTokenizer(TextTokenizer): + """ + Text tokenizer for ASCII-256 Character Level Tokenization. + """ + + def __init__(self, **kwargs): + self.num_text_tokens = 256 + super(CharacterLevelTokenizer, self).__init__() + self._tokens = [ + self.IdToToken(Id) for Id in range(self.num_text_tokens) + ] + self._vocab = {t: i for i, t in enumerate(self._tokens)} + + def __len__(self): + return 256 + + @staticmethod + def exists(model_path): + return True + + def Train(self, corpus): + pass + + @property + def tokens(self): + return self._tokens + + @property + def vocab(self): + return self._vocab + + def EncodeAsIds(self, text, process_fn=None): + """convert text to ascii 256 Ids""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + processed_text = str(processed_text) + tokens = [self.TokenToId(c) for c in processed_text] + return Tokenization(tokens, processed_text, text) + + def EncodeAsTokens(self, text, process_fn=None): + """convert text to ascii 256 characters""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + processed_text = str(processed_text) + tokens = [c for c in processed_text] + return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id): + """ascii index to character""" + return chr(Id) + + def TokenToId(self, token): + """ascii character to index""" + return ord(token) + + def DecodeIds(self, Ids): + """converts ascii ids to tokens before joining them into text""" + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + return ''.join([self.IdToToken(tok) for tok in Ids]) + + def DecodeTokens(self, Tokens): + """just concatenates ascii tokens into text""" + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return ''.join(Tokens) + + +MAX_SENTENCEPIECE_SENTENCES = 100000000 + + +def get_corpus_freq(dataset, filepath, filetype='tsv'): + """ + Take corpus, split it into sentences, and extract word frequencies. + Write frequencies to `filepath` as a tsv. Only write the first + MAX_SENTENCEPIECE_SENTENCES most common words to the file. + """ + nltk.download('punkt', download_dir='./nltk') + if filetype == 'tsv': + delimiter = '\t' + else: + delimiter = ',' + + print('compute corpus frequency\n', flush=True) + + total_sentence_count = 0 + maxlen = 0 + freqs = {} + for entry in dataset: + if isinstance(entry, dict): + entry = entry['text'] + lines = entry.strip().split('\n') + for line in lines: + sentences = nltk_tokenize.sent_tokenize(line) + total_sentence_count += len(sentences) + for sentence in sentences: + maxlen = max(len(line), maxlen) + for word in sentence.split(): + if word not in freqs: + freqs[word] = 0 + freqs[word] += 1 + + print('length of freqs before truncating ' + str(len(freqs)), flush=True) + print('file path for freq ' + str(filepath), flush=True) + + freqs_sorted = {} + counter = 0 + for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True): + if counter >= MAX_SENTENCEPIECE_SENTENCES: + break + counter += 1 + freqs_sorted[word] = count + + print( + 'length of freqs after trancating ' + str(len(freqs_sorted)), + flush=True) + + with open(filepath, 'w') as f: + writer = csv.writer(f, delimiter=delimiter) + for k, v in freqs_sorted.items(): + writer.writerow([str(k), str(v)]) + + return total_sentence_count, maxlen + + +class SentencePieceTokenizer(TextTokenizer): + """Trains and uses sentencepiece for text tokenization""" + + def __init__(self, + model_type='bpe', + vocab_size=None, + corpus=None, + model_path=None, + character_coverage=1.0, + **kwargs): + self.character_coverage = character_coverage + self.model_type = model_type.lower() + self.spm_model = model_path + self.num_text_tokens = vocab_size + make_train = not SentencePieceTokenizer.exists(self.spm_model) + if make_train: + assert corpus is not None and self.num_text_tokens is not None + self.Train(corpus, self.num_text_tokens) + self._tokens = [] + self._vocab = {} + self.load_spm_model() + super(SentencePieceTokenizer, self).__init__() + + def __len__(self): + return self.num_text_tokens + + @property + def tokens(self): + return self._tokens + + @property + def vocab(self): + return self._vocab + + @staticmethod + def exists(model_path): + if model_path is None: + return False + # check if path exists + dne = not os.path.exists(model_path) + # check if path.model exists + if dne and not model_path.endswith('.model'): + dne = not os.path.exists(model_path + '.model') + return not dne + + def load_spm_model(self): + """load sentencepiece model and parse vocab""" + if not os.path.exists( + self.spm_model) and not self.spm_model.endswith('.model'): + self.spm_model = self.spm_model + '.model' + self.sp = spm.SentencePieceProcessor() + self.sp.Load(self.spm_model) + self.vocab_size = self.num_text_tokens = len(self.sp) + self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)] + self._vocab = {t: i for i, t in enumerate(self._tokens)} + + def Train(self, corpus, num_text_tokens): + """train sentencepiece model on corpus using word frequencies""" + self.num_text_tokens = num_text_tokens + use_model_path = self.spm_model + random_hash = str(random.randint(0, 2147483647)) + if use_model_path is None: + use_model_path = random_hash + if use_model_path.endswith('.model'): + use_model_path = use_model_path[:use_model_path.rfind('.model')] + input_path = use_model_path + '.tsv.' + random_hash + line_count, maxlenline = get_corpus_freq(corpus, input_path) + line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES) + print( + 'line count used as input_sentence_size ', line_count, flush=True) + print('training sentencepiece model', flush=True) + train_string = '--input={file_path} --model_prefix={model_prefix} --vocab_size={vocab_size}' \ + + ' --model_type={model_type} --character_coverage={character_coverage} ' \ + + '--input_sentence_size={input_sentence_size} ' \ + + '--input_format=tsv' + train_string = train_string.format( + file_path=input_path, + model_prefix=use_model_path, + vocab_size=num_text_tokens, + model_type=self.model_type, + character_coverage=self.character_coverage, + input_sentence_size=int(line_count)) # , #)#, + print( + 'calling spm.SentencePieceTrainer.Train(%s)' % (train_string), + flush=True) + spm.SentencePieceTrainer.Train(train_string) + os.remove(input_path) + self.spm_model = use_model_path + '.model' + print('sentencepiece model written to ' + self.spm_model, flush=True) + + def EncodeAsIds(self, text, process_fn=None): + """convert text to sentencepiece Ids""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.sp.EncodeAsIds(processed_text) + return Tokenization(tokens, processed_text, text) + + def EncodeAsTokens(self, text, process_fn=None): + """convert text to sentencepiece tokens""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.sp.EncodeAsTokens(processed_text) + return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id): + """convert Id to sentencpiece token""" + return self.sp.IdToPiece(Id) + + def TokenToId(self, token): + """convert sentencpiece token to Id""" + return self.sp.PieceToId(token) + + def DecodeIds(self, Ids): + """converts ids to a text string""" + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + return self.sp.DecodeIds(Ids) + + def DecodeTokens(self, Tokens): + """converts sentencepiece tokens to a text string""" + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return self.sp.DecodeTokens(Tokens) + + +class BertWordPieceTokenizer(Tokenizer): + """ + Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization + in BERT training. Default to bert-large-uncased tokenizer. + """ + + def __init__(self, + tokenizer_model_type=None, + cache_dir=None, + add_block_symbols=False, + add_sentinel_token=0, + add_task_mask=False, + add_decoder_mask=False, + **kwargs): + # default to bert-large-uncased tokenizer + if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP: + tokenizer_model_type = 'bert-large-uncased' + if not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0: + print('loading BertWordPieceTokenizer (', tokenizer_model_type, + ') from cache_dir ', cache_dir) + do_lower_case = not ('-cased' in tokenizer_model_type + or 'chinese' in tokenizer_model_type) + self.text_tokenizer = BertTokenizer.from_pretrained( + tokenizer_model_type, + do_lower_case=do_lower_case, + cache_dir=cache_dir) + if not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0: + print('loaded', tokenizer_model_type) + # disable max len warnings by increasing max len + self.text_tokenizer.max_len = int(1e12) + + # set command tokens from wordpiece tokenizer values + self.num_command_tokens = 6 + self.num_tokens = len(self.text_tokenizer.vocab) + self.num_text_tokens = self.num_tokens - 5 + self.num_type_tokens = 2 + + self._command_tokens = [ + CommandToken('pad', '[PAD]', self.text_tokenizer.vocab['[PAD]']), + CommandToken('ENC', '[CLS]', self.text_tokenizer.vocab['[CLS]']), + CommandToken('MASK', '[MASK]', + self.text_tokenizer.vocab['[MASK]']), + CommandToken('unk', '[UNK]', self.text_tokenizer.vocab['[UNK]']), + CommandToken('sep', '[SEP]', self.text_tokenizer.vocab['[SEP]']), + CommandToken('eos', '[PAD]', self.text_tokenizer.vocab['[PAD]']), + ] + if add_block_symbols: + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', self.num_tokens), + CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + if add_task_mask: + self._command_tokens.extend([ + CommandToken('gMASK', '[gMASK]', self.num_tokens), + CommandToken('sMASK', '[sMASK]', self.num_tokens + 1) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + if add_decoder_mask: + self._command_tokens.extend( + [CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)]) + self.num_tokens += 1 + self.num_command_tokens += 1 + if add_sentinel_token > 0: + for i in range(1, add_sentinel_token): + self._command_tokens.extend([ + CommandToken(f'MASK{i}', f'[MASK{i}]', self.num_tokens), + CommandToken(f'sop{i}', f'<|startofpiece{i}|>', + self.num_tokens + 1) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = { + tok.token: tok + for tok in self._command_tokens + } + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + + # set type tokens + self.type_tokens = [ + TypeToken('str0', '', 0), + TypeToken('str1', '', 1), + ] + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + + # parse tokens and vocabs from tokenizer + + self._tokens = list(self.text_tokenizer.vocab.keys()) + self._vocab = {k: v for k, v in self.text_tokenizer.vocab.items()} + + self._text_tokens = list(self._tokens) + self._text_token_vocab = { + k: v + for k, v in self.text_tokenizer.vocab.items() + } + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = { + t: Id + for Id, t in self.command_id_map.items() + } + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()} + + def _encode(self, text): + tokens = self.text_tokenizer.tokenize(text) + ids = self.text_tokenizer.convert_tokens_to_ids(tokens) + return ids + + def EncodeAsTokens(self, text, process_fn=None): + """convert wordpiece token to Id""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.text_tokenizer.tokenize(processed_text) + return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id, type_token=False): + """convert Id to sentencpiece token""" + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + if Id in self.command_id_map: + return self.command_id_map[Id].token + return self.text_tokenizer.ids_to_tokens[Id] + + def TokenToId(self, token, type_token=False): + """convert sentencpiece token to Id""" + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + return self.text_tokenizer.vocab[token] + + def DecodeIds(self, Ids, type_token=False): + """converts ids to wordpiece tokens and joins them as a text string""" + if type_token: + return ' '.join( + Id.token if isinstance(Id, TypeToken) else self. + type_id_map[Id].token for Id in Ids) + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + Tokens = [] + for Id in Ids: + if Id in self.command_id_map: + Tokens.append(self.command_id_map[Id].token) + elif Id in self.text_tokenizer.ids_to_tokens: + Tokens.append(self.text_tokenizer.ids_to_tokens[Id]) + new_tokens = [] + for token in Tokens: + if token.startswith('##') and len(new_tokens) > 0: + new_tokens[-1] += token[2:] + else: + new_tokens.append(token) + return ' '.join(new_tokens) + + def DecodeTokens(self, Tokens, type_token=False): + """converts wordpiece tokens to a text string""" + if type_token: + return ' '.join( + t.token if isinstance(t, TypeToken) else t for t in Tokens) + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return ' '.join(Tokens) + + +class GPT2BPETokenizer(Tokenizer): + + def __init__(self, + model_type_or_path, + cache_dir=None, + add_block_symbols=False, + add_task_mask=False, + add_decoder_mask=False, + **kwargs): + self.text_tokenizer = GPT2Tokenizer.from_pretrained( + model_type_or_path, cache_dir=cache_dir) + + # disable max len warnings by increasing max len + self.text_tokenizer.max_len = int(1e12) + self.num_tokens = len(self.text_tokenizer.encoder) + self.num_type_tokens = 2 + if model_type_or_path.startswith('roberta'): + self.num_command_tokens = 6 + self.num_text_tokens = self.num_tokens - 3 + self._command_tokens = [ + CommandToken('pad', '<|endoftext|>', + self.text_tokenizer.encoder['']), + CommandToken('eos', '<|endoftext|>', + self.text_tokenizer.encoder['']), + CommandToken('sep', '[SEP]', + self.text_tokenizer.encoder['']), + CommandToken('ENC', '[CLS]', + self.text_tokenizer.encoder['']), + CommandToken( + 'MASK', + '[MASK]', + self.text_tokenizer.encoder[''], + lstrip=True), + CommandToken('unk', '[UNK]', + self.text_tokenizer.encoder['']) + ] + if add_block_symbols: + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', self.num_tokens), + CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + else: + self.num_command_tokens = 2 + self.num_text_tokens = self.num_tokens - 1 + self._command_tokens = [ + CommandToken('pad', '<|endoftext|>', + self.text_tokenizer.encoder['<|endoftext|>']), + CommandToken('eos', '<|endoftext|>', + self.text_tokenizer.encoder['<|endoftext|>']) + ] + if add_block_symbols: + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', self.num_tokens), + CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1), + CommandToken('ENC', '[CLS]', self.num_tokens + 2), + CommandToken( + 'MASK', '[MASK]', self.num_tokens + 3, lstrip=True), + CommandToken('sep', '[SEP]', self.num_tokens + 4), + CommandToken('unk', '[UNK]', self.num_tokens + 5) + ]) + self.num_tokens += 6 + self.num_command_tokens += 6 + if add_block_symbols: + if add_task_mask: + self._command_tokens.extend([ + CommandToken( + 'gMASK', '[gMASK]', self.num_tokens, lstrip=True), + CommandToken( + 'sMASK', '[sMASK]', self.num_tokens + 1, lstrip=True) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + if add_decoder_mask: + self._command_tokens.extend( + [CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)]) + self.num_tokens += 1 + self.num_command_tokens += 1 + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = { + tok.token: tok + for tok in self._command_tokens + } + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + + self.type_tokens = [ + TypeToken('str0', '', 0), + TypeToken('str1', '', 1), + ] + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + + self._tokens = list(self.text_tokenizer.encoder.keys()) + self._vocab = {k: v for k, v in self.text_tokenizer.encoder.items()} + + self._text_tokens = list(self._tokens) + self._text_token_vocab = { + k: v + for k, v in self.text_tokenizer.encoder.items() + } + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = { + t: Id + for Id, t in self.command_id_map.items() + } + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()} + + for idx, tok in self.command_id_map.items(): + self.text_tokenizer.decoder[idx] = tok.token + + def EncodeAsIds(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + + def split_on_token(tok_extended: CommandToken, text): + result = [] + tok = tok_extended.token + split_text = text.split(tok) + for i, sub_text in enumerate(split_text): + # CommandToken can control whitespace stripping around them. + # We use them for GPT2 and Roberta to have different behavior depending on the special token + # Cf. https://github.com/huggingface/transformers/pull/2778 + # and https://github.com/huggingface/transformers/issues/3788 + # Strip white spaces on the right + if tok_extended.rstrip and i > 0: + # A bit counter-intuitive but we strip the left of the string + # since tok_extended.rstrip means the special token is eating all white spaces on its right + sub_text = sub_text.lstrip() + # Strip white spaces on the left + if tok_extended.lstrip and i < len(split_text) - 1: + sub_text = sub_text.rstrip() # Opposite here + + if i == 0 and not sub_text: + result.append(tok) + elif i == len(split_text) - 1: + if sub_text: + result.append(sub_text) + else: + pass + else: + if sub_text: + result.append(sub_text) + result.append(tok) + return result + + def split_on_tokens(tok_list, text): + if not text.strip(): + return [] + if not tok_list: + return self.text_tokenizer.encode(text) + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self._command_token_tokens: + tokenized_text.extend(split_on_token(tok, sub_text)) + else: + tokenized_text.append(sub_text) + text_list = tokenized_text + + return list( + itertools.chain.from_iterable( + (self.text_tokenizer.encode(token) + if token not in self._command_token_tokens else + [self.command_token_map[token].Id] + for token in tokenized_text))) + + no_split_tokens = self._command_tokens + Ids = split_on_tokens(no_split_tokens, processed_text) + tokenization = Tokenization(Ids, processed_text, text) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def _encode(self, text): + return self.text_tokenizer.encode(text) + + def EncodeAsTokens(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = [] + for token in re.findall(self.text_tokenizer.pat, processed_text): + token = ''.join(self.text_tokenizer.bye_encoder[b] + for b in token.encode('utf-8')) + tokens.extend( + bpe_token + for bpe_token in self.text_tokenizer.bpe(token).split(' ')) + tokenization = Tokenization(tokens, processed_text, text, asIds=False) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def DecodeAsTokens(self, Ids): + return [self.IdToToken(x) for x in Ids] + + def IdToToken(self, Id, type_token=False): + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + if Id in self.command_id_map: + return self.command_id_map[Id].token + return self.text_tokenizer.decoder[Id] + + def TokenToId(self, token, type_token=False): + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + return self.text_tokenizer.encoder[token] + + def DecodeIds(self, Ids, type_token=False): + if type_token: + return ' '.join( + Id.token if isinstance(Id, TypeToken) else self. + type_id_map[Id].token for Id in Ids) + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + return self.text_tokenizer.decode(Ids) + + def DecodeTokens(self, Tokens, type_token=False): + if type_token: + return ' '.join( + t.token if isinstance(t, TypeToken) else t for t in Tokens) + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return self.text_tokenizer.decode( + [self.TokenToId(tok) for tok in Tokens]) + + +class ChineseSPTokenizer(Tokenizer): + + def __init__(self, + model_path, + add_block_symbols=False, + add_task_mask=False, + add_decoder_mask=False, + **kwargs): + self.text_tokenizer = sp_tokenizer.from_pretrained(model_path) + + self.num_command_tokens = 0 + self.num_text_tokens = self.text_tokenizer.sp.vocab_size() + self.num_tokens = self.num_text_tokens + self.num_type_tokens = 2 + + self._command_tokens = [ + CommandToken('pad', '<|endoftext|>', self.num_text_tokens), + CommandToken('eos', '<|endoftext|>', self.num_text_tokens), + CommandToken('sep', '[SEP]', self.num_text_tokens + 1), + CommandToken('ENC', '[CLS]', self.num_text_tokens + 2), + CommandToken( + 'MASK', '[MASK]', self.num_text_tokens + 3, lstrip=True), + CommandToken('unk', '[UNK]', self.num_text_tokens + 4) + ] + self.num_tokens += 5 + self.num_command_tokens += 6 + if add_block_symbols: + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', self.num_tokens + 1), + CommandToken('eop', '<|endofpiece|>', self.num_tokens + 2) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + if add_task_mask: + self._command_tokens.extend([ + CommandToken( + 'gMASK', '[gMASK]', self.num_tokens, lstrip=True), + CommandToken( + 'sMASK', '[sMASK]', self.num_tokens + 1, lstrip=True) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + if add_decoder_mask: + self._command_tokens.extend( + [CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)]) + self.num_tokens += 1 + self.num_command_tokens += 1 + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = { + tok.token: tok + for tok in self._command_tokens + } + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + + self.type_tokens = [ + TypeToken('str0', '', 0), + TypeToken('str1', '', 1), + ] + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + + # self._tokens = list(self.text_tokenizer.encoder.keys()) + # self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()} + # + # self._text_tokens = list(self._tokens) + # self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()} + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = { + t: Id + for Id, t in self.command_id_map.items() + } + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()} + + def _encode(self, text): + ids = self.text_tokenizer.encode(text) + return ids + + def EncodeAsTokens(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.text_tokenizer.tokenize(processed_text) + tokenization = Tokenization(tokens, processed_text, text, asIds=False) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + # return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id, type_token=False): + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + if Id in self.command_id_map: + return self.command_id_map[Id].token + elif Id in self.type_id_map: + return self.type_id_map[Id].token + else: + return self.text_tokenizer.convert_id_to_token(int(Id)) + + def TokenToId(self, token, type_token=False): + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + return self.text_tokenizer.convert_token_to_id(token) + + def DecodeIds(self, Ids, type_token=False): + if type_token: + return ' '.join( + Id.token if isinstance(Id, TypeToken) else self. + type_id_map[Id].token for Id in Ids) + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + Ids = list(map(int, Ids)) + pieces = [] + last = 0 + for i, token_id in enumerate(Ids): + if token_id in self.command_id_map: + pieces.append(Ids[last:i]) + pieces.append(token_id) + last = i + 1 + pieces.append(Ids[last:]) + text = '' + for piece in pieces: + if isinstance(piece, int): + text += self.command_id_map[piece].token + elif piece: + text += self.text_tokenizer.decode(piece) + return text + + def DecodeTokens(self, Tokens, type_token=False): + if type_token: + return ' '.join( + t.token if isinstance(t, TypeToken) else t for t in Tokens) + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return self.text_tokenizer.decode( + [self.TokenToId(tok) for tok in Tokens]) diff --git a/modelscope/models/nlp/mglm/data_utils/tokenization_gpt2.py b/modelscope/models/nlp/mglm/data_utils/tokenization_gpt2.py new file mode 100644 index 00000000..d179e055 --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/tokenization_gpt2.py @@ -0,0 +1,359 @@ +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI GPT.""" +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import logging +import os +import sys +from io import open + +import json +import regex as re + +from .file_utils import cached_path + +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. + def lru_cache(): + return lambda func: func + + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'gpt2': '.pytorch_pretrained_bert/gpt2-vocab.json', + 'roberta': '.pytorch_pretrained_bert/roberta-vocab.json' +} +PRETRAINED_MERGES_ARCHIVE_MAP = { + 'gpt2': '.pytorch_pretrained_bert/gpt2-merges.txt', + 'roberta': '.pytorch_pretrained_bert/roberta-merges.txt' +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'gpt2': 1024, +} +VOCAB_NAME = 'vocab.json' +MERGES_NAME = 'merges.txt' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + _chr = unichr if sys.version_info[0] == 2 else chr + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [_chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class GPT2Tokenizer(object): + """ + GPT-2 BPE tokenizer. Peculiarities: + - Byte-level BPE + """ + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path, + cache_dir=None, + *inputs, + **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ + pretrained_model_name_or_path] + merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[ + pretrained_model_name_or_path] + special_tokens_file = None + else: + vocab_file = os.path.join(pretrained_model_name_or_path, + VOCAB_NAME) + merges_file = os.path.join(pretrained_model_name_or_path, + MERGES_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, + SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info('loading special tokens file {}'.format( + special_tokens_file)) + # redirect to the cache, if necessary + # try: + # resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + # resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) + # except EnvironmentError: + # logger.error( + # "Model name '{}' was not found in model name list ({}). " + # "We assumed '{}' was a path or url but couldn't find files {} and {} " + # "at this path or url.".format( + # pretrained_model_name_or_path, + # ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + # pretrained_model_name_or_path, + # vocab_file, merges_file)) + # return None + # if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: + # logger.info("loading vocabulary file {}".format(vocab_file)) + # logger.info("loading merges file {}".format(merges_file)) + # else: + # logger.info("loading vocabulary file {} from cache at {}".format( + # vocab_file, resolved_vocab_file)) + # logger.info("loading merges file {} from cache at {}".format( + # merges_file, resolved_merges_file)) + resolved_vocab_file = vocab_file + resolved_merges_file = merges_file + logger.info('loading vocabulary file {}'.format(vocab_file)) + logger.info('loading merges file {}'.format(merges_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ + pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open( + special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls( + resolved_vocab_file, + resolved_merges_file, + special_tokens=special_tokens, + *inputs, + **kwargs) + return tokenizer + + def __init__(self, + vocab_file, + merges_file, + errors='replace', + special_tokens=None, + max_len=None): + self.max_len = max_len if max_len is not None else int(1e12) + self.encoder = json.load(open(vocab_file)) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_data] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + + self.special_tokens = {} + self.special_tokens_decoder = {} + self.set_special_tokens(special_tokens) + + def __len__(self): + return len(self.encoder) + len(self.special_tokens) + + def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return + self.special_tokens = dict((tok, len(self.encoder) + i) + for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = { + v: k + for k, v in self.special_tokens.items() + } + logger.info('Special tokens {}'.format(self.special_tokens)) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: # noqa + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def tokenize(self, text): + """ Tokenize a string. """ + bpe_tokens = [] + for token in re.findall(self.pat, text): + if sys.version_info[0] == 2: + token = ''.join(self.byte_encoder[ord(b)] for b in token) + else: + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend( + bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + """ Converts a sequence of tokens into ids using the vocab. """ + ids = [] + if isinstance(tokens, str) or (sys.version_info[0] == 2 + and isinstance(tokens, unicode)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.encoder.get(tokens, 0) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.encoder.get(token, 0)) + if len(ids) > self.max_len: + logger.warning( + 'Token indices sequence length is longer than the specified maximum ' + ' sequence length for this OpenAI GPT model ({} > {}). Running this' + ' sequence through the model will result in indexing errors'. + format(len(ids), self.max_len)) + return ids + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a sequence of ids in BPE tokens using the vocab.""" + tokens = [] + for i in ids: + if i in self.special_tokens_decoder: + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.decoder[i]) + return tokens + + def encode(self, text): + return self.convert_tokens_to_ids(self.tokenize(text)) + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors=self.errors) + return text + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary and merge files to a directory.""" + if not os.path.isdir(vocab_path): + logger.error('Vocabulary path ({}) should be a directory'.format( + vocab_path)) + return + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + merge_file = os.path.join(vocab_path, MERGES_NAME) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + + index = 0 + with open(merge_file, 'w', encoding='utf-8') as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted( + self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + 'Saving vocabulary to {}: BPE merge indices are not consecutive.' + ' Please check that the tokenizer is not corrupted!'. + format(merge_file)) + index = token_index + writer.write(' '.join(bpe_tokens) + u'\n') + index += 1 + + index = len(self.encoder) + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token, token_index in sorted( + self.special_tokens.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + 'Saving special tokens vocabulary to {}: BPE indices are not consecutive.' + ' Please check that the tokenizer is not corrupted!'. + format(special_tokens_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + + return vocab_file, merge_file, special_tokens_file diff --git a/modelscope/models/nlp/mglm/data_utils/wordpiece.py b/modelscope/models/nlp/mglm/data_utils/wordpiece.py new file mode 100755 index 00000000..1cecffbd --- /dev/null +++ b/modelscope/models/nlp/mglm/data_utils/wordpiece.py @@ -0,0 +1,408 @@ +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py""" # noqa + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import collections +import logging +import os +import unicodedata +from io import open + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': + '.pytorch_pretrained_bert/bert-base-uncased-vocab.txt', + 'bert-large-uncased': + '.pytorch_pretrained_bert/bert-large-uncased-vocab.txt', + 'bert-base-cased': + '.pytorch_pretrained_bert/bert-base-cased-vocab.txt', + 'bert-large-cased': + '.pytorch_pretrained_bert/bert-large-cased-vocab.txt', + 'bert-base-multilingual-uncased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt', + 'bert-base-multilingual-cased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt', + 'bert-base-chinese': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt', +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'bert-base-uncased': 512, + 'bert-large-uncased': 512, + 'bert-base-cased': 512, + 'bert-large-cased': 512, + 'bert-base-multilingual-uncased': 512, + 'bert-base-multilingual-cased': 512, + 'bert-base-chinese': 512, +} +VOCAB_NAME = 'vocab.txt' + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, 'r', encoding='utf-8') as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, + vocab_file, + do_lower_case=True, + max_len=None, + do_basic_tokenize=True, + never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + .format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([ + (ids, tok) for tok, ids in self.vocab.items() + ]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + if self.do_basic_tokenize: + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + logger.warning( + 'Token indices sequence length is longer than the specified maximum ' + ' sequence length for this BERT model ({} > {}). Running this' + ' sequence through BERT will result in indexing errors'.format( + len(ids), self.max_len)) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path, + cache_dir=None, + *inputs, + **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ + pretrained_model_name_or_path] + else: + vocab_file = pretrained_model_name_or_path + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + 'associated to this path or url.'.format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + vocab_file)) + return None + if resolved_vocab_file == vocab_file: + logger.info('loading vocabulary file {}'.format(vocab_file)) + else: + logger.info('loading vocabulary file {} from cache at {}'.format( + vocab_file, resolved_vocab_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ + pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) + return tokenizer + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(' '.join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize('NFD', text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == 'Mn': + continue + output.append(char) + return ''.join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return [''.join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(' ') + output.append(char) + output.append(' ') + else: + output.append(char) + return ''.join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # noqa + (cp >= 0x3400 and cp <= 0x4DBF) or # noqa + (cp >= 0x20000 and cp <= 0x2A6DF) or # noqa + (cp >= 0x2A700 and cp <= 0x2B73F) or # noqa + (cp >= 0x2B740 and cp <= 0x2B81F) or # noqa + (cp >= 0x2B820 and cp <= 0x2CEAF) or # noqa + (cp >= 0xF900 and cp <= 0xFAFF) or # noqa + (cp >= 0x2F800 and cp <= 0x2FA1F)): # noqa + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(' ') + else: + output.append(char) + return ''.join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = ''.join(chars[start:end]) + if start > 0: + substr = '##' + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == ' ' or char == '\t' or char == '\n' or char == '\r': + return True + cat = unicodedata.category(char) + if cat == 'Zs': + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == '\t' or char == '\n' or char == '\r': + return False + cat = unicodedata.category(char) + if cat.startswith('C'): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) + or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith('P'): + return True + return False diff --git a/modelscope/models/nlp/mglm/fp16/__init__.py b/modelscope/models/nlp/mglm/fp16/__init__.py new file mode 100644 index 00000000..90d20bcf --- /dev/null +++ b/modelscope/models/nlp/mglm/fp16/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .fp16 import * # noqa +from .fp16util import (BN_convert_float, FP16Model, clip_grad_norm, + convert_module, convert_network, + master_params_to_model_params, + model_grads_to_master_grads, network_to_half, + prep_param_lists, to_python_float, tofp16) +from .loss_scaler import * # noqa diff --git a/modelscope/models/nlp/mglm/fp16/fp16.py b/modelscope/models/nlp/mglm/fp16/fp16.py new file mode 100755 index 00000000..10fbd804 --- /dev/null +++ b/modelscope/models/nlp/mglm/fp16/fp16.py @@ -0,0 +1,660 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stable version of apex FP16 Optimizer""" +import torch +from torch import nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Variable +from torch.nn.parameter import Parameter + +from .fp16util import (clip_grad_norm, master_params_to_model_params, + model_grads_to_master_grads) +from .loss_scaler import DynamicLossScaler, LossScaler + +FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_fp16(val): + """Convert fp32 `val` to fp16""" + + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, FLOAT_TYPES): + val = val.half() + return val + + return conversion_helper(val, half_conversion) + + +def fp16_to_fp32(val): + """Convert fp16 `val` to fp32""" + + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, HALF_TYPES): + val = val.float() + return val + + return conversion_helper(val, float_conversion) + + +class FP16_Module(nn.Module): + + def __init__(self, module): + super(FP16_Module, self).__init__() + self.add_module('module', module.half()) + + def forward(self, *inputs, **kwargs): + return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.module.named_parameters(prefix=prefix, recurse=recurse) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + def load_state_dict(self, state_dict, strict=True): + return self.module.load_state_dict(state_dict, strict=strict) + + +# TODO: Update overflow check + downscale to use Carl's fused kernel. +class FP16_Optimizer(object): + """ + :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, + and manage static or dynamic loss scaling and master weights in a manner transparent to the user. + For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, + and changing the call to ``backward``. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + # Name the FP16_Optimizer instance to replace the existing optimizer + # (recommended but not required): + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + # loss.backward() becomes: + optimizer.backward(loss) + ... + + Example with dynamic loss scaling:: + + ... + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) + # optional arg to control dynamic loss scaling behavior + # dynamic_loss_args={'scale_window' : 500}) + # Usually, dynamic_loss_args is not necessary. + + Args: + init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. + static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. + dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. + dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. + verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. + + ``init_optimizer`` is expected to have been constructed in the ordinary way. + It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be + named to replace ``init_optimizer``, for two reasons: + First, it means that references to the same name + later in the file will not have to change. + Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to + modify ``init_optimizer``. If you do choose a unique name for the new + :class:`FP16_Optimizer` instance, you should only work with this new instance, + because the preexisting optimizer might no longer behave as expected. + + ``init_optimizer`` may be any Pytorch optimizer. + It may contain a mixture of fp16 and fp32 parameters organized into any number of + ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will + ingest these ``param_groups`` and remember them. + + Calls to :: + + loss.backward() + + must be replaced with :: + + optimizer.backward(loss) + + because :class:`FP16_Optimizer` requires ownership of the backward pass to implement + loss scaling and copies to master gradients. + + .. note:: + Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients + are downscaled before being applied. This means that adjusting the loss scale, or using + dynamic loss scaling, should not require retuning the learning rate or any other + hyperparameters. + + + **Advanced options** + + **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. + See docstring for :attr:`step`. + + **Gradient clipping**: Use :attr:`clip_master_grads`. + + **Multiple losses**: If your model accumulates gradients from multiple losses, + this can be made more efficient by supplying ``update_master_grads=False`` + to :attr:`backward`. See docstring for :attr:`backward`. + + **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: + + print(optimizer.loss_scale) + optimizer.loss_scale = new_loss_scale + + For static loss scaling, manually adjusting the loss scale over time is a reasonable + thing to do. During later epochs, gradients may become smaller, and a + higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss + scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting + the loss scale is not recommended. + + **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in + Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` + should still work as intended. + """ # noqa + + def __init__(self, + init_optimizer, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=False): + if not torch.cuda.is_available: + raise SystemError('Cannot use fp16 without CUDA.') + + self.verbose = verbose + + self.optimizer = init_optimizer + # init_state_dict sets up an alternative way to cast per-param state tensors. + # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. + # init_state_dict = init_optimizer.state_dict() + + self.fp16_groups = [] + self.fp32_from_fp16_groups = [] + self.fp32_from_fp32_groups = [] + for i, param_group in enumerate(self.optimizer.param_groups): + self.maybe_print( + 'FP16_Optimizer processing param group {}:'.format(i)) + fp16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_fp16_params_this_group = [] + for i, param in enumerate(param_group['params']): + if param.requires_grad: + if param.type() == 'torch.cuda.HalfTensor': + self.maybe_print( + 'FP16_Optimizer received torch.cuda.HalfTensor with {}' + .format(param.size())) + fp16_params_this_group.append(param) + master_param = param.detach().clone().float() + master_param.requires_grad = True + # Copythe model parallel flag. + master_param.model_parallel = param.model_parallel + param_group['params'][i] = master_param + fp32_from_fp16_params_this_group.append(master_param) + # Reset existing state dict key to the new master param. + # We still need to recast per-param state tensors, if any, to FP32. + if param in self.optimizer.state: + self.optimizer.state[ + master_param] = self.optimizer.state.pop(param) + elif param.type() == 'torch.cuda.FloatTensor': + self.maybe_print( + 'FP16_Optimizer received torch.cuda.FloatTensor with {}' + .format(param.size())) + fp32_params_this_group.append(param) + param_group['params'][i] = param + else: + raise TypeError( + 'Wrapped parameters must be either ' + 'torch.cuda.FloatTensor or torch.cuda.HalfTensor. ' + 'Received {}'.format(param.type())) + + self.fp16_groups.append(fp16_params_this_group) + self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors + self.optimizer.load_state_dict(self.optimizer.state_dict()) + # alternative way to cast per-param state tensors: + # self.optimizer.load_state_dict(init_state_dict) + + if dynamic_loss_scale: + self.dynamic_loss_scale = True + if dynamic_loss_args is not None: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + else: + self.loss_scaler = DynamicLossScaler() + else: + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(static_loss_scale) + + self.overflow = False + self.first_closure_call_this_step = True + + self.clip_grad_norm = clip_grad_norm + + def maybe_print(self, msg): + if self.verbose: + print(msg) + + def __getstate__(self): + raise RuntimeError( + 'FP16_Optimizer should be serialized using state_dict().') + + def __setstate__(self, state): + raise RuntimeError( + 'FP16_Optimizer should be deserialized using load_state_dict().') + + def zero_grad(self, set_grads_to_None=False): + """ + Zero fp32 and fp16 parameter grads. + """ + # In principle, only the .grad attributes of the model params need to be zeroed, + # because gradients are copied into the FP32 master params. However, we zero + # all gradients owned by the optimizer, just to be safe: + for group in self.optimizer.param_groups: + for p in group['params']: + if set_grads_to_None: + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + # Zero fp16 gradients owned by the model: + for fp16_group in self.fp16_groups: + for param in fp16_group: + if set_grads_to_None: + param.grad = None + else: + if param.grad is not None: + param.grad.detach_( + ) # as in torch.optim.optimizer.zero_grad() + param.grad.zero_() + + def _check_overflow(self): + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + for group in self.fp32_from_fp32_groups: + for param in group: + params.append(param) + self.overflow = self.loss_scaler.has_overflow(params) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + def _master_params_to_model_params(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp16_group, fp32_from_fp16_group) + + def _model_params_to_master_params(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp32_from_fp16_group, fp16_group) + + # To consider: Integrate distributed with this wrapper by registering a hook on each variable + # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. + def _model_grads_to_master_grads(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) + + def _downscale_master(self): + if self.loss_scale != 1.0: + for group in self.optimizer.param_groups: + for param in group['params']: + if param.grad is not None: + param.grad.data.mul_(1. / self.loss_scale) + + def clip_master_grads(self, max_norm, norm_type=2): + """ + Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the current fp32 gradients (viewed as a single vector). + + .. warning:: + Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). + """ # noqa + if not self.overflow: + fp32_params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + fp32_params.append(param) + return self.clip_grad_norm(fp32_params, max_norm, norm_type) + else: + return -1 + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + state_dict = {} + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict[ + 'first_closure_call_this_step'] = self.first_closure_call_this_step + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups + return state_dict + + def load_state_dict(self, state_dict): + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + self.first_closure_call_this_step = state_dict[ + 'first_closure_call_this_step'] + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # At this point, the optimizer's references to the model's fp32 parameters are up to date. + # The optimizer's hyperparameters and internal buffers are also up to date. + # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still + # out of date. There are two options. + # 1: Refresh the master params from the model's fp16 params. + # This requires less storage but incurs precision loss. + # 2: Save and restore the fp32 master copies separately. + # We choose option 2. + # + # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device + # of their associated parameters, because it's possible those buffers might not exist yet in + # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been + # constructed in the same way as the one whose state_dict we are loading, the same master params + # are guaranteed to exist, so we can just copy_() from the saved master params. + for current_group, saved_group in zip(self.fp32_from_fp16_groups, + state_dict['fp32_from_fp16']): + for current, saved in zip(current_group, saved_group): + current.data.copy_(saved.data) + + def step(self, closure=None): # could add clip option. + """ + If no closure is supplied, :attr:`step` should be called after + ``fp16_optimizer_obj.backward(loss)``. + :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to + :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params + originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run + another forward pass using their model. + + If a closure is supplied, :attr:`step` may be called without a prior call to + :attr:`backward(loss)`. + This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. + However, the user should take care that any ``loss.backward()`` call within the closure + has been replaced by ``fp16_optimizer_obj.backward(loss)``. + + Args: + closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. + + Example with closure:: + + # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an + # existing pytorch optimizer. + for input, target in dataset: + def closure(): + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + # loss.backward() becomes: + optimizer.backward(loss) + return loss + optimizer.step(closure) + + .. warning:: + Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. + + .. _`ordinary Pytorch optimizer use`: + http://pytorch.org/docs/master/optim.html#optimizer-step-closure + """ # noqa + + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self.maybe_print( + 'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}' + .format(scale, self.loss_scale)) + return + + if closure is not None: + retval = self._step_with_closure(closure) + else: + retval = self.optimizer.step() + + self._master_params_to_model_params() + + return retval + + def _step_with_closure(self, closure): + + def wrapped_closure(): + # helpful for debugging + # print("Calling wrapped_closure, first_closure_call_this_step = {}" + # .format(self.first_closure_call_this_step)) + if self.first_closure_call_this_step: + # We expect that the fp16 params are initially fresh on entering self.step(), + # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() + # is called within self.optimizer.step(). + self.first_closure_call_this_step = False + else: + # If self.optimizer.step() internally calls wrapped_closure more than once, + # it may update the fp32 params after each call. However, self.optimizer + # doesn't know about the fp16 params at all. If the fp32 params get updated, + # we can't rely on self.optimizer to refresh the fp16 params. We need + # to handle that manually: + self._master_params_to_model_params() + # Our API expects the user to give us ownership of the backward() call by + # replacing all calls to loss.backward() with optimizer.backward(loss). + # This requirement holds whether or not the call to backward() is made within a closure. + # If the user is properly calling optimizer.backward(loss) within "closure," + # calling closure() here will give the fp32 master params fresh gradients + # for the optimizer to play with, so all wrapped_closure needs to do is call + # closure() and return the loss. + temp_loss = closure() + while (self.overflow): + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + self.maybe_print( + 'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, ' + 'reducing to {}'.format(scale, self.loss_scale)) + temp_loss = closure() + return temp_loss + + retval = self.optimizer.step(wrapped_closure) + + self.first_closure_call_this_step = True + + return retval + + def backward(self, loss, update_master_grads=True, retain_graph=False): + """ + :attr:`backward` performs the following conceptual steps: + + 1. fp32_loss = loss.float() (see first Note below) + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). + 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. + 5. Finally, master grads are divided by loss_scale. + + In this way, after :attr:`backward`, the master params have fresh gradients, + and :attr:`step` may be called. + + .. note:: + :attr:`backward` internally converts the loss to fp32 before applying the loss scale. + This provides some additional safety against overflow if the user has supplied an + fp16 loss value. + However, for maximum overflow safety, the user should + compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to + :attr:`backward`. + + .. warning:: + The gradients found in a model's leaves after the call to + :attr:`backward` should not be regarded as valid in general, + because it's possible + they have been scaled (and in the case of dynamic loss scaling, + the scale factor may change over time). + If the user wants to inspect gradients after a call to :attr:`backward`, + only the master gradients should be regarded as valid. These can be retrieved via + :attr:`inspect_master_grad_data()`. + + Args: + loss: The loss output by the user's model. loss may be either float or half (but see first Note above). + update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. + retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). + + Example:: + + # Ordinary operation: + optimizer.backward(loss) + + # Naive operation with multiple losses (technically valid, but less efficient): + # fp32 grads will be correct after the second call, but + # the first call incurs an unnecessary fp16->fp32 grad copy. + optimizer.backward(loss1) + optimizer.backward(loss2) + + # More efficient way to handle multiple losses: + # The fp16->fp32 grad copy is delayed until fp16 grads from all + # losses have been accumulated. + optimizer.backward(loss1, update_master_grads=False) + optimizer.backward(loss2, update_master_grads=False) + optimizer.update_master_grads() + """ # noqa + # To consider: try multiple backward passes using retain_grad=True to find + # a loss scale that works. After you find a loss scale that works, do a final dummy + # backward pass with retain_graph=False to tear down the graph. Doing this would avoid + # discarding the iteration, but probably wouldn't improve overall efficiency. + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + if update_master_grads: + self.update_master_grads() + + def update_master_grads(self): + """ + Copy the ``.grad`` attribute from stored references to fp16 parameters to + the ``.grad`` attribute of the fp32 master parameters that are directly + updated by the optimizer. :attr:`update_master_grads` only needs to be called if + ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. + """ # noqa + if self.dynamic_loss_scale: + self._check_overflow() + if self.overflow: return # noqa + self._model_grads_to_master_grads() + self._downscale_master() + + def inspect_master_grad_data(self): + """ + When running with :class:`FP16_Optimizer`, + ``.grad`` attributes of a model's fp16 leaves should not be + regarded as truthful, because they might be scaled. + After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, + the fp32 master params' ``.grad`` + attributes will contain valid gradients properly divided by the loss scale. However, + because :class:`FP16_Optimizer` flattens some parameters, accessing them may be + nonintuitive. :attr:`inspect_master_grad_data` + allows those gradients to be viewed with shapes corresponding to their associated model leaves. + + Returns: + List of lists (one list for each parameter group). The list for each parameter group + is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. + """ + if self.overflow: + print( + 'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. ' + 'Gradients are currently invalid (may be inf, nan, or stale). Returning None.' + ) + return None + else: + # The optimizer owns only references to master params. + master_grads_data = [] + for param_group in self.optimizer.param_groups: + master_grads_this_group = [] + for param in param_group['params']: + if param.grad is not None: + master_grads_this_group.append(param.grad.data) + else: + master_grads_this_group.append(None) + master_grads_data.append(master_grads_this_group) + return master_grads_data + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) diff --git a/modelscope/models/nlp/mglm/fp16/fp16util.py b/modelscope/models/nlp/mglm/fp16/fp16util.py new file mode 100644 index 00000000..3fcd3005 --- /dev/null +++ b/modelscope/models/nlp/mglm/fp16/fp16util.py @@ -0,0 +1,220 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Variable + +from modelscope.models.nlp.mglm import mpu + + +class tofp16(nn.Module): + """ + Utility module that implements:: + + def forward(self, input): + return input.half() + """ + + def __init__(self): + super(tofp16, self).__init__() + + def forward(self, input): + return input.half() + + +def BN_convert_float(module): + """ + Utility function for network_to_half(). + + Retained for legacy purposes. + """ + if isinstance( + module, + torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: + module.float() + for child in module.children(): + BN_convert_float(child) + return module + + +def network_to_half(network): + """ + Convert model to half precision in a batchnorm-safe way. + + Retained for legacy purposes. It is recommended to use FP16Model. + """ + return nn.Sequential(tofp16(), BN_convert_float(network.half())) + + +def convert_module(module, dtype): + """ + Converts a module's immediate parameters and buffers to dtype. + """ + for param in module.parameters(recurse=False): + if param is not None: + if param.data.dtype.is_floating_point: + param.data = param.data.to(dtype=dtype) + if param._grad is not None and param._grad.data.dtype.is_floating_point: + param._grad.data = param._grad.data.to(dtype=dtype) + + for buf in module.buffers(recurse=False): + if buf is not None and buf.data.dtype.is_floating_point: + buf.data = buf.data.to(dtype=dtype) + + +def convert_network(network, dtype): + """ + Converts a network's parameters and buffers to dtype. + """ + for module in network.modules(): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm + ) and module.affine is True: + continue + convert_module(module, dtype) + return network + + +class FP16Model(nn.Module): + """ + Convert model to half precision in a batchnorm-safe way. + """ + + def __init__(self, network): + super(FP16Model, self).__init__() + self.network = convert_network(network, dtype=torch.half) + + def forward(self, *inputs): + inputs = tuple(t.half() for t in inputs) + return self.network(*inputs) + + +def backwards_debug_hook(grad): + raise RuntimeError( + 'master_params recieved a gradient in the backward pass!') + + +def prep_param_lists(model, flat_master=False): + """ + Creates a list of FP32 master parameters for a given model, as in + `Training Neural Networks with Mixed Precision: Real Examples`_. + + Args: + model (torch.nn.Module): Existing Pytorch model + flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. + Returns: + A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. + + Example:: + + model_params, master_params = prep_param_lists(model) + + .. warning:: + Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. + + .. _`Training Neural Networks with Mixed Precision: Real Examples`: + http://on-demand.gputechconf.com/gtc/2018/video/S81012/ + """ # noqa + model_params = [ + param for param in model.parameters() if param.requires_grad + ] + + if flat_master: + # Give the user some more useful error messages + try: + # flatten_dense_tensors returns a contiguous flat array. + # http://pytorch.org/docs/master/_modules/torch/_utils.html + master_params = _flatten_dense_tensors( + [param.data for param in model_params]).float() + except: # noqa + print( + 'Error in prep_param_lists: model may contain a mixture of parameters ' + 'of different types. Use flat_master=False, or use F16_Optimizer.' + ) + raise + master_params = torch.nn.Parameter(master_params) + master_params.requires_grad = True + # master_params.register_hook(backwards_debug_hook) + if master_params.grad is None: + master_params.grad = master_params.new(*master_params.size()) + return model_params, [master_params] + else: + master_params = [ + param.clone().float().detach() for param in model_params + ] + for param in master_params: + param.requires_grad = True + return model_params, master_params + + +def model_grads_to_master_grads(model_params, + master_params, + flat_master=False): + """ + Copy model gradients to master gradients. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. + """ # noqa + if flat_master: + # The flattening may incur one more deep copy than is necessary. + master_params[0].grad.data.copy_( + _flatten_dense_tensors([p.grad.data for p in model_params])) + else: + for model, master in zip(model_params, master_params): + if model.grad is not None: + if master.grad is None: + master.grad = Variable( + master.data.new(*master.data.size())) + master.grad.data.copy_(model.grad.data) + else: + master.grad = None + + +def master_params_to_model_params(model_params, + master_params, + flat_master=False): + """ + Copy master parameters to model parameters. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. + """ # noqa + if flat_master: + for model, master in zip( + model_params, + _unflatten_dense_tensors(master_params[0].data, model_params)): + model.data.copy_(master) + else: + for model, master in zip(model_params, master_params): + model.data.copy_(master.data) + + +# Backward compatibility fixes + + +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +clip_grad_norm = mpu.clip_grad_norm diff --git a/modelscope/models/nlp/mglm/fp16/loss_scaler.py b/modelscope/models/nlp/mglm/fp16/loss_scaler.py new file mode 100755 index 00000000..721571b3 --- /dev/null +++ b/modelscope/models/nlp/mglm/fp16/loss_scaler.py @@ -0,0 +1,245 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from modelscope.models.nlp.mglm import mpu + + +# item() is a recent addition, so this helps with backward compatibility. +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + + +class LossScaler: + """ + Class that manages a static loss scale. This class is intended to interact with + :class:`FP16_Optimizer`, and should not be directly manipulated by the user. + + Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to + :class:`FP16_Optimizer`'s constructor. + + Args: + scale (float, optional, default=1.0): The loss scale. + """ + + def __init__(self, scale=1): + self.cur_scale = scale + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + return False + + def update_scale(self, overflow): + pass + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss * self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + + +class DynamicLossScaler: + """ + Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` + indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of + :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` + operates, because the default options can be changed using the + the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. + + Loss scaling is designed to combat the problem of underflowing gradients encountered at long + times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are + encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has + occurred. + :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, + and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients detected, + :class:`DynamicLossScaler` increases the loss scale once more. + In this way :class:`DynamicLossScaler` attempts to "ride the edge" of + always using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` + scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. + scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. + """ # noqa + + def __init__(self, + init_scale=2**32, + scale_factor=2., + scale_window=1000, + min_scale=1, + delayed_shift=1, + consecutive_hysteresis=False): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + self.min_scale = min_scale + self.delayed_shift = delayed_shift + self.cur_hysteresis = delayed_shift + self.consecutive_hysteresis = consecutive_hysteresis + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params): + for p in params: + if p.grad is not None and DynamicLossScaler._has_inf_or_nan( + p.grad.data): + return True + + return False + + def has_overflow(self, params): + overflow = self.has_overflow_serial(params) + # Since each model parallel GPU carries only part of the model, + # make sure overflow flag is synced across all the model parallel GPUs + overflow_gpu = torch.cuda.ByteTensor([overflow]) + torch.distributed.all_reduce( + overflow_gpu, + op=torch.distributed.ReduceOp.MAX, + group=mpu.get_model_parallel_group()) + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if 'value cannot be converted' not in instance.args[0]: + raise + return True + else: + if cpu_sum == float( + 'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + # `overflow` is boolean indicating whether the gradient overflowed + def update_scale(self, overflow): + + if not hasattr(self, 'min_scale'): + self.min_scale = 1 + if not hasattr(self, 'delayed_shift'): + self.delayed_shift = 1 + if not hasattr(self, 'cur_hysteresis'): + self.cur_hysteresis = 1 + if not hasattr(self, 'consecutive_hysteresis'): + self.consecutive_hysteresis = True + if overflow: + # self.cur_scale /= self.scale_factor + if self.delayed_shift == 1 or self.cur_hysteresis == 1: + self.cur_scale = max(self.cur_scale / self.scale_factor, + self.min_scale) + else: + self.cur_hysteresis -= 1 + self.last_overflow_iter = self.cur_iter + else: + if self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + if (self.cur_iter + - self.last_overflow_iter) % self.scale_window == 0: + if not self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss * self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + + +############################################################## +# Example usage below here -- assuming it's in a separate file +############################################################## +""" +TO-DO separate out into an example. +if __name__ == "__main__": + import torch + from torch.autograd import Variable + from dynamic_loss_scaler import DynamicLossScaler + + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 64, 1000, 100, 10 + + # Create random Tensors to hold inputs and outputs, and wrap them in Variables. + x = Variable(torch.randn(N, D_in), requires_grad=False) + y = Variable(torch.randn(N, D_out), requires_grad=False) + + w1 = Variable(torch.randn(D_in, H), requires_grad=True) + w2 = Variable(torch.randn(H, D_out), requires_grad=True) + parameters = [w1, w2] + + learning_rate = 1e-6 + optimizer = torch.optim.SGD(parameters, lr=learning_rate) + loss_scaler = DynamicLossScaler() + + for t in range(500): + y_pred = x.mm(w1).clamp(min=0).mm(w2) + loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale + print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) + print('Iter {} scaled loss: {}'.format(t, loss.data[0])) + print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) + + # Run backprop + optimizer.zero_grad() + loss.backward() + + # Check for overflow + has_overflow = DynamicLossScaler.has_overflow(parameters) + + # If no overflow, unscale grad and update as usual + if not has_overflow: + for param in parameters: + param.grad.data.mul_(1. / loss_scaler.loss_scale) + optimizer.step() + # Otherwise, don't do anything -- ie, skip iteration + else: + print('OVERFLOW!') + + # Update loss scale for next iteration + loss_scaler.update_scale(has_overflow) + +""" diff --git a/modelscope/models/nlp/mglm/generation_utils.py b/modelscope/models/nlp/mglm/generation_utils.py new file mode 100644 index 00000000..6db75b2d --- /dev/null +++ b/modelscope/models/nlp/mglm/generation_utils.py @@ -0,0 +1,483 @@ +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections import UserDict +from typing import Iterable, List, Optional, Tuple + +import torch + +PROCESS_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`): + Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses. + next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): + :obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses. + next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): + Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + + Return: + :obj:`UserDict`: A dictionary composed of the fields as defined above: + + - **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated + scores of all non-finished beams. + - **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens + to be added to the non-finished beam_hypotheses. + - **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices + indicating to which beam the next tokens shall be added. + +""" + +FINALIZE_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): + The final scores of all non-finished beams. + final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): + The last tokens to be added to the non-finished beam_hypotheses. + final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): + The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + +""" + + +class BeamScorer(ABC): + """ + Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and + :meth:`~transformers.PretrainedModel.beam_sample`. + """ + + @abstractmethod + def process(self, input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs) -> Tuple[torch.Tensor]: + raise NotImplementedError('This is an abstract method.') + + @abstractmethod + def finalize(self, input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, **kwargs) -> torch.LongTensor: + raise NotImplementedError('This is an abstract method.') + + +class BeamSearchScorer(BeamScorer): + r""" + :class:`transformers.BeamScorer` implementing standard beam search decoding. + + Adapted in part from `Facebook's XLM beam search code + `__. + + Args: + batch_size (:obj:`int`): + Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. + max_length (:obj:`int`): + The maximum length of the sequence to be generated. + num_beams (:obj:`int`): + Number of beams for beam search. + device (:obj:`torch.device`): + Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of + :obj:`BeamSearchScorer` will be allocated. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + :meth:`~transformer.BeamSearchScorer.finalize`. + """ + + def __init__( + self, + batch_size: int, + max_length: int, + num_beams: int, + device: torch.device, + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[bool] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + ): + self.max_length = max_length + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + + self._is_init = False + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.num_beams, + max_length=self.max_length, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + ) for _ in range(batch_size) + ] + self._done = torch.tensor([False for _ in range(batch_size)], + dtype=torch.bool, + device=self.device) + + # if not isinstance(num_beams, int) or num_beams <= 1: + # raise ValueError( + # ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def process(self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + mems=None) -> Tuple[torch.Tensor]: + cur_len = input_ids.shape[-1] + batch_size = len(self._beam_hyps) + assert batch_size == (input_ids.shape[0] // self.num_beams) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + device = next_scores.device + next_beam_scores = torch.zeros((batch_size, self.num_beams), + dtype=next_scores.dtype, + device=device) + next_beam_tokens = torch.zeros((batch_size, self.num_beams), + dtype=next_tokens.dtype, + device=device) + next_beam_indices = torch.zeros((batch_size, self.num_beams), + dtype=next_indices.dtype, + device=device) + + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + assert ( + len(beam_hyp) >= self.num_beams + ), 'Batch can only be done if at least {} beams have been generated'.format( + self.num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), 'generated beams >= num_beams -> eos_token_id and pad_token have to be defined' + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence + beam_idx = 0 + for beam_token_rank, (next_token, next_score, + next_index) in enumerate( + zip(next_tokens[batch_idx], + next_scores[batch_idx], + next_indices[batch_idx])): + batch_beam_idx = batch_idx * self.num_beams + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() + in eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams + if is_beam_token_worse_than_top_num_beams: + continue + beam_hyp.add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + mems=[mem[[next_index.item()]] + for mem in mems] if mems else None) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.num_beams: + break + + if beam_idx < self.num_beams: + raise ValueError( + f'At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected.' # noqa + ) # noqa + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + next_scores[batch_idx].max().item(), cur_len) + + return UserDict({ + 'next_beam_scores': next_beam_scores.view(-1), + 'next_beam_tokens': next_beam_tokens.view(-1), + 'next_beam_indices': next_beam_indices.view(-1), + }) + + def finalize(self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + final_beam_tokens: torch.LongTensor, + final_beam_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + mems=None) -> Tuple[torch.LongTensor, List[torch.Tensor]]: + batch_size = len(self._beam_hyps) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + continue + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + beam_hyp.add( + final_tokens, + final_score, + mems=[mem[[batch_beam_idx]] + for mem in mems] if mems else None) + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + + # retrieve best hypotheses + for i, beam_hyp in enumerate(self._beam_hyps): + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + best_hyp, mems = sorted_hyps.pop()[1:] + sent_lengths[self.num_beam_hyps_to_keep * i + + j] = len(best_hyp) + best.append((best_hyp, mems)) + + # prepare for adding eos + sent_max_len = min(sent_lengths.max().item(), self.max_length) + decoded: torch.LongTensor = input_ids.new( + batch_size * self.num_beam_hyps_to_keep, sent_max_len) + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, '`pad_token_id` has to be defined' + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + mems = [] + for i, (hypo, mem) in enumerate(best): + decoded[i, :sent_lengths[i]] = hypo + if sent_lengths[i] < sent_max_len: + decoded[i, sent_lengths[i]] = eos_token_id + mems.append(mem) + mems = [ + torch.cat([mem[i] for mem in mems], dim=0) + for i in range(len(mems[0])) + ] if mems and mems[0] else None + return decoded, mems + + +class BeamHypotheses: + + def __init__(self, num_beams: int, max_length: int, length_penalty: float, + early_stopping: bool): + """ + Initialize n-best list of hypotheses. + """ + self.max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / (max(hyp.shape[-1], 1)**self.length_penalty) + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp, mems)) + if len(self) > self.num_beams: + sorted_next_scores = sorted([ + (s, idx) for idx, (s, _, _) in enumerate(self.beams) + ]) + del self.beams[sorted_next_scores[0][1]] + self.worst_score = sorted_next_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: + """ + If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst + one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + cur_score = best_sum_logprobs / cur_len**self.length_penalty + ret = self.worst_score >= cur_score + return ret + + +class LogitsProcessor(ABC): + """Abstract base class for all logit processors that can be applied during generation.""" + + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + """Torch method for processing logits.""" + raise NotImplementedError( + f'{self.__class__} is an abstract class. Only classes inheriting this class can be called.' + ) + + +class LogitsProcessorList(list): + """ + This class can be used to create a list of :class:`~transformers.LogitsProcessor` or + :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from + list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or + :class:`~transformers.LogitsProcessor` to the inputs. + """ + + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +class MinLengthLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. + + Args: + min_length (:obj:`int`): + The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. + eos_token_id (:obj:`int`): + The id of the `end-of-sequence` token. + """ + + def __init__(self, min_length: int, eos_token_id: int): + if not isinstance(min_length, int) or min_length < 0: + raise ValueError( + f'`min_length` has to be a positive integer, but is {min_length}' + ) + + if not isinstance(eos_token_id, int) or eos_token_id < 0: + raise ValueError( + f'`eos_token_id` has to be a positive integer, but is {eos_token_id}' + ) + + self.min_length = min_length + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + if cur_len < self.min_length: + scores[:, self.eos_token_id] = -float('inf') + return scores + + +class NoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq + `__. + + Args: + ngram_size (:obj:`int`): + All ngrams of size :obj:`ngram_size` can only occur once. + """ + + def __init__(self, ngram_size: int): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError( + f'`ngram_size` has to be a strictly positive integer, but is {ngram_size}' + ) + self.ngram_size = ngram_size + + def __call__(self, input_ids: torch.LongTensor, + scores: torch.FloatTensor) -> torch.FloatTensor: + num_batch_hypotheses = scores.shape[0] + cur_len = input_ids.shape[-1] + banned_batch_tokens = self._calc_banned_ngram_tokens( + input_ids, num_batch_hypotheses, cur_len) + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float('inf') + + return scores + + def _calc_banned_ngram_tokens(self, prev_input_ids: torch.Tensor, + num_hypos: int, + cur_len: int) -> List[Iterable[int]]: + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < self.ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] + for i in range(self.ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get( + prev_ngram_tuple, []) + [ngram[-1]] + + def _get_generated_ngrams(hypo_idx): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - self.ngram_size + ngram_idx = tuple(prev_input_ids[hypo_idx, + start_idx:cur_len].tolist()) + return generated_ngrams[hypo_idx].get(ngram_idx, []) + + banned_tokens = [ + _get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos) + ] + return banned_tokens diff --git a/modelscope/models/nlp/mglm/mglm_for_text_summarization.py b/modelscope/models/nlp/mglm/mglm_for_text_summarization.py new file mode 100644 index 00000000..ea1dfb5a --- /dev/null +++ b/modelscope/models/nlp/mglm/mglm_for_text_summarization.py @@ -0,0 +1,469 @@ +# Copyright (c) 2022 Zhipu.AI + +import os +import random +from os import path as osp +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from . import mpu +from .arguments import get_args +from .generation_utils import BeamSearchScorer +from .train_utils import get_model +from .utils import load_checkpoint + +__all__ = ['MGLMForTextSummarization'] + + +def setup_args(args): + args.block_lm = True + args.task_mask = True + args.cloze_eval = True + args.num_layers = 24 + args.hidden_size = 1536 + args.num_attention_heads = 16 + args.max_position_embeddings = 1024 + args.tokenizer_type = 'ChineseSPTokenizer' + args.load_pretrained = '' + args.DDP_impl = 'none' + args.model_parallel_size = 1 + args.fp16 = True + args.cache_dir = 'cache' + args.out_seq_length = 200 + args.seq_length = 512 + args.temperature = 0.9 + args.top_k = 2 + args.top_p = 0.8 + args.frequency_penalty = 0.1 + args.presence_penalty = 0.1 + args.mem_length = args.seq_length + args.mem_length - 1 + return args + + +def setup_model(args): + """Setup model and optimizer.""" + + model = get_model(args, model_type='generation') + + if args.load_pretrained is not None: + args.no_load_optim = True + args.load = args.load_pretrained + _ = load_checkpoint(model, None, None, args) + + return model + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + + if seed is not None and seed > 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + + +def get_masks_and_position_ids(data, + eod_token, + reset_position_ids, + reset_attention_mask, + loss_mask=None, + attention_mask=None, + set_loss_mask=False, + mem_length=None): + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if mem_length: + if attention_mask is None: + attention_mask = torch.ones( + (1, seq_length, seq_length + mem_length), device=data.device) + attention_mask = torch.tril( + torch.triu(attention_mask, 1 - seq_length + mem_length), + mem_length) + else: + if reset_attention_mask: + att_mask_batch = batch_size + else: + att_mask_batch = 1 + if attention_mask is None: + attention_mask = torch.ones( + (att_mask_batch, seq_length, seq_length), device=data.device) + attention_mask = torch.tril(attention_mask) + attention_mask = attention_mask.unsqueeze(1) + + # Loss mask. + if loss_mask is None: + loss_mask = torch.ones( + data.size(), dtype=torch.float, device=data.device) + + # Position ids. + position_ids = torch.arange( + seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + if set_loss_mask: + loss_mask[data == eod_token] = 0.0 + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1):] -= (i + 1 - prev_index) + prev_index = i + 1 + + return attention_mask, loss_mask, position_ids + + +def initialize_distributed(args): + """Initialize torch.distributed.""" + + # Manually set the device ids. + device = args.rank % torch.cuda.device_count() + if args.local_rank is not None: + device = args.local_rank + torch.cuda.set_device(device) + # Call the init process + init_method = 'tcp://' + args.master_ip = os.getenv('MASTER_ADDR', 'localhost') + args.master_port = os.getenv('MASTER_PORT', '6000') + init_method += args.master_ip + ':' + args.master_port + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, + rank=args.rank, + init_method=init_method) + + # Set the model-parallel / data-parallel communicators. + mpu.initialize_model_parallel(args.model_parallel_size) + + # Optional DeepSpeed Activation Checkpointing Features + # + if hasattr( + args, 'deepspeed' + ) and args.deepspeed and args.deepspeed_activation_checkpointing: + set_deepspeed_activation_checkpointing(args) + + +def get_batch(context_tokens, device, args): + tokens = context_tokens + tokens = tokens.view(args.batch_size, -1).contiguous() + tokens = tokens.to(device) + + # Get the masks and postition ids. + if args.block_lm: + attention_mask = torch.tensor([tokens.size(1)], + device=device, + dtype=torch.long) + position_ids = torch.arange( + tokens.size(1), device=device, dtype=torch.long) + if not args.no_block_position: + block_position_ids = torch.zeros( + tokens.size(1), device=device, dtype=torch.long) + position_ids = torch.stack((position_ids, block_position_ids), + dim=0) + position_ids = position_ids.unsqueeze(0) + else: + attention_mask, loss_mask, position_ids = get_masks_and_position_ids( + tokens, + args.eod_token, + reset_position_ids=False, + reset_attention_mask=False, + set_loss_mask=False, + mem_length=args.mem_length) + + return tokens, attention_mask, position_ids + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, + None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + + return logits + + +def sample_sequence(model, + tokenizer, + context_tokens, + context_length, + args, + device, + mems=None, + end_tokens=None): + if not args.block_lm: + context_tokens, attention_mask, position_ids = get_batch( + context_tokens, device, args) + tokens = torch.empty((args.num_beams, 0), + device=context_tokens.device, + dtype=torch.long) + else: + tokens = context_tokens.new_full((1, 1), + tokenizer.get_command('sop').Id) + counter = 0 + if mems is None: + mems = [] + if end_tokens is None: + end_tokens = [args.eod_token] + + last_beam_num = 1 + output_tokens_list = [] + generated_tokens_list = [] + + while counter < args.out_seq_length: + if counter == 0 and not args.block_lm: + next_token_logits, *mems = model(context_tokens, position_ids, + attention_mask, *mems) + else: + if args.block_lm: + if args.no_block_position: + position_ids = context_tokens.new_full( + (last_beam_num, 1), context_length + counter) + else: + position_ids = context_tokens.new_ones(last_beam_num, 2, 1) + position_ids[:, 0] = context_length + position_ids[:, 1] = counter + 1 + attention_mask = context_tokens.new_zeros( + [1], device=context_tokens.device, dtype=torch.long) + else: + position_ids = context_tokens.new_ones((last_beam_num, 1)) * ( + context_length + counter - 1) + attention_mask = context_tokens.new_ones( + last_beam_num, + 1, + 1, + args.mem_length + 1, + device=context_tokens.device, + dtype=torch.float) + last_token = tokens[:, -1:] + next_token_logits, *mems = model(last_token, position_ids, + attention_mask, *mems) + next_token_logits = next_token_logits[:, -1] + + next_token_logits /= args.temperature + frequency_count = torch.zeros(next_token_logits.shape) + for tk in output_tokens_list: + frequency_count[0][tk] += 1 + + next_token_logits -= (args.frequency_penalty + * frequency_count).to(device) + next_token_logits -= ( + args.presence_penalty * # noqa + (frequency_count > 0)).to(device) + + next_token_logits = top_k_logits( + next_token_logits, top_k=args.top_k, top_p=args.top_p) + log_probs = F.softmax(next_token_logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1)[0] + is_end = prev.item() in end_tokens + if is_end: + break + decode_tokens = tokenizer.DecodeIds([prev.item()]) # noqa + generated_tokens_list.append(prev.item()) + prev = prev.view(1, 1) + tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1) + counter += 1 + output_tokens_list = tokens.view(-1).contiguous() + return torch.cat((context_tokens, tokens), dim=1), mems + + +def read_context(tokenizer, args, context): + terminate_runs, skip_run = 0, 0 # noqa + if mpu.get_model_parallel_rank() == 0: + while True: + # raw_text = input("\nContext prompt (stop to exit) >>> ") + raw_text = context + if not raw_text: + print('Prompt should not be empty!') + break + # if raw_text == "stop": + # terminate_runs = 1 + # break + generation_mask = '[gMASK]' if args.task_mask else '[MASK]' + if args.block_lm and 'MASK]' not in raw_text: + raw_text += ' ' + generation_mask + # output.write(raw_text) + context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization + if args.block_lm: + context_tokens = [tokenizer.get_command('ENC').Id + ] + context_tokens + if not raw_text.endswith('[gMASK]'): + context_tokens = context_tokens + [ + tokenizer.get_command('eos').Id + ] + context_length = len(context_tokens) + + if context_length >= args.seq_length: + print('\nContext length', context_length, + '\nPlease give smaller context than the window length!') + break + break + else: + context_length = 0 + + terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) + torch.distributed.broadcast( + terminate_runs_tensor, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + terminate_runs = terminate_runs_tensor[0].item() + + if terminate_runs == 1: + return terminate_runs, None, None, None + + context_length_tensor = torch.cuda.LongTensor([context_length]) + + torch.distributed.broadcast( + context_length_tensor, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + context_length = context_length_tensor[0].item() + if mpu.get_model_parallel_rank() == 0: + context_tokens_tensor = torch.cuda.LongTensor(context_tokens) + else: + context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) + torch.distributed.broadcast( + context_tokens_tensor, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + if mpu.get_model_parallel_rank() != 0: + raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist()) + return terminate_runs, raw_text, context_tokens_tensor, context_length + + +@MODELS.register_module(Tasks.text_summarization, module_name=Models.mglm) +class MGLMForTextSummarization(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text summarization model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + from .configure_data import prepare_tokenizer + # Disable CuDNN. + torch.backends.cudnn.enabled = False + # Arguments. + self.args = setup_args(get_args()) + self.args.load_pretrained = model_dir + # Pytorch distributed. + try: + initialize_distributed(self.args) + except (RuntimeError): + print('group process initialized twice') + # Random seeds for reproducability. + set_random_seed(self.args.seed) + # setting default batch size to 1 + self.args.batch_size = 1 + self.args.tokenizer_path = model_dir + self.tokenizer = prepare_tokenizer(self.args) + self.model = setup_model(self.args) + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + + def forward(self, input: Dict[str, str]) -> Dict[str, str]: + pass + + def generate(self, input: Dict[str, str]) -> Dict[str, str]: + model = self.model + tokenizer = self.tokenizer + args = self.args + device = torch.cuda.current_device() + model.eval() + + context = input['text'] + self.cfg.model.prompt + with torch.no_grad(): + terminate_runs, raw_text, context_tokens_tensor, context_length = read_context( + tokenizer, args, context) + mems = [] + tokens, attention_mask, position_ids = get_batch( + context_tokens_tensor, device, args) + mask_tokens = ['MASK', 'sMASK', 'gMASK' + ] if args.task_mask else ['MASK'] + mask_tokens = [ + tokenizer.get_command(token).Id for token in mask_tokens + ] + end_tokens = [tokenizer.get_command('eop').Id, args.eod_token] + + mask_positions = [] + for token in mask_tokens: + mask_positions += (context_tokens_tensor == token).nonzero( + as_tuple=True)[0].tolist() + mask_positions.sort() + if args.no_block_position: + for mask_position in mask_positions: + position_ids[0, mask_position + 1:] += args.out_seq_length + _, *mems = model(tokens, position_ids, attention_mask, *mems) + for mask_position in mask_positions: + if args.no_block_position: + position = position_ids[0, mask_position].item() + else: + position = mask_position + tokens, mems, = sample_sequence( + model, + tokenizer, + tokens, + position, + args, + device, + mems=mems, + end_tokens=end_tokens) + output_tokens_list = tokens.view(-1).contiguous() + trim_decode_tokens = tokenizer.DecodeIds( + output_tokens_list.tolist()) + res = trim_decode_tokens.split('<|startofpiece|>')[-1] + print(res) + return {OutputKeys.TEXT: res} diff --git a/modelscope/models/nlp/mglm/model/__init__.py b/modelscope/models/nlp/mglm/model/__init__.py new file mode 100755 index 00000000..84c55ae3 --- /dev/null +++ b/modelscope/models/nlp/mglm/model/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .distributed import (DistributedDataParallel, + PyTorchDistributedDataParallel) +from .downstream import (GLMForMultiTokenCloze, GLMForMultiTokenClozeFast, + GLMForSequenceClassification, GLMForSingleTokenCloze) +from .modeling_glm import (GLMModel, + glm_get_params_for_weight_decay_optimization) diff --git a/modelscope/models/nlp/mglm/model/distributed.py b/modelscope/models/nlp/mglm/model/distributed.py new file mode 100755 index 00000000..a3c84e9f --- /dev/null +++ b/modelscope/models/nlp/mglm/model/distributed.py @@ -0,0 +1,127 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Variable +from torch.nn.modules import Module +from torch.nn.parallel.distributed import DistributedDataParallel as DDP + +from modelscope.models.nlp.mglm import mpu + + +class PyTorchDistributedDataParallel(DDP): + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.module.named_parameters(prefix=prefix, recurse=recurse) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + sd = self.module.state_dict(destination, prefix, keep_vars) + return sd + + def load_state_dict(self, state_dict, strict=True): + return self.module.load_state_dict(state_dict, strict=strict) + + +class DistributedDataParallel(Module): + + def __init__(self, module): + super(DistributedDataParallel, self).__init__() + self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + self.module = module + self.data_parallel_group = mpu.get_data_parallel_group() + src_rank = mpu.get_model_parallel_rank() + for p in self.module.parameters(): + if torch.is_tensor(p): + dist.broadcast(p, src_rank, group=self.data_parallel_group) + + def allreduce_params(reduce_after=True, + no_scale=False, + fp32_allreduce=False): + if (self.needs_reduction): + self.needs_reduction = False + buckets = {} + for name, param in self.module.named_parameters(): + if param.requires_grad and param.grad is not None: + tp = (param.data.type()) + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + if self.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print( + 'WARNING: gloo dist backend for half parameters may be extremely slow. It is recommended to use the NCCL backend in this case.' # noqa + ) + self.warn_on_half = False + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + if fp32_allreduce: + coalesced = coalesced.float() + if not no_scale and not reduce_after: + coalesced /= dist.get_world_size( + group=self.data_parallel_group) + dist.all_reduce(coalesced, group=self.data_parallel_group) + torch.cuda.synchronize() + if not no_scale and reduce_after: + coalesced /= dist.get_world_size( + group=self.data_parallel_group) + for buf, synced in zip( + grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + self.hook_handles = [] + self.hooks = [] + for param in list(self.module.parameters()): + + def allreduce_hook(*unused): + Variable._execution_engine.queue_callback(allreduce_params) + + self.allreduce_params = allreduce_params + + def forward(self, *inputs, **kwargs): + self.needs_reduction = True + return self.module(*inputs, **kwargs) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + sd = self.module.state_dict(destination, prefix, keep_vars) + return sd + + def load_state_dict(self, state_dict, strict=True): + return self.module.load_state_dict(state_dict, strict=strict) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.module.named_parameters(prefix=prefix, recurse=recurse) + + ''' + def _sync_buffers(self): + buffers = list(self.module._all_buffers()) + if len(buffers) > 0: + # cross-node buffer sync + flat_buffers = _flatten_dense_tensors(buffers) + dist.broadcast(flat_buffers, 0) + for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): + buf.copy_(synced) + def train(self, mode=True): + # Clear NCCL communicator and CUDA event cache of the default group ID, + # These cache will be recreated at the later call. This is currently a + # work-around for a potential NCCL deadlock. + if dist._backend == dist.dist_backend.NCCL: + dist._clear_group_cache() + super(DistributedDataParallel, self).train(mode) + self.module.train(mode) + ''' diff --git a/modelscope/models/nlp/mglm/model/downstream.py b/modelscope/models/nlp/mglm/model/downstream.py new file mode 100644 index 00000000..61b1e807 --- /dev/null +++ b/modelscope/models/nlp/mglm/model/downstream.py @@ -0,0 +1,242 @@ +# Copyright (c) 2022 Zhipu.AI +"""Multiple choice model.""" + +import torch +import torch.nn + +from .modeling_glm import GLMModel + + +class GLMForMultiTokenCloze(torch.nn.Module): + + def __init__(self, + language_model: GLMModel, + take_softmax=True, + length_penalty=0.0): + super(GLMForMultiTokenCloze, self).__init__() + self.model = language_model + self.take_softmax = take_softmax + self.length_penalty = length_penalty + + def state_dict(self, destination=None, prefix='', keep_vars=False): + # [h.remove() for h in self.hook_handles] + sd = self.model.state_dict(destination, prefix, keep_vars) + return sd + + def load_state_dict(self, state_dict, strict=True): + return self.model.load_state_dict(state_dict, strict=strict) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.model.named_parameters(prefix=prefix, recurse=recurse) + + def forward(self, + input_ids, + position_ids, + attention_mask, + target_ids=None, + logit_mask=None, + prompt_pos=None): + if target_ids is None: + return self.model(input_ids, position_ids, attention_mask) + num_choices = None + if len(input_ids.shape) == 3: + batch_size, num_choices = input_ids.shape[:2] + input_ids = input_ids.reshape(-1, input_ids.size(-1)) + attention_mask = attention_mask.reshape(-1, + *attention_mask.size()[2:]) + position_ids = position_ids.reshape(-1, *position_ids.size()[2:]) + target_ids = target_ids.reshape(-1, target_ids.size(-1)) + logit_mask = logit_mask.reshape(-1, logit_mask.size(-1)) + if prompt_pos is not None: + prompt_pos = prompt_pos.reshape(-1, prompt_pos.size(-1)) + outputs, *mems = self.model( + input_ids, position_ids, attention_mask, prompt_pos=prompt_pos) + if self.take_softmax: + outputs = torch.nn.functional.log_softmax(outputs, dim=-1) + # select the target logits + batch_ids = torch.arange( + target_ids.size(0), dtype=torch.long, device=target_ids.device) + batch_ids = batch_ids.unsqueeze(1).expand_as(target_ids) + seq_ids = torch.arange( + target_ids.size(-1), dtype=torch.long, device=target_ids.device) + seq_ids = seq_ids.unsqueeze(0).expand_as(target_ids) + logits = outputs[batch_ids, seq_ids, target_ids] + logits = (logits * logit_mask).sum(dim=1) + if self.length_penalty > 0.0: + logits = logits / logit_mask.sum(dim=1)**self.length_penalty + if num_choices is not None: + logits = logits.view(-1, num_choices) + return (logits, *mems) + + +class GLMForMultiTokenClozeFast(torch.nn.Module): + + def __init__(self, language_model, take_softmax=True, length_penalty=0.0): + super(GLMForMultiTokenClozeFast, self).__init__() + self.model = language_model + self.take_softmax = take_softmax + self.length_penalty = length_penalty + + def forward(self, input_ids, position_ids, attention_mask, dec_input_ids, + dec_position_ids, dec_attention_mask, dec_target_ids, + dec_logit_mask): + # encoder + outputs, *mems = self.model( + input_ids, + position_ids, + attention_mask, + return_memory=True, + detach_memory=False) + batch_size, num_choices, max_dec_len = dec_input_ids.size() + max_enc_len = input_ids.size(-1) + + enc_mems = [] + for hidden in mems: + hidden = hidden.unsqueeze(1).expand(-1, num_choices, -1, + -1).reshape( + batch_size * num_choices, + *hidden.size()[1:]) + enc_mems.append(hidden) + + def build_dec_mask_matrix(seq_length, sep, memory_length=0): + m = enc_mems[0].new_ones((1, seq_length, seq_length)) + m = torch.tril(m) + + # sep = dec_attention_mask + ids = torch.arange( + memory_length, device=sep.device, dtype=sep.dtype).view(1, -1) + mask = ids < sep.view(-1, 1) # batch * mem + mask = mask.unsqueeze(1).float().expand(-1, seq_length, -1) + + m = m.expand(batch_size * num_choices, -1, -1) + m = torch.cat((mask, m), dim=2) + m = m.unsqueeze(1) + return m + + dec_input_ids = dec_input_ids.reshape(-1, max_dec_len) + dec_position_ids = dec_position_ids.reshape( + -1, + *dec_position_ids.size()[2:]) + # dec_attention_mask = dec_attention_mask.reshape(-1, *dec_attention_mask.size()[2:]).unsqueeze(1) + dec_attention_mask = build_dec_mask_matrix( + max_dec_len, dec_attention_mask.reshape(-1), max_enc_len) + dec_target_ids = dec_target_ids.reshape(-1, dec_target_ids.size(-1)) + dec_logit_mask = dec_logit_mask.reshape(-1, dec_logit_mask.size(-1)) + + outputs, *mems = self.model(dec_input_ids, dec_position_ids, + dec_attention_mask, *enc_mems) + if self.take_softmax: + outputs = torch.nn.functional.log_softmax(outputs, dim=-1) + + batch_ids = torch.arange( + dec_target_ids.size(0), + dtype=torch.long, + device=dec_target_ids.device) + batch_ids = batch_ids.unsqueeze(1).expand_as(dec_target_ids) + seq_ids = torch.arange( + dec_target_ids.size(-1), + dtype=torch.long, + device=dec_target_ids.device) + seq_ids = seq_ids.unsqueeze(0).expand_as(dec_target_ids) + logits = outputs[batch_ids, seq_ids, dec_target_ids] + logits = (logits * dec_logit_mask).sum(dim=1) + if self.length_penalty > 0.0: + logits = logits / dec_logit_mask.sum(dim=1)**self.length_penalty + if num_choices is not None: + logits = logits.view(-1, num_choices) + return (logits, *mems) + + +class GLMForSingleTokenCloze(torch.nn.Module): + + def __init__(self, language_model, take_softmax=False): + super().__init__() + self.model = language_model + self.take_softmax = take_softmax + + def state_dict(self, destination=None, prefix='', keep_vars=False): + # [h.remove() for h in self.hook_handles] + sd = self.model.state_dict(destination, prefix, keep_vars) + return sd + + def load_state_dict(self, state_dict, strict=True): + return self.model.load_state_dict(state_dict, strict=strict) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + return self.model.named_parameters(prefix=prefix, recurse=recurse) + + def forward(self, + input_ids, + position_ids, + attention_mask, + target_ids=None, + logit_mask=None, + prompt_pos=None): + if target_ids is None: + return self.model(input_ids, position_ids, attention_mask) + assert len(input_ids.shape) == 2 + outputs, *mems = self.model( + input_ids, position_ids, attention_mask, prompt_pos=prompt_pos) + batch_ids = torch.arange( + outputs.size(0), + dtype=attention_mask.dtype, + device=attention_mask.device) + target_logits = outputs[batch_ids, attention_mask] + if self.take_softmax: + target_prob = torch.nn.functional.log_softmax( + target_logits, dim=-1) + else: + target_prob = target_logits + batch_ids = batch_ids.unsqueeze(1).expand_as(target_ids) + output = target_prob[batch_ids, target_ids] + + return (output, target_logits, *mems) + + +class GLMForSequenceClassification(torch.nn.Module): + + def __init__(self, + language_model, + hidden_size, + hidden_dropout, + pool_token, + num_class=1): + super().__init__() + self.pool_token = pool_token + self.model = language_model + self.num_class = num_class + # Multi-choice head. + self.pool_layer = torch.nn.Linear(hidden_size, hidden_size) + self.multichoice_dropout = torch.nn.Dropout(hidden_dropout) + self.multichoice_head = torch.nn.Linear(hidden_size, num_class) + + def forward(self, input_ids, position_ids, attention_mask): + num_choices = None + if len(input_ids.shape) == 3: + assert self.num_class == 1 + batch_size, num_choices = input_ids.shape[:2] + input_ids = input_ids.reshape(-1, input_ids.size(-1)) + attention_mask = attention_mask.reshape(-1, + *attention_mask.size()[2:]) + position_ids = position_ids.reshape(-1, *position_ids.size()[2:]) + outputs, *mems = self.model(input_ids, position_ids, attention_mask) + if self.pool_token == 'start': + output = outputs[torch.arange( + outputs.size(0), + dtype=attention_mask.dtype, + device=attention_mask.device), attention_mask] + elif self.pool_token == 'pad': + output = outputs[torch.arange( + outputs.size(0), + dtype=attention_mask.dtype, + device=attention_mask.device), attention_mask - 1] + elif self.pool_token == 'cls': + output = outputs[:, 0] + else: + raise NotImplementedError + output = torch.tanh(self.pool_layer(output)) + multichoice_output = self.multichoice_dropout(output) + logits = self.multichoice_head(multichoice_output) + if num_choices is not None: + logits = logits.view(-1, num_choices) + return (logits, *mems) diff --git a/modelscope/models/nlp/mglm/model/modeling_bert.py b/modelscope/models/nlp/mglm/model/modeling_bert.py new file mode 100644 index 00000000..965f82a7 --- /dev/null +++ b/modelscope/models/nlp/mglm/model/modeling_bert.py @@ -0,0 +1,1576 @@ +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import copy +import logging +import math +import os +import shutil +import tarfile +import tempfile + +import json +import mpu +import torch +import torch.nn.functional as F +from data_utils.file_utils import cached_path +from torch import nn +from torch.nn import CrossEntropyLoss + +# from torch.utils.checkpoint import checkpoint + + +def normal_init_method(mean, std): + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=mean, std=std) + + return init_ + + +def scaled_init_method(mean, std, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = std / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=mean, std=std) + + return init_ + + +def bert_extended_attention_mask(attention_mask): + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # [b, 1, s, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + return extended_attention_mask + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'bert-base-uncased': + '/root/data/bert-base-uncased.tar.gz', + 'bert-large-uncased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz', + 'bert-base-cased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz', + 'bert-large-cased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz', + 'bert-base-multilingual-uncased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz', + 'bert-base-multilingual-cased': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz', + 'bert-base-chinese': + 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz', +} +CONFIG_NAME = 'bert_config.json' +WEIGHTS_NAME = 'pytorch_model.bin' +TF_WEIGHTS_NAME = 'model.ckpt' + + +def load_tf_weights_in_bert(model, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + print( + 'Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see ' + 'https://www.tensorflow.org/install/ for installation instructions.' + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + print('Converting TensorFlow checkpoint from {}'.format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + print('Loading TF weight {} with shape {}'.format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ['adam_v', 'adam_m'] for n in name): + print('Skipping {}'.format('/'.join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) # noqa + else: + l = [m_name] # noqa + if l[0] == 'kernel' or l[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif l[0] == 'output_bias' or l[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif l[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + else: + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print('Initialize PyTorch weight {}'.format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish} + + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + deep_init=False, + fp32_layernorm=False, + fp32_embedding=False, + fp32_tokentypes=False, + layernorm_epsilon=1e-12): + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, str): + with open( + vocab_size_or_config_json_file, 'r', + encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.deep_init = deep_init + self.fp32_layernorm = fp32_layernorm + self.fp32_embedding = fp32_embedding + self.layernorm_epsilon = layernorm_epsilon + self.fp32_tokentypes = fp32_tokentypes + else: + raise ValueError( + 'First argument must be either a vocabulary size (int)' + 'or the path to a pretrained model config file (str)') + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, 'r', encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' + + +try: + from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm +except ImportError: + print( + 'Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.' + ) + + class BertLayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size) + # self.word_embeddings = mpu.VocabParallelEmbedding( + # config.vocab_size, config.hidden_size, + # init_method=normal_init_method(mean=0.0, + # std=config.initializer_range)) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.fp32_layernorm = config.fp32_layernorm + self.fp32_embedding = config.fp32_embedding + self.fp32_tokentypes = config.fp32_tokentypes + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + if not self.fp32_tokentypes: + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + if self.fp32_embedding and not self.fp32_layernorm: + embeddings = embeddings.half() + previous_type = embeddings.type() + if self.fp32_layernorm: + embeddings = embeddings.float() + embeddings = self.LayerNorm(embeddings) + if self.fp32_layernorm: + if self.fp32_embedding: + embeddings = embeddings.half() + else: + embeddings = embeddings.type(previous_type) + else: + embeddings = words_embeddings.float() + position_embeddings.float( + ) + token_type_embeddings.float() + if self.fp32_tokentypes and not self.fp32_layernorm: + embeddings = embeddings.half() + previous_type = embeddings.type() + if self.fp32_layernorm: + embeddings = embeddings.float() + embeddings = self.LayerNorm(embeddings) + if self.fp32_layernorm: + if self.fp32_tokentypes: + embeddings = embeddings.half() + else: + embeddings = embeddings.type(previous_type) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + previous_type = attention_probs.type() # noqa + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super(BertSelfOutput, self).__init__() + if hasattr(config, 'deep_init') and config.deep_init: + init_method = scaled_init_method( + mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + else: + init_method = normal_init_method( # noqa + mean=0.0, std=config.initializer_range) + self.dense = nn.Linear( + config.hidden_size, config.hidden_size, bias=True) + # self.dense = mpu.RowParallelLinear( + # input_size=config.hidden_size, + # output_size=config.hidden_size, + # bias=True, + # input_is_parallel=True, + # stride=1, + # init_method=init_method) + self.fp32_layernorm = config.fp32_layernorm + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + ln_input = hidden_states + input_tensor + previous_type = ln_input.type() + if self.fp32_layernorm: + ln_input = ln_input.float() + hidden_states = self.LayerNorm(ln_input) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + # self.self = mpu.BertParallelSelfAttention( + # hidden_size=config.hidden_size, + # num_attention_heads=config.num_attention_heads, + # dropout_prob=config.attention_probs_dropout_prob, + # output_parallel=True, + # init_method=normal_init_method(mean=0.0, + # std=config.initializer_range)) + self.output = BertSelfOutput(config) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear( + config.hidden_size, config.intermediate_size, bias=True) + # self.dense = mpu.ColumnParallelLinear( + # input_size=config.hidden_size, + # output_size=config.intermediate_size, + # bias=True, + # gather_output=False, + # stride=1, + # init_method=normal_init_method(mean=0.0, + # std=config.initializer_range)) + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super(BertOutput, self).__init__() + if hasattr(config, 'deep_init') and config.deep_init: + init_method = scaled_init_method( + mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + else: + init_method = normal_init_method( # noqa + mean=0.0, std=config.initializer_range) + self.dense = nn.Linear( + config.intermediate_size, config.hidden_size, bias=True) + # self.dense = mpu.RowParallelLinear( + # input_size=config.intermediate_size, + # output_size=config.hidden_size, + # bias=True, + # input_is_parallel=True, + # stride=1, + # init_method=init_method) + self.fp32_layernorm = config.fp32_layernorm + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + ln_input = hidden_states + input_tensor + previous_type = ln_input.type() + if self.fp32_layernorm: + ln_input = ln_input.float() + hidden_states = self.LayerNorm(ln_input) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super(BertEncoder, self).__init__() + # layer = BertLayer(config) + # self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)]) + + # def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + # all_encoder_layers = [] + # for layer_module in self.layer: + # hidden_states = layer_module(hidden_states, attention_mask) + # if output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + # if not output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + # return all_encoder_layers + def forward(self, + hidden_states, + attention_mask, + output_all_encoded_layers=True, + checkpoint_activations=False): + all_encoder_layers = [] + + def custom(start, end): + + def custom_forward(*inputs): + layers = self.layer[start:end] + x_ = inputs[0] + for layer in layers: + x_ = layer(x_, inputs[1]) + return x_ + + return custom_forward + + if checkpoint_activations: + l = 0 # noqa + num_layers = len(self.layer) + chunk_length = 1 # math.ceil(math.sqrt(num_layers)) + while l < num_layers: + hidden_states = mpu.checkpoint( + custom(l, l + chunk_length), hidden_states, + attention_mask * 1) + l += chunk_length # noqa + # decoder layers + else: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attention_mask) + + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + + if not output_all_encoded_layers or checkpoint_activations: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.fp32_layernorm = config.fp32_layernorm + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.LayerNorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + # self.decoder_weight = bert_model_embedding_weights + # self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) + # self.bias.model_parallel = True + self.fp32_embedding = config.fp32_embedding + self.fp32_layernorm = config.fp32_layernorm + + def convert_to_type(tensor): + if self.fp32_embedding: + return tensor.half() + else: + return tensor + + self.type_converter = convert_to_type + self.converted = False + + def forward(self, hidden_states): + if not self.converted: + self.converted = True + if self.fp32_embedding: + self.transform.half() + if self.fp32_layernorm: + self.transform.LayerNorm.float() + hidden_states = self.transform(self.type_converter(hidden_states)) + hidden_states = self.decoder(hidden_states) + self.bias + # hidden_states = mpu.copy_to_model_parallel_region(hidden_states) + # hidden_states = F.linear(self.type_converter(hidden_states), + # self.type_converter(self.decoder_weight), + # self.type_converter(self.bias)) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + for p in self.seq_relationship.parameters(): + if p is None: + continue + pooled_output = pooled_output.type_as(p) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class PreTrainedBertModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + def __init__(self, config, *inputs, **kwargs): + super(PreTrainedBertModel, self).__init__() + if not isinstance(config, BertConfig): + raise ValueError( + 'Parameter config in `{}(config)` should be an instance of class `BertConfig`. ' + 'To create a model from a Google pretrained model use ' + '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( + self.__class__.__name__, self.__class__.__name__)) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, + pretrained_model_name, + state_dict=None, + cache_dir=None, + fp32_layernorm=False, + fp32_embedding=False, + layernorm_epsilon=1e-12, + fp32_tokentypes=False, + *inputs, + **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ # noqa + if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] + else: + archive_file = pretrained_model_name + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path( + archive_file, cache_dir=cache_dir) + except FileNotFoundError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + 'associated to this path or url.'.format( + pretrained_model_name, + ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), + archive_file)) + return None + if resolved_archive_file == archive_file: + logger.info('loading archive file {}'.format(archive_file)) + else: + logger.info('loading archive file {} from cache at {}'.format( + archive_file, resolved_archive_file)) + tempdir = None + if os.path.isdir(resolved_archive_file): + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info('extracting archive file {} to temp dir {}'.format( + resolved_archive_file, tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = BertConfig.from_json_file(config_file) + config.fp32_layernorm = fp32_layernorm + config.fp32_embedding = fp32_embedding + config.layernorm_epsilon = layernorm_epsilon + config.fp32_tokentypes = fp32_tokentypes + logger.info('Model config {}'.format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, + True, missing_keys, unexpected_keys, + error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + if len(missing_keys) > 0: + print('Weights of {} not initialized from pretrained model: {}'. + format(model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + print('Weights from pretrained model not used in {}: {}'.format( + model.__class__.__name__, unexpected_keys)) + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + return model + + +class BertModel(PreTrainedBertModel): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Params: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLF`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ # noqa + + def __init__(self, config): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + output_all_encoded_layers=True, + checkpoint_activations=False): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder( + embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers, + checkpoint_activations=checkpoint_activations) + sequence_output = encoded_layers[-1] + for p in self.pooler.parameters(): + if p is None: + continue + sequence_output = sequence_output.type_as(p) + break + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers or checkpoint_activations: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class BertForPreTraining(PreTrainedBertModel): + """BERT model with pre-training heads. + This module comprises the BERT model followed by the two pre-training heads: + - the masked language modeling head, and + - the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `masked_lm_labels` and `next_sentence_label` are not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `masked_lm_labels` or `next_sentence_label` is `None`: + Outputs a tuple comprising + - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and + - the next sentence classification logits of shape [batch_size, 2]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForPreTraining(config) + masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config): + super(BertForPreTraining, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads( + config, self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + next_sentence_label=None, + checkpoint_activations=False): + sequence_output, pooled_output = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + checkpoint_activations=checkpoint_activations) + prediction_scores, seq_relationship_score = self.cls( + sequence_output, pooled_output) + + if masked_lm_labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size).float(), + masked_lm_labels.view(-1)) + next_sentence_loss = loss_fct( + seq_relationship_score.view(-1, 2).float(), + next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + return total_loss + else: + return prediction_scores, seq_relationship_score + + +class BertForMaskedLM(PreTrainedBertModel): + """BERT model with the masked language modeling head. + This module comprises the BERT model followed by the masked language modeling head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + + Outputs: + if `masked_lm_labels` is not `None`: + Outputs the masked language modeling loss. + if `masked_lm_labels` is `None`: + Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForMaskedLM(config) + masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config): + super(BertForMaskedLM, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config, + self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + checkpoint_activations=False): + sequence_output, _ = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + checkpoint_activations=checkpoint_activations) + prediction_scores = self.cls(sequence_output) + + if masked_lm_labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + masked_lm_labels.view(-1)) + return masked_lm_loss + else: + return prediction_scores + + +class BertForNextSentencePrediction(PreTrainedBertModel): + """BERT model with next sentence prediction head. + This module comprises the BERT model followed by the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `next_sentence_label` is not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `next_sentence_label` is `None`: + Outputs the next sentence classification logits of shape [batch_size, 2]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForNextSentencePrediction(config) + seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config): + super(BertForNextSentencePrediction, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + next_sentence_label=None, + checkpoint_activations=False): + _, pooled_output = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + checkpoint_activations=checkpoint_activations) + seq_relationship_score = self.cls(pooled_output) + + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + next_sentence_loss = loss_fct( + seq_relationship_score.view(-1, 2), + next_sentence_label.view(-1)) + return next_sentence_loss + else: + return seq_relationship_score + + +class BertForSequenceClassification(PreTrainedBertModel): + """BERT model for classification. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_labels]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_labels = 2 + + model = BertForSequenceClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config, num_labels=2): + super(BertForSequenceClassification, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + checkpoint_activations=False): + _, pooled_output = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + checkpoint_activations=checkpoint_activations) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + + +class BertForMultipleChoice(PreTrainedBertModel): + """BERT model for multiple choice tasks. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_choices`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` + and type 1 corresponds to a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_choices]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) + input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) + token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_choices = 2 + + model = BertForMultipleChoice(config, num_choices) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config): + super(BertForMultipleChoice, self).__init__(config) + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + checkpoint_activations=False): + batch_size, num_choices = input_ids.shape[:2] + flat_input_ids = input_ids.reshape(-1, input_ids.size(-1)) + flat_token_type_ids = token_type_ids.reshape(-1, + token_type_ids.size(-1)) + flat_attention_mask = attention_mask.reshape(-1, + attention_mask.size(-1)) + _, pooled_output = self.bert( + flat_input_ids, + flat_token_type_ids, + flat_attention_mask, + output_all_encoded_layers=False, + checkpoint_activations=checkpoint_activations) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.reshape(-1, num_choices) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + return loss + else: + return reshaped_logits + + +class BertForTokenClassification(PreTrainedBertModel): + """BERT model for token-level classification. + This module is composed of the BERT model with a linear layer on top of + the full hidden state of the last layer. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_labels]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_labels = 2 + + model = BertForTokenClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config, num_labels=2): + super(BertForTokenClassification, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + # self.classifier = mpu.RowParallelLinear( + # input_size=config.hidden_size, + # output_size=num_labels, + # bias=True, + # input_is_parallel=True, + # stride=1, + # init_method=normal_init_method(mean=0.0, + # std=config.initializer_range)) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + checkpoint_activations=False): + sequence_output, _ = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + checkpoint_activations=checkpoint_activations) + with mpu.get_cuda_rng_tracker().fork(): + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + + +class BertForQuestionAnswering(PreTrainedBertModel): + """BERT model for Question Answering (span extraction). + This module is composed of the BERT model with a linear layer on top of + the sequence output that computes start_logits and end_logits + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + + Outputs: + if `start_positions` and `end_positions` are not `None`: + Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. + if `start_positions` or `end_positions` is `None`: + Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end + position tokens of shape [batch_size, sequence_length]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForQuestionAnswering(config) + start_logits, end_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config): + super(BertForQuestionAnswering, self).__init__(config) + self.bert = BertModel(config) + # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version + # self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + # self.qa_outputs = mpu.RowParallelLinear( + # input_size=config.hidden_size, + # output_size=2, + # bias=True, + # input_is_parallel=True, + # stride=1, + # init_method=normal_init_method(mean=0.0, + # std=config.initializer_range)) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + start_positions=None, + end_positions=None, + checkpoint_activations=False): + sequence_output, _ = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + checkpoint_activations=checkpoint_activations) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + return total_loss + else: + return start_logits, end_logits diff --git a/modelscope/models/nlp/mglm/model/modeling_glm.py b/modelscope/models/nlp/mglm/model/modeling_glm.py new file mode 100644 index 00000000..80f61cef --- /dev/null +++ b/modelscope/models/nlp/mglm/model/modeling_glm.py @@ -0,0 +1,245 @@ +# Modified by Zhipu.AI +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GPT-2 model.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.nlp.mglm import mpu +from modelscope.models.nlp.mglm.model.prompt import PromptSpell +from modelscope.models.nlp.mglm.utils import print_rank_0 + + +def init_method_normal(std=0.02): + """Init method based on normal distribution. + + This is only used for embeddings. The transformer has its + own initializer. + """ + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +class GLMModel(torch.nn.Module): + """GLM Language model. + + The output of the forward method are the logits (parallel or + serial depending on the `parallel_output` flag. + """ + + def __init__( + self, + num_layers, + vocab_size, + hidden_size, + num_attention_heads, + embedding_dropout_prob, + attention_dropout_prob, + output_dropout_prob, + max_sequence_length, + max_memory_length, + checkpoint_activations, + checkpoint_num_layers=1, + parallel_output=True, + relative_encoding=False, + block_position_encoding=False, + output_predict=True, + spell_length=None, + spell_func='lstm', + attention_scale=1.0, + ): + + super(GLMModel, self).__init__() + + self.parallel_output = parallel_output + self.output_predict = output_predict + self.hidden_size = hidden_size + + init_method = init_method_normal(std=0.02) + + # Word embeddings (parallel). + self.word_embeddings = mpu.VocabParallelEmbedding( + vocab_size, hidden_size, init_method=init_method) + + # Transformer + self.transformer = mpu.GPT2ParallelTransformer( + num_layers, + hidden_size, + num_attention_heads, + max_sequence_length, + max_memory_length, + embedding_dropout_prob, + attention_dropout_prob, + output_dropout_prob, + checkpoint_activations, + checkpoint_num_layers, + attention_scale=attention_scale, + relative_encoding=relative_encoding, + block_position_encoding=block_position_encoding) + if spell_length is not None: + self.prompt_spell = PromptSpell(spell_length, self.hidden_size, + spell_func) + + def freeze_transformer(self, tune_prefix_layers=None): + log_str = 'Freeze transformer' + self.word_embeddings.requires_grad_(False) + self.transformer.requires_grad_(False) + if tune_prefix_layers is not None: + log_str += f' tune {tune_prefix_layers} prefix layers' + for i in range(tune_prefix_layers): + self.transformer.layers[i].requires_grad_(True) + print_rank_0(log_str) + + def forward(self, + input_ids, + position_ids, + attention_mask, + *mems, + return_memory=False, + detach_memory=True, + prompt_pos=None): + # Embeddings. + batch_size = input_ids.size(0) + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + if prompt_pos is not None: + embeddings = embeddings.clone() + prompt_embeds = self.prompt_spell() + batch_index = torch.arange( + batch_size, device=input_ids.device).unsqueeze(1) + embeddings[batch_index, prompt_pos] = prompt_embeds + # Transformer. + transformer_output = self.transformer( + embeddings, + position_ids, + attention_mask, + mems, + return_memory=return_memory, + detach_memory=detach_memory) + logits, hidden_layers = transformer_output + outputs = hidden_layers + + if self.output_predict: + # Parallel logits. + logits_parallel = mpu.copy_to_model_parallel_region(logits) + logits_parallel = F.linear(logits_parallel, + self.word_embeddings.weight) + + if self.parallel_output: + return (logits_parallel, *outputs) + + return (mpu.gather_from_model_parallel_region(logits_parallel), + *outputs) + else: + return (logits, *outputs) + + +class EncoderDecoder(torch.nn.Module): + """Seq2Seq Transformer Model + The output of the forward method are the logits (parallel or serial depending on the `parallel_output` flag). + """ + + def __init__(self, + num_layers, + vocab_size, + hidden_size, + num_attention_heads, + embedding_dropout_prob, + attention_dropout_prob, + output_dropout_prob, + max_sequence_length, + max_memory_length, + checkpoint_activations, + checkpoint_num_layers=1, + parallel_output=True, + output_predict=True): + super(EncoderDecoder, self).__init__() + + self.parallel_output = parallel_output + self.output_predict = output_predict + + init_method = init_method_normal(std=0.02) + + # Word embeddings (parallel). + self.word_embeddings = mpu.VocabParallelEmbedding( + vocab_size, hidden_size, init_method=init_method) + + # Transformer + self.encoder = mpu.GPT2ParallelTransformer( + num_layers, hidden_size, num_attention_heads, max_sequence_length, + max_memory_length, embedding_dropout_prob, attention_dropout_prob, + output_dropout_prob, checkpoint_activations, checkpoint_num_layers) + self.decoder = mpu.GPT2ParallelTransformer( + num_layers, + hidden_size, + num_attention_heads, + max_sequence_length, + max_memory_length, + embedding_dropout_prob, + attention_dropout_prob, + output_dropout_prob, + checkpoint_activations, + checkpoint_num_layers, + use_decoder_layer=True) + + def forward(self, source_ids, target_ids, source_position_ids, + target_position_ids, source_mask, target_mask): + # Embeddings. + source_embeddings = self.word_embeddings(source_ids) + target_embeddings = self.word_embeddings(target_ids) + + # Transformer. + encoder_output, _ = self.encoder(source_embeddings, + source_position_ids, source_mask) + decoder_output, _ = self.decoder(target_embeddings, + target_position_ids, target_mask) + if self.output_predict: + # Parallel logits. + output_parallel = mpu.copy_to_model_parallel_region(decoder_output) + logits_parallel = F.linear(output_parallel, + self.word_embeddings.weight) + + if self.parallel_output: + return (logits_parallel, ) + + return (mpu.gather_from_model_parallel_region(logits_parallel), ) + else: + return (decoder_output, ) + + +def glm_get_params_for_weight_decay_optimization(module): + weight_decay_params = {'params': []} + no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + for module_ in module.modules(): + if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)): + no_weight_decay_params['params'].extend([ + p for p in list(module_._parameters.values()) + if p is not None and p.requires_grad + ]) + else: + weight_decay_params['params'].extend([ + p for n, p in list(module_._parameters.items()) + if p is not None and p.requires_grad and n != 'bias' + ]) + no_weight_decay_params['params'].extend([ + p for n, p in list(module_._parameters.items()) + if p is not None and p.requires_grad and n == 'bias' + ]) + + return weight_decay_params, no_weight_decay_params diff --git a/modelscope/models/nlp/mglm/model/prompt.py b/modelscope/models/nlp/mglm/model/prompt.py new file mode 100644 index 00000000..a29ceda0 --- /dev/null +++ b/modelscope/models/nlp/mglm/model/prompt.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022 Zhipu.AI + +import random + +import torch + + +class PromptSpell(torch.nn.Module): + + def __init__(self, spell_length, hidden_size, spell_func): + super(PromptSpell, self).__init__() + self.spell_length = spell_length + self.hidden_size = hidden_size + self.spell_embeddings = torch.nn.Embedding(self.spell_length, + self.hidden_size) + self.spell_func = spell_func + if self.spell_func == 'lstm': + self.lstm_head = torch.nn.LSTM( + input_size=self.hidden_size, + hidden_size=self.hidden_size, + num_layers=2, + # dropout=self.lstm_dropout, + bidirectional=True, + batch_first=True) # .to(torch.device("cuda")) + self.mlp_head = torch.nn.Sequential( + torch.nn.Linear(2 * self.hidden_size, self.hidden_size), + torch.nn.ReLU(), + torch.nn.Linear(self.hidden_size, self.hidden_size)) + elif self.spell_func == 'mlp': + self.mlp_head = torch.nn.Sequential( + torch.nn.Linear(self.hidden_size, self.hidden_size), + torch.nn.ReLU(), + torch.nn.Linear(self.hidden_size, self.hidden_size)) + elif self.spell_func != 'none': + raise NotImplementedError('Prompt function ' + self.spell_func) + + def init_embedding(self, word_embeddings=None, task_tokens=None): + num_words = 5000 + with torch.no_grad(): + for i in range(self.spell_length): + rand_token = random.randrange(num_words) + if task_tokens is None: + target_embedding = word_embeddings[rand_token] + else: + word_embedding = word_embeddings[rand_token] + task_token = random.choice(task_tokens) + task_embedding = word_embeddings[task_token] + ratio = random.random() + target_embedding = word_embedding * ratio + task_embedding * ( + 1 - ratio) + self.spell_embeddings.weight.data[i] = target_embedding + + def forward(self): + prompt_embeds = self.spell_embeddings.weight.unsqueeze(0) + if self.spell_func == 'lstm': + prompt_embeds = self.lstm_head(prompt_embeds)[0] + if self.spell_func == 'lstm' or self.spell_func == 'mlp': + prompt_embeds = self.mlp_head(prompt_embeds) + return prompt_embeds diff --git a/modelscope/models/nlp/mglm/mpu/__init__.py b/modelscope/models/nlp/mglm/mpu/__init__.py new file mode 100755 index 00000000..8cca4e2c --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model parallel utility interface.""" + +from .cross_entropy import vocab_parallel_cross_entropy +from .data import broadcast_data +from .grads import clip_grad_norm +from .initialize import (destroy_model_parallel, get_data_parallel_group, + get_data_parallel_rank, get_data_parallel_world_size, + get_model_parallel_group, get_model_parallel_rank, + get_model_parallel_src_rank, + get_model_parallel_world_size, + initialize_model_parallel, + model_parallel_is_initialized) +from .layers import (ColumnParallelLinear, ParallelEmbedding, + RowParallelLinear, VocabParallelEmbedding) +from .mappings import (copy_to_model_parallel_region, + gather_from_model_parallel_region, + reduce_from_model_parallel_region, + scatter_to_model_parallel_region) +from .random import (checkpoint, get_cuda_rng_tracker, + model_parallel_cuda_manual_seed, + partition_activations_in_checkpoint) +from .transformer import (BertParallelSelfAttention, + BertParallelTransformerLayer, + GPT2ParallelTransformer, LayerNorm) diff --git a/modelscope/models/nlp/mglm/mpu/cross_entropy.py b/modelscope/models/nlp/mglm/mpu/cross_entropy.py new file mode 100644 index 00000000..2ebcf7a8 --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/cross_entropy.py @@ -0,0 +1,110 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from .initialize import (get_model_parallel_group, get_model_parallel_rank, + get_model_parallel_world_size) +from .utils import VocabUtility + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + + @staticmethod + def forward(ctx, vocab_parallel_logits, target): + + # Copy so the input remains unchanged. + logits = vocab_parallel_logits.clone() + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce( + logits_max, + op=torch.distributed.ReduceOp.MAX, + group=get_model_parallel_group()) + # Subtract the maximum value. + logits.sub_(logits_max.unsqueeze(dim=-1)) + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = logits.exp() + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_model_parallel_group()) + + # Get the partition's vocab indecies + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = get_model_parallel_rank() + world_size = get_model_parallel_world_size() + vocab_start_index, vocab_end_index = get_vocab_range( + partition_vocab_size, rank, world_size) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | ( + target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange( + start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce( + predicted_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_model_parallel_group()) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange( + start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, + masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +def vocab_parallel_cross_entropy(vocab_parallel_logits, target): + """Helper function for the cross entropy.""" + return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) diff --git a/modelscope/models/nlp/mglm/mpu/data.py b/modelscope/models/nlp/mglm/mpu/data.py new file mode 100644 index 00000000..6f595f0f --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/data.py @@ -0,0 +1,117 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from .initialize import (get_model_parallel_group, get_model_parallel_rank, + get_model_parallel_src_rank) + +_MAX_DATA_DIM = 5 + + +def _check_data_types(keys, data, target_dtype): + """Check that all the keys have the same target data type.""" + for key in keys: + assert data[key].dtype == target_dtype, '{} has data type {} which '\ + 'is different than {}'.format(key, data[key].dtype, target_dtype) + + +def _build_key_size_numel_dictionaries(keys, data): + """Build the size on rank 0 and broadcast.""" + max_dim = _MAX_DATA_DIM + sizes = [0 for _ in range(max_dim) for _ in keys] + + # Pack the sizes on rank zero. + if get_model_parallel_rank() == 0: + offset = 0 + for key in keys: + assert data[key].dim( + ) < max_dim, 'you should increase MAX_DATA_DIM' + size = data[key].size() + for i, s in enumerate(size): + sizes[i + offset] = s + offset += max_dim + + # Move to GPU and broadcast. + sizes_cuda = torch.cuda.LongTensor(sizes) + torch.distributed.broadcast( + sizes_cuda, + get_model_parallel_src_rank(), + group=get_model_parallel_group()) + + # Move back to cpu and unpack. + sizes_cpu = sizes_cuda.cpu() + key_size = {} + key_numel = {} + total_numel = 0 + offset = 0 + for key in keys: + i = 0 + size = [] + numel = 1 + while sizes_cpu[offset + i] > 0: + this_size = sizes_cpu[offset + i] + size.append(this_size) + numel *= this_size + i += 1 + key_size[key] = size + key_numel[key] = numel + total_numel += numel + offset += max_dim + + return key_size, key_numel, total_numel + + +def broadcast_data(keys, data, datatype): + """Broadcast data from rank zero of each model parallel group to the + members of the same model parallel group. + + Arguments: + keys: list of keys in the data disctionary to be broadcasted + data: data dictionary of string keys and cpu tensor values. + datatype: torch data type of all tensors in data associated + with keys. + """ + # Build (key, size) and (key, number of elements) dictionaries along + # with the total number of elements on all ranks. + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries( + keys, data) + + # Pack on rank zero. + if get_model_parallel_rank() == 0: + # Check that all keys have the same data type. + _check_data_types(keys, data, datatype) + # Flatten the data associated with the keys + flatten_data = torch.cat( + [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + else: + flatten_data = torch.empty( + total_numel, device=torch.cuda.current_device(), dtype=datatype) + + # Boradcast + torch.distributed.broadcast( + flatten_data, + get_model_parallel_src_rank(), + group=get_model_parallel_group()) + + # Unpack + output = {} + offset = 0 + for key in keys: + size = key_size[key] + numel = key_numel[key] + output[key] = flatten_data.narrow(0, offset, numel).view(size) + offset += numel + + return output diff --git a/modelscope/models/nlp/mglm/mpu/grads.py b/modelscope/models/nlp/mglm/mpu/grads.py new file mode 100644 index 00000000..a7dc6c5c --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/grads.py @@ -0,0 +1,72 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + +import torch +from torch._six import inf + +from .initialize import get_model_parallel_group, get_model_parallel_rank + + +def clip_grad_norm(parameters, max_norm, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + max_norm = float(max_norm) + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(p.grad.data.abs().max() for p in parameters) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + # Take max across all GPUs. + torch.distributed.all_reduce( + total_norm_cuda, + op=torch.distributed.ReduceOp.MAX, + group=get_model_parallel_group()) + total_norm = total_norm_cuda[0].item() + else: + total_norm = 0 + for p in parameters: + if p.model_parallel or (get_model_parallel_rank() == 0): + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item()**norm_type + # Sum across all model parallel GPUs. + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + torch.distributed.all_reduce( + total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=get_model_parallel_group()) + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1: + for p in parameters: + p.grad.data.mul_(clip_coef) + return total_norm diff --git a/modelscope/models/nlp/mglm/mpu/initialize.py b/modelscope/models/nlp/mglm/mpu/initialize.py new file mode 100644 index 00000000..33f8dbda --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/initialize.py @@ -0,0 +1,130 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model and data parallel groups.""" + +import torch + +from .utils import ensure_divisibility + +# Model parallel group that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None + + +def initialize_model_parallel(model_parallel_size_): + """ + Initialize model data parallel groups. + + Arguments: + model_parallel_size: number of GPUs used to parallelize model. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model. The present function will + create 4 model parallel groups and 2 data parallel grous as: + 4 model parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 data parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + if torch.distributed.get_rank() == 0: + print('> initializing model parallel with size {}'.format( + model_parallel_size_)) + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + model_parallel_size = min(model_parallel_size_, world_size) + ensure_divisibility(world_size, model_parallel_size) + rank = torch.distributed.get_rank() + + # Build the data parallel groups. + global _DATA_PARALLEL_GROUP + assert _DATA_PARALLEL_GROUP is None, \ + 'data parallel group is already initialized' + for i in range(model_parallel_size): + ranks = range(i, world_size, model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank % model_parallel_size): + _DATA_PARALLEL_GROUP = group + + # Build the model parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, \ + 'model parallel group is already initialized' + for i in range(world_size // model_parallel_size): + ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank // model_parallel_size): + _MODEL_PARALLEL_GROUP = group + + +def model_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, \ + 'model parallel group is not initialized' + return _MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_model_parallel_world_size(): + """Return world size for the model parallel group.""" + return torch.distributed.get_world_size(group=get_model_parallel_group()) + + +def get_model_parallel_rank(): + """Return my rank for the model parallel group.""" + return torch.distributed.get_rank(group=get_model_parallel_group()) + + +def get_model_parallel_src_rank(): + """Calculate the global rank corresponding to a local rank zeor + in the model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=get_data_parallel_group()) + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return torch.distributed.get_rank(group=get_data_parallel_group()) + + +def destroy_model_parallel(): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None diff --git a/modelscope/models/nlp/mglm/mpu/layers.py b/modelscope/models/nlp/mglm/mpu/layers.py new file mode 100644 index 00000000..4eb94b50 --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/layers.py @@ -0,0 +1,357 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + +import math + +import torch +import torch.nn.functional as F +import torch.nn.init as init +from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm +from torch.nn.parameter import Parameter + +from .initialize import get_model_parallel_rank, get_model_parallel_world_size +from .mappings import (copy_to_model_parallel_region, + gather_from_model_parallel_region, + reduce_from_model_parallel_region, + scatter_to_model_parallel_region) +from .random import get_cuda_rng_tracker +from .utils import VocabUtility, divide, split_tensor_along_last_dim + + +def _initialize_affine_weight(weight, + output_size, + input_size, + per_partition_size, + partition_dim, + init_method, + stride=1, + return_master_weight=False): + """Initialize affine weight for model parallel. + + Build the master weight on all processes and scatter + the relevant chunk.""" + # If we only use 1 process for model parallelism, bypass scatter. + world_size = get_model_parallel_world_size() + if world_size == 1: + init_method(weight) + if return_master_weight: + return weight + return None + + # Initialize master weight + master_weight = torch.empty( + output_size, input_size, dtype=weight.dtype, requires_grad=False) + init_method(master_weight) + + # Split and copy + per_partition_per_stride_size = divide(per_partition_size, stride) + weight_list = torch.split( + master_weight, per_partition_per_stride_size, dim=partition_dim) + rank = get_model_parallel_rank() + my_weight_list = weight_list[rank::world_size] + + with torch.no_grad(): + torch.cat(my_weight_list, dim=partition_dim, out=weight) + if return_master_weight: + return master_weight + return None + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + + def __init__(self, + num_embeddings, + embedding_dim, + init_method=init.xavier_normal_): + super(VocabParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set the detauls for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = \ + VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_model_parallel_rank(), + get_model_parallel_world_size()) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # noqa + + # Allocate weights. + self.weight = Parameter( + torch.Tensor(self.num_embeddings_per_partition, + self.embedding_dim)) + self.weight.model_parallel = True + # And initialize. + _initialize_affine_weight(self.weight, self.num_embeddings, + self.embedding_dim, + self.num_embeddings_per_partition, 0, + init_method) + + def forward(self, input_): + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | \ + (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + # Get the embeddings. + output_parallel = F.embedding(masked_input, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + # Mask the output embedding. + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_from_model_parallel_region(output_parallel) + return output + + +class ParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the embedding dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + + def __init__(self, + num_embeddings, + embedding_dim, + init_method=init.xavier_normal_, + keep_master_weight_for_test=False): + super(ParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set some detauls for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + # Divide the weight matrix along the embedding dimension. + world_size = get_model_parallel_world_size() + self.embedding_dim_per_partition = divide(self.embedding_dim, + world_size) + + # Allocate weights. + self.weight = Parameter( + torch.Tensor(self.num_embeddings, + self.embedding_dim_per_partition)) + self.weight.model_parallel = True + # And initialize. + _initialize_affine_weight( + self.weight, + self.num_embeddings, + self.embedding_dim, + self.embedding_dim_per_partition, + 1, + init_method, + stride=1, + return_master_weight=False) + + def forward(self, input_): + input_parallel = copy_to_model_parallel_region(input_) + output_parallel = F.embedding(input_parallel, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + output = gather_from_model_parallel_region(output_parallel) + return output + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + gather_output: If true, call all-gether on output and make Y avaiable + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + """ + + def __init__(self, + input_size, + output_size, + bias=True, + gather_output=True, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False): + super(ColumnParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = get_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, world_size) + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + self.weight = Parameter( + torch.Tensor(self.output_size_per_partition, self.input_size)) + self.weight.model_parallel = True + if bias: + self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) + self.bias.model_parallel = True + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + # Initialize weight. + self.master_weight = _initialize_affine_weight( + self.weight, + self.output_size, + self.input_size, + self.output_size_per_partition, + 0, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test) + + def forward(self, input_): + # Set up backprop all-reduce. + input_parallel = copy_to_model_parallel_region(input_) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_from_model_parallel_region(output_parallel) + else: + output = output_parallel + return output + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + """ + + def __init__(self, + input_size, + output_size, + bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False): + super(RowParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + # Divide the weight matrix along the last dimension. + world_size = get_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, world_size) + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + self.weight = Parameter( + torch.Tensor(self.output_size, self.input_size_per_partition)) + self.weight.model_parallel = True + if bias: + self.bias = Parameter(torch.Tensor(self.output_size)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + # Initialize weight. + self.master_weight = _initialize_affine_weight( + self.weight, + self.output_size, + self.input_size, + self.input_size_per_partition, + 1, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test) + + def forward(self, input_): + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + input_parallel = scatter_to_model_parallel_region(input_) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight) + # All-reduce across all the partitions. + output_ = reduce_from_model_parallel_region(output_parallel) + if self.bias is not None: + output = output_ + self.bias + else: + output = output_ + return output diff --git a/modelscope/models/nlp/mglm/mpu/mappings.py b/modelscope/models/nlp/mglm/mpu/mappings.py new file mode 100644 index 00000000..b3056dd7 --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/mappings.py @@ -0,0 +1,144 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from .initialize import get_model_parallel_group +from .utils import split_tensor_along_last_dim + + +def _reduce(input_): + """All-reduce the the input tensor across model parallel group.""" + group = get_model_parallel_group() + + # Bypass the function if we are using only 1 GPU. + if torch.distributed.get_world_size(group=group) == 1: + return input_ + + # All-reduce. + torch.distributed.all_reduce(input_, group=group) + + return input_ + + +def _split(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + group = get_model_parallel_group() + + # Bypass the function if we are using only 1 GPU. + if torch.distributed.get_world_size(group=group) == 1: + return input_ + + # Split along last dimension. + world_size = torch.distributed.get_world_size(group=group) + input_list = split_tensor_along_last_dim(input_, world_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = torch.distributed.get_rank(group=group) + output = input_list[rank].contiguous() + + return output + + +def _gather(input_): + """Gather tensors and concatinate along the last dimension.""" + group = get_model_parallel_group() + + # Bypass the function if we are using only 1 GPU. + if torch.distributed.get_world_size(group=group) == 1: + return input_ + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = torch.distributed.get_rank(group=group) + world_size = torch.distributed.get_world_size(group=group) + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-redcue the input from the model parallel region.""" + + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _ScatterToModelParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def forward(ctx, input_): + return _split(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output) + + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def forward(ctx, input_): + return _gather(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + + +# ----------------- +# Helper functions. +# ----------------- + + +def copy_to_model_parallel_region(input_): + return _CopyToModelParallelRegion.apply(input_) + + +def reduce_from_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) + + +def scatter_to_model_parallel_region(input_): + return _ScatterToModelParallelRegion.apply(input_) + + +def gather_from_model_parallel_region(input_): + return _GatherFromModelParallelRegion.apply(input_) diff --git a/modelscope/models/nlp/mglm/mpu/random.py b/modelscope/models/nlp/mglm/mpu/random.py new file mode 100755 index 00000000..2cdf236d --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/random.py @@ -0,0 +1,408 @@ +# Modified by Samyam Rajbhandari +# Used to partition the activations stored for backward propagation +# Therefore reduces the memory consumption + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch +import contextlib + +import torch +import torch.distributed as dist +from torch import _C +from torch.cuda import _lazy_call +from torch.cuda import device as device_ctx_manager + +from .initialize import (get_data_parallel_rank, get_model_parallel_group, + get_model_parallel_rank, + get_model_parallel_world_size) + +# from torch.utils.checkpoint import detach_variable + +PARTITION_ACTIVATIONS = False +PA_CORRECTNESS_TEST = False + + +def see_memory_usage(message, force=False): + if not force: + return + dist.barrier() + if dist.get_rank() == 0: + print(message) + print('Memory Allocated ', + torch.cuda.memory_allocated() / (1024 * 1024 * 1024), + 'GigaBytes') + print('Max Memory Allocated ', + torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), + 'GigaBytes') + print('Cache Allocated ', + torch.cuda.memory_cached() / (1024 * 1024 * 1024), 'GigaBytes') + print('Max cache Allocated ', + torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), + 'GigaBytes') + print(' ') + # input("Press Any Key To Continue ..") + + +mp_rank = None # get_model_parallel_rank() +mp_size = None # get_model_parallel_world_size() +mp_group = None # get_model_parallel_group() + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' +transport_stream = None +cuda_device = None + + +def detach_variable(inputs, device=None): + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + out.append(inp) + continue + + requires_grad = inp.requires_grad + + if device is not None: + x = inp.to(device=device) + else: + x = inp + + x = x.detach() + x.requires_grad = requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + 'Only tuple of tensors is supported. Got Unsupported input type: ', + type(inputs).__name__) + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Argumentss: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + else: + # newer PyTorch + if device == -1: + device = torch.device('cuda') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('cuda', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception('cuda rng state {} already exists'.format(name)) + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model paralle groups. This is used for + example for dropout in the non-model-parallel regions. + model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + model_parallel_seed = offset + get_model_parallel_rank() + # Data parallel gets the original sedd. + data_parallel_seed = seed + + if torch.distributed.get_rank() == 0: + print( + '> initializing model parallel cuda seeds on global rank {}, ' + 'model parallel rank {}, and data parallel rank {} with ' + 'model parallel seed: {} and data parallel seed: {}'.format( + torch.distributed.get_rank(), get_model_parallel_rank(), + get_data_parallel_rank(), model_parallel_seed, + data_parallel_seed), + flush=True) + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, + model_parallel_seed) + + +def get_partition_start(item): + global mp_rank, mp_size, mp_group + partition_size = get_partition_size(item) + start = partition_size * mp_rank + return int(start) + + +def get_partition_size(item): + global mp_rank, mp_size, mp_group + size = item.numel() + partition_size = size / mp_size + return int(partition_size) + + +def get_full_inputs(tensors): + inputs = [] + for i in range(int(len(tensors) / 2) - 1): + item = tensors[2 * i] + size = tensors[2 * i + 1] + partition_size = item.numel() + tensor_size = partition_size * mp_size + flat_tensor = torch.zeros([tensor_size], + dtype=item.dtype, + device=item.device) + partitions = [] + for i in range(mp_size): + part_i = flat_tensor.narrow(0, partition_size * i, partition_size) + if i == mp_rank: + part_i.copy_(item) + partitions.append(part_i) + dist.all_gather(partitions, partitions[mp_rank], group=mp_group) + input_tensor = flat_tensor.view(list(size.numpy())) + item.data = input_tensor.data + + inputs.append(item) + inputs.append(tensors[-2]) + + return tuple(inputs) + + +class CheckpointFunction(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + + @staticmethod + def forward(ctx, run_function, *args): + ctx.run_function = run_function + global mp_rank, mp_size, mp_group + if mp_rank is None: + mp_rank = get_model_parallel_rank() + mp_size = get_model_parallel_world_size() + mp_group = get_model_parallel_group() + + global cuda_device, transport_stream, PARTITION_ACTIVATIONS + if cuda_device is None: + if dist.get_rank() == 0: + print( + f'Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}' + ) + + cuda_device = torch.cuda.current_device() + # The transport stream is used to overlap the allgather communication for the activations + # with the computation in the backward pass + transport_stream = torch.cuda.Stream(device=cuda_device) + + if PARTITION_ACTIVATIONS: + inputs = [ + item.detach().contiguous().view(-1).narrow( + 0, get_partition_start(item), + get_partition_size(item)).clone() for item in args[:-1] + ] + inputs.append(args[-1]) + + # just in case something funky is happening such as reuse of inputs + inputs_cuda = [item.to(cuda_device) for item in args] + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # ctx.save_for_backward(*args) + with torch.no_grad(): + outputs = run_function(*inputs_cuda) + + del inputs_cuda + + if PARTITION_ACTIVATIONS: + new_args = [] + for arg, inp in zip(args, inputs): + size = torch.tensor(arg.size()) + arg.data = inp.data + new_args.append(arg) + new_args.append(size) + ctx.save_for_backward(*new_args) + else: + ctx.save_for_backward(*args) + + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError('Checkpointing is not compatible with .grad(), ' + 'please use .backward() if possible') + + global cuda_device, transport_stream, PARTITION_ACTIVATIONS + + if PARTITION_ACTIVATIONS: + with torch.cuda.stream(transport_stream): + inputs = get_full_inputs(ctx.saved_tensors) + detached_inputs = detach_variable(inputs) + else: + inputs = ctx.saved_tensors + detached_inputs = detach_variable(inputs) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + if PARTITION_ACTIVATIONS: + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(transport_stream) + + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs, ) + torch.autograd.backward(outputs, args) + return (None, ) + tuple(inp.grad for inp in detached_inputs) + + +def checkpoint(function, *args): + """Checkpoint a model or part of the model. + This has been directly copied from torch.utils.checkpoint.""" + return CheckpointFunction.apply(function, *args) + + +def partition_activations_in_checkpoint(partition_activation): + global PARTITION_ACTIVATIONS + PARTITION_ACTIVATIONS = partition_activation + if dist.get_rank() == 0: + print( + f'**************Partition Activations {PARTITION_ACTIVATIONS}************' + ) diff --git a/modelscope/models/nlp/mglm/mpu/tests/__init__.py b/modelscope/models/nlp/mglm/mpu/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/nlp/mglm/mpu/tests/commons.py b/modelscope/models/nlp/mglm/mpu/tests/commons.py new file mode 100644 index 00000000..ecfd5e72 --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/tests/commons.py @@ -0,0 +1,86 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random + +import mpu +import numpy +import torch + + +class IdentityLayer(torch.nn.Module): + + def __init__(self, size, scale=1.0): + super(IdentityLayer, self).__init__() + self.weight = torch.nn.Parameter(scale * torch.randn(size)) + + def forward(self): + return self.weight + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + + +def initialize_distributed(backend='nccl'): + """Initialize torch.distributed.""" + # Get local rank in case it is provided. + parser = argparse.ArgumentParser() + parser.add_argument( + '--local_rank', + type=int, + default=None, + help='local rank passed from distributed launcher') + args = parser.parse_args() + local_rank = args.local_rank + + # Get rank and world size. + rank = int(os.getenv('RANK', '0')) + world_size = int(os.getenv('WORLD_SIZE', '1')) + + print('> initializing torch.distributed with local rank: {}, ' + 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) + + # Set the device id. + device = rank % torch.cuda.device_count() + if local_rank is not None: + device = local_rank + torch.cuda.set_device(device) + + # Call the init process. + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + init_method=init_method) + + +def print_separator(message): + torch.distributed.barrier() + filler_len = (78 - len(message)) // 2 + filler = '-' * filler_len + string = '\n' + filler + ' {} '.format(message) + filler + if torch.distributed.get_rank() == 0: + print(string, flush=True) + torch.distributed.barrier() diff --git a/modelscope/models/nlp/mglm/mpu/tests/test_cross_entropy.py b/modelscope/models/nlp/mglm/mpu/tests/test_cross_entropy.py new file mode 100644 index 00000000..47fd1d7e --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/tests/test_cross_entropy.py @@ -0,0 +1,106 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys + +import mpu +import torch +import torch.nn.functional as F +from commons import (IdentityLayer, initialize_distributed, print_separator, + set_random_seed) +from mpu.cross_entropy import vocab_parallel_cross_entropy + +sys.path.append('../..') + + +def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, + seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), + scale=logits_scale).cuda() + logits = identity() + target = torch.cuda.LongTensor(size=(batch_size, + seq_length)).random_(0, vocab_size) + loss = F.cross_entropy( + logits.view(-1, + logits.size()[-1]), target.view(-1), + reduction='none').view_as(target).mean() + loss.backward() + return loss, identity.weight.grad + + +def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), + scale=logits_scale).cuda() + logits = identity() + logits_parallel = mpu.scatter_to_model_parallel_region(logits) + target = torch.cuda.LongTensor(size=(batch_size, + seq_length)).random_(0, vocab_size) + loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() + loss.backward() + return loss, identity.weight.grad + + +def test_cross_entropy(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cross entropy with model parallel size {} ...'.format( + model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + batch_size = 13 + seq_length = 17 + vocab_size_per_partition = 11 + logits_scale = 1000.0 + vocab_size = vocab_size_per_partition * model_parallel_size + seed = 1234 + + loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, + vocab_size, logits_scale, + seed) + loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, vocab_size, + logits_scale, seed) + + error = loss_torch.sub_(loss_mpu).abs().max() + print(' max error in loss on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = grad_torch.sub_(grad_mpu).abs().max() + print(' max error in grad on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test cross entropy') + test_cross_entropy(model_parallel_size) + model_parallel_size *= 2 diff --git a/modelscope/models/nlp/mglm/mpu/tests/test_data.py b/modelscope/models/nlp/mglm/mpu/tests/test_data.py new file mode 100644 index 00000000..66575300 --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/tests/test_data.py @@ -0,0 +1,91 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import operator +import sys + +import mpu +import torch +from commons import initialize_distributed, print_separator +from mpu import data as data_utils + +sys.path.append('../..') + + +def test_boradcast_data(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print( + '> testing boradcast_data with model parallel size {} ...'.format( + model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + torch.manual_seed(1234 + mpu.get_data_parallel_rank()) + model_parallel_size = mpu.get_model_parallel_world_size() + + key_size_t = { + 'key1': [7, 11], + 'key2': [8, 2, 1], + 'key3': [13], + 'key4': [5, 1, 2], + 'key5': [5, 12] + } + keys = list(key_size_t.keys()) + + data = {} + data_t = {} + for key in key_size_t: + data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) + data_t[key] = data[key].clone() + data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) + data_t['keyX'] = data['keyX'].clone() + if mpu.get_model_parallel_rank() != 0: + data = None + + data_utils._check_data_types(keys, data_t, torch.int64) + key_size, key_numel, \ + total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) + for key in keys: + assert key_size[key] == key_size_t[key] + total_numel_t = 0 + for key in keys: + target_size = functools.reduce(operator.mul, key_size_t[key], 1) + assert key_numel[key] == target_size + total_numel_t += target_size + assert total_numel == total_numel_t + + data_b = data_utils.broadcast_data(keys, data, torch.int64) + for key in keys: + tensor = data_t[key].cuda() + assert data_b[key].sub(tensor).abs().max() == 0 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test test boradcast data') + test_boradcast_data(model_parallel_size) + model_parallel_size *= 2 diff --git a/modelscope/models/nlp/mglm/mpu/tests/test_initialize.py b/modelscope/models/nlp/mglm/mpu/tests/test_initialize.py new file mode 100644 index 00000000..df62d213 --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/tests/test_initialize.py @@ -0,0 +1,95 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import mpu +import torch +from commons import initialize_distributed, print_separator + +sys.path.append('../..') + + +def test_initialize_model_parallel(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing initialize_model_parallel with size {} ...'.format( + model_parallel_size)) + model_parallel_size_ = min(model_parallel_size, + torch.distributed.get_world_size()) + assert not mpu.model_parallel_is_initialized() + mpu.initialize_model_parallel(model_parallel_size_) + assert mpu.model_parallel_is_initialized() + + # Checks. + def check(group, world_size, rank): + assert world_size == torch.distributed.get_world_size(group=group) + assert rank == torch.distributed.get_rank(group=group) + + # Model parallel. + world_size = model_parallel_size_ + rank = torch.distributed.get_rank() % model_parallel_size_ + assert world_size == mpu.get_model_parallel_world_size() + assert rank == mpu.get_model_parallel_rank() + check(mpu.get_model_parallel_group(), world_size, rank) + + # Data parallel. + world_size = torch.distributed.get_world_size() // model_parallel_size_ + rank = torch.distributed.get_rank() // model_parallel_size + assert world_size == mpu.get_data_parallel_world_size() + assert rank == mpu.get_data_parallel_rank() + check(mpu.get_data_parallel_group(), world_size, rank) + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_get_model_parallel_src_rank(model_parallel_size_): + + if torch.distributed.get_rank() == 0: + print('> testing get_model_parallel_src_rank with size {} ...'.format( + model_parallel_size_)) + model_parallel_size = min(model_parallel_size_, + torch.distributed.get_world_size()) + assert not mpu.model_parallel_is_initialized() + mpu.initialize_model_parallel(model_parallel_size) + assert mpu.model_parallel_is_initialized() + + # Checks + src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() + assert mpu.get_model_parallel_src_rank() == src_rank + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test initialize model parallel') + test_initialize_model_parallel(model_parallel_size) + print_separator('test model parallel source rank') + test_get_model_parallel_src_rank(model_parallel_size) + model_parallel_size *= 2 diff --git a/modelscope/models/nlp/mglm/mpu/tests/test_layers.py b/modelscope/models/nlp/mglm/mpu/tests/test_layers.py new file mode 100644 index 00000000..2dbc987a --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/tests/test_layers.py @@ -0,0 +1,533 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys + +import mpu +import torch +import torch.nn.init as init +from commons import initialize_distributed, print_separator, set_random_seed +from mpu import layers +from torch.nn.parameter import Parameter + +sys.path.append('../..') + + +def test_parallel_embedding(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing parallel embedding with model parallel size {} ...'. + format(model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + batch_size = 17 + seq_length = 23 + vocab_size = 48 + hidden_size = 16 + seed = 1236 + + set_random_seed(123) + input_data = torch.LongTensor(size=(batch_size, seq_length)).random_( + 0, vocab_size).cuda() + loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() + + set_random_seed(seed) + embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() + + output = embedding_original(input_data) + loss_original = torch.mul(output, loss_weight).sum() + loss_original.backward() + + set_random_seed(seed) + embedding_parallel = layers.ParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).cuda() + output = embedding_parallel(input_data) + loss_parallel = torch.mul(output, loss_weight).sum() + loss_parallel.backward() + + set_random_seed(seed) + embedding_vocab_parallel = layers.VocabParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).cuda() + output = embedding_vocab_parallel(input_data) + loss_vocab_parallel = torch.mul(output, loss_weight).sum() + loss_vocab_parallel.backward() + + torch.distributed.barrier() + error = loss_parallel.sub(loss_original).abs() + print(' error in loss (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + torch.distributed.barrier() + error = loss_vocab_parallel.sub(loss_original).abs() + print(' error in loss (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + hidden_size // model_parallel_size, + 1)[mpu.get_model_parallel_rank()] + error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() + print(' error in grad (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + vocab_size // model_parallel_size, + 0)[mpu.get_model_parallel_rank()] + error = embedding_vocab_parallel.weight.grad.sub( + weight_grad_orig).abs().max() + print(' error in grad (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_initialize_affine_weight(model_parallel_size): + + mpu.initialize_model_parallel(model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing initialize_affine_weight with model parallel ' + 'size: {}'.format(model_parallel_size)) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + input_size_coeff = 13 + input_size = input_size_coeff * model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * model_parallel_size + + # --------------- + # Column parallel + # --------------- + weight = torch.empty(output_size_coeff, input_size) + set_random_seed(seed) + layers._initialize_affine_weight(weight, output_size, input_size, + output_size_coeff, 0, + torch.nn.init.normal_) + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = mpu.get_model_parallel_rank() + my_weight = torch.split( + master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' column parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # ------------ + # Row parallel + # ------------ + weight = torch.empty(output_size, input_size_coeff) + set_random_seed(seed) + mpu.layers._initialize_affine_weight(weight, output_size, input_size, + input_size_coeff, 1, + torch.nn.init.normal_) + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = mpu.get_model_parallel_rank() + my_weight = torch.split( + master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' row parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer2D(torch.nn.Module): + + def __init__(self, m, n): + super(IdentityLayer2D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n)) + torch.nn.init.xavier_normal_(self.weight) + + def forward(self): + return self.weight + + +def test_column_parallel_linear(model_parallel_size): + + mpu.initialize_model_parallel(model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing ColumnParallelLinear with model parallel ' + 'size: {}'.format(model_parallel_size)) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).cuda() + linear_layer = mpu.ColumnParallelLinear( + input_size, output_size, keep_master_weight_for_test=True).cuda() + loss_weight = torch.randn([batch_size, output_size]).cuda() + # Forward + input_ = identity_layer() + output = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.cuda() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = mpu.get_model_parallel_rank() + my_dLdA = torch.split( + dLdA, output_size_coeff, dim=0)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + my_dLdb = torch.split( + dLdb, output_size_coeff, dim=0)[rank].contiguous().clone() + error = my_dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +def test_row_parallel_linear(model_parallel_size): + + mpu.initialize_model_parallel(model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing RowParallelLinear with model parallel ' + 'size: {}'.format(model_parallel_size)) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).cuda() + linear_layer = mpu.RowParallelLinear( + input_size, output_size, keep_master_weight_for_test=True).cuda() + loss_weight = torch.randn([batch_size, output_size]).cuda() + # Forward + input_ = identity_layer() + output = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.cuda() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = mpu.get_model_parallel_rank() + my_dLdA = torch.split( + dLdA, input_size_coeff, dim=1)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer3D(torch.nn.Module): + + def __init__(self, m, n, k): + super(IdentityLayer3D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n, k)) + torch.nn.init.xavier_normal_(self.weight) + + def forward(self): + return self.weight + + +def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, + sequence_length): + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * torch.distributed.get_world_size( + ) # noqa + hidden_size = hidden_size_per_att_head * num_att_heads + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).cuda() + attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, + dropout_prob).cuda() + loss_weight = torch.randn([batch_size, sequence_length, + hidden_size]).cuda() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() + # Forward + input_ = identity_layer() + output = attention_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = mpu.get_model_parallel_rank() + mpu.destroy_model_parallel() + return rank, hidden_size, model_parallel_size, loss, \ + attention_layer, identity_layer + + +def test_parallel_self_attention(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelSelfAttention with model parallel ' + 'size: {}'.format(model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + dropout_prob = 0.0 # has to be zero + batch_size = 5 + sequence_length = 13 + + rank_1, hideen_size_1, model_parallel_size_1, loss_1, \ + attention_layer_1, identity_layer_1 = parallel_self_attention( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + + rank, hidden_size, model_parallel_size, loss, \ + attention_layer, identity_layer = parallel_self_attention( + model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + assert hideen_size_1 == hidden_size + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + my_lin_grad_list = torch.split( + attention_layer_1.query_key_value.weight.grad, + hidden_size // model_parallel_size, 0)[rank::model_parallel_size] + my_lin_grad = torch.cat(my_lin_grad_list, dim=0) + error = my_lin_grad.sub( + attention_layer.query_key_value.weight.grad).abs().max() + torch.distributed.barrier() + print(' weight gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +def parallel_transformer(model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, + sequence_length): + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * torch.distributed.get_world_size( + ) + hidden_size = hidden_size_per_att_head * num_att_heads + intermediate_size = 4 * hidden_size + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).cuda() + transformer_layer = mpu.BertParallelTransformerLayer( + hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, + torch.nn.functional.relu, 1.0e-5).cuda() + + loss_weight = torch.randn([batch_size, sequence_length, + hidden_size]).cuda() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() + # Forward + input_ = identity_layer() + output = transformer_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = mpu.get_model_parallel_rank() + mpu.destroy_model_parallel() + return rank, hidden_size, model_parallel_size, loss, \ + transformer_layer, identity_layer + + +def test_parallel_transformer_layer(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelTransformerLayer with model parallel ' + 'size: {}'.format(model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + batch_size = 5 + sequence_length = 13 + + rank_1, hidden_size_1, model_parallel_size_1, loss_1, \ + transformer_layer_1, identity_layer_1 = parallel_transformer( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + rank, hidden_size, model_parallel_size, loss, \ + transformer_layer, identity_layer = parallel_transformer( + model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-5, 'error: {}'.format(error) + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-5, 'error: {}'.format(error) + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +if __name__ == '__main__': + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + print_separator('test initialize affine weight') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_initialize_affine_weight(model_parallel_size) + model_parallel_size *= 2 + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test parallel embedding') + test_parallel_embedding(model_parallel_size) + model_parallel_size *= 2 + + print_separator('test column-parallel linear') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_column_parallel_linear(model_parallel_size) + model_parallel_size *= 2 + + print_separator('test row-parallel linear') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_row_parallel_linear(model_parallel_size) + model_parallel_size *= 2 + + print_separator('test parallel self-attention') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_parallel_self_attention(model_parallel_size) + model_parallel_size *= 2 + + print_separator('test parallel transformer') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_parallel_transformer_layer(model_parallel_size) + model_parallel_size *= 2 diff --git a/modelscope/models/nlp/mglm/mpu/tests/test_random.py b/modelscope/models/nlp/mglm/mpu/tests/test_random.py new file mode 100644 index 00000000..55cc2351 --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/tests/test_random.py @@ -0,0 +1,206 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import mpu +import torch +from commons import initialize_distributed, print_separator + +sys.path.append('../..') + + +def test_set_cuda_rng_state(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing set_rng_state with size {} ...'.format( + model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + size = 123 + seed = 1234 + torch.cuda.manual_seed(seed) + tensor = torch.cuda.FloatTensor(size) + + # Get the state + rng_state = torch.cuda.get_rng_state() + rng_state_copy = rng_state.clone() + + # Do some stuff. + for _ in range(5): + torch.randn(size, out=tensor) + result_1 = tensor.clone() + + assert rng_state.sub(rng_state_copy).max() == 0 + assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 + + # State should be different. + new_rng_state = torch.cuda.get_rng_state() + max_diff = new_rng_state.sub(rng_state).max() + print( + ' max diff in rng state (should be non-zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), max_diff)) + assert max_diff > 0 + + # Reset the rng state and do the same stuff. + mpu.random._set_cuda_rng_state(rng_state) + for _ in range(5): + torch.randn(size, out=tensor) + mpu.random._set_cuda_rng_state(rng_state) + for _ in range(5): + torch.randn(size, out=tensor) + result_2 = tensor.clone() + + # Results should be the same + error = result_2.sub(result_1).abs().max() + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Input state should have remained intact. + error = rng_state.sub(rng_state_copy).max() + print(' max error in rng state (should be zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), error)) + assert error == 0 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_cuda_rng_tracker(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cuda rng tracker with size {} ...'.format( + model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed_1 = 1234 + seed_2 = 4321 + size = [12, 21] + tensor = torch.cuda.FloatTensor(size) + + # Set to seed_1 and generate two tensors. + torch.cuda.manual_seed(seed_1) + torch.randn(size, out=tensor) + target_11 = tensor.clone() + torch.randn(size, out=tensor) + target_12 = tensor.clone() + + # Set to seed_2 and generate two tensors. + torch.cuda.manual_seed(seed_2) + torch.randn(size, out=tensor) + target_21 = tensor.clone() + torch.randn(size, out=tensor) + target_22 = tensor.clone() + + # Now if we interleave seed_1 and seed_2, + # we should still get the same tensors + torch.cuda.manual_seed(seed_1) + mpu.get_cuda_rng_tracker().add('test', seed_2) + + torch.randn(size, out=tensor) + result_11 = tensor.clone() + + with mpu.get_cuda_rng_tracker().fork('test'): + torch.randn(size, out=tensor) + result_21 = tensor.clone() + + torch.randn(size, out=tensor) + result_12 = tensor.clone() + + with mpu.get_cuda_rng_tracker().fork('test'): + torch.randn(size, out=tensor) + result_22 = tensor.clone() + + diff = result_11.sub(result_21).abs().max() + diff = min(diff, result_12.sub(result_22).abs().max()) + print(' max diff in generated tensors (should be non-zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) + assert diff > 1.0e-6 + error = max( + result_11.sub(target_11).abs().max(), + result_12.sub(target_12).abs().max()) + error = max(error, result_21.sub(target_21).abs().max()) + error = max(error, result_22.sub(target_22).abs().max()) + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset the tracker + mpu.get_cuda_rng_tracker().reset() + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_model_parallel_cuda_manual_seed(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing model parallel cuda manual seed with size {} ...'. + format(model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + mpu.model_parallel_cuda_manual_seed(12345) + assert torch.cuda.initial_seed() == 12345 + with mpu.get_cuda_rng_tracker().fork(): + assert torch.cuda.initial_seed() == (12345 + 2718 + + mpu.get_model_parallel_rank()) + + # Reset the tracker + mpu.get_cuda_rng_tracker().reset() + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test set rng state') + test_set_cuda_rng_state(model_parallel_size) + model_parallel_size *= 2 + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test cuda rng tracker') + test_cuda_rng_tracker(model_parallel_size) + model_parallel_size *= 2 + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test model parallel cuda manual seed') + test_model_parallel_cuda_manual_seed(model_parallel_size) + model_parallel_size *= 2 diff --git a/modelscope/models/nlp/mglm/mpu/transformer.py b/modelscope/models/nlp/mglm/mpu/transformer.py new file mode 100755 index 00000000..c12b2e10 --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/transformer.py @@ -0,0 +1,1200 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformer.""" + +import math + +import deepspeed +import torch +import torch.nn.init as init +from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm + +from .initialize import get_model_parallel_world_size +from .layers import ColumnParallelLinear, RowParallelLinear +from .mappings import gather_from_model_parallel_region +from .random import checkpoint, get_cuda_rng_tracker +from .utils import divide, split_tensor_along_last_dim + + +class PositionalEmbedding(torch.nn.Module): + + def __init__(self, hidden_size): + super(PositionalEmbedding, self).__init__() + + self.hidden_size = hidden_size + + inv_freq = 1 / ( + 10000**(torch.arange(0.0, hidden_size, 2.0) / hidden_size)) # noqa + self.register_buffer('inv_freq', inv_freq) + + def forward(self, pos_seq, bsz=None): + sinusoid_inp = torch.ger(pos_seq, self.inv_freq) + pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) + + if bsz is not None: + return pos_emb[None, :, :].expand(bsz, -1, -1) + else: + return pos_emb[None, :, :] + + +class ParallelCrossAttention(torch.nn.Module): + """Parallel cross-attention layer for Transformer""" + + def __init__(self, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + init_method, + output_layer_init_method=None): + super(ParallelCrossAttention, self).__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + # Per attention head and per partition values. + world_size = get_model_parallel_world_size() + self.hidden_size_per_partition = divide(hidden_size, world_size) + self.hidden_size_per_attention_head = divide(hidden_size, + num_attention_heads) + self.num_attention_heads_per_partition = divide( + num_attention_heads, world_size) + # Strided linear layer. + self.query = ColumnParallelLinear( + hidden_size, + hidden_size, + gather_output=False, + init_method=init_method) + self.key_value = ColumnParallelLinear( + hidden_size, + 2 * hidden_size, + stride=2, + gather_output=False, + init_method=init_method) + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) + + # Output. + self.dense = RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + self.output_dropout = torch.nn.Dropout(output_dropout_prob) + + if deepspeed.checkpointing.is_configured(): + global get_cuda_rng_tracker, checkpoint + get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker + checkpoint = deepspeed.checkpointing.checkpoint + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + \ + (self.num_attention_heads_per_partition, # noqa + self.hidden_size_per_attention_head) # noqa + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def forward(self, hidden_states, encoder_states, cross_mask): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Attention heads. [b, s, hp] + mixed_query_layer = self.query(hidden_states) + mixed_x_layer = self.key_value(encoder_states) + (mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) + + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + # Raw attention scores. [b, np, s, s] + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.hidden_size_per_attention_head) + if cross_mask is not None: + # Apply the left to right attention mask. + attention_scores = torch.mul(attention_scores, cross_mask) - \ + 10000.0 * (1.0 - cross_mask) # noqa + + # Attention probabilities. [b, np, s, s] + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) # noqa + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = self.dense(context_layer) + output = self.output_dropout(output) + + return output + + +class ParallelSelfAttention(torch.nn.Module): + """Parallel self-attention layer for GPT2. + + Self-attention layer takes input with size [b, s, h] where b is + the batch size, s is the sequence lenght, and h is the hidden size + and creates output of the same size. + Arguments: + hidden_size: total hidden size of the layer (h). + num_attention_heads: number of attention heads (n). Note that we + require n to be divisible by number of GPUs + used to parallelize the model. Also, we + require hidden size to be divisible by n. + attention_dropout_prob: dropout probability for the attention scores. + init_method: weight initialization. + output_layer_init_method: output layer initialization. If None, use + `init_method`. + We use the following notation: + h: hidden_size + n: num_attention_heads + p: number of partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + """ + + def __init__(self, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + init_method, + output_layer_init_method=None, + relative_encoding=False, + performer=False, + attention_scale=1.0): + super(ParallelSelfAttention, self).__init__() + self.performer = performer + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + # Per attention head and per partition values. + world_size = get_model_parallel_world_size() + self.hidden_size_per_partition = divide(hidden_size, world_size) + self.hidden_size_per_attention_head = divide(hidden_size, + num_attention_heads) + self.num_attention_heads_per_partition = divide( + num_attention_heads, world_size) + self.relative_encoding = relative_encoding + self.attention_scale = attention_scale + # Strided linear layer. + self.query_key_value = ColumnParallelLinear( + hidden_size, + 3 * hidden_size, + stride=3, + gather_output=False, + init_method=init_method) + if relative_encoding: + self.relative = ColumnParallelLinear( + hidden_size, + hidden_size, + gather_output=False, + init_method=init_method) + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) + + # Output. + self.dense = RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + self.output_dropout = torch.nn.Dropout(output_dropout_prob) + + if deepspeed.checkpointing.is_configured(): + global get_cuda_rng_tracker, checkpoint + get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker + checkpoint = deepspeed.checkpointing.checkpoint + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + \ + (self.num_attention_heads_per_partition, # noqa + self.hidden_size_per_attention_head) # noqa + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + @staticmethod + def _rel_shift(x, zero_triu=False): + # ql x kl x bsz x h + # bsz x h x ql x kl + zero_pad = torch.zeros((*x.size()[:-2], x.size(-2), 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:-2], x.size(-1) + 1, x.size(-2)) + + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(0), x.size(1))) + x = x * torch.tril(ones, x.size(1) - x.size(0))[:, :, None, None] + + return x + + def forward(self, + hidden_states, + ltor_mask, + position_embeddings=None, + r_w_bias=None, + r_r_bias=None, + mem=None): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Attention heads. [b, s, hp] + query_length = hidden_states.size(1) + + if mem is None: + mixed_x_layer = self.query_key_value(hidden_states) + (mixed_query_layer, mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim( + mixed_x_layer, 3) + else: + cat = torch.cat((mem, hidden_states), 1) + mixed_x_layer = self.query_key_value(cat) + (mixed_query_layer, mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim( + mixed_x_layer, 3) + mixed_query_layer = mixed_query_layer[:, -query_length:] + + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + if self.relative_encoding: + relative_layer = self.relative(position_embeddings) + relative_layer = self._transpose_for_scores( + relative_layer) # 1 (bsz) x n_head x klen x d_head + # Raw attention scores. [b, np, qs, ks] + rw_head_q = query_layer + r_w_bias.unsqueeze(1) + ac_score = torch.matmul(rw_head_q, key_layer.transpose(-1, -2)) + rr_head_q = query_layer + r_r_bias.unsqueeze(1) + bd_score = torch.matmul(rr_head_q, + relative_layer.transpose(-1, -2)) + bd_score = self._rel_shift(bd_score) # qlen x klen x bsz x n_head + # bd_score = bd_score.permute(2, 3, 0, 1) # bsz n_head qlen klen + + attention_scores = ac_score + bd_score + attention_scores = attention_scores / math.sqrt( + self.hidden_size_per_attention_head) + else: + if self.attention_scale > 1.0: + # Raw attention scores. [b, np, s, s] + attention_scores = torch.matmul( + query_layer / math.sqrt(self.attention_scale), + key_layer.transpose(-1, -2) + / math.sqrt(self.hidden_size_per_attention_head + * self.attention_scale)) + else: + attention_scores = torch.matmul( + query_layer, + key_layer.transpose(-1, -2) + / math.sqrt(self.hidden_size_per_attention_head)) + + # Apply the left to right attention mask. + attention_scores = torch.mul(attention_scores, ltor_mask) + if self.attention_scale > 1.0: + max_attention_scores = attention_scores.max( + dim=-1, keepdim=True)[0] + attention_scores -= max_attention_scores + attention_scores *= self.attention_scale + # if torch.distributed.get_rank() == 0: + # print(min_attention_scores, attention_scores.max().item()) + attention_scores = attention_scores + (-65504.0) * (1.0 - ltor_mask) + # Attention probabilities. [b, np, s, s] + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) # noqa + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = self.dense(context_layer) + output = self.output_dropout(output) + + return output + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * ( + 1.0 + torch.tanh(0.7978845608028654 * x * # noqa + (1.0 + 0.044715 * x * x))) # noqa + + +def gelu(x): + return gelu_impl(x) + + +class ParallelMLP(torch.nn.Module): + """MLP for GPT2. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform gelu transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + + Arguments: + hidden_size: The hidden size of the self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + init_method: initialization method used for the weights. Note + that all biases are initialized to zero and + layernorm weight are initialized to one. + output_layer_init_method: output layer initialization. If None, + use `init_method`. + """ + + def __init__(self, + hidden_size, + output_dropout_prob, + init_method, + output_layer_init_method=None): + super(ParallelMLP, self).__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + # Project to 4h. + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + gather_output=False, + init_method=init_method) + # Project back to h. + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + self.dropout = torch.nn.Dropout(output_dropout_prob) + + def forward(self, hidden_states): + # [b, s, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = gelu(intermediate_parallel) + + # [b, s, h] + output = self.dense_4h_to_h(intermediate_parallel) + output = self.dropout(output) + return output + + +class ParallelDecoderLayer(torch.nn.Module): + """A single layer transformer for GPT2. + + We use the following notation: + h: hidden size + n: number of attention heads + b: batch size + s: sequence length + Transformore layer takes input with size [b, s, h] and returns an + output of the same size. + + Arguments: + hidden_size: The hidden size of the self attention. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + init_method: initialization method used for the weights. Note + that all biases are initialized to zero and + layernorm weight are initialized to one. + output_layer_init_method: output layers (attention output and + mlp output) initialization. If None, + use `init_method`. + """ + + def __init__(self, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + layernorm_epsilon, + init_method, + output_layer_init_method=None): + super(ParallelDecoderLayer, self).__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + + # Layernorm on the input data. + self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + + # Self attention. + self.self_attention = ParallelSelfAttention( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + init_method, + output_layer_init_method=output_layer_init_method) + + # Layernorm after the self attention. + self.post_self_layernorm = LayerNorm( + hidden_size, eps=layernorm_epsilon) + + self.cross_attention = ParallelCrossAttention( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + init_method, + output_layer_init_method=output_layer_init_method) + + # Layernorm after the cross attention. + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=layernorm_epsilon) + + # MLP + self.mlp = ParallelMLP( + hidden_size, + output_dropout_prob, + init_method, + output_layer_init_method=output_layer_init_method) + + def forward(self, + hidden_states, + encoder_states, + ltor_mask, + cross_mask=None): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Layer norm at the begining of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + self_attention_output = self.self_attention(layernorm_output, + ltor_mask) + # Residual connection. + self_layernorm_input = hidden_states + self_attention_output + # Layer norm post the self attention. + self_layernorm_output = self.post_self_layernorm(self_layernorm_input) + # Cross attention + attention_output = self.cross_attention(self_layernorm_output, + encoder_states, cross_mask) + # Residual connection + layernorm_input = self_layernorm_input + attention_output + # Layer norm post the cross attention + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + # Second residual connection. + output = layernorm_input + mlp_output + return output + + +class ParallelTransformerLayer(torch.nn.Module): + """A single layer transformer for GPT2. + + We use the following notation: + h: hidden size + n: number of attention heads + b: batch size + s: sequence length + Transformore layer takes input with size [b, s, h] and returns an + output of the same size. + + Arguments: + hidden_size: The hidden size of the self attention. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + init_method: initialization method used for the weights. Note + that all biases are initialized to zero and + layernorm weight are initialized to one. + output_layer_init_method: output layers (attention output and + mlp output) initialization. If None, + use `init_method`. + """ + + def __init__(self, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + layernorm_epsilon, + init_method, + output_layer_init_method=None, + relative_encoding=False, + performer=False, + attention_scale=1.0): + super(ParallelTransformerLayer, self).__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + + # Layernorm on the input data. + self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + + # Self attention. + self.attention = ParallelSelfAttention( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + init_method, + output_layer_init_method=output_layer_init_method, + relative_encoding=relative_encoding, + performer=performer, + attention_scale=attention_scale) + + # Layernorm on the input data. + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=layernorm_epsilon) + + # MLP + self.mlp = ParallelMLP( + hidden_size, + output_dropout_prob, + init_method, + output_layer_init_method=output_layer_init_method) + + def forward(self, + hidden_states, + ltor_mask, + position_embeddings=None, + r_w_bias=None, + r_r_bias=None, + mem=None): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Layer norm at the begining of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + mem = self.input_layernorm(mem) if mem is not None else None + # Self attention. + attention_output = self.attention(layernorm_output, ltor_mask, + position_embeddings, r_w_bias, + r_r_bias, mem) + # Residual connection. + layernorm_input = hidden_states + attention_output + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + # Second residual connection. + output = layernorm_input + mlp_output + + return output + + +def unscaled_init_method(sigma): + """Init method based on N(0, sigma).""" + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +class GPT2ParallelTransformer(torch.nn.Module): + """GPT-2 transformer. + + This module takes input from embedding layer and it's output can + be used directly by a logit layer. It consists of L (num-layers) + blocks of: + layer norm + self attention + residual connection + layer norm + mlp + residual connection + followed by a final layer norm. + + Arguments: + num_layers: Number of transformer layers. + hidden_size: The hidden size of the self attention. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + checkpoint_activations: if True, checkpoint activations. + checkpoint_num_layers: number of layers to checkpoint. This + is basically the chunk size in checkpoitning. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + init_method_std: standard deviation of the init method which has + the form N(0, std). + use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers) + scaling for the output weights ( + output of self attention and mlp). + """ + + def __init__( + self, + num_layers, + hidden_size, + num_attention_heads, + max_sequence_length, + max_memory_length, + embedding_dropout_prob, + attention_dropout_prob, + output_dropout_prob, + checkpoint_activations, + checkpoint_num_layers=1, + layernorm_epsilon=1.0e-5, + init_method_std=0.02, + use_scaled_init_for_output_weights=True, + relative_encoding=False, + block_position_encoding=False, + performer=False, + use_decoder_layer=False, + attention_scale=1.0, + ): + super(GPT2ParallelTransformer, self).__init__() + self.hidden_size = hidden_size + # Store activation checkpoiting flag. + self.checkpoint_activations = checkpoint_activations + self.checkpoint_num_layers = checkpoint_num_layers + self.max_memory_length = max_memory_length + self.performer = performer + self.use_decoder_layer = use_decoder_layer + assert not (performer and relative_encoding) + + output_layer_init_method = None + if use_scaled_init_for_output_weights: + output_layer_init_method = scaled_init_method( + init_method_std, num_layers) + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + self.relative_encoding = relative_encoding + self.block_position_encoding = block_position_encoding + if relative_encoding: + # Relative position embedding + self.position_embeddings = PositionalEmbedding(hidden_size) + # Per attention head and per partition values. + world_size = get_model_parallel_world_size() + self.hidden_size_per_attention_head = divide( + hidden_size, num_attention_heads) + self.num_attention_heads_per_partition = divide( + num_attention_heads, world_size) + self.r_w_bias = torch.nn.Parameter( + torch.Tensor(self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + self.r_w_bias.model_parallel = True + self.r_r_bias = torch.nn.Parameter( + torch.Tensor(self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + self.r_r_bias.model_parallel = True + # Always initialize bias to zero. + with torch.no_grad(): + self.r_w_bias.zero_() + self.r_r_bias.zero_() + else: + # Position embedding (serial). + if block_position_encoding: + self.position_embeddings = torch.nn.Embedding( + max_sequence_length + 1, hidden_size) + self.block_position_embeddings = torch.nn.Embedding( + max_sequence_length + 1, hidden_size) + torch.nn.init.normal_( + self.block_position_embeddings.weight, + mean=0.0, + std=init_method_std) + else: + self.position_embeddings = torch.nn.Embedding( + max_sequence_length, hidden_size) + # Initialize the position embeddings. + torch.nn.init.normal_( + self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def get_layer(): + if use_decoder_layer: + return ParallelDecoderLayer( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + layernorm_epsilon, + unscaled_init_method(init_method_std), + output_layer_init_method=output_layer_init_method) + else: + return ParallelTransformerLayer( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + layernorm_epsilon, + unscaled_init_method(init_method_std), + output_layer_init_method=output_layer_init_method, + relative_encoding=relative_encoding, + performer=performer, + attention_scale=attention_scale) + + # Transformer layers. + self.layers = torch.nn.ModuleList( + [get_layer() for _ in range(num_layers)]) + + # Final layer norm before output. + self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + + if deepspeed.checkpointing.is_configured(): + global get_cuda_rng_tracker, checkpoint + get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker + checkpoint = deepspeed.checkpointing.checkpoint + + def forward(self, + hidden_states, + position_ids, + attention_mask, + memory_states=None, + encoder_states=None, + return_memory=False, + detach_memory=True): + batch_size, query_length = hidden_states.size()[:2] + memory_length = memory_states[0].size(1) if memory_states else 0 + key_length = query_length + memory_length + # attention mask is the beginning postion of B region, \in [0, query_len) + is_scalar = torch.numel(attention_mask) == 1 + is_sep = is_scalar or torch.numel(attention_mask) == batch_size + if self.performer: + assert is_scalar, 'attention_mask should be a scalar to indicate the seperation position.' + assert memory_length == 0, 'Do not support transformer-xl.' + if is_sep: + sep = attention_mask.item() if is_scalar else attention_mask + + # conventional transformer + def build_mask_matrix(seq_length, sep, memory_length=0): + m = hidden_states.new_ones((1, seq_length, seq_length)) + m = torch.tril(m) + if is_scalar: + m[0, :, :sep] = 1 + else: + m = m.expand(batch_size, -1, -1) + ids = torch.arange( + seq_length, device=sep.device, + dtype=sep.dtype).view(1, -1) + mask = ids < sep.view(-1, 1) + m = m.masked_fill(mask.unsqueeze(1).expand_as(m), 1) + if memory_length > 0: + m = m.expand(batch_size, -1, -1) + m = torch.cat( + (hidden_states.new_ones((batch_size, seq_length, + memory_length)), m), # noqa + dim=2) # noqa + m = m.unsqueeze(1) + return m + + if not self.performer: + attention_mask = build_mask_matrix( + query_length, sep, memory_length=memory_length) + else: + attention_mask = attention_mask[:, :, :, + -query_length - memory_length:] + + if self.relative_encoding: + position_sequence = torch.arange( + key_length - 1, + -1, + -1.0, + device=hidden_states.device, + dtype=hidden_states.dtype) + position_embeddings = self.position_embeddings(position_sequence) + # Apply dropout + position_embeddings = self.embedding_dropout(position_embeddings) + else: + if self.block_position_encoding: + position_ids, block_position_ids = position_ids[:, + 0], position_ids[:, + 1] + position_embeddings = self.position_embeddings(position_ids) + hidden_states = hidden_states + position_embeddings + if self.block_position_encoding: + block_position_embeddings = self.block_position_embeddings( + block_position_ids) + hidden_states = hidden_states + block_position_embeddings + hidden_states = self.embedding_dropout(hidden_states) + + def check_detach(_hidden_states): + if detach_memory: + return _hidden_states.detach() + return _hidden_states + + if self.max_memory_length > 0 or return_memory: + mem_layers = [check_detach(hidden_states)] + else: + mem_layers = [] + + def custom(start, end): + + def custom_forward(*inputs): + layers_ = self.layers[start:end] + x_, inputs = inputs[0], inputs[1:] + if self.relative_encoding: + inputs, mems_ = inputs[:4], inputs[4:] + else: + inputs, mems_ = inputs[:1], inputs[1:] + for i, layer in enumerate(layers_): + mem_i_ = mems_[i] if mems_ else None + x_ = layer(x_, *inputs, mem=mem_i_) + if self.max_memory_length > 0 or return_memory: + mem_layers.append(check_detach(x_)) + return x_ + + return custom_forward + + if self.checkpoint_activations: + l = 0 # noqa + num_layers = len(self.layers) + chunk_length = self.checkpoint_num_layers + while l < num_layers: + args = [hidden_states, attention_mask + ] if not self.use_decoder_layer else [ + hidden_states, + encoder_states, + attention_mask # noqa + ] # noqa + if self.relative_encoding: + args += [position_embeddings, self.r_w_bias, self.r_r_bias] + if memory_states: + args += memory_states[l:l + chunk_length] + hidden_states = checkpoint(custom(l, l + chunk_length), *args) + l += chunk_length # noqa + else: + for i, layer in enumerate(self.layers): + args = [hidden_states, attention_mask + ] if not self.use_decoder_layer else [ + hidden_states, + encoder_states, + attention_mask # noqa + ] # noqa + if self.relative_encoding: + args += [position_embeddings, self.r_w_bias, self.r_r_bias] + mem_i = memory_states[i] if memory_states else None + hidden_states = layer(*args, mem=mem_i) + if self.max_memory_length > 0 or return_memory: + mem_layers.append(check_detach(hidden_states)) + + # Final layer norm. + output = self.final_layernorm(hidden_states) + if self.max_memory_length > 0 or return_memory: + mem_layers = self.update_mems( + mem_layers, memory_states, return_memory=return_memory) + + return (output, mem_layers) + + def update_mems(self, hiddens, mems, return_memory=False): + memory_length = mems[0].size(1) if mems else 0 + query_length = hiddens[0].size(1) + new_memory_length = memory_length + query_length + if not return_memory: + new_memory_length = min(self.max_memory_length, new_memory_length) + new_mems = [] + # with torch.no_grad(): + for i in range(len(hiddens)): + if new_memory_length <= query_length: + new_mems.append(hiddens[i][:, -new_memory_length:]) + else: + new_mems.append( + torch.cat((mems[i][:, -new_memory_length + query_length:], + hiddens[i]), + dim=1)) + return new_mems + + +class BertParallelSelfAttention(torch.nn.Module): + """Parallel self-attention layer for BERT. + + Self-attention layer takes input with size [b, s, h] where b is + the batch size, s is the sequence lenght, and h is the hidden size + and creates output of the same size. + Arguments: + hidden_size: total hidden size of the layer (h). + num_attention_heads: number of attention heads (n). Note that we + require n to be divisible by number of GPUs + used to parallelize the model. Also, we + require hidden size be divisible by n. + dropout_prob: dropout probability for the attention scores. + output_parallel: If true, no all-gather is done on the output and + the output values will be per partition. + We use the following notation: + h: hidden_size + n: num_attention_heads + p: number of partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + """ + + def __init__(self, + hidden_size, + num_attention_heads, + dropout_prob, + output_parallel=False, + init_method=init.xavier_normal_): + super(BertParallelSelfAttention, self).__init__() + # Input configuration. + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.dropout_prob = dropout_prob + self.output_parallel = output_parallel + # Per attention head and per partition values. + world_size = get_model_parallel_world_size() + self.hidden_size_per_partition = divide(hidden_size, world_size) + self.hidden_size_per_attention_head = divide(hidden_size, + num_attention_heads) + self.num_attention_heads_per_partition = divide( + num_attention_heads, world_size) + # Strided linear layer. + self.query_key_value = ColumnParallelLinear( + hidden_size, + 3 * hidden_size, + stride=3, + gather_output=False, + init_method=init_method) + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.dropout = torch.nn.Dropout(dropout_prob) + + if deepspeed.checkpointing.is_configured(): + global get_cuda_rng_tracker, checkpoint + get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker + checkpoint = deepspeed.checkpointing.checkpoint + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + \ + (self.num_attention_heads_per_partition, # noqa + self.hidden_size_per_attention_head) # noqa + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + + # Attention heads. [b, s, hp] + mixed_x_layer = self.query_key_value(hidden_states) + (mixed_query_layer, mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + # Raw attention scores. [b, np, s, s] + norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head)) + attention_scores = torch.matmul( + query_layer / norm_factor, + key_layer.transpose(-1, -2) / norm_factor) + # Apply the attention mask. + attention_scores += attention_mask + + # Attention probabilities. [b, np, s, s] + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with get_cuda_rng_tracker().fork(): + attention_probs = self.dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, ) # noqa + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + if self.output_parallel: + output = context_layer + else: + output = gather_from_model_parallel_region(context_layer) + + return output + + +class BertParallelTransformerOutput(torch.nn.Module): + """The output layer used after self attention and intermediate + parts of transformer layer.""" + + def __init__(self, + input_size, + output_size, + dropout_prob, + layernorm_epsilon=1.0e-12, + input_is_parallel=False, + init_method=init.xavier_normal_): + super(BertParallelTransformerOutput, self).__init__() + # Components. + self.dense = RowParallelLinear( + input_size, + output_size, + input_is_parallel=input_is_parallel, + init_method=init_method) + self.dropout = torch.nn.Dropout(dropout_prob) + self.layernorm = LayerNorm(output_size, eps=layernorm_epsilon) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + layernorm_input = hidden_states + input_tensor + hidden_states = self.layernorm(layernorm_input) + return hidden_states + + +class BertParallelTransformerLayer(torch.nn.Module): + """A single layer transformer for Bert. + + We use the following notation: + h: hidden size + n: number of attention heads + b: batch size + s: sequence length + Transformore layer takes input with size [b, s, h] and returns an + output of the same size. + + Arguments: + hidden_size: The hidden size of the self attention. + intermediate_size: size of the intermediate state after + self attention. In both BERT and GPT + this is set to be 4 times the hidden + size. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + intermediate_activation_fn: activation function for output + of intermediate. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + init_method: initialization method used for the weights. Note + that all biases are initialized to zero and + layernorm weight are initialized to one. + """ + + def __init__(self, + hidden_size, + intermediate_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + intermediate_activation_fn, + layernorm_epsilon, + init_method=init.xavier_normal_): + super(BertParallelTransformerLayer, self).__init__() + + # Self attention. + self.attention = BertParallelSelfAttention( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_parallel=True, + init_method=init_method) + # Self attention output. + self.self_output = BertParallelTransformerOutput( + hidden_size, + hidden_size, + output_dropout_prob, + layernorm_epsilon=layernorm_epsilon, + input_is_parallel=True, + init_method=init_method) + # Intermediate. + self.intermediate = ColumnParallelLinear( + hidden_size, + intermediate_size, + gather_output=False, + init_method=init_method) + self.intermediate_activation_fn = intermediate_activation_fn + # Output. + self.output = BertParallelTransformerOutput( + intermediate_size, + hidden_size, + output_dropout_prob, + layernorm_epsilon=layernorm_epsilon, + input_is_parallel=True, + init_method=init_method) + + def forward(self, hidden_states, attention_mask): + # [b, s, hp] + attention_output_parallel = self.attention(hidden_states, + attention_mask) + # [b, s, h] + attention_self_output = self.self_output(attention_output_parallel, + hidden_states) + # [b, s, ip] + intermediate_output_parallel = self.intermediate(attention_self_output) + intermediate_output_parallel = self.intermediate_activation_fn( + intermediate_output_parallel) + # [b, s, h] + layer_output = self.output(intermediate_output_parallel, + attention_self_output) + + return layer_output diff --git a/modelscope/models/nlp/mglm/mpu/utils.py b/modelscope/models/nlp/mglm/mpu/utils.py new file mode 100644 index 00000000..76c37a2b --- /dev/null +++ b/modelscope/models/nlp/mglm/mpu/utils.py @@ -0,0 +1,70 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim(tensor, + num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class VocabUtility: + """Split the vocabulary into `world_size` chunks amd return the + first and last index of the vocabulary belonging to the `rank` + partition: Note that indecies in [fist, last)""" + + @staticmethod + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, world_size): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size, rank, + world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size) diff --git a/modelscope/models/nlp/mglm/process_grid.py b/modelscope/models/nlp/mglm/process_grid.py new file mode 100644 index 00000000..d425c970 --- /dev/null +++ b/modelscope/models/nlp/mglm/process_grid.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022 Zhipu.AI + +import glob +import os +import statistics +import sys + +import json + +path_pattern = sys.argv[1] +target_type = sys.argv[2] +best_value, best_result, best_name = None, None, None +mean_result = {} +print(path_pattern) +for dir_path in glob.glob(path_pattern, recursive=True): + entry = os.path.basename(dir_path) + valid_result = None + test_found = os.path.exists(os.path.join(dir_path, 'test_results.json')) + valid_path = os.path.join(dir_path, 'results.json') + if os.path.exists(valid_path): + print(entry) + with open(valid_path) as file: + valid_result = json.load(file) + else: + print(f'{entry} no validation results') + continue + if not test_found: + print(f'{entry} not tested yet') + if target_type == 'max': + metric = sys.argv[3] + metric_value = valid_result[metric] + if best_value is None or metric_value > best_value: + best_value = metric_value + best_result = valid_result + best_name = entry + elif target_type == 'mean' or target_type == 'median': + if mean_result: + for metric, value in valid_result.items(): + if metric not in ['type', 'epoch']: + mean_result[metric].append(value) + else: + mean_result = { + metric: [value] + for metric, value in valid_result.items() + if metric not in ['type', 'epoch'] + } + +if target_type == 'max': + print(f'Best result found at {best_name}: {best_result}') +elif target_type == 'mean': + mean_result = { + metric: sum(value) / len(value) + for metric, value in mean_result.items() + } + print(f'Mean result {mean_result}') +elif target_type == 'median': + mean_result = { + metric: statistics.median(value) + for metric, value in mean_result.items() + } + print(f'Mean result {mean_result}') diff --git a/modelscope/models/nlp/mglm/requirements.txt b/modelscope/models/nlp/mglm/requirements.txt new file mode 100644 index 00000000..e44ae5d1 --- /dev/null +++ b/modelscope/models/nlp/mglm/requirements.txt @@ -0,0 +1,22 @@ +boto3 +botocore +deepspeed +fasttext +filelock +ftfy +langdetect +lsh +matplotlib +mpi4py +nltk +pandas +regex +requests +rouge_score +scikit_learn +scipy +sentencepiece +termcolor +tldextract +tqdm +transformers diff --git a/modelscope/models/nlp/mglm/run_test.py b/modelscope/models/nlp/mglm/run_test.py new file mode 100644 index 00000000..2f568265 --- /dev/null +++ b/modelscope/models/nlp/mglm/run_test.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022 Zhipu.AI + +import sys + +if sys.argv[1] == 'block': + from test.test_block import main + main() +elif sys.argv[1] == 'rel_shift': + from test.test_rel_shift import main + main() diff --git a/modelscope/models/nlp/mglm/tasks/data_utils.py b/modelscope/models/nlp/mglm/tasks/data_utils.py new file mode 100644 index 00000000..179d304e --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/data_utils.py @@ -0,0 +1,389 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tasks data utility.""" +import copy +import pickle +import re +from typing import Dict, List, Optional + +import json +import numpy as np +import torch +import torch.utils.data +from torch.utils.data.dataloader import default_collate + +from modelscope.models.nlp.mglm import mpu + + +def clean_text(text): + """Remove new lines and multiple spaces and adjust end of sentence dot.""" + + text = text.replace('\n', ' ') + text = re.sub(r'\s+', ' ', text) + for _ in range(3): + text = text.replace(' . ', '. ') + + return text + + +class InputExample(object): + """A raw input example consisting of one or two segments of text and a label""" + + def __init__(self, + guid, + text_a, + text_b=None, + label=None, + logits=None, + meta: Optional[Dict] = None, + idx=-1, + num_choices=1): + """ + Create a new InputExample. + + :param guid: a unique textual identifier + :param text_a: the sequence of text + :param text_b: an optional, second sequence of text + :param label: an optional label + :param logits: an optional list of per-class logits + :param meta: an optional dictionary to store arbitrary meta information + :param idx: an optional numeric index + """ + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.label = label + self.logits = logits + self.idx = idx + self.num_choices = num_choices + self.meta = meta if meta else {} + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serialize this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serialize this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' + + @staticmethod + def load_examples(path: str) -> List['InputExample']: + """Load a set of input examples from a file""" + with open(path, 'rb') as fh: + return pickle.load(fh) + + @staticmethod + def save_examples(examples: List['InputExample'], path: str) -> None: + """Save a set of input examples to a file""" + with open(path, 'wb') as fh: + pickle.dump(examples, fh) + + +def num_special_tokens_to_add(text_a_ids, + text_b_ids, + answer_ids, + add_cls, + add_sep, + add_piece, + add_eos=True): + num_tokens = 0 + if add_cls: + num_tokens += 1 + if text_b_ids and add_sep: + num_tokens += 1 + if add_eos: + num_tokens += 1 + if not answer_ids and add_piece: + num_tokens += 1 + return num_tokens + + +def build_input_from_ids(text_a_ids, + text_b_ids, + answer_ids, + max_seq_length, + tokenizer, + args=None, + add_cls=True, + add_sep=False, + add_piece=False, + add_eos=True, + mask_id=None): + if mask_id is None: + mask_id = tokenizer.get_command('MASK').Id + eos_id = tokenizer.get_command('eos').Id + cls_id = tokenizer.get_command('ENC').Id + sep_id = tokenizer.get_command('sep').Id + ids = [] + types = [] + paddings = [] + # CLS + if add_cls: + ids.append(cls_id) + types.append(0) + paddings.append(1) + # A + len_text_a = len(text_a_ids) + ids.extend(text_a_ids) + types.extend([0] * len_text_a) + paddings.extend([1] * len_text_a) + # B + if text_b_ids is not None: + # SEP + if add_sep: + ids.append(sep_id) + types.append(0) + paddings.append(1) + len_text_b = len(text_b_ids) + ids.extend(text_b_ids) + types.extend([1] * len_text_b) + paddings.extend([1] * len_text_b) + eos_length = 1 if add_eos else 0 + # Cap the size. + if len(ids) >= max_seq_length - eos_length: + max_seq_length_m1 = max_seq_length - 1 + ids = ids[0:max_seq_length_m1] + types = types[0:max_seq_length_m1] + paddings = paddings[0:max_seq_length_m1] + end_type = 0 if text_b_ids is None else 1 + if add_eos: + ids.append(eos_id) + types.append(end_type) + paddings.append(1) + sep = len(ids) + target_ids = [0] * len(ids) + loss_masks = [0] * len(ids) + position_ids = list(range(len(ids))) + block_position_ids = [0] * len(ids) + # Piece + if add_piece or answer_ids is not None: + sop_id = tokenizer.get_command('sop').Id + mask_position = ids.index( + mask_id + ) if not args.sentinel_token else args.max_position_embeddings + ids.append(sop_id) + types.append(end_type) + paddings.append(1) + position_ids.append(mask_position) + block_position_ids.append(1) + if answer_ids is not None: + len_answer = len(answer_ids) + ids.extend(answer_ids[:-1]) + types.extend([end_type] * (len_answer - 1)) + paddings.extend([1] * (len_answer - 1)) + position_ids.extend([mask_position] * (len_answer - 1)) + if not args.no_block_position: + block_position_ids.extend(range(2, len(answer_ids) + 1)) + else: + block_position_ids.extend([1] * (len(answer_ids) - 1)) + target_ids.extend(answer_ids) + loss_masks.extend([1] * len(answer_ids)) + else: + target_ids.append(0) + loss_masks.append(1) + # Padding. + padding_length = max_seq_length - len(ids) + if padding_length > 0: + ids.extend([eos_id] * padding_length) + types.extend([eos_id] * padding_length) + paddings.extend([0] * padding_length) + position_ids.extend([0] * padding_length) + block_position_ids.extend([0] * padding_length) + target_ids.extend([0] * padding_length) + loss_masks.extend([0] * padding_length) + if not args.masked_lm: + position_ids = [position_ids, block_position_ids] + return ids, types, paddings, position_ids, sep, target_ids, loss_masks + + +def build_decoder_input(enc_ids, answer_ids, max_seq_length, + max_dec_seq_length, tokenizer): + mask_id = tokenizer.get_command('MASK').Id + eos_id = tokenizer.get_command('eos').Id + sop_id = tokenizer.get_command('sop').Id + enc_len = len(enc_ids) # noqa + masks = [] + # TODO: it probably takes too much memory + # for i in range(max_dec_seq_length): + # m = [1]*enc_len + [0]*(max_seq_length - enc_len) + [1]*(i+1) + [0]*(max_dec_seq_length-1-i) + # masks.append(m) + mask_position = enc_ids.index(mask_id) + len_answer = len(answer_ids) + ids = [sop_id] + answer_ids[:-1] + types = [0] * len_answer # not used + paddings = [1] * len_answer + position_ids = [mask_position] * len_answer + block_position_ids = list(range(1, len_answer + 1)) + target_ids = answer_ids + loss_masks = [1] * len_answer + # Padding. + padding_length = max_dec_seq_length - len(ids) + if padding_length > 0: + ids.extend([eos_id] * padding_length) + types.extend([0] * padding_length) + paddings.extend([0] * padding_length) + position_ids.extend([0] * padding_length) + block_position_ids.extend([0] * padding_length) + target_ids.extend([0] * padding_length) + loss_masks.extend([0] * padding_length) + position_ids = [position_ids, block_position_ids] + return ids, types, paddings, position_ids, masks, target_ids, loss_masks + + +def build_sample(ids, + types=None, + paddings=None, + positions=None, + masks=None, + label=None, + unique_id=None, + target=None, + logit_mask=None, + segment_ids=None, + prompt_ids=None): + """Convert to numpy and return a sample consumed by the batch producer.""" + + ids_np = np.array(ids, dtype=np.int64) + sample = {'text': ids_np, 'label': int(label)} + if types is not None: + types_np = np.array(types, dtype=np.int64) + sample['types'] = types_np + if paddings is not None: + paddings_np = np.array(paddings, dtype=np.int64) + sample['padding_mask'] = paddings_np + if positions is not None: + positions_np = np.array(positions, dtype=np.int64) + sample['position'] = positions_np + if masks is not None: + masks_np = np.array(masks, dtype=np.int64) + sample['mask'] = masks_np + if target is not None: + target_np = np.array(target, dtype=np.int64) + sample['target'] = target_np + if logit_mask is not None: + logit_mask_np = np.array(logit_mask, dtype=np.int64) + sample['logit_mask'] = logit_mask_np + if segment_ids is not None: + segment_ids = np.array(segment_ids, dtype=np.int64) + sample['segment_id'] = segment_ids + if prompt_ids is not None: + prompt_ids = np.array(prompt_ids, dtype=np.int64) + sample['prompt_pos'] = prompt_ids + if unique_id is not None: + sample['uid'] = unique_id + return sample + + +def build_decoder_sample(sample, dec_ids, dec_position, dec_masks, dec_target, + dec_logit_mask): + sample['dec_text'] = np.array(dec_ids) + sample['dec_position'] = np.array(dec_position) + sample['dec_mask'] = np.array(dec_masks) + sample['dec_target'] = np.array(dec_target) + sample['dec_logit_mask'] = np.array(dec_logit_mask) + return sample + + +def my_collate(batch): + new_batch = [{key: value + for key, value in sample.items() if key != 'uid'} + for sample in batch] + text_list = [sample['text'] for sample in batch] + + def pad_choice_dim(data, choice_num): + if len(data) < choice_num: + data = np.concatenate([data] + + [data[0:1]] * (choice_num - len(data))) + return data + + if len(text_list[0].shape) == 2: + choice_nums = list(map(len, text_list)) + max_choice_num = max(choice_nums) + for i, sample in enumerate(new_batch): + for key, value in sample.items(): + if key != 'label': + sample[key] = pad_choice_dim(value, max_choice_num) + else: + sample[key] = value + sample['loss_mask'] = np.array( + [1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]), + dtype=np.int64) + + if 'dec_text' in new_batch[0]: + choice_nums = [len(sample['dec_text']) for sample in new_batch] + if choice_nums.count(choice_nums[0]) != len(choice_nums): + max_choice_num = max(choice_nums) + for i, sample in enumerate(new_batch): + for key, value in sample.items(): + if key.startswith('dec_'): + sample[key] = pad_choice_dim(value, max_choice_num) + sample['loss_mask'] = np.array( + [1] * choice_nums[i] + [0] * # noqa + (max_choice_num - choice_nums[i]), + dtype=np.int64) + + new_batch = default_collate(new_batch) + if 'uid' in batch[0]: + uid_list = [sample['uid'] for sample in batch] + new_batch['uid'] = uid_list + return new_batch + + +class FakeDataloader: + + def __init__(self, num_iters): + self.num_iters = num_iters + + def __iter__(self): + if self.num_iters is not None: + for _ in range(self.num_iters): + yield None + else: + while True: + yield None + + +def build_data_loader(dataset, + batch_size, + num_workers, + drop_last, + shuffle=True, + only_rank0=False): + """Data loader. Note that batch-size is the local (per GPU) batch-size.""" + + # Sampler. + if only_rank0: + rank, world_size = 0, 1 + else: + world_size = mpu.get_data_parallel_world_size() + rank = mpu.get_data_parallel_rank() + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) + + # Data loader. Note that batch size is the per GPU batch size. + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + shuffle=False, + num_workers=num_workers, + drop_last=drop_last, + pin_memory=True, + collate_fn=my_collate) + + return data_loader diff --git a/modelscope/models/nlp/mglm/tasks/eval_utils.py b/modelscope/models/nlp/mglm/tasks/eval_utils.py new file mode 100644 index 00000000..da23a884 --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/eval_utils.py @@ -0,0 +1,249 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation utilities.""" + +import datetime +import os +import random +import time +from collections import OrderedDict +from typing import List + +import mpu +import torch +from finetune_glm import process_batch +from sklearn.metrics import f1_score +from tasks.data_utils import InputExample, build_data_loader +from utils import debug_finetune_data, get_spare_port, print_rank_0 + + +def accuracy_metric(predictions, labels, examples): + count = 0 + num_predictions = max(len(predictions), 1) + assert len(predictions) == len(labels) + for prediction, label in zip(predictions, labels): + count += prediction == label + return count * 100.0 / num_predictions + + +def f1_metric(predictions, labels, examples): + return f1_score(labels, predictions) + + +def f1_macro_metric(predictions, labels, examples): + return f1_score(labels, predictions, average='macro') + + +global_tokenizer = None + + +def accuracy_func_provider(single_dataset_provider, + metric_dict, + args, + is_test=False, + eval_func=None, + output_func=None, + only_rank0=True, + tokenizer=None): + """Provide function that calculates accuracies.""" + # Build dataloaders. + global global_tokenizer + global_tokenizer = tokenizer + if only_rank0 and torch.distributed.is_initialized( + ) and torch.distributed.get_rank() != 0: + return None + if is_test and not args.eval_valid: + datapaths = args.test_data if args.test_data is not None else ['test'] + else: + datapaths = args.valid_data if args.valid_data is not None else ['dev'] + if eval_func is None: + eval_func = multichoice_evaluate + dataloaders = [] + eval_batch_size = args.eval_batch_size if args.eval_batch_size else args.batch_size + for datapath in datapaths: + dataset = single_dataset_provider(datapath) + dataloader = build_data_loader( + dataset, + eval_batch_size, + num_workers=args.num_workers, + drop_last=False, + shuffle=False, + only_rank0=only_rank0) + dataloaders.append((dataset.dataset_name, dataloader)) + + def metrics_func(model, + epoch, + output_predictions=False, + summary_writer=None): + print_rank_0('calculating metrics ...') + score_dict = OrderedDict([(key, 0.0) for key in metric_dict + ]) if isinstance(metric_dict, dict) else { + metric_dict: 0.0 + } # noqa + total = 0 + for name, dataloader in dataloaders: + example_dict = None + if hasattr(dataloader.dataset, 'examples'): + example_dict = dataloader.dataset.examples + start_time = time.time() + predictions, labels, examples = eval_func(model, dataloader, + example_dict, args) + elapsed_time = time.time() - start_time + if output_predictions and torch.distributed.get_rank() == 0: + filename = os.path.join(args.log_dir, name + '.jsonl') + output_func(predictions, examples, filename) + total_count = len(predictions) + single_dict = { + key: metric(predictions, labels, examples) + for key, metric in metric_dict.items() + } + output_str = ' > |epoch: {}| metrics for {}: total {}'.format( + epoch, name, total_count) + for key, value in single_dict.items(): + output_str += ' {} = {:.4f} %'.format(key, value) + if summary_writer is not None and epoch >= 0 and not is_test and len( + dataloaders) > 1: + summary_writer.add_scalar(f'Train/valid_{name}_{key}', + value, epoch) + output_str += ' elapsed time (sec): {:.3f}'.format(elapsed_time) + if len(dataloaders) > 1: + print_rank_0(output_str) + for key in score_dict: + score_dict[key] += single_dict[key] * total_count + total += total_count + score_dict = { + key: score / float(total) + for key, score in score_dict.items() + } + output_str = ' >> |epoch: {}| overall: total = {}'.format(epoch, total) + for key, score in score_dict.items(): + output_str += ' {} = {:.4f}'.format(key, score) + if summary_writer is not None and epoch >= 0 and not is_test: + summary_writer.add_scalar(f'Train/valid_{key}', score, epoch) + print_rank_0(output_str) + return score_dict + + return metrics_func + + +segment_length = 10 + + +def multichoice_evaluate(model, dataloader, example_dict, args): + """Calculate correct over total answers and return prediction if the + `output_predictions` is true.""" + model.eval() + port = get_spare_port(args) + print_rank_0(f'Using port {port}') + store = torch.distributed.TCPStore(args.master_ip, port, + torch.distributed.get_world_size(), + torch.distributed.get_rank() == 0, + datetime.timedelta(seconds=30)) + # file_path = os.path.join("/cache", args.experiment_name + "_store") + # print_rank_0(f"Using file store at {file_path}") + # store = torch.distributed.FileStore(file_path, torch.distributed.get_world_size()) + with torch.no_grad(): + # For all the batches in the dataset. + for _, batch in enumerate(dataloader): + # Run the model forward. + data = process_batch(batch, args) + if args.pretrained_bert: + tokens, types, labels_, attention_mask = data['text'], data[ + 'types'], data['label'], data['padding_mask'] + inputs = [tokens, types, attention_mask] + elif args.cloze_eval: + tokens, labels_, position_ids = data['text'], data[ + 'label'], data['position'] + attention_mask, target_ids, logit_mask = data['mask'], data[ + 'target'], data['logit_mask'] + if not args.fast_decode: + inputs = [ + tokens, position_ids, attention_mask, target_ids, + logit_mask + ] + if args.continuous_prompt: + prompt_pos = data['prompt_pos'] + inputs.append(prompt_pos) + else: + dec_input_ids, dec_position_ids, dec_attention_mask = data[ + 'dec_text'], data['dec_position'], data['dec_mask'] + dec_target_ids, dec_logit_mask = data['dec_target'], data[ + 'dec_logit_mask'] + inputs = [ + tokens, position_ids, attention_mask, dec_input_ids, + dec_position_ids, dec_attention_mask, dec_target_ids, + dec_logit_mask + ] + else: + tokens, labels_, position_ids, attention_mask = data[ + 'text'], data['label'], data['position'], data['mask'] + inputs = [tokens, position_ids, attention_mask] + if len(inputs[0].shape + ) == 3 and inputs[0].size(1) > segment_length: + logit_list = [] + for i in range((inputs[0].size(1) - 1) // segment_length + 1): + input_batch = [ + arg[:, i * segment_length:(i + 1) * segment_length] + for arg in inputs + ] + if args.pretrained_bert: + logits = model(*input_batch) + else: + logits, *mems = model(*input_batch) + logit_list.append(logits) + logits = torch.cat(logit_list, dim=1) + elif args.cloze_eval and args.fast_decode: + logit_list = [] + num_choices = inputs[3].size(1) + for i in range((num_choices - 1) // segment_length + 1): + input_batch = inputs[:3] + [ + arg[:, i * segment_length:(i + 1) * segment_length] + for arg in inputs[3:] + ] + logits, *mems = model(*input_batch) + logit_list.append(logits) + logits = torch.cat(logit_list, dim=1) + else: + if args.pretrained_bert: + logits = model(*inputs) + else: + logits, *mems = model(*inputs) + if 'segment_id' in data: + from torch_scatter import scatter_sum + if 'loss_mask' in data: + logits = logits * data['loss_mask'] + logits = scatter_sum(logits, data['segment_id'], dim=1) + elif 'loss_mask' in data: + loss_mask = data['loss_mask'] + logits = logits * loss_mask - 10000.0 * (1.0 - loss_mask) + uid_list = batch['uid'] + if isinstance(uid_list, torch.Tensor): + uid_list = uid_list.cpu().numpy().tolist() + predicted = torch.argmax(logits, dim=-1).tolist() + labels = labels_.tolist() + if args.task.lower() == 'wsc': + predicted = [1 if pred == 0 else 0 for pred in predicted] + if mpu.get_model_parallel_rank() == 0: + for uid, prediction, label in zip(uid_list, predicted, labels): + store.set(uid, str((prediction, label))) + model.train() + torch.distributed.barrier() + predictions, labels, examples = [], [], [] + for uid, example in example_dict.items(): + prediction, label = eval(store.get(uid)) + predictions.append(prediction) + labels.append(label) + examples.append(example) + torch.distributed.barrier() + return predictions, labels, examples diff --git a/modelscope/models/nlp/mglm/tasks/language_model/dataset.py b/modelscope/models/nlp/mglm/tasks/language_model/dataset.py new file mode 100644 index 00000000..cfdfa714 --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/language_model/dataset.py @@ -0,0 +1,249 @@ +# Copyright (c) 2022 Zhipu.AI + +import math +from bisect import bisect_right +from itertools import accumulate + +import json +import numpy as np +import torch +from tasks.data_utils import build_input_from_ids, num_special_tokens_to_add +from tasks.language_model.detokenizer import get_detokenizer +from utils import print_rank_0 + + +class LMDataset(torch.utils.data.Dataset): + + def __init__(self, args, documents, tokenizer, num_original_tokens, + num_tokenized_tokens): + self.args = args + self.documents = documents + self.max_seq_len = args.seq_length - 1 + self.tokenizer = tokenizer + self.overalapping_eval = args.overlapping_eval + if self.overalapping_eval is None: + self.overalapping_eval = self.max_seq_len + self.overalapping_eval = max(1, self.overalapping_eval) + self.num_original_tokens = num_original_tokens + self.num_tokenized_tokens = num_tokenized_tokens + # remove first sequence tokens + targets = [ + max(len(tokens) - self.max_seq_len, 0) for tokens in self.documents + ] + self.num_sequences = [ + max(math.ceil(target / self.overalapping_eval) + 1, 1) + for target in targets + ] + self.weights = list(accumulate(self.num_sequences)) + self.left_weights = [0] + self.weights[:-1] + self.unidirectional = args.unidirectional + self.block_lm = args.block_lm + mask_token = 'gMASK' if args.task_mask else 'MASK' + self.mask_id = self.tokenizer.get_command(mask_token).Id + + def __len__(self): + return sum(self.num_sequences) + + def __getitem__(self, idx): + document_idx = bisect_right(self.weights, idx) + idx = idx - self.left_weights[document_idx] + start_idx = idx * self.overalapping_eval + end_idx = start_idx + self.max_seq_len + tokens = self.documents[document_idx][start_idx:end_idx] + if self.block_lm: + if idx == 0 or self.unidirectional: + prompt, text = tokens[:1], tokens[1:] + else: + prompt_length = self.max_seq_len - self.overalapping_eval + prompt, text = tokens[:prompt_length], tokens[prompt_length:] + prompt = prompt + [self.mask_id] + num_special_tokens = num_special_tokens_to_add( + prompt, + None, + text, + add_cls=True, + add_sep=False, + add_piece=True, + add_eos=False) + data = build_input_from_ids( + prompt, + None, + text, + self.max_seq_len + num_special_tokens + 1, + self.tokenizer, + args=self.args, + add_cls=True, + add_sep=False, + add_piece=True, + add_eos=False, + mask_id=self.mask_id) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + if idx != 0 and self.unidirectional: + loss_masks = np.array(loss_masks, dtype=np.int64) + loss_masks[:-self.overalapping_eval] = 0 + return { + 'text': np.array(ids, dtype=np.int64), + 'target': np.array(target_ids, dtype=np.int64), + 'attention_mask': np.array(sep, dtype=np.int64), + 'loss_mask': np.array(loss_masks, dtype=np.int64), + 'position_id': np.array(position_ids, dtype=np.int64) + } + else: + loss_masks = [1] * len(tokens) + if len(tokens) < self.max_seq_len: + tokens = tokens + [0] * (self.max_seq_len - len(tokens)) + loss_masks = loss_masks + [0] * ( + self.max_seq_len - len(loss_masks)) + if idx != 0: + loss_masks = np.array(loss_masks, dtype=np.int64) + loss_masks[:-self.overalapping_eval] = 0 + return { + 'text': np.array(tokens, dtype=np.int64), + 'loss_mask': np.array(loss_masks, dtype=np.int64) + } + + +class LambadaDataset(torch.utils.data.Dataset): + + def __init__(self, args, tokenizer, strict=True): + data_path = args.valid_data[0] + print_rank_0( + '> building lambada dataset from {} ...'.format(data_path)) + self.args = args + self.max_seq_length = args.seq_length + self.tokenizer = tokenizer + self.pad_idx = tokenizer.get_command('pad').Id + self.strict = strict + self.block_lm = args.block_lm + self.unidirectional = args.unidirectional + mask_token = 'gMASK' if args.task_mask else 'MASK' + self.mask_id = self.tokenizer.get_command(mask_token).Id + + self.tokens = [] + self.labels = [] + with open(data_path, 'r') as f: + for line in f.readlines(): + text = json.loads(line)['text'] + tokens, labels = self.get_tokens(text) + self.tokens.append(tokens) + self.labels.append(labels) + + def get_tokens(self, text): + if not self.strict: + tokens = self.tokenizer.EncodeAsIds(text).tokenization + return tokens[:-1], [tokens[-1]] + last_token = text.split()[-1] + start_idx = text.rfind(last_token) + beginning_tokens = self.tokenizer.EncodeAsIds( + text[:start_idx].strip()).tokenization + last_token = self.tokenizer.EncodeAsIds(' ' + last_token).tokenization + return beginning_tokens, last_token + + def __len__(self): + return len(self.tokens) + + def __getitem__(self, idx): + tokens, answer = self.tokens[idx], self.labels[idx] + if self.block_lm: + if self.unidirectional: + tokens, answer_tokens = tokens[:1], tokens[1:] + answer + else: + answer_tokens = answer + tokens = tokens + [self.mask_id] + num_special_tokens = num_special_tokens_to_add( + tokens, + None, + answer_tokens, + add_cls=True, + add_sep=False, + add_piece=True) + left_shift = len(tokens) + len( + answer_tokens) + num_special_tokens - self.max_seq_length + if left_shift > 0: + tokens = tokens[left_shift:] + data = build_input_from_ids( + tokens, + None, + answer_tokens, + self.max_seq_length, + self.tokenizer, + args=self.args, + add_cls=True, + add_sep=False, + add_piece=True, + mask_id=self.mask_id) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + if self.unidirectional: + loss_masks = np.array(loss_masks, dtype=np.int64) + last_index = len(loss_masks) + while loss_masks[last_index - 1] == 0: + last_index -= 1 + loss_masks[:last_index - len(answer)] = 0 + return { + 'text': np.array(ids, dtype=np.int64), + 'target': np.array(target_ids, dtype=np.int64), + 'attention_mask': np.array(sep, dtype=np.int64), + 'loss_mask': np.array(loss_masks, dtype=np.int64), + 'position_id': np.array(position_ids, dtype=np.int64) + } + else: + left_shift = len(tokens) - self.max_seq_length + if left_shift > 0: + tokens = tokens[left_shift:] + ids = tokens + answer + if len(ids) < self.max_seq_length: + ids = ids + [0] * (self.max_seq_length - len(ids)) + loss_masks = [0] * len(tokens) + [1] * len(answer) + if len(loss_masks) < self.max_seq_length: + loss_masks = loss_masks + [0] * ( + self.max_seq_length - len(loss_masks)) + return { + 'text': np.array(ids, dtype=np.int64), + 'loss_mask': np.array(loss_masks, dtype=np.int64) + } + + +def build_lambada_dataset(tokenizer, args): + """Build lambada dataset.""" + assert len(args.valid_data) == 1 + val_dataset = LambadaDataset(args, tokenizer, strict=True) + print_rank_0(' > found {} samples, {} label tokens.'.format( + len(val_dataset), sum(map(len, val_dataset.labels)))) + return val_dataset + + +def build_lm_dataset(tokenizer, args): + documents = [] + num_tokens, num_original_tokens = 0, 0 + with open(args.valid_data[0], encoding='utf-8') as file: + for line in file: + tokens = tokenizer.EncodeAsIds(line.strip()).tokenization + num_tokens += len(tokens) + num_original_tokens += len(line.strip().split(' ')) + documents.append(tokens) + val_dataset = LMDataset(args, documents, tokenizer, num_original_tokens, + num_tokens) + print_rank_0( + ' > number of document: {}, number of original tokens {}, number of detokenized tokens: {}' + .format(len(documents), num_original_tokens, num_tokens)) + return val_dataset + + +def build_wikitext103_dataset(tokenizer, args): + """""" + + assert len(args.valid_data) == 1 + with open(args.valid_data[0], 'rb') as reader: + entire_data = reader.read().decode('utf-8') + num_original_tokens = len(entire_data.strip().split(' ')) + entire_data = get_detokenizer('wikitext')(entire_data) + print_rank_0(entire_data[:1024]) + tokenized_data = tokenizer.EncodeAsIds(entire_data).tokenization + num_tokenized_tokens = len(tokenized_data) + + val_dataset = LMDataset(args, [tokenized_data], tokenizer, + num_original_tokens, num_tokenized_tokens) + print_rank_0(' > number of original tokens: {}, number of detokenized ' + 'tokens: {}'.format(num_original_tokens, + num_tokenized_tokens)) + return val_dataset diff --git a/modelscope/models/nlp/mglm/tasks/language_model/detokenizer.py b/modelscope/models/nlp/mglm/tasks/language_model/detokenizer.py new file mode 100755 index 00000000..dc1524de --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/language_model/detokenizer.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022 Zhipu.AI + +import re + + +def ptb_detokenizer(string): + string = string.replace(" '", "'") + string = string.replace(' \n', '\n') + string = string.replace('\n ', '\n') + string = string.replace(" n't", "n't") + string = string.replace(' N ', '1 ') + string = string.replace('$ 1', '$1') + string = string.replace('# 1', '#1') + return string + + +def wikitext_detokenizer(string): + # contractions + string = string.replace("s '", "s'") + string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) + # number separators + string = string.replace(' @-@ ', '-') + string = string.replace(' @,@ ', ',') + string = string.replace(' @.@ ', '.') + # punctuation + string = string.replace(' : ', ': ') + string = string.replace(' ; ', '; ') + string = string.replace(' . ', '. ') + string = string.replace(' ! ', '! ') + string = string.replace(' ? ', '? ') + string = string.replace(' , ', ', ') + # double brackets + string = re.sub(r'\(\s*([^\)]*?)\s*\)', r'(\1)', string) + string = re.sub(r'\[\s*([^\]]*?)\s*\]', r'[\1]', string) + string = re.sub(r'{\s*([^}]*?)\s*}', r'{\1}', string) + string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) + string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) + # miscellaneous + string = string.replace('= = = =', '====') + string = string.replace('= = =', '===') + string = string.replace('= =', '==') + string = string.replace(' ' + chr(176) + ' ', chr(176)) + string = string.replace(' \n', '\n') + string = string.replace('\n ', '\n') + string = string.replace(' N ', ' 1 ') + string = string.replace(" 's", "'s") + + return string + + +def lambada_detokenizer(string): + return string + + +def get_detokenizer(dataset): + return DETOKENIZERS[dataset] + + +DETOKENIZERS = { + 'ptb': ptb_detokenizer, + 'wikitext': wikitext_detokenizer, + 'lambada': lambada_detokenizer, +} diff --git a/modelscope/models/nlp/mglm/tasks/language_model/finetune.py b/modelscope/models/nlp/mglm/tasks/language_model/finetune.py new file mode 100644 index 00000000..b6089e6f --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/language_model/finetune.py @@ -0,0 +1,254 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GPT2 zero-shot evaluation.""" + +import functools +import math + +import mpu +import torch +from finetune_glm import finetune +from pretrain_glm import get_batch +from tasks.data_utils import build_data_loader +from tasks.language_model.dataset import (build_lambada_dataset, + build_lm_dataset, + build_wikitext103_dataset) +from utils import print_rank_0 + +global_tokenizer = None + + +def lm_forward_step(data, model, args, timers, mems, eval_metric=None): + """Forward step.""" + + # Get the batch. + if timers is not None: + timers('batch generator').start() + if 'mask' in data: + data['attention_mask'] = data.pop('mask') + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data, args) + if timers is not None: + timers('batch generator').stop() + + def print_masked_text(batch_id): + block_position_ids = position_ids[:, 1] + position_ids_ = position_ids[:, 0] + output_tokens = [] + sep = attention_mask[batch_id].item() + for i, token in enumerate(tokens[batch_id, :sep].tolist()): + if global_tokenizer is not None: + token = global_tokenizer.IdToToken(token) + if token.startswith('[MASK'): + token = f'[{position_ids_[batch_id, i].item()}, {token}]' + if token.startswith('##') and len( + output_tokens) > 0 and not output_tokens[-1].endswith( + ']'): + output_tokens[-1] += token[2:] + else: + output_tokens.append(token) + else: + output_tokens.append(str(token)) + print(' '.join(output_tokens)) + last_index = None + for i in range(sep, tokens.size(1)): + if global_tokenizer.IdToToken( + tokens[batch_id, i].item()).startswith('<|startofpiece'): + if last_index is not None: + print( + global_tokenizer.DecodeIds( + tokens[batch_id, last_index:i].tolist()), '|', + global_tokenizer.DecodeIds( + labels[batch_id, last_index:i].tolist())), + print(position_ids_[batch_id, last_index:i].tolist(), + block_position_ids[batch_id, last_index:i].tolist()) + last_index = i + if last_index is not None: + print( + global_tokenizer.DecodeIds(tokens[batch_id, + last_index:].tolist()), '|', + global_tokenizer.DecodeIds(labels[batch_id, + last_index:].tolist())) + print(position_ids_[batch_id, last_index:].tolist(), + block_position_ids[batch_id, last_index:].tolist()) + + # Forward model. + if args.continuous_prompt: + prompt_pos = data['prompt_pos'].long().cuda() + logits, *mems = model( + tokens, position_ids, attention_mask, *mems, prompt_pos=prompt_pos) + else: + logits, *mems = model(tokens, position_ids, attention_mask, *mems) + + if eval_metric is None or eval_metric == 'loss': + losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), + labels) + loss_mask = loss_mask.view(-1) + # The loss is not normalized for fair comparison + loss = torch.sum(losses.view(-1) * loss_mask) + if eval_metric is None: + loss = loss / loss_mask.sum() + return loss, mems, 'bert' + elif eval_metric == 'accuracy' or eval_metric == 'classify': + logits = mpu.gather_from_model_parallel_region(logits) + outputs = torch.argmax(logits, -1) + correct = (outputs == labels).float() + correct[(1 - loss_mask).bool()] = 1 + correct = correct.prod(-1) + if eval_metric == 'accuracy': + correct = correct.sum() + return correct, mems, 'bert' + else: + raise NotImplementedError( + 'Metric {} not implemented'.format(eval_metric)) + + +def classify_evaluate(model, dataloader, example_dict, args): + """Evaluation.""" + # Turn on evaluation mode which disables dropout. + model.eval() + predictions, labels, examples = [], [], [] + with torch.no_grad(): + # For all the batches in the dataset. + for iteration, batch in enumerate(dataloader): + # Forward evaluation. + output, _, _ = lm_forward_step( + batch, model, args, None, [], eval_metric='classify') + uid_list = batch['uid'] + example_batch = [example_dict[uid] for uid in uid_list] + predictions.extend(output.long().tolist()) + label = batch['label'].tolist() + labels.extend(label) + examples.extend(example_batch) + return predictions, labels, examples + + +def evaluate(model, dataloader, eval_metric, args): + """Evaluation.""" + # Turn on evaluation mode which disables dropout. + model.eval() + total_output, total_count = 0.0, 0 + total_tokens = 0 + with torch.no_grad(): + # For all the batches in the dataset. + for iteration, batch in enumerate(dataloader): + if (iteration + 1) % args.log_interval == 0: + print_rank_0('> working on iteration: {}'.format(iteration)) + # Forward evaluation. + output, _, _ = lm_forward_step( + batch, model, args, None, [], eval_metric=eval_metric) + count = batch['text'].size(0) + count = torch.cuda.LongTensor([count]) + # Reduce across processes. + torch.distributed.all_reduce( + output, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce( + count, group=mpu.get_data_parallel_group()) + + total_output += output.item() + total_count += count.item() + total_tokens += batch['loss_mask'].sum().item() + totals = torch.cuda.FloatTensor([total_output, total_tokens]) + torch.distributed.all_reduce(totals, group=mpu.get_data_parallel_group()) + total_output, total_tokens = totals.tolist() + print(total_tokens) + return {eval_metric: total_output}, total_count + + +def evaluate_and_print_results(data_loader, model, eval_metric, args): + """Evaluate and print results on screen.""" + + # Evaluate and get results. + output, _ = evaluate(model, data_loader, eval_metric, args) + + string = '' + if eval_metric == 'loss': + output = output['loss'] + num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens + num_original_tokens = data_loader.dataset.num_original_tokens + val_loss = output / (num_tokenized_tokens - 1) + ppl = math.exp(min(20, val_loss)) + token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) + adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) + string += 'avg loss: {:.4E} | '.format(val_loss) + string += 'ppl: {:.4E} | '.format(ppl) + string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) + string += 'token ratio: {} |'.format(token_ratio) + score_dict = { + 'avg loss': val_loss, + 'ppl': ppl, + 'adjusted ppl': adjusted_ppl + } + + elif eval_metric == 'accuracy': + output = output['accuracy'] + num_examples = len(data_loader.dataset) + acc = output / num_examples * 100 + string += 'number correct: {} | '.format(output) + string += 'total examples: {} | '.format(num_examples) + string += 'avg accuracy: {:.2f}'.format(acc) + score_dict = {'accuracy': acc} + else: + raise NotImplementedError('evaluation method for {} metric is not ' + 'implemented yet.'.format(eval_metric)) + + length = len(string) + 1 + print_rank_0('-' * length) + print_rank_0(string) + print_rank_0('-' * length) + return score_dict + + +def metrics_func_provider(args, tokenizer, is_test): + """Privde metrics callback function.""" + + if args.task.lower() == 'lambda': + eval_metric = 'accuracy' + dataset = build_lambada_dataset(tokenizer, args) + elif args.task == 'wikitext': + eval_metric = 'loss' + dataset = build_wikitext103_dataset(tokenizer, args) + elif args.task == 'language_model': + eval_metric = 'loss' + dataset = build_lm_dataset(tokenizer, args) + else: + raise NotImplementedError('{} task is not implemented.'.format( + args.task)) + # Data stuff + dataloader = build_data_loader( + dataset, + args.eval_batch_size, + args.num_workers, + drop_last=False, + shuffle=False) + + def metrics_func(model, + epoch, + output_predictions=False, + summary_writer=None): + return evaluate_and_print_results( + dataloader, model, eval_metric=eval_metric, args=args) + + global global_tokenizer + global_tokenizer = tokenizer + return metrics_func + + +def main(args): + """Main program.""" + finetune( + args, + None, {}, + end_of_epoch_callback_provider=metrics_func_provider, + forward_step=lm_forward_step) diff --git a/modelscope/models/nlp/mglm/tasks/seq2seq/dataset.py b/modelscope/models/nlp/mglm/tasks/seq2seq/dataset.py new file mode 100644 index 00000000..6a4e275f --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/seq2seq/dataset.py @@ -0,0 +1,667 @@ +# Copyright (c) 2022 Zhipu.AI + +import os +import random + +import json +import numpy as np +import torch +import torch.utils.data +from data_utils.corpora import punctuation_standardization +from tasks.data_utils import InputExample +from tqdm import tqdm +from utils import print_rank_0 + + +def gigaword_detokenize(string, is_target=False): + _tok_dict = { + '(': '-lrb-', + ')': '-rrb-', + '[': '-lsb-', + ']': '-rsb-', + '{': '-lcb-', + '}': '-rcb-', + '&': '&', + '<': '<', + '>': '>' + } + string = string.replace('UNK', '[UNK]') + string = string.replace('', '[UNK]') + for key, value in _tok_dict.items(): + string = string.replace(value, key) + # string = string.replace("''", "\"") + # string = string.replace("``", "\"") + # string = string.replace("`", "'") + # string = string.replace(" n't", "n't") + # string = string.replace(" 's", "'s") + # string = string.replace(" 'd", "'d") + # string = string.replace(" 'll", "'ll") + return string + + +def cnndm_detokenize(string, is_target=False): + _tok_dict = { + '(': '-LRB-', + ')': '-RRB-', + '[': '-LSB-', + ']': '-RSB-', + '{': '-LCB-', + '}': '-RCB-' + } + if not is_target: + string = string.replace('', '') + else: + string = string.replace('', '[SEP]') + for key, value in _tok_dict.items(): + string = string.replace(value, key) + string = string.replace("''", "\"") + string = string.replace('``', "\"") + string = string.replace('`', "'") + string = string.replace(" n't", "n't") + string = string.replace(" 's", "'s") + string = string.replace(" 'd", "'d") + string = string.replace(" 'll", "'ll") + return string + + +def blanklm_detokenize(string, is_target=False): + string = string.replace('_UNK', '[UNK]') + string = string.replace('', '[MASK]') + return string + + +class SummmaryProcessor: + + def __init__(self, task, data_dir, tokenizer): + self.task = task + self.data_dir = data_dir + self.tokenizer = tokenizer + + def create_examples(self, split): + if split == 'train': + filename = 'train' + elif split == 'dev': + filename = 'val' + elif split == 'test': + filename = 'test' + else: + raise NotImplementedError(split) + print_rank_0( + f'Creating {self.task}-{split} dataset from {self.data_dir}') + if self.task == 'gigaword': + detokenizer = gigaword_detokenize + elif self.task == 'cnn_dm': + detokenizer = cnndm_detokenize + else: + detokenizer = None + source_texts, target_texts = [], [] + with open( + os.path.join(self.data_dir, f'{filename}.source'), + encoding='utf-8') as file: + for line in file: + line = line.strip() + line = punctuation_standardization(line) + line = detokenizer(line) if detokenizer else line + source_texts.append(line) + with open( + os.path.join(self.data_dir, f'{filename}.target'), + encoding='utf-8') as file: + for line in file: + line = line.strip() + line = punctuation_standardization(line) + line = detokenizer( + line, is_target=True) if detokenizer else line + target_texts.append(line) + assert len(source_texts) == len(target_texts) + example_list = [] + for idx, (source_text, + target_text) in enumerate(zip(source_texts, target_texts)): + if (idx + 1) % 20000 == 0: + print_rank_0(f'Complete {idx + 1} examples') + guid = '%s-%s' % (split, idx) + meta = { + 'ref': + self.tokenizer.DecodeIds( + self.tokenizer.EncodeAsIds(target_text).tokenization) + } + example = InputExample( + guid=guid, text_a=source_text, text_b=target_text, meta=meta) + if idx < 10: + print_rank_0( + (source_text.encode('utf-8'), target_text.encode('utf-8'), + meta['ref'].encode('utf-8'))) + example_list.append(example) + return example_list + + +class SQuADProcessor: + + def __init__(self, data_dir, tokenizer): + self.data_dir = data_dir + self.tokenizer = tokenizer + + def create_examples(self, split): + if split == 'train': + filename = 'train.json' + elif split == 'dev': + filename = 'dev.json' + elif split == 'test': + filename = 'test.json' + else: + raise NotImplementedError(split) + print_rank_0(f'Creating SQuAD-{split} dataset from {self.data_dir}') + example_list = [] + idx = 0 + with open( + os.path.join(self.data_dir, filename), + encoding='utf-8') as file: + dataset = json.load(file) + for paragraphs in dataset: + for paragraph in paragraphs['paragraphs']: + context = paragraph['context'] + for qa in paragraph['qas']: + question = qa['question'] + answers = {answer['text'] for answer in qa['answers']} + answer_starts = { + answer['text']: answer['answer_start'] + for answer in qa['answers'] + } + for answer in answers: + guid = '%s-%s' % (split, idx) + meta = { + 'answer_start': + answer_starts[answer], + 'answer': + answer, + 'question': + question, + 'ref': + self.tokenizer.DecodeIds( + self.tokenizer.EncodeAsIds( + question).tokenization) + } + example = InputExample( + guid=guid, text_a=context, meta=meta) + if idx < 10: + print_rank_0((context.encode('utf-8'), + answer.encode('utf-8'), + meta['ref'].encode('utf-8'))) + example_list.append(example) + idx += 1 + print_rank_0(f'Creating {len(example_list)} examples for {split}') + return example_list + + +class XSumProcessor: + + def __init__(self, data_dir, tokenizer): + self.data_dir = data_dir + self.tokenizer = tokenizer + + def create_examples(self, split): + if split == 'train': + key = 'train' + elif split == 'dev': + key = 'validation' + elif split == 'test': + key = 'test' + else: + raise NotImplementedError(split) + print_rank_0(f'Creating XSUM-{split} dataset from {self.data_dir}') + with open( + os.path.join( + self.data_dir, + 'XSum-TRAINING-DEV-TEST-SPLIT-90-5-5.json')) as file: + id_list = json.load(file) + id_list = id_list[key] + source_texts, target_texts = [], [] + for i, idx in enumerate(id_list): + with open(os.path.join(self.data_dir, f'{idx}.summary')) as file: + key, sentences = None, [] + source_text, target_text = None, None + for line in file: + line = line.strip() + if line.startswith('[SN]'): + if key is not None: + if key == 'RESTBODY': + source_text = ' '.join(sentences) + elif key == 'FIRST-SENTENCE': + target_text = ' '.join(sentences) + key = line[4:-4] + sentences = [] + elif line: + sentences.append(line) + if key is not None: + if key == 'RESTBODY': + source_text = ' '.join(sentences) + elif key == 'FIRST-SENTENCE': + target_text = ' '.join(sentences) + source_texts.append(source_text) + target_texts.append(target_text) + if (i + 1) % 1000 == 0: + print_rank_0(f'Complete {i + 1} examples') + assert len(source_texts) == len(target_texts) + example_list = [] + for idx, (source_text, + target_text) in enumerate(zip(source_texts, target_texts)): + if (idx + 1) % 20000 == 0: + print_rank_0(f'Complete {idx + 1} examples') + guid = '%s-%s' % (split, idx) + meta = { + 'ref': + self.tokenizer.DecodeIds( + self.tokenizer.EncodeAsIds(target_text).tokenization) + } + example = InputExample( + guid=guid, text_a=source_text, text_b=target_text, meta=meta) + if idx < 10: + print_rank_0( + (source_text.encode('utf-8'), target_text.encode('utf-8'), + meta['ref'].encode('utf-8'))) + example_list.append(example) + return example_list + + +class Seq2SeqDataset(torch.utils.data.Dataset): + + def __init__(self, args, split, tokenizer): + self.args = args + self.task, self.data_dir = args.task.lower(), args.data_dir + self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length + self.split = split + self.tokenizer = tokenizer + self.dataset_name = split + if self.task in ['gigaword', 'cnn_dm', 'cnn_dm_original']: + self.processor = SummmaryProcessor(self.task, self.data_dir, + tokenizer) + elif self.task in ['xsum']: + self.processor = XSumProcessor(self.data_dir, tokenizer) + elif self.task in ['squad_generation']: + self.processor = SQuADProcessor(self.data_dir, tokenizer) + else: + raise NotImplementedError + example_list = self.processor.create_examples(split) + self.example_list = example_list + self.examples = {example.guid: example for example in example_list} + + print_rank_0(f'Return {len(self.examples)} {split} examples') + + def __len__(self): + return len(self.example_list) + + def __getitem__(self, idx): + example = self.example_list[idx] + cls_id = self.tokenizer.get_command('ENC').Id + mask_token = 'sMASK' if self.args.task_mask else 'MASK' + mask_id = self.tokenizer.get_command(mask_token).Id + pad_id = self.tokenizer.get_command('pad').Id + sop_id = self.tokenizer.get_command('sop').Id + eop_id = self.tokenizer.get_command('eop').Id + if self.task in ['gigaword', 'cnn_dm', 'cnn_dm_original', 'xsum']: + source_text, target_text = example.text_a, example.text_b + source_tokens = self.tokenizer.EncodeAsIds( + ' ' + source_text).tokenization + prompt = [cls_id, mask_id + ] + self.tokenizer.EncodeAsIds(' Content:').tokenization + if len(source_tokens) > self.max_src_length - len(prompt): + source_tokens = source_tokens[:self.max_src_length + - len(prompt)] + source_tokens = prompt + source_tokens + elif self.task == 'squad_generation': + source_text = example.text_a + target_text, answer = example.meta['question'], example.meta[ + 'answer'] + source_tokens = self.tokenizer.EncodeAsIds( + source_text.rstrip() + ' Question:').tokenization + answer_tokens = self.tokenizer.EncodeAsIds(' Answer: ' + + answer).tokenization + if len(source_tokens + ) > self.max_src_length - len(answer_tokens) - 2: + max_src_length = self.max_src_length - len(answer_tokens) - 2 + answer_pattern = self.tokenizer.EncodeAsIds( + ' ' + answer).tokenization + + def sub_finder(mylist, pattern): + matches = [] + for i in range(len(mylist)): + if mylist[i] == pattern[0] and mylist[ + i:i + len(pattern)] == pattern: + matches.append(i) + return matches + + answer_indices = sub_finder(source_tokens, answer_pattern) + if len(answer_indices) == 0: + print(f'Answer {answer} not exists in the source text') + source_tokens = source_tokens[:max_src_length] + else: + start_index = max(answer_indices[0] - max_src_length // 2, + 0) + source_tokens = source_tokens[start_index:start_index + + max_src_length] + source_tokens = [cls_id] + source_tokens + [mask_id + ] + answer_tokens + else: + raise NotImplementedError + if len(source_tokens) < self.max_src_length: + source_tokens = source_tokens + [pad_id] * ( + self.max_src_length - len(source_tokens)) + sep = len(source_tokens) + position_ids = list(range(len(source_tokens))) + block_position_ids = [0] * len(source_tokens) + mask_pos = source_tokens.index(mask_id) + if self.split == 'train': + target_tokens = self.tokenizer.EncodeAsIds( + ' ' + target_text).tokenization + target_tokens = target_tokens + [eop_id] + if len(target_tokens) > self.max_tgt_length: + target_tokens = target_tokens[:self.max_tgt_length] + loss_mask = [1] * len(target_tokens) + if len(target_tokens) < self.max_tgt_length: + loss_mask += [0] * (self.max_tgt_length - len(target_tokens)) + target_tokens += [pad_id] * ( + self.max_tgt_length - len(target_tokens)) + tokens = source_tokens + [sop_id] + target_tokens[:-1] + loss_mask = [0] * len(source_tokens) + loss_mask + target_ids = [0] * len(source_tokens) + target_tokens + position_ids += [mask_pos] * len(target_tokens) + if self.args.no_block_position: + block_position_ids += [1] * len(target_tokens) + else: + block_position_ids += list(range(1, len(target_tokens) + 1)) + position_ids = [position_ids, block_position_ids] + sample = { + 'text': np.array(tokens, dtype=np.int64), + 'target': np.array(target_ids, dtype=np.int64), + 'attention_mask': np.array(sep, dtype=np.int64), + 'loss_mask': np.array(loss_mask, dtype=np.int64), + 'position_id': np.array(position_ids, dtype=np.int64), + 'uid': example.guid + } + else: + tokens = source_tokens + [sop_id] + position_ids = position_ids + [mask_pos] + block_position_ids = block_position_ids + [1] + position_ids = [position_ids, block_position_ids] + sample = { + 'text': np.array(tokens, dtype=np.int64), + 'attention_mask': np.array(sep, dtype=np.int64), + 'position_id': np.array(position_ids, dtype=np.int64), + 'uid': example.guid + } + return sample + + +class ExtractionDataset(torch.utils.data.Dataset): + + def __init__(self, args, split, tokenizer): + self.args = args + task, data_dir = args.task.lower(), args.data_dir + self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length + self.split = split + self.tokenizer = tokenizer + if split == 'train': + filename = 'train' + elif split == 'dev': + filename = 'valid' + elif split == 'test': + filename = 'test' + else: + raise NotImplementedError(split) + print_rank_0(f'Creating {task}-{split} dataset from {data_dir}') + self.dataset_name = split + source_texts, target_texts = [], [] + with open( + os.path.join(data_dir, f'{filename}.source'), + encoding='utf-8') as file: + for line in file: + line = line.strip() + source_texts.append(line) + with open( + os.path.join(data_dir, f'{filename}.target'), + encoding='utf-8') as file: + for line in file: + line = line.strip() + target_texts.append(line) + self.examples, self.example_list = {}, [] + for idx, (source_text, + target_text) in enumerate(zip(source_texts, target_texts)): + if (idx + 1) % 20000 == 0: + print_rank_0(f'Complete {idx + 1} examples') + guid = '%s-%s' % (split, idx) + meta = {'ref': target_text} + example = InputExample( + guid=guid, text_a=source_text, text_b=target_text, meta=meta) + self.examples[guid] = example + self.example_list.append(example) + print_rank_0(f'Return {len(self.examples)} {split} examples') + + def __len__(self): + return len(self.example_list) + + def __getitem__(self, idx): + example = self.example_list[idx] + source_text, target_text = example.text_a, example.text_b + mask_token = 'MASK' + mask_id = self.tokenizer.get_command(mask_token).Id + sop_id = self.tokenizer.get_command('sop').Id + eop_id = self.tokenizer.get_command('eop').Id + pad_id = self.tokenizer.get_command('pad').Id + + def pad_to(text, max_len, pad_id): + if len(text) > max_len: + text = text[:max_len] + else: + text = text + [pad_id] * (max_len - len(text)) + return text + + source_tokens = self.tokenizer.EncodeAsIds(source_text).tokenization + masked_tgt = target_text.split('|') + source_tokens = pad_to(source_tokens, self.max_src_length, pad_id) + sep = len(source_tokens) + position_ids = list(range(len(source_tokens))) + block_position_ids = [0] * len(source_tokens) + if self.split == 'train': + mask_positions = [ + i for i, x in enumerate(source_tokens) if x == mask_id + ] + assert len(mask_positions) <= len(masked_tgt) + tokens = source_tokens + target_ids = [0] * len(source_tokens) + loss_mask = [0] * len(source_tokens) + for i, mask_pos in enumerate(mask_positions): + tgt_text = masked_tgt[i] + tgt_tokens = self.tokenizer.EncodeAsIds( + ' ' + tgt_text).tokenization + tokens += [sop_id] + tgt_tokens + target_ids += tgt_tokens + [eop_id] + loss_mask += [1] * (len(tgt_tokens) + 1) + position_ids += [mask_pos] * (len(tgt_tokens) + 1) + block_position_ids += [ + i + 1 for i in range(len(tgt_tokens) + 1) + ] + tokens = pad_to(tokens, self.max_src_length + self.max_tgt_length, + pad_id) + target_ids = pad_to(target_ids, + self.max_src_length + self.max_tgt_length, + pad_id) + loss_mask = pad_to(loss_mask, + self.max_src_length + self.max_tgt_length, 0) + position_ids = pad_to(position_ids, + self.max_src_length + self.max_tgt_length, 0) + block_position_ids = pad_to( + block_position_ids, self.max_src_length + self.max_tgt_length, + 0) + position_ids = [position_ids, block_position_ids] + sample = { + 'text': np.array(tokens, dtype=np.int64), + 'target': np.array(target_ids, dtype=np.int64), + 'attention_mask': np.array(sep, dtype=np.int64), + 'loss_mask': np.array(loss_mask, dtype=np.int64), + 'position_id': np.array(position_ids, dtype=np.int64), + 'uid': example.guid + } + else: + tokens = source_tokens + [sop_id] + mask_pos = source_tokens.index(mask_id) + position_ids = position_ids + [mask_pos] + block_position_ids = block_position_ids + [1] + position_ids = [position_ids, block_position_ids] + sample = { + 'text': np.array(tokens, dtype=np.int64), + 'attention_mask': np.array(sep, dtype=np.int64), + 'position_id': np.array(position_ids, dtype=np.int64), + 'uid': example.guid + } + return sample + + +class BlankLMDataset(torch.utils.data.Dataset): + + def __init__(self, args, split, tokenizer): + self.args = args + task, data_dir = args.task.lower(), args.data_dir + self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length + self.split = split + assert args.tokenizer_type == 'BertWordPieceTokenizer' + self.tokenizer = tokenizer + if split == 'train': + filename = 'train' + elif split == 'dev': + filename = 'valid' + elif split == 'test': + filename = 'test' + else: + raise NotImplementedError(split) + print_rank_0(f'Creating {task}-{split} dataset from {data_dir}') + self.dataset_name = split + detokenizer = blanklm_detokenize + source_texts, target_texts = [], [] + with open( + os.path.join(data_dir, f'{filename}.txt'), + encoding='utf-8') as file: + for line in file: + line = line.strip() + line = detokenizer(line) if detokenizer else line + target_texts.append(line) + if split == 'test': + with open( + os.path.join( + data_dir, + f'blank/test.maskratio{args.blank_maskratio:.1f}.blank' + ), + encoding='utf-8') as file: + for line in file: + line = line.strip() + line = detokenizer(line) if detokenizer else line + source_texts.append(line) + else: + source_texts = target_texts + self.examples, self.example_list = {}, [] + for idx, (source_text, + target_text) in enumerate(zip(source_texts, target_texts)): + # if idx > 10000: + # break + if (idx + 1) % 20000 == 0: + print_rank_0(f'Complete {idx + 1} examples') + guid = '%s-%s' % (split, idx) + meta = {'ref': target_text} + example = InputExample( + guid=guid, text_a=source_text, text_b=target_text, meta=meta) + self.examples[guid] = example + self.example_list.append(example) + print_rank_0(f'Return {len(self.examples)} {split} examples') + self.random = random.Random(args.seed) + + def __len__(self): + return len(self.example_list) + + def __getitem__(self, idx): + example = self.example_list[idx] + source_text, target_text = example.text_a, example.text_b # noqa + mask_token = 'gMASK' if self.args.task_mask else 'MASK' + mask_id = self.tokenizer.get_command(mask_token).Id + sop_id = self.tokenizer.get_command('sop').Id + eop_id = self.tokenizer.get_command('eop').Id + pad_id = self.tokenizer.get_command('pad').Id + if self.split in ['train', 'dev']: + masked_src, masked_tgt = self.mask_text(source_text) + source_text = masked_src + + def pad_to(text, max_len, pad_id): + if len(text) > max_len: + text = text[:max_len] + else: + text = text + [pad_id] * (max_len - len(text)) + return text + + source_tokens = self.tokenizer.EncodeAsIds(' ' + + source_text).tokenization + source_tokens = pad_to(source_tokens, self.max_src_length, pad_id) + sep = len(source_tokens) + position_ids = list(range(len(source_tokens))) + block_position_ids = [0] * len(source_tokens) + if self.split in ['train', 'dev']: + mask_positions = [ + i for i, x in enumerate(source_tokens) if x == mask_id + ] + assert len(mask_positions) <= len(masked_tgt) + tokens = source_tokens + target_ids = [0] * len(source_tokens) + loss_mask = [0] * len(source_tokens) + for i, mask_pos in enumerate(mask_positions): + tgt_text = masked_tgt[i] + tgt_tokens = self.tokenizer.EncodeAsIds( + ' ' + tgt_text).tokenization + tokens += [sop_id] + tgt_tokens + target_ids += tgt_tokens + [eop_id] + loss_mask += [1] * (len(tgt_tokens) + 1) + position_ids += [mask_pos] * (len(tgt_tokens) + 1) + block_position_ids += [ + i + 1 for i in range(len(tgt_tokens) + 1) + ] + max_length = self.max_src_length + int( + self.max_src_length * self.args.blank_maskratio) + tokens = pad_to(tokens, max_length, pad_id) + target_ids = pad_to(target_ids, max_length, pad_id) + loss_mask = pad_to(loss_mask, max_length, 0) + position_ids = pad_to(position_ids, max_length, 0) + block_position_ids = pad_to(block_position_ids, max_length, 0) + position_ids = [position_ids, block_position_ids] + sample = { + 'text': np.array(tokens, dtype=np.int64), + 'target': np.array(target_ids, dtype=np.int64), + 'attention_mask': np.array(sep, dtype=np.int64), + 'loss_mask': np.array(loss_mask, dtype=np.int64), + 'position_id': np.array(position_ids, dtype=np.int64), + 'uid': example.guid + } + else: + tokens = source_tokens + [sop_id] + mask_pos = source_tokens.index(mask_id) + position_ids = position_ids + [mask_pos] + block_position_ids = block_position_ids + [1] + position_ids = [position_ids, block_position_ids] + sample = { + 'text': np.array(tokens, dtype=np.int64), + 'attention_mask': np.array(sep, dtype=np.int64), + 'position_id': np.array(position_ids, dtype=np.int64), + 'uid': example.guid + } + return sample + + def mask_text(self, text): + tokens = text.split() + mask_ratio = self.args.blank_maskratio + n = len(tokens) + indices = sorted(self.random.sample(range(n), int(n * mask_ratio))) + masked_src, masked_tgt = '', [] + for i, idx in enumerate(indices): + if i == 0 or idx != indices[i - 1] + 1: + masked_tgt.append('') + masked_tgt[-1] += ' ' + tokens[idx] + tokens[idx] = '[MASK]' + for i, token in enumerate(tokens): + if i != 0 and token == '[MASK]' and tokens[i - 1] == '[MASK]': + continue + masked_src += ' ' + token + return masked_src, masked_tgt diff --git a/modelscope/models/nlp/mglm/tasks/seq2seq/evaluate.py b/modelscope/models/nlp/mglm/tasks/seq2seq/evaluate.py new file mode 100644 index 00000000..5fd28b89 --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/seq2seq/evaluate.py @@ -0,0 +1,538 @@ +# Copyright (c) 2022 Zhipu.AI + +import datetime +import random +import string + +import mpu +import torch +import torch.nn.functional as F +from generation_utils import (BeamSearchScorer, LogitsProcessorList, + MinLengthLogitsProcessor, + NoRepeatNGramLogitsProcessor) +from rouge_score import rouge_scorer +from utils import print_rank_0 + + +def _is_digit(w): + for ch in w: + if not (ch.isdigit() or ch == ','): + return False + return True + + +gigaword_tok_dict = { + '(': '-lrb-', + ')': '-rrb-', + '[': '-lsb-', + ']': '-rsb-', + '{': '-lcb-', + '}': '-rcb-', + '[UNK]': 'UNK', + '&': '&', + '<': '<', + '>': '>' +} + +cnndm_tok_dict = { + '(': '-LRB-', + ')': '-RRB-', + '[': '-LSB-', + ']': '-RSB-', + '{': '-LCB-', + '}': '-RCB-' +} + + +def fix_tokenization(text, dataset): + if dataset == 'cnn_dm_org': + return text + if dataset == 'gigaword': + text = text.replace('[UNK]', 'UNK') + return text + input_tokens = text.split() + output_tokens = [] + has_left_quote = False + has_left_single_quote = False + + i = 0 + prev_dash = False + while i < len(input_tokens): + tok = input_tokens[i] + flag_prev_dash = False + if tok == "\"": + if has_left_quote: + output_tokens.append("''") + else: + output_tokens.append('``') + has_left_quote = not has_left_quote + i += 1 + elif tok == "'" and len( + output_tokens) > 0 and output_tokens[-1].endswith( + 'n') and i < len(input_tokens) - 1 and input_tokens[ + i + 1] == 't': # noqa + output_tokens[-1] = output_tokens[-1][:-1] + output_tokens.append("n't") + i += 2 + elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[ + i + 1] in ('s', 'd', 'll'): + output_tokens.append("'" + input_tokens[i + 1]) + i += 2 + elif tok == "'": + if has_left_single_quote: + output_tokens.append("'") + else: + output_tokens.append('`') + has_left_single_quote = not has_left_single_quote + i += 1 + elif tok == '.' and i < len(input_tokens) - 2 and input_tokens[ + i + 1] == '.' and input_tokens[i + 2] == '.': + output_tokens.append('...') + i += 3 + elif tok == ',' and len(output_tokens) > 0 and _is_digit( + output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit( + input_tokens[i + 1]): + # $ 3 , 000 -> $ 3,000 + output_tokens[-1] += ',' + input_tokens[i + 1] + i += 2 + elif tok == '.' and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and \ + input_tokens[i + 1].isdigit(): + # 3 . 03 -> $ 3.03 + output_tokens[-1] += '.' + input_tokens[i + 1] + i += 2 + elif tok == '.' and len(output_tokens) > 0 and len( + output_tokens[-1]) == 1 and output_tokens[-1].isalpha( # noqa + ) and i < len(input_tokens) - 2 and len( # noqa + input_tokens[i + 1]) == 1 and input_tokens[ + i + 1].isalpha( # noqa + ) and input_tokens[i + 2] == '.': # noqa + # U . N . -> U.N. + k = i + 3 + while k + 2 < len(input_tokens): + if len(input_tokens[k + 1]) == 1 and input_tokens[ + k + 1].isalpha() and input_tokens[k + 2] == '.': + k += 2 + else: + break + output_tokens[-1] += ''.join(input_tokens[i:k]) + i = k + elif tok == '-': + if i < len(input_tokens) - 1 and input_tokens[i + 1] == '-': + output_tokens.append('--') + i += 2 + elif i == len(input_tokens) - 1 or i == 0: + output_tokens.append('-') + i += 1 + elif output_tokens[-1] not in string.punctuation and input_tokens[ + i + 1][0] not in string.punctuation: + output_tokens[-1] += '-' + i += 1 + flag_prev_dash = True + else: + output_tokens.append('-') + i += 1 + elif prev_dash and len( + output_tokens) > 0 and tok[0] not in string.punctuation: + output_tokens[-1] += tok + i += 1 + else: + output_tokens.append(tok) + i += 1 + prev_dash = flag_prev_dash + return ' '.join(output_tokens) + + +def count_tokens(tokens): + counter = {} + for t in tokens: + if t in counter.keys(): + counter[t] += 1 + else: + counter[t] = 1 + return counter + + +def get_f1(text_a, text_b): + tokens_a = text_a.lower().split() + tokens_b = text_b.lower().split() + if len(tokens_a) == 0 or len(tokens_b) == 0: + return 1 if len(tokens_a) == len(tokens_b) else 0 + set_a = count_tokens(tokens_a) + set_b = count_tokens(tokens_b) + match = 0 + for token in set_a.keys(): + if token in set_b.keys(): + match += min(set_a[token], set_b[token]) + p = match / len(tokens_a) + r = match / len(tokens_b) + return 2.0 * p * r / (p + r + 1e-5) + + +def remove_duplicate(l_list, duplicate_rate): + tk_list = [l.lower().split() for l in l_list] # noqa + r_list = [] + history_set = set() + for i, w_list in enumerate(tk_list): + w_set = set(w_list) + if len(w_set & history_set) / len(w_set) <= duplicate_rate: + r_list.append(l_list[i]) + history_set |= w_set + return r_list + + +def rouge_metric(predictions, + labels, + examples, + metric='rouge-1', + duplicate_rate=0.7, + dataset='cnn_dm'): + metric_dict = { + 'rouge-1': 'rouge1', + 'rouge-2': 'rouge2', + 'rouge-l': 'rougeLsum' + } + refs = [example.meta['ref'] for example in examples] + ref_list = [] + for ref in refs: + ref = ref.strip().split('[SEP]') + ref = [fix_tokenization(sentence, dataset=dataset) for sentence in ref] + ref = '\n'.join(ref) + ref_list.append(ref) + pred_list = [] + for prediction in predictions: + buf = [] + for sentence in prediction.strip().split('[SEP]'): + sentence = fix_tokenization(sentence, dataset=dataset) + if any(get_f1(sentence, s) > 1.0 for s in buf): + continue + s_len = len(sentence.split()) + if s_len <= 4: + continue + buf.append(sentence) + if duplicate_rate and duplicate_rate < 1: + buf = remove_duplicate(buf, duplicate_rate) + line = '\n'.join(buf) + pred_list.append(line) + if torch.distributed.get_rank() == 0: + import json + with open('./results.json', 'w') as output: + for ref, pred in zip(ref_list, pred_list): + output.write(json.dumps({'ref': ref, 'pred': pred}) + '\n') + scorer = rouge_scorer.RougeScorer([metric_dict[metric]], use_stemmer=True) + scores = [ + scorer.score(pred, ref) for pred, ref in zip(pred_list, ref_list) + ] + scores = [score[metric_dict[metric]].fmeasure for score in scores] + scores = sum(scores) / len(scores) + return scores + + +def process_batch(batch, args): + """Process batch and produce inputs for the model.""" + tokens = batch['text'].long().cuda() + attention_mask = batch['attention_mask'].long().cuda() + position_ids = batch['position_id'].long().cuda() + return tokens, attention_mask, position_ids + + +class DecoderEvaluater: + + def __init__(self, args, tokenizer): + self.tokenizer = tokenizer + self.start_token = tokenizer.get_command('sop').Id + self.end_token = tokenizer.get_command('eop').Id + self.mask_token = tokenizer.get_command( + 'sMASK').Id if args.task_mask else tokenizer.get_command('MASK').Id + self.pad_token = tokenizer.get_command('pad').Id + self.processors = LogitsProcessorList() + if args.min_tgt_length > 0: + processor = MinLengthLogitsProcessor(args.min_tgt_length, + self.end_token) + self.processors.append(processor) + if args.no_repeat_ngram_size > 0: + processor = NoRepeatNGramLogitsProcessor(args.no_repeat_ngram_size) + self.processors.append(processor) + + def evaluate(self, model, dataloader, example_dict, args): + """Calculate correct over total answers and return prediction if the + `output_predictions` is true.""" + model.eval() + store = torch.distributed.TCPStore(args.master_ip, + 18931 + random.randint(0, 10000), + mpu.get_data_parallel_world_size(), + torch.distributed.get_rank() == 0, + datetime.timedelta(seconds=30)) + print_rank_0('Distributed store created') + with torch.no_grad(): + # For all the batches in the dataset. + for idx, data in enumerate(dataloader): + tokens, attention_mask, position_ids = process_batch( + data, args) + batch_size = tokens.size(0) + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + max_length=args.out_seq_length, + num_beams=args.num_beams, + device=tokens.device, + length_penalty=args.length_penalty, + do_early_stopping=False, + ) + beam_scores = torch.zeros((batch_size, args.num_beams), + dtype=torch.float, + device=tokens.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * args.num_beams, )) + # Run the model forward. + counter = 0 + while counter < args.tgt_seq_length: + if counter == 0: + next_token_logits, *mems = model( + tokens, + position_ids, + attention_mask, + return_memory=True) + seq_length = next_token_logits.size(1) + next_token_logits = next_token_logits[:, -1] + next_token_logits = next_token_logits.unsqueeze( + 1).repeat(1, args.num_beams, + 1).view(batch_size * args.num_beams, -1) + mems = [ + mem.unsqueeze(1).repeat( + 1, args.num_beams, 1, + 1).view(batch_size * args.num_beams, + seq_length, -1) for mem in mems + ] + position_ids = tokens.new_ones(batch_size, + args.num_beams, 2, 1) + for i, text in enumerate(tokens.tolist()): + mask_pos = text.index(self.mask_token) + position_ids[i, :, 0] = mask_pos + position_ids = position_ids.reshape( + batch_size * args.num_beams, 2, 1) + tokens = tokens.new_zeros(batch_size * args.num_beams, + 0) + attention_mask = tokens.new_zeros( + [batch_size * args.num_beams]) + else: + if not args.no_block_position: + position_ids[:, 1] = counter + 1 + last_token = tokens[:, -1:] + next_token_logits, *mems = model( + last_token, + position_ids, + attention_mask, + *mems, + return_memory=True) + next_token_logits = next_token_logits[:, -1] + next_token_scores = F.log_softmax( + next_token_logits, dim=-1) + next_token_scores = self.processors( + tokens, next_token_scores) + next_token_scores = next_token_scores + beam_scores[:, None].expand_as( + next_token_scores) + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view( + batch_size, args.num_beams * vocab_size) + + probs = F.softmax(next_token_scores, dim=-1) + if args.select_topk: + _, next_tokens = torch.topk( + probs, k=2 * args.num_beams, dim=-1, largest=True) + else: + next_tokens = torch.multinomial( + probs, num_samples=2 * args.num_beams) + next_token_scores = torch.gather(next_token_scores, -1, + next_tokens) + next_token_scores, _indices = torch.sort( + next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size + # stateless + beam_outputs = beam_scorer.process( + tokens, + next_token_scores, + next_tokens, + next_indices, + eos_token_id=self.end_token, + pad_token_id=self.pad_token) + beam_scores = beam_outputs['next_beam_scores'] + beam_next_tokens = beam_outputs['next_beam_tokens'] + beam_idx = beam_outputs['next_beam_indices'] + beam_next_tokens = beam_next_tokens.unsqueeze(-1) + tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], + dim=-1) + mems = [mem[beam_idx] for mem in mems] if mems else [] + if beam_scorer.is_done: + break + counter += 1 + tokens, _ = beam_scorer.finalize( + tokens, + beam_scores, + next_tokens, + next_indices, + eos_token_id=self.end_token, + pad_token_id=self.pad_token) + predictions = [] + for text in tokens.tolist(): + text = [ + token for token in text + if token not in [self.end_token, self.pad_token] + ] + text = self.tokenizer.DecodeIds(text) + predictions.append(text) + uid_list = data['uid'] + if isinstance(uid_list, torch.Tensor): + uid_list = uid_list.cpu().numpy().tolist() + for uid, prediction in zip(uid_list, predictions): + store.set(uid, prediction) + if (idx + 1) % args.log_interval == 0: + print_rank_0(f'Iteration {idx + 1} / {len(dataloader)}') + model.train() + torch.distributed.barrier() + print_rank_0('Evaluation completed') + predictions, examples = [], [] + for uid, example in example_dict.items(): + predictions.append(store.get(uid).decode('utf-8')) + examples.append(example) + torch.distributed.barrier() + return predictions, [], examples + + +def blanklm_fix_tokenization(text): + text = text.replace('` `', '``') + text = text.replace("\' \'", "\'\'") + text = text.replace("n \' t", "n\'t") + text = text.replace("\' s", "\'s") + text = text.replace("\' m", "\'m") + text = text.replace("\' re", "\'re") + text = text.replace('. . .', '...') + text = text.replace(' . .', ' ..') + text = text.replace('- -', '--') + text = text.replace('u . s .', 'u.s.') + text = text.replace('u . k .', 'u.k.') + text = text.replace('e . g .', 'e.g.') + return text + + +class BlankLMEvaluater(DecoderEvaluater): + + def evaluate(self, model, dataloader, example_dict, args): + model.eval() + store = torch.distributed.TCPStore(args.master_ip, + 18931 + random.randint(0, 10000), + mpu.get_data_parallel_world_size(), + torch.distributed.get_rank() == 0, + datetime.timedelta(seconds=30)) + print_rank_0('Distributed store created') + + with torch.no_grad(): + for idx, data in enumerate(dataloader): + tokens, attention_mask, position_ids = process_batch( + data, args) + src_tokens = tokens + batch_size = tokens.size(0) + mask_positions = [] + current_mask = [] + for text in tokens.tolist(): + mask_positions.append([ + i for i, x in enumerate(text) if x == self.mask_token + ]) + current_mask.append(0) + # print(self.tokenizer.DecodeIds(text)) + # print(mask_positions[-1]) + counter = 0 + done = [False] * batch_size + while counter < args.tgt_seq_length: + if counter == 0: + # print(tokens) + # print(position_ids) + next_token_logits, *mems = model( + tokens, + position_ids, + attention_mask, + return_memory=True) + next_token_logits = next_token_logits[:, -1] + position_ids = tokens.new_ones(batch_size, 2, 1) + for i, text in enumerate(tokens.tolist()): + mask_pos = mask_positions[i][current_mask[i]] + position_ids[i, 0] = mask_pos + tokens = tokens.new_zeros(batch_size, 0) + attention_mask = tokens.new_zeros(batch_size) + else: + position_ids[:, 1] = position_ids[:, 1] + 1 + last_token = tokens[:, -1:] + next_token_logits, *mems = model( + last_token, + position_ids, + attention_mask, + *mems, + return_memory=True) + next_token_logits = next_token_logits[:, -1] + next_token_scores = F.log_softmax( + next_token_logits, dim=-1) + next_token_scores = self.processors( + tokens, next_token_scores) + next_tokens = next_token_scores.max(dim=-1)[1] + # print(self.tokenizer.DecodeIds(next_tokens.tolist())) + for i, next_token in enumerate(next_tokens.tolist()): + if next_token == self.end_token: + if current_mask[i] + 1 < len(mask_positions[i]): + current_mask[i] += 1 + next_tokens[i] = self.start_token + position_ids[i, 0] = mask_positions[i][ + current_mask[i]] + position_ids[i, 1] = 0 + else: + done[i] = True + if done[i]: + next_tokens[i] = self.pad_token + if all(done): + break + tokens = torch.cat( + [tokens, next_tokens.unsqueeze(-1)], dim=-1) + counter += 1 + predictions = [] + for i, text in enumerate(tokens.tolist()): + text = [ + token for token in text + if token not in [self.end_token, self.pad_token] + ] + blanks = [[]] + for token in text: + if token == self.start_token: + blanks.append([]) + else: + blanks[-1].append(token) + output_tokens = [] + current_blank = 0 + for token in src_tokens[i].tolist(): + if token == self.mask_token: + if current_blank < len(blanks): + output_tokens += blanks[current_blank] + current_blank += 1 + else: + if token not in [self.pad_token]: + output_tokens.append(token) + text = self.tokenizer.DecodeIds(output_tokens[:-1]) + text = blanklm_fix_tokenization(text) + predictions.append(text) + # print(text) + uid_list = data['uid'] + if isinstance(uid_list, torch.Tensor): + uid_list = uid_list.cpu().numpy().tolist() + for uid, prediction in zip(uid_list, predictions): + store.set(uid, prediction) + if (idx + 1) % args.log_interval == 0: + print_rank_0(f'Iteration {idx + 1} / {len(dataloader)}') + + model.train() + torch.distributed.barrier() + print_rank_0('Evaluation completed') + predictions, examples = [], [] + for uid, example in example_dict.items(): + predictions.append(store.get(uid).decode('utf-8')) + examples.append(example) + torch.distributed.barrier() + return predictions, [], examples diff --git a/modelscope/models/nlp/mglm/tasks/seq2seq/finetune.py b/modelscope/models/nlp/mglm/tasks/seq2seq/finetune.py new file mode 100644 index 00000000..4c0c28e7 --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/seq2seq/finetune.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Race.""" +import functools +from collections import OrderedDict + +import mpu +import torch +from finetune_glm import finetune +from pretrain_glm import get_batch +from tasks.eval_utils import accuracy_func_provider +from tasks.seq2seq.dataset import (BlankLMDataset, ExtractionDataset, + Seq2SeqDataset) +from tasks.seq2seq.evaluate import (BlankLMEvaluater, DecoderEvaluater, + rouge_metric) + +global_tokenizer = None + + +def seq2seq_forward_step(data, model, args, timers, mems): + """Forward step.""" + + # Get the batch. + if timers is not None: + timers('batch generator').start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data, args) + if timers is not None: + timers('batch generator').stop() + # Forward model. + logits, *mems = model(tokens, position_ids, attention_mask, *mems) + # logits, loss_mask = logits[:, args.src_seq_length:], loss_mask[:, args.src_seq_length:] + # target_ids = target_ids[:, args.src_seq_length:] + losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), + labels) + if args.label_smoothing > 0.0: + epsilon = args.label_smoothing + smooth_loss = -torch.nn.functional.log_softmax( + logits, dim=-1).mean(dim=-1) + losses = (1 - epsilon) * losses + epsilon * smooth_loss + loss_mask = loss_mask.reshape(-1) + # The loss is not normalized for fair comparison + loss = torch.sum(losses.reshape(-1) * loss_mask) / loss_mask.sum() + return loss, mems, 'bert' + + +def train_valid_datasets_provider(args, tokenizer): + """Provide train and validation datasets.""" + if args.task.lower() == 'blank': + train_dataset = BlankLMDataset( + args, split='train', tokenizer=tokenizer) + valid_dataset = None + elif args.task.lower() == 'extraction': + train_dataset = ExtractionDataset( + args, split='train', tokenizer=tokenizer) + valid_dataset = None + else: + train_dataset = Seq2SeqDataset( + args, split='train', tokenizer=tokenizer) + valid_dataset = None + global global_tokenizer + global_tokenizer = tokenizer + return train_dataset, valid_dataset + + +def metrics_func_provider(args, tokenizer, is_test): + """Provide metrics callback function.""" + + def single_dataset_provider(split): + if args.task.lower() == 'blank': + return BlankLMDataset(args, split=split, tokenizer=tokenizer) + elif args.task.lower() == 'extraction': + return ExtractionDataset(args, split=split, tokenizer=tokenizer) + else: + return Seq2SeqDataset(args, split=split, tokenizer=tokenizer) + + if args.task.lower() in ['blank', 'extraction']: + evaluater = BlankLMEvaluater(args, tokenizer) + eval_func = evaluater.evaluate + metric_dict = {} + else: + evaluater = DecoderEvaluater(args, tokenizer) + eval_func = evaluater.evaluate + if args.tokenizer_type == 'BertWordPieceTokenizer': + dataset = 'cnn_dm' + elif args.task.lower() == 'gigaword': + dataset = 'gigaword' + else: + dataset = 'cnn_dm_org' + metric_dict = OrderedDict({ + 'rouge-1': + functools.partial(rouge_metric, metric='rouge-1', dataset=dataset), + 'rouge-2': + functools.partial(rouge_metric, metric='rouge-2', dataset=dataset), + 'rouge-l': + functools.partial(rouge_metric, metric='rouge-l', dataset=dataset) + }) + + def output_func(predictions, examples, output_file): + with open(output_file + '.hyps', 'w', encoding='utf-8') as output: + for prediction in predictions: + output.write(prediction) + output.write('\n') + with open(output_file + '.refs', 'w', encoding='utf-8') as output: + for example in examples: + output.write(example.meta['ref']) + output.write('\n') + if args.task.lower() == 'squad_generation': + with open( + output_file + '.source', 'w', encoding='utf-8') as output: + for example in examples: + output.write( + example.text_a.replace('\n', ' ') + ' Answer: ' + + example.meta['answer']) + output.write('\n') + + return accuracy_func_provider( + single_dataset_provider, + metric_dict, + args, + is_test=is_test, + eval_func=eval_func, + output_func=output_func, + only_rank0=False) + + +def main(args): + if args.src_seq_length > args.max_position_embeddings: + args.max_position_embeddings = args.src_seq_length + if args.task.lower() in [ + 'cnn_dm', 'cnn_dm_original', 'gigaword', 'blank', + 'squad_generation', 'xsum', 'extraction' + ]: + finetune( + args, + train_valid_datasets_provider, {}, + end_of_epoch_callback_provider=metrics_func_provider, + forward_step=seq2seq_forward_step) + else: + raise NotImplementedError(args.task) diff --git a/modelscope/models/nlp/mglm/tasks/superglue/README.md b/modelscope/models/nlp/mglm/tasks/superglue/README.md new file mode 100644 index 00000000..94aab0e9 --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/superglue/README.md @@ -0,0 +1,137 @@ +# Use GLM for your NLU tasks +To use GLM for your own NLU tasks, you should implement a subclass of `DataProcessor` in [tasks/superglue/dataset.py](dataset.py) and a subclass of `PVP` in [tasks/superglue/pvp.py](pvp.py). You should also specify the We will take the RTE and ReCoRD tasks in SuperGLUE as an example. + +## 1. Design your patterns +RTE is an NLI task in which the model is required to predict text entailment between a premise and a hypothesis. The label can be `entailment` or `not_entailment` One sample from the training set is +``` +premise: No Weapons of Mass Destruction Found in Iraq Yet. +hypothesis: Weapons of Mass Destruction Found in Iraq. +label: not_entailment +``` +We design the pattern as +``` +"`hypothesis`"?, [MASK], "`premise`" +``` +GLM predicts "Yes" for `entailment` and "No" for `not_entailment`. "Yes" and "No" are called verbalizers for `entailment` and `not_entailment`. + +ReCoRD is a multi-choice QA task. Each example consists of a news article and a Cloze-style question about the article in which one entity is masked out. The system must predict the masked out entity from a list of possible entities in the provided passage. We directly adopt the cloze-style question as our pattern and use GLM to predict the masked entity. + +## 2. Implement subclass of `DataProcessor` +A subclass of `DataProcessor` should implement `get_train_examples`, `get_dev_examples` and `get_test_examples`, which return the examples of the train, dev, and test sets. The returned value is a list of `InputExample`. It should also implement `get_labels` to return the list of possible labels. Hete we take the `RTEProcessor` as an example: +```python +class RteProcessor(DataProcessor): + """Processor for the RTE data set.""" + + def get_train_examples(self, data_dir): + return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") + + def get_test_examples(self, data_dir): + return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") + + def get_unlabeled_examples(self, data_dir): + return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") + + def get_labels(self): + return ["entailment", "not_entailment"] + + def _create_examples(self, path: str, set_type: str, hypothesis_name: str = "hypothesis", + premise_name: str = "premise") -> List[InputExample]: + examples = [] + + with open(path, encoding='utf8') as f: + for line_idx, line in enumerate(f): + example_json = json.loads(line) + idx = example_json['idx'] + if isinstance(idx, str): + try: + idx = int(idx) + except ValueError: + idx = line_idx + label = example_json.get('label') + guid = "%s-%s" % (set_type, idx) + text_a = example_json[premise_name] + text_b = example_json[hypothesis_name] + + example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx) + examples.append(example) + + return examples +``` +After that, you should add the implemented class to ``PROCESSORS`` at the end of [tasks/superglue/dataset.py](dataset.py): +```python +PROCESSORS = { + ... + "rte": RteProcessor +} +``` + +## 3. Implement subclass of `PVP` +To implement a subclass of `PVP`, you should first decide your verbalizers is single-token or multi-token. The verbalizers in RTE, "Yes" and "No" are single-token. Instead, the verbalizers in ReCoRD are multi-token, as one entity can be tokenized into multiple tokens with WordPiece or BPE tokenizer. + +For single-token task, you should set `is_multi_token=False` in the class definition. You should implement `get_parts` to return the inputs to GLM given an example and `verbalize` to return the verbalizer given a label. Take `RTEPVP` as an example: +```python +class RtePVP(PVP): + is_multi_token = False + VERBALIZER = { + "not_entailment": [" No"], + "entailment": [" Yes"] + } + + @property + def spell_length(self): + return self.pattern_id + + def get_parts(self, example: InputExample) -> FilledPattern: + # switch text_a and text_b to get the correct order + text_a = example.text_a + text_b = example.text_b.rstrip(string.punctuation) + return ['"', self.shortenable(text_b), '" ?'], [[self.mask], ', "', self.shortenable(text_a), '"'] + + def verbalize(self, label) -> List[str]: + return RtePVP.VERBALIZER[label] +``` +We use `PvP.shortenable` to mark the segments that can be truncated when exceeding the maximum sequence length. + +For multi-token task, you should set `is_multi_token=True` in the class definition. You should implement `get_parts` to return the inputs to GLM given an example and `get_answers` to return the candidates. Take `ReCoRDPVP` as an example: +```python +class RecordPVP(PVP): + is_multi_token = True + + def get_answers(self, example: InputExample): + choices = example.meta['candidates'] + choices = [" " + choice for choice in choices] + return choices + + def get_parts(self, example: InputExample) -> FilledPattern: + premise = self.shortenable(example.text_a) + + assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token' + question_a, question_b = example.text_b.split('@placeholder') + return [premise, " " + question_a.rstrip(), [self.mask], question_b], [] +``` +After that, you should implement the class to `PVPS` at the end of [tasks/superglue/pvp.py](pvp.py): +```python +PVPS = { + ... + 'rte': RtePVP, + 'record': RecordPVP +} +``` +## 4. Run the experiment +To run the experiment for your new task, you should create a config file like [config_tasks/task_rte.sh](/config_tasks/task_rte.sh). You should also specify the evaluation metrics for the task in `DEFAULT_METRICS` of [tasks/superglue/finetune.py](finetune.py): +```python +DEFAULT_METRICS = { + ... + "record": [("EM", qa_exact_match), ("F1", qa_f1)], + "rte": [("accuracy", accuracy_metric)] +} +``` +Then you can run the experiment with [finetune_superglue.sh](/scripts/finetune_superglue.sh): +```shell +bash scripts/finetune_superglue.sh \ + config_tasks/model_blocklm_large.sh \ + config_tasks/task_rte.sh +``` diff --git a/modelscope/models/nlp/mglm/tasks/superglue/__init__.py b/modelscope/models/nlp/mglm/tasks/superglue/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/nlp/mglm/tasks/superglue/dataset.py b/modelscope/models/nlp/mglm/tasks/superglue/dataset.py new file mode 100644 index 00000000..36367671 --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/superglue/dataset.py @@ -0,0 +1,1475 @@ +# Copyright (c) 2022 Zhipu.AI +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains the logic for loading training and test data for all tasks. +""" + +import copy +import csv +import glob +import os +import random +import re +from abc import ABC, abstractmethod +from collections import Counter, defaultdict +from typing import Callable, Dict, List + +import json +import numpy as np +import pandas as pd +from data_utils import (build_input_from_ids, build_sample, + num_special_tokens_to_add) +from data_utils.corpora import punctuation_standardization +from torch.utils.data import Dataset +from tqdm import tqdm +from utils import print_rank_0 + +from modelscope.models.nlp.mglm.tasks.data_utils import InputExample +from modelscope.models.nlp.mglm.tasks.superglue.pvp import PVPS + +TRAIN_SET = 'train' +DEV_SET = 'dev' +TEST_SET = 'test' +TRUE_DEV_SET = 'true_dev' +UNLABELED_SET = 'unlabeled' + +SPLIT_TYPES = [TRAIN_SET, DEV_SET, TEST_SET, TRUE_DEV_SET, UNLABELED_SET] + + +def get_output_func(task_name, args): + return PROCESSORS[task_name](args).output_prediction + + +def read_tsv(path, **kwargs): + return pd.read_csv( + path, + sep='\t', + quoting=csv.QUOTE_NONE, + dtype=str, + na_filter=False, + **kwargs) + + +class SuperGlueDataset(Dataset): + + def __init__(self, + args, + task_name, + data_dir, + seq_length, + split, + tokenizer, + for_train=False, + pattern_ensemble=False, + pattern_text=False): + self.processor = PROCESSORS[task_name](args) + args.variable_num_choices = self.processor.variable_num_choices + print_rank_0( + f'Creating {task_name} dataset from file at {data_dir} (split={split})' + ) + self.dataset_name = f'{task_name}-{split}' + self.cloze_eval = args.cloze_eval + self.seq_length = seq_length + self.tokenizer = tokenizer + self.pattern_ensemble = pattern_ensemble + self.pattern_text = pattern_text + if pattern_text: + assert self.cloze_eval, 'Labeled examples only exist in cloze evaluation' + self.args = args + if split == DEV_SET: + example_list = self.processor.get_dev_examples( + data_dir, for_train=for_train) + elif split == TEST_SET: + example_list = self.processor.get_test_examples(data_dir) + elif split == TRUE_DEV_SET: + example_list = self.processor.get_true_dev_examples(data_dir) + elif split == TRAIN_SET: + if task_name == 'wsc': + example_list = self.processor.get_train_examples( + data_dir, cloze_eval=args.cloze_eval) + else: + example_list = self.processor.get_train_examples(data_dir) + elif split == UNLABELED_SET: + example_list = self.processor.get_unlabeled_examples(data_dir) + for example in example_list: + example.label = self.processor.get_labels()[0] + else: + raise ValueError( + f"'split' must be one of {SPLIT_TYPES}, got '{split}' instead") + if split == TEST_SET: + self.labeled = False + else: + self.labeled = True + + label_distribution = Counter(example.label for example in example_list) + print_rank_0( + f'Returning {len(example_list)} {split} examples with label dist.: {list(label_distribution.items())}' + ) + self.samples = [] + example_list.sort(key=lambda x: x.num_choices) + self.example_list = example_list + if self.cloze_eval: + if self.pattern_ensemble: + pattern_ids = PVPS[task_name].available_patterns() + self.pvps = [] + for pattern_id in pattern_ids: + self.pvps.append(PVPS[task_name]( + args, + tokenizer, + self.processor.get_labels(), + seq_length, + pattern_id=pattern_id, + num_prompt_tokens=args.num_prompt_tokens, + is_multi_token=args.multi_token, + max_segment_length=args.segment_length, + fast_decode=args.fast_decode, + split=split)) + else: + self.pvp = PVPS[task_name]( + args, + tokenizer, + self.processor.get_labels(), + seq_length, + pattern_id=args.pattern_id, + num_prompt_tokens=args.num_prompt_tokens, + is_multi_token=args.multi_token, + max_segment_length=args.segment_length, + fast_decode=args.fast_decode, + split=split) + self.examples = {example.guid: example for example in example_list} + + def __len__(self): + if self.cloze_eval and self.pattern_ensemble: + return len(self.example_list) * len(self.pvps) + else: + return len(self.example_list) + + def __getitem__(self, idx): + sample_idx = idx % len(self.example_list) + example = self.example_list[sample_idx] + if self.cloze_eval: + kwargs = {} + if self.pattern_text: + kwargs = {'labeled': True, 'priming': True} + if self.pattern_ensemble: + pvp_idx = idx // len(self.example_list) + sample = self.pvps[pvp_idx].encode(example, **kwargs) + else: + sample = self.pvp.encode(example, **kwargs) + if self.pattern_text: + eos_id = self.tokenizer.get_command('eos').Id + cls_id = self.tokenizer.get_command('ENC').Id + input_ids = [cls_id] + sample + [eos_id] + sample = { + 'text': input_ids, + 'loss_mask': np.array([1] * len(input_ids)) + } + else: + sample = self.processor.encode(example, self.tokenizer, + self.seq_length, self.args) + return sample + + +class DataProcessor(ABC): + """ + Abstract class that provides methods for loading training, testing, development and unlabeled examples for a given + task + """ + + def __init__(self, args): + self.args = args + self.num_truncated = 0 + + def output_prediction(self, predictions, examples, output_file): + with open(output_file, 'w') as output: + for prediction, example in zip(predictions, examples): + prediction = self.get_labels()[prediction] + data = {'idx': example.idx, 'label': prediction} + output.write(json.dumps(data) + '\n') + + @property + def variable_num_choices(self): + return False + + @abstractmethod + def get_train_examples(self, data_dir) -> List[InputExample]: + """Get a collection of `InputExample`s for the train set.""" + pass + + @abstractmethod + def get_dev_examples(self, + data_dir, + for_train=False) -> List[InputExample]: + """Get a collection of `InputExample`s for the dev set.""" + pass + + def get_test_examples(self, data_dir) -> List[InputExample]: + """Get a collection of `InputExample`s for the test set.""" + return [] + + def get_unlabeled_examples(self, data_dir) -> List[InputExample]: + """Get a collection of `InputExample`s for the unlabeled set.""" + return [] + + @abstractmethod + def get_labels(self) -> List[str]: + """Get the list of labels for this data set.""" + pass + + def get_classifier_input(self, example: InputExample, tokenizer): + return example.text_a, example.text_b + + def encode(self, example: InputExample, tokenizer, seq_length, args): + text_a, text_b = self.get_classifier_input(example, tokenizer) + tokens_a = tokenizer.EncodeAsIds(text_a).tokenization + tokens_b = tokenizer.EncodeAsIds(text_b).tokenization + num_special_tokens = num_special_tokens_to_add( + tokens_a, + tokens_b, + None, + add_cls=True, + add_sep=True, + add_piece=False) + if len(tokens_a) + len(tokens_b) + num_special_tokens > seq_length: + self.num_truncated += 1 + data = build_input_from_ids( + tokens_a, + tokens_b, + None, + seq_length, + tokenizer, + args=args, + add_cls=True, + add_sep=True, + add_piece=False) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + label = 0 + if example.label is not None: + label = example.label + label = self.get_labels().index(label) + if args.pretrained_bert: + sample = build_sample( + ids, + label=label, + types=types, + paddings=paddings, + unique_id=example.guid) + else: + sample = build_sample( + ids, + positions=position_ids, + masks=sep, + label=label, + unique_id=example.guid) + return sample + + +class SuperGLUEProcessor(DataProcessor): + + def __init__(self, args): + super(SuperGLUEProcessor, self).__init__(args) + self.few_superglue = args.few_superglue + + def get_train_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'train.jsonl'), 'train') + + def get_dev_examples(self, data_dir, for_train=False): + if self.few_superglue: + return self._create_examples( + os.path.join(data_dir, 'dev32.jsonl'), 'dev') + else: + return self._create_examples( + os.path.join(data_dir, 'val.jsonl'), 'dev') + + def get_test_examples(self, data_dir): + if self.few_superglue: + return self._create_examples( + os.path.join(data_dir, 'val.jsonl'), 'test') + else: + return self._create_examples( + os.path.join(data_dir, 'test.jsonl'), 'test') + + def get_unlabeled_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'unlabeled.jsonl'), 'unlabeled') + + def _create_examples(self, *args, **kwargs): + pass + + +class RteProcessor(SuperGLUEProcessor): + """Processor for the RTE data set.""" + + def get_labels(self): + return ['entailment', 'not_entailment'] + + def _create_examples(self, + path: str, + set_type: str, + hypothesis_name: str = 'hypothesis', + premise_name: str = 'premise') -> List[InputExample]: + examples = [] + + with open(path, encoding='utf8') as f: + for line_idx, line in enumerate(f): + example_json = json.loads(line) + idx = example_json['idx'] + if isinstance(idx, str): + try: + idx = int(idx) + except ValueError: + idx = line_idx + label = example_json.get('label') + guid = '%s-%s' % (set_type, idx) + text_a = punctuation_standardization( + example_json[premise_name]) + text_b = punctuation_standardization( + example_json[hypothesis_name]) + + example = InputExample( + guid=guid, + text_a=text_a, + text_b=text_b, + label=label, + idx=idx) + examples.append(example) + + return examples + + +class AxGProcessor(RteProcessor): + """Processor for the AX-G diagnostic data set.""" + + def get_train_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'AX-g.jsonl'), 'train') + + def get_test_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'AX-g.jsonl'), 'test') + + +class AxBProcessor(RteProcessor): + """Processor for the AX-B diagnostic data set.""" + + def get_train_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'AX-b.jsonl'), 'train') + + def get_test_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'AX-b.jsonl'), 'test') + + def _create_examples(self, + path, + set_type, + hypothesis_name='sentence2', + premise_name='sentence1'): + return super()._create_examples(path, set_type, hypothesis_name, + premise_name) + + +class CbProcessor(RteProcessor): + """Processor for the CB data set.""" + + def get_labels(self): + return ['entailment', 'contradiction', 'neutral'] + + +class WicProcessor(SuperGLUEProcessor): + """Processor for the WiC data set.""" + + def get_labels(self): + return ['false', 'true'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + with open(path, encoding='utf8') as f: + for line in f: + example_json = json.loads(line) + idx = example_json['idx'] + if isinstance(idx, str): + idx = int(idx) + label = 'true' if example_json.get('label') else 'false' + guid = '%s-%s' % (set_type, idx) + text_a = punctuation_standardization(example_json['sentence1']) + text_b = punctuation_standardization(example_json['sentence2']) + meta = {'word': example_json['word']} + example = InputExample( + guid=guid, + text_a=text_a, + text_b=text_b, + label=label, + idx=idx, + meta=meta) + examples.append(example) + return examples + + def get_classifier_input(self, example: InputExample, tokenizer): + text_a = example.meta['word'] + ': ' + example.text_a + return text_a, example.text_b + + +class WscProcessor(SuperGLUEProcessor): + """Processor for the WSC data set.""" + + @property + def variable_num_choices(self): + return self.args.wsc_negative + + def get_train_examples(self, data_dir, cloze_eval=True): + return self._create_examples( + os.path.join(data_dir, 'train.jsonl'), + 'train', + cloze_eval=cloze_eval) + + def get_labels(self): + return ['False', 'True'] + + def get_classifier_input(self, example: InputExample, tokenizer): + target = example.meta['span1_text'] + pronoun_idx = example.meta['span2_index'] + + # mark the pronoun with asterisks + words_a = example.text_a.split() + words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*' + text_a = ' '.join(words_a) + text_b = target + return text_a, text_b + + def _create_examples(self, + path: str, + set_type: str, + cloze_eval=True) -> List[InputExample]: + examples = [] + + with open(path, encoding='utf8') as f: + for line in f: + example_json = json.loads(line) + idx = example_json['idx'] + label = str( + example_json['label']) if 'label' in example_json else None + guid = '%s-%s' % (set_type, idx) + text_a = punctuation_standardization(example_json['text']) + meta = { + 'span1_text': example_json['target']['span1_text'], + 'span2_text': example_json['target']['span2_text'], + 'span1_index': example_json['target']['span1_index'], + 'span2_index': example_json['target']['span2_index'] + } + if 'candidates' in example_json: + candidates = [ + cand['text'] for cand in example_json['candidates'] + ] + # candidates = list(set(candidates)) + filtered = [] + for i, cand in enumerate(candidates): + if cand not in candidates[:i]: + filtered.append(cand) + candidates = filtered + + # the indices in the dataset are wrong for some examples, so we manually fix them + span1_index, span1_text = meta['span1_index'], meta[ + 'span1_text'] + span2_index, span2_text = meta['span2_index'], meta[ + 'span2_text'] + words_a = text_a.split() + words_a_lower = text_a.lower().split() + words_span1_text = span1_text.lower().split() + span1_len = len(words_span1_text) + + if words_a_lower[span1_index:span1_index + + span1_len] != words_span1_text: + for offset in [-1, +1]: + if words_a_lower[span1_index + offset:span1_index + + span1_len + + offset] == words_span1_text: + span1_index += offset + + # if words_a_lower[span1_index:span1_index + span1_len] != words_span1_text: + # print_rank_0(f"Got '{words_a_lower[span1_index:span1_index + span1_len]}' but expected " + # f"'{words_span1_text}' at index {span1_index} for '{words_a}'") + + if words_a[span2_index] != span2_text: + for offset in [-1, +1]: + if words_a[span2_index + offset] == span2_text: + span2_index += offset + + if words_a[span2_index] != span2_text and words_a[ + span2_index].startswith(span2_text): + words_a = words_a[:span2_index] \ + + [words_a[span2_index][:len(span2_text)], words_a[span2_index][len(span2_text):]] + words_a[span2_index + 1:] # noqa + + assert words_a[span2_index] == span2_text, \ + f"Got '{words_a[span2_index]}' but expected '{span2_text}' at index {span2_index} for '{words_a}'" + + text_a = ' '.join(words_a) + meta['span1_index'], meta[ + 'span2_index'] = span1_index, span2_index + + if self.args.task == 'wsc1': + example = InputExample( + guid=guid, + text_a=text_a, + text_b=span1_text, + label=label, + meta=meta, + idx=idx) + examples.append(example) + if set_type == 'train' and label == 'True': + for cand in candidates: + example = InputExample( + guid=guid, + text_a=text_a, + text_b=cand, + label='False', + meta=meta, + idx=idx) + examples.append(example) + continue + + if cloze_eval and set_type == 'train' and label != 'True': + continue + if set_type == 'train' and 'candidates' in example_json and len( + candidates) > 9: + for i in range(0, len(candidates), 9): + _meta = copy.deepcopy(meta) + _meta['candidates'] = candidates[i:i + 9] + if len(_meta['candidates']) < 9: + _meta['candidates'] += candidates[:9 - len( + _meta['candidates'])] + example = InputExample( + guid=guid, + text_a=text_a, + label=label, + meta=_meta, + idx=idx) + examples.append(example) + else: + if 'candidates' in example_json: + meta['candidates'] = candidates + example = InputExample( + guid=guid, + text_a=text_a, + label=label, + meta=meta, + idx=idx) + examples.append(example) + + return examples + + +class BoolQProcessor(SuperGLUEProcessor): + """Processor for the BoolQ data set.""" + + def get_labels(self): + return ['false', 'true'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + + with open(path, encoding='utf8') as f: + for line in f: + example_json = json.loads(line) + idx = example_json['idx'] + label = str(example_json['label']).lower( + ) if 'label' in example_json else None + guid = '%s-%s' % (set_type, idx) + text_a = punctuation_standardization(example_json['passage']) + text_b = punctuation_standardization(example_json['question']) + example = InputExample( + guid=guid, + text_a=text_a, + text_b=text_b, + label=label, + idx=idx) + examples.append(example) + + return examples + + +class CopaProcessor(SuperGLUEProcessor): + """Processor for the COPA data set.""" + + def get_labels(self): + return [0, 1] + + def encode(self, example: InputExample, tokenizer, seq_length, args): + if args.pretrained_bert: + ids_list, types_list, paddings_list = [], [], [] + else: + ids_list, positions_list, sep_list = [], [], [] + question = example.meta['question'] + joiner = 'because' if question == 'cause' else 'so' + text_a = punctuation_standardization(example.text_a) + ' ' + joiner + tokens_a = tokenizer.EncodeAsIds(text_a).tokenization + for choice in [example.meta['choice1'], example.meta['choice2']]: + choice = punctuation_standardization(choice) + tokens_b = tokenizer.EncodeAsIds(choice).tokenization + num_special_tokens = num_special_tokens_to_add( + tokens_a, + tokens_b, + None, + add_cls=True, + add_sep=True, + add_piece=False) + if len(tokens_a) + len(tokens_b) + num_special_tokens > seq_length: + self.num_truncated += 1 + data = build_input_from_ids( + tokens_a, + tokens_b, + None, + seq_length, + tokenizer, + args, + add_cls=True, + add_sep=True, + add_piece=False) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + if args.pretrained_bert: + ids_list.append(ids) + types_list.append(types) + paddings_list.append(paddings) + else: + ids_list.append(ids) + positions_list.append(position_ids) + sep_list.append(sep) + label = 0 + if example.label is not None: + label = example.label + label = self.get_labels().index(label) + if args.pretrained_bert: + sample = build_sample( + ids_list, + label=label, + types=types_list, + paddings=paddings_list, + unique_id=example.guid) + else: + sample = build_sample( + ids_list, + positions=positions_list, + masks=sep_list, + label=label, + unique_id=example.guid) + return sample + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + + with open(path, encoding='utf8') as f: + for line in f: + example_json = json.loads(line) + label = example_json[ + 'label'] if 'label' in example_json else None + idx = example_json['idx'] + guid = '%s-%s' % (set_type, idx) + text_a = example_json['premise'] + meta = { + 'choice1': example_json['choice1'], + 'choice2': example_json['choice2'], + 'question': example_json['question'] + } + example = InputExample( + guid=guid, text_a=text_a, label=label, meta=meta, idx=idx) + examples.append(example) + + if set_type == 'train' or set_type == 'unlabeled': + mirror_examples = [] + for ex in examples: + label = 1 if ex.label == 0 else 0 + meta = { + 'choice1': ex.meta['choice2'], + 'choice2': ex.meta['choice1'], + 'question': ex.meta['question'] + } + mirror_example = InputExample( + guid=ex.guid + 'm', + text_a=ex.text_a, + label=label, + meta=meta) + mirror_examples.append(mirror_example) + examples += mirror_examples + print_rank_0( + f'Added {len(mirror_examples)} mirror examples, total size is {len(examples)}...' + ) + return examples + + +class MultiRcProcessor(SuperGLUEProcessor): + """Processor for the MultiRC data set.""" + + def get_labels(self): + return [0, 1] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + + with open(path, encoding='utf8') as f: + for line in f: + example_json = json.loads(line) + + passage_idx = example_json['idx'] + text = punctuation_standardization( + example_json['passage']['text']) + questions = example_json['passage']['questions'] + for question_json in questions: + question = punctuation_standardization( + question_json['question']) + question_idx = question_json['idx'] + answers = question_json['answers'] + for answer_json in answers: + label = answer_json[ + 'label'] if 'label' in answer_json else None + answer_idx = answer_json['idx'] + guid = f'{set_type}-p{passage_idx}-q{question_idx}-a{answer_idx}' + meta = { + 'passage_idx': + passage_idx, + 'question_idx': + question_idx, + 'answer_idx': + answer_idx, + 'answer': + punctuation_standardization(answer_json['text']) + } + idx = [passage_idx, question_idx, answer_idx] + example = InputExample( + guid=guid, + text_a=text, + text_b=question, + label=label, + meta=meta, + idx=idx) + examples.append(example) + + question_indices = list( + set(example.meta['question_idx'] for example in examples)) + label_distribution = Counter(example.label for example in examples) + print_rank_0( + f'Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label ' + f'distribution {list(label_distribution.items())}') + return examples + + def output_prediction(self, predictions, examples, output_file): + with open(output_file, 'w') as output: + passage_dict = defaultdict(list) + for prediction, example in zip(predictions, examples): + passage_dict[example.meta['passage_idx']].append( + (prediction, example)) + for passage_idx, data in passage_dict.items(): + question_dict = defaultdict(list) + passage_data = { + 'idx': passage_idx, + 'passage': { + 'questions': [] + } + } + for prediction, example in data: + question_dict[example.meta['question_idx']].append( + (prediction, example)) + for question_idx, data in question_dict.items(): + question_data = {'idx': question_idx, 'answers': []} + for prediction, example in data: + prediction = self.get_labels()[prediction] + question_data['answers'].append({ + 'idx': + example.meta['answer_idx'], + 'label': + prediction + }) + passage_data['passage']['questions'].append(question_data) + output.write(json.dumps(passage_data) + '\n') + + def get_classifier_input(self, example: InputExample, tokenizer): + text_a = example.text_a + text_b = ' '.join([example.text_b, 'answer:', example.meta['answer']]) + return text_a, text_b + + +class RaceProcessor(DataProcessor): + + @property + def variable_num_choices(self): + return True + + def get_labels(self): + return ['A', 'B', 'C', 'D'] + + def get_train_examples(self, data_dir): + return self._create_examples(os.path.join(data_dir, 'train'), 'train') + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples( + os.path.join(data_dir, 'dev'), 'dev', for_train=for_train) + + def get_test_examples(self, data_dir): + return self._create_examples(os.path.join(data_dir, 'test'), 'test') + + @staticmethod + def _create_examples(path, + set_type, + for_train=False) -> List[InputExample]: + examples = [] + + def clean_text(text): + """Remove new lines and multiple spaces and adjust end of sentence dot.""" + + text = text.replace('\n', ' ') + text = re.sub(r'\s+', ' ', text) + for _ in range(3): + text = text.replace(' . ', '. ') + + return text + + filenames = glob.glob(os.path.join( + path, 'middle', '*.txt')) + glob.glob( + os.path.join(path, 'high', '*.txt')) + for filename in filenames: + with open(filename, 'r') as f: + for line in f: + data = json.loads(line) + idx = data['id'] + context = data['article'] + questions = data['questions'] + choices = data['options'] + answers = data['answers'] + # Check the length. + assert len(questions) == len(answers) + assert len(questions) == len(choices) + + context = clean_text(context) + for question_idx, question in enumerate(questions): + answer = answers[question_idx] + choice = choices[question_idx] + guid = f'{set_type}-p{idx}-q{question_idx}' + ex_idx = [set_type, idx, question_idx] + meta = {'choices': choice} + example = InputExample( + guid=guid, + text_a=context, + text_b=question, + label=answer, + meta=meta, + idx=ex_idx) + examples.append(example) + return examples + + +class RecordProcessor(SuperGLUEProcessor): + """Processor for the ReCoRD data set.""" + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples( + os.path.join(data_dir, 'val.jsonl'), 'dev', for_train=for_train) + + @property + def variable_num_choices(self): + return True + + def get_labels(self): + return ['0', '1'] + + def output_prediction(self, predictions, examples, output_file): + with open(output_file, 'w') as output: + for prediction, example in zip(predictions, examples): + prediction = example.meta['candidates'][prediction] + data = {'idx': example.idx, 'label': prediction} + output.write(json.dumps(data) + '\n') + + def encode(self, example: InputExample, tokenizer, seq_length, args): + if args.pretrained_bert: + ids_list, types_list, paddings_list = [], [], [] + else: + ids_list, positions_list, sep_list = [], [], [] + tokens_a = tokenizer.EncodeAsIds(example.text_a).tokenization + tokens_b = tokenizer.EncodeAsIds( + example.text_b).tokenization if example.text_b else None + for answer in example.meta['candidates']: + answer_ids = tokenizer.EncodeAsIds(answer).tokenization + total_length = len(tokens_a) + len(tokens_b) + len(answer_ids) + total_length += num_special_tokens_to_add( + tokens_a, + tokens_b + answer_ids, + None, + add_cls=True, + add_sep=True, + add_piece=False) + if total_length > seq_length: + self.num_truncated += 1 + data = build_input_from_ids( + tokens_a, + tokens_b + answer_ids, + None, + seq_length, + tokenizer, + args, + add_cls=True, + add_sep=True, + add_piece=False) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + if args.pretrained_bert: + ids_list.append(ids) + types_list.append(types) + paddings_list.append(paddings) + else: + ids_list.append(ids) + positions_list.append(position_ids) + sep_list.append(sep) + label = example.label + label = self.get_labels().index(label) + if args.pretrained_bert: + sample = build_sample( + ids_list, + label=label, + types=types_list, + paddings=paddings_list, + unique_id=example.guid) + else: + sample = build_sample( + ids_list, + positions=positions_list, + masks=sep_list, + label=label, + unique_id=example.guid) + return sample + + @staticmethod + def _create_examples(path, + set_type, + seed=42, + max_train_candidates_per_question: int = 10, + for_train=False) -> List[InputExample]: + examples = [] + + entity_shuffler = random.Random(seed) + + with open(path, encoding='utf8') as f: + for idx, line in enumerate(f): + example_json = json.loads(line) + + idx = example_json['idx'] + text = punctuation_standardization( + example_json['passage']['text']) + entities = set() + + for entity_json in example_json['passage']['entities']: + start = entity_json['start'] + end = entity_json['end'] + entity = punctuation_standardization(text[start:end + 1]) + entities.add(entity) + + entities = list(entities) + entities.sort() + + text = text.replace( + '@highlight\n', '- ' + ) # we follow the GPT-3 paper wrt @highlight annotations + questions = example_json['qas'] + + for question_json in questions: + question = punctuation_standardization( + question_json['query']) + question_idx = question_json['idx'] + answers = set() + + for answer_json in question_json.get('answers', []): + answer = punctuation_standardization( + answer_json['text']) + answers.add(answer) + + answers = list(answers) + + if set_type == 'train' or for_train: + # create a single example per *correct* answer + for answer_idx, answer in enumerate(answers): + candidates = [ + ent for ent in entities if ent not in answers + ] + if len(candidates + ) > max_train_candidates_per_question - 1: + entity_shuffler.shuffle(candidates) + candidates = candidates[: + max_train_candidates_per_question + - 1] + + guid = f'{set_type}-p{idx}-q{question_idx}-a{answer_idx}' + meta = { + 'passage_idx': idx, + 'question_idx': question_idx, + 'candidates': [answer] + candidates, + 'answers': [answer] + } + ex_idx = [idx, question_idx, answer_idx] + example = InputExample( + guid=guid, + text_a=text, + text_b=question, + label='0', + meta=meta, + idx=ex_idx, + num_choices=len(candidates) + 1) + examples.append(example) + + else: + # create just one example with *all* correct answers and *all* answer candidates + guid = f'{set_type}-p{idx}-q{question_idx}' + meta = { + 'passage_idx': idx, + 'question_idx': question_idx, + 'candidates': entities, + 'answers': answers + } + example = InputExample( + guid=guid, + text_a=text, + text_b=question, + label='1', + meta=meta, + idx=question_idx, + num_choices=len(entities)) + examples.append(example) + + question_indices = list( + set(example.meta['question_idx'] for example in examples)) + label_distribution = Counter(example.label for example in examples) + print_rank_0( + f'Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label ' + f'distribution {list(label_distribution.items())}') + return examples + + +class MnliProcessor(DataProcessor): + """Processor for the MultiNLI data set (GLUE version).""" + + def get_train_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'train.tsv'), 'train') + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples( + os.path.join(data_dir, 'dev_matched.tsv'), 'dev_matched') + + def get_test_examples(self, data_dir) -> List[InputExample]: + return self._create_examples( + os.path.join(data_dir, 'test_matched.tsv'), 'test_matched') + + def get_unlabeled_examples(self, data_dir) -> List[InputExample]: + return self.get_train_examples(data_dir) + + def get_labels(self): + return ['contradiction', 'entailment', 'neutral'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + df = read_tsv(path) + + for idx, row in df.iterrows(): + guid = f'{set_type}-{idx}' + text_a = punctuation_standardization(row['sentence1']) + text_b = punctuation_standardization(row['sentence2']) + label = row.get('gold_label', None) + example = InputExample( + guid=guid, text_a=text_a, text_b=text_b, label=label) + examples.append(example) + + return examples + + +class MnliMismatchedProcessor(MnliProcessor): + """Processor for the MultiNLI mismatched data set (GLUE version).""" + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples( + os.path.join(data_dir, 'dev_mismatched.tsv'), 'dev_mismatched') + + def get_test_examples(self, data_dir) -> List[InputExample]: + return self._create_examples( + os.path.join(data_dir, 'test_mismatched.tsv'), 'test_mismatched') + + +class AgnewsProcessor(DataProcessor): + """Processor for the AG news data set.""" + + def get_train_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'train.csv'), 'train') + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples(os.path.join(data_dir, 'test.csv'), 'dev') + + def get_test_examples(self, data_dir) -> List[InputExample]: + raise NotImplementedError() + + def get_unlabeled_examples(self, data_dir) -> List[InputExample]: + return self.get_train_examples(data_dir) + + def get_labels(self): + return ['1', '2', '3', '4'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + + with open(path) as f: + reader = csv.reader(f, delimiter=',') + for idx, row in enumerate(reader): + label, headline, body = row + guid = '%s-%s' % (set_type, idx) + text_a = punctuation_standardization( + headline.replace('\\', ' ')) + text_b = punctuation_standardization(body.replace('\\', ' ')) + + example = InputExample( + guid=guid, text_a=text_a, text_b=text_b, label=label) + examples.append(example) + + return examples + + +class YahooAnswersProcessor(DataProcessor): + """Processor for the Yahoo Answers data set.""" + + def get_train_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'train.csv'), 'train') + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples(os.path.join(data_dir, 'test.csv'), 'dev') + + def get_test_examples(self, data_dir) -> List[InputExample]: + raise NotImplementedError() + + def get_unlabeled_examples(self, data_dir) -> List[InputExample]: + return self.get_train_examples(data_dir) + + def get_labels(self): + return ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + + with open(path, encoding='utf8') as f: + reader = csv.reader(f, delimiter=',') + for idx, row in enumerate(reader): + label, question_title, question_body, answer = row + guid = '%s-%s' % (set_type, idx) + text_a = ' '.join([ + question_title.replace('\\n', ' ').replace('\\', ' '), + question_body.replace('\\n', ' ').replace('\\', ' ') + ]) + text_a = punctuation_standardization(text_a) + text_b = answer.replace('\\n', ' ').replace('\\', ' ') + text_b = punctuation_standardization(text_b) + + example = InputExample( + guid=guid, text_a=text_a, text_b=text_b, label=label) + examples.append(example) + + return examples + + +class YelpPolarityProcessor(DataProcessor): + """Processor for the YELP binary classification set.""" + + def get_train_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'train.csv'), 'train') + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples(os.path.join(data_dir, 'test.csv'), 'dev') + + def get_test_examples(self, data_dir) -> List[InputExample]: + raise NotImplementedError() + + def get_unlabeled_examples(self, data_dir) -> List[InputExample]: + return self.get_train_examples(data_dir) + + def get_labels(self): + return ['1', '2'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + + with open(path) as f: + reader = csv.reader(f, delimiter=',') + for idx, row in enumerate(reader): + label, body = row + guid = '%s-%s' % (set_type, idx) + text_a = body.replace('\\n', ' ').replace('\\', ' ') + text_a = punctuation_standardization(text_a) + + example = InputExample(guid=guid, text_a=text_a, label=label) + examples.append(example) + + return examples + + +class YelpFullProcessor(YelpPolarityProcessor): + """Processor for the YELP full classification set.""" + + def get_test_examples(self, data_dir) -> List[InputExample]: + raise NotImplementedError() + + def get_labels(self): + return ['1', '2', '3', '4', '5'] + + +class XStanceProcessor(DataProcessor): + """Processor for the X-Stance data set.""" + + def __init__(self, args, language: str = None): + super().__init__(args) + if language is not None: + assert language in ['de', 'fr'] + self.language = language + + def get_train_examples(self, data_dir): + return self._create_examples(os.path.join(data_dir, 'train.jsonl')) + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples(os.path.join(data_dir, 'test.jsonl')) + + def get_test_examples(self, data_dir) -> List[InputExample]: + raise NotImplementedError() + + def get_unlabeled_examples(self, data_dir) -> List[InputExample]: + return self.get_train_examples(data_dir) + + def get_labels(self): + return ['FAVOR', 'AGAINST'] + + def _create_examples(self, path: str) -> List[InputExample]: + examples = [] + + with open(path, encoding='utf8') as f: + for line in f: + example_json = json.loads(line) + label = example_json['label'] + id_ = example_json['id'] + text_a = punctuation_standardization(example_json['question']) + text_b = punctuation_standardization(example_json['comment']) + language = example_json['language'] + + if self.language is not None and language != self.language: + continue + + example = InputExample( + guid=id_, text_a=text_a, text_b=text_b, label=label) + examples.append(example) + + return examples + + +class Sst2Processor(DataProcessor): + + def get_train_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'train.tsv'), 'train') + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples(os.path.join(data_dir, 'dev.tsv'), 'dev') + + def get_test_examples(self, data_dir) -> List[InputExample]: + return self._create_examples( + os.path.join(data_dir, 'test.tsv'), 'test') + + def get_labels(self): + return ['0', '1'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + df = read_tsv(path) + + for idx, row in df.iterrows(): + guid = f'{set_type}-{idx}' + text_a = punctuation_standardization(row['sentence']) + label = row.get('label', None) + example = InputExample(guid=guid, text_a=text_a, label=label) + examples.append(example) + + return examples + + +class ColaProcessor(Sst2Processor): + + def get_labels(self): + return ['0', '1'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + if set_type != 'test': + df = read_tsv(path, header=None) + else: + df = read_tsv(path) + + for idx, row in df.iterrows(): + guid = f'{set_type}-{idx}' + if set_type != 'test': + text_a = punctuation_standardization(row[3]) + label = row[1] + else: + text_a = punctuation_standardization(row['sentence']) + label = None + example = InputExample(guid=guid, text_a=text_a, label=label) + examples.append(example) + + return examples + + +class MrpcProcessor(Sst2Processor): + + def get_labels(self): + return ['0', '1'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + df = read_tsv(path) + + for idx, row in df.iterrows(): + guid = f'{set_type}-{idx}' + text_a = punctuation_standardization(row['#1 String']) + text_b = punctuation_standardization(row['#2 String']) + label = row.get('Quality', None) + example = InputExample( + guid=guid, text_a=text_a, text_b=text_b, label=label) + examples.append(example) + + return examples + + +class QqpProcessor(Sst2Processor): + + def get_labels(self): + return ['0', '1'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + df = read_tsv(path) + + for idx, row in df.iterrows(): + guid = f'{set_type}-{idx}' + text_a = punctuation_standardization(row['question1']) + text_b = punctuation_standardization(row['question2']) + label = row.get('is_duplicate', None) + example = InputExample( + guid=guid, text_a=text_a, text_b=text_b, label=label) + examples.append(example) + + return examples + + +class QnliProcessor(Sst2Processor): + + def get_labels(self): + return ['entailment', 'not_entailment'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + df = read_tsv(path) + + for idx, row in df.iterrows(): + guid = f'{set_type}-{idx}' + text_a = punctuation_standardization(row['question']) + text_b = punctuation_standardization(row['sentence']) + label = row.get('label', None) + example = InputExample( + guid=guid, text_a=text_a, text_b=text_b, label=label) + examples.append(example) + + return examples + + +class SquadProcessor(DataProcessor): + + def get_train_examples(self, data_dir): + return self._create_examples( + os.path.join(data_dir, 'train-v2.0.json'), 'train') + + def get_dev_examples(self, data_dir, for_train=False): + return self._create_examples( + os.path.join(data_dir, 'dev-v2.0.json'), 'dev') + + def get_labels(self): + return ['0'] + + @staticmethod + def _create_examples(path: str, set_type: str) -> List[InputExample]: + examples = [] + with open(path) as f: + data = json.load(f)['data'] + + for idx, passage in enumerate(data): + for pid, paragraph in enumerate(passage['paragraphs']): + context = paragraph['context'] + for qid, qas in enumerate(paragraph['qas']): + if len(qas['answers']) == 0: + continue + guid = f'{set_type}-{idx}-{pid}-{qid}' + example = InputExample( + guid=guid, + text_a=context, + text_b=qas['question'], + label='0', + meta={'answer': qas['answers'][0]}) + examples.append(example) + + return examples + + +CLASSIFICATION_DATASETS = {'wic', 'rte', 'cb', 'boolq', 'multirc', 'wsc'} +MULTI_CHOICE_DATASETS = {'copa', 'record'} + +PROCESSORS = { + 'mnli': MnliProcessor, + 'mnli-mm': MnliMismatchedProcessor, + 'agnews': AgnewsProcessor, + 'yahoo': YahooAnswersProcessor, + 'yelp-polarity': YelpPolarityProcessor, + 'yelp-full': YelpFullProcessor, + 'xstance-de': lambda: XStanceProcessor('de'), + 'xstance-fr': lambda: XStanceProcessor('fr'), + 'xstance': XStanceProcessor, + 'wic': WicProcessor, + 'rte': RteProcessor, + 'cb': CbProcessor, + 'wsc': WscProcessor, + 'wsc1': WscProcessor, + 'boolq': BoolQProcessor, + 'copa': CopaProcessor, + 'multirc': MultiRcProcessor, + 'record': RecordProcessor, + 'ax-g': AxGProcessor, + 'ax-b': AxBProcessor, + 'sst2': Sst2Processor, + 'cola': ColaProcessor, + 'mrpc': MrpcProcessor, + 'qqp': QqpProcessor, + 'qnli': QnliProcessor, + 'squad': SquadProcessor, + 'race': RaceProcessor, + 'squad': SquadProcessor +} # type: Dict[str,Callable[[1],DataProcessor]] diff --git a/modelscope/models/nlp/mglm/tasks/superglue/evaluate.py b/modelscope/models/nlp/mglm/tasks/superglue/evaluate.py new file mode 100644 index 00000000..145fb45b --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/superglue/evaluate.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 Zhipu.AI +""" +Official evaluation script for ReCoRD v1.0. +(Some functions are adopted from the SQuAD evaluation script.) +""" + +from __future__ import print_function +import functools +import re +import string +from collections import Counter, defaultdict +from typing import List + +from tasks.data_utils import InputExample + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + if not ground_truths: + return 0.0 + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def qa_evaluate(predictions, labels, examples: List[InputExample], metric): + assert len(examples) == len(predictions) + score = 0.0 + for example, prediction in zip(examples, predictions): + ground_truths = example.meta['answers'] + prediction = example.meta['candidates'][prediction] + if ground_truths: + score += metric_max_over_ground_truths(metric, prediction, + ground_truths) + score = 100.0 * score / len(predictions) + return score + + +def multirc_em(predictions, labels, examples: List[InputExample]): + """Compute the exact match (EM) for a sequence of predictions and actual labels""" + question_ids = [example.meta['question_idx'] for example in examples] + unique_questions = set(question_ids) + + q_actuals = list(zip(question_ids, labels)) + q_predictions = list(zip(question_ids, predictions)) + + actuals_per_question = defaultdict(list) + predictions_per_question = defaultdict(list) + + for qid, val in q_actuals: + actuals_per_question[qid].append(val) + for qid, val in q_predictions: + predictions_per_question[qid].append(val) + + em = 0 + for qid in unique_questions: + if actuals_per_question[qid] == predictions_per_question[qid]: + em += 1 + em /= len(unique_questions) + return em + + +qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score) +qa_f1 = functools.partial(qa_evaluate, metric=f1_score) diff --git a/modelscope/models/nlp/mglm/tasks/superglue/finetune.py b/modelscope/models/nlp/mglm/tasks/superglue/finetune.py new file mode 100644 index 00000000..371705ff --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/superglue/finetune.py @@ -0,0 +1,138 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Race.""" + +from collections import OrderedDict + +from finetune_glm import finetune +from tasks.eval_utils import (accuracy_func_provider, accuracy_metric, + f1_macro_metric, f1_metric) +from tasks.superglue.dataset import (CLASSIFICATION_DATASETS, + MULTI_CHOICE_DATASETS, PROCESSORS, + SuperGlueDataset, get_output_func) +from tasks.superglue.evaluate import multirc_em, qa_exact_match, qa_f1 +from tasks.superglue.pvp import PVPS + +DEFAULT_METRICS = { + 'record': [('EM', qa_exact_match), ('F1', qa_f1)], + 'copa': [('accuracy', accuracy_metric)], + 'rte': [('accuracy', accuracy_metric)], + 'boolq': [('accuracy', accuracy_metric)], + 'wic': [('accuracy', accuracy_metric)], + 'wsc': [('accuracy', accuracy_metric)], + 'cb': [('accuracy', accuracy_metric), ('f1-macro', f1_macro_metric)], + 'multirc': [('f1a', f1_metric), ('em', multirc_em), + ('acc', accuracy_metric)], + 'mnli': [('accuracy', accuracy_metric)], + 'sst2': [('accuracy', accuracy_metric)], + 'qnli': [('accuracy', accuracy_metric)], + 'qqp': [('accuracy', accuracy_metric)], + 'mrpc': [('accuracy', accuracy_metric)], + 'cola': [('accuracy', accuracy_metric)], + 'squad': [('accuracy', accuracy_metric)], +} + + +def train_valid_datasets_provider(args, tokenizer, pattern_text=False): + """Provide train and validation datasets.""" + task_name = args.task.lower() + data_dir = args.data_dir + train_dataset = SuperGlueDataset( + args, + task_name, + data_dir, + args.seq_length, + 'train', + tokenizer, + pattern_text=pattern_text) + valid_dataset = SuperGlueDataset( + args, + task_name, + data_dir, + args.seq_length, + 'dev', + tokenizer, + for_train=True, + pattern_text=pattern_text) + + return train_dataset, valid_dataset + + +def metrics_func_provider(args, tokenizer, is_test): + """Privde metrics callback function.""" + + def single_dataset_provider(split): + return SuperGlueDataset(args, args.task.lower(), args.data_dir, + args.seq_length, split, tokenizer) + + output_func = get_output_func(args.task.lower(), args) + eval_func = None + if args.task.lower() in ['wsc', 'squad' + ] and args.cloze_eval and not args.wsc_negative: + from tasks.language_model.finetune import classify_evaluate + eval_func = classify_evaluate + metric_dict = OrderedDict(DEFAULT_METRICS[args.task.lower()]) + return accuracy_func_provider( + single_dataset_provider, + metric_dict, + args, + is_test=is_test, + eval_func=eval_func, + output_func=output_func, + only_rank0=False, + tokenizer=tokenizer) + + +def main(args): + model_kwargs = {} + processor = PROCESSORS[args.task.lower()](args) + pvp = PVPS[args.task.lower()]( + args, + None, + processor.get_labels(), + args.seq_length, + pattern_id=args.pattern_id, + is_multi_token=args.multi_token, + num_prompt_tokens=args.num_prompt_tokens) + if args.continuous_prompt: + model_kwargs['spell_length'] = pvp.spell_length + if args.task.lower() in ['wsc', 'squad' + ] and args.cloze_eval and not args.wsc_negative: + from tasks.language_model.finetune import lm_forward_step + finetune( + args, + train_valid_datasets_provider, + model_kwargs, + end_of_epoch_callback_provider=metrics_func_provider, + forward_step=lm_forward_step) + else: + if args.cloze_eval: + multi_token = pvp.is_multi_token + else: + multi_token = args.task.lower() in MULTI_CHOICE_DATASETS + args.multi_token = multi_token + if not multi_token: + model_kwargs[ + 'model_type'] = 'multiple_choice' if args.cloze_eval else 'classification' + model_kwargs['multi_token'] = False + model_kwargs['num_labels'] = len(processor.get_labels()) + else: + model_kwargs['model_type'] = 'multiple_choice' + model_kwargs['multi_token'] = True + model_kwargs['num_labels'] = 1 + finetune( + args, + train_valid_datasets_provider, + model_kwargs, + end_of_epoch_callback_provider=metrics_func_provider) diff --git a/modelscope/models/nlp/mglm/tasks/superglue/pvp.py b/modelscope/models/nlp/mglm/tasks/superglue/pvp.py new file mode 100644 index 00000000..ff394172 --- /dev/null +++ b/modelscope/models/nlp/mglm/tasks/superglue/pvp.py @@ -0,0 +1,1541 @@ +# Copyright (c) 2022 Zhipu.AI +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains the pattern-verbalizer pairs (PVPs) for all tasks. +""" +import copy +import math +import random +import string +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Dict, List, Tuple, Union + +import numpy as np +from tasks.data_utils import (InputExample, build_decoder_input, + build_decoder_sample, build_input_from_ids, + build_sample, num_special_tokens_to_add) +from utils import print_rank_0 + +FilledPattern = Tuple[List[Union[str, Tuple[str, bool]]], + List[Union[str, Tuple[str, bool]]]] + + +class PVP(ABC): + """ + This class contains functions to apply patterns and verbalizers as required by PET. Each task requires its own + custom implementation of a PVP. + """ + + def __init__(self, + args, + tokenizer, + label_list, + max_seq_length, + pattern_id: int = 0, + verbalizer_file: str = None, + seed: int = 42, + is_multi_token=False, + max_segment_length=0, + fast_decode: bool = False, + split='train', + num_prompt_tokens=0): + """ + Create a new PVP. + + :param args: the args + :param tokenizer: the tokenizer + :param label_list: the list of labels + :param max_seq_length: the maximum length of the sequence + :param pattern_id: the pattern id to use + :param seed: a seed to be used for generating random numbers if necessary + :param is_multi_token: if the verbalizers contain multiple tokens + :param fast_decode: whether to use the fast decode mode for multi-token tasks + :param continuous_prompt: whether to use continuous prompt optimization + """ + self.args = args + self.tokenizer = tokenizer + self.label_list = label_list + self.max_seq_length = max_seq_length + self.pattern_id = pattern_id + self.num_prompt_tokens = num_prompt_tokens + self.rng = random.Random(seed) + self.num_truncated = 0 + self.fast_decode = fast_decode + self.split = split + self.max_dec_seq_length = 16 + self._is_multi_token = is_multi_token + self.max_segment_length = max_segment_length + self.task_mask = args.task_mask + self.continuous_prompt = args.continuous_prompt + self.prefix_prompt = args.prefix_prompt + if self.continuous_prompt: + print_rank_0( + f'Prompt tokens in pvp {self.num_prompt_tokens} spell length {self.spell_length}' + ) + + if verbalizer_file: + self.verbalize = PVP._load_verbalizer_from_file( + verbalizer_file, self.pattern_id) + + @property + def is_multi_token(self): + return self._is_multi_token + + @property + def spell_length(self): + return 0 + + @property + def mask(self) -> str: + """Return the underlying LM's mask token""" + return self.tokenizer.get_command('MASK').Id + + @property + def mask_id(self) -> int: + """Return the underlying LM's mask id""" + return self.tokenizer.get_command('MASK').Id + + @property + def max_num_verbalizers(self) -> int: + """Return the maximum number of verbalizers across all labels""" + return max(len(self.verbalize(label)) for label in self.label_list) + + @staticmethod + def shortenable(s): + """Return an instance of this string that is marked as shortenable""" + return s, True + + @staticmethod + def remove_final_punc(s: Union[str, Tuple[str, bool]]): + """Remove the final punctuation mark""" + if isinstance(s, tuple): + return PVP.remove_final_punc(s[0]), s[1] + return s.rstrip(string.punctuation) + + @staticmethod + def lowercase_first(s: Union[str, Tuple[str, bool]]): + """Lowercase the first character""" + if isinstance(s, tuple): + return PVP.lowercase_first(s[0]), s[1] + return s[0].lower() + s[1:] + + @staticmethod + def uppercase_first(s: Union[str, Tuple[str, bool]]): + """Lowercase the first character""" + if isinstance(s, tuple): + return PVP.uppercase_first(s[0]), s[1] + return s[0].upper() + s[1:] + + @staticmethod + def available_patterns(): + return [0] + + def replace_prompt_tokens(self, parts_a, parts_b): + if not self.continuous_prompt: + parts_a = [part for part in parts_a if part is not None] + parts_b = [part for part in parts_b if part is not None] + return parts_a, parts_b + num_prompt_tokens = self.num_prompt_tokens + num_pos = 0 + for parts in (parts_a, parts_b): + for part in parts: + if part is None: + num_pos += 1 + avg_prompt_tokens = math.ceil(num_prompt_tokens / num_pos) + new_parts_a, new_parts_b = [], [] + for part in parts_a: + if part is None: + if num_prompt_tokens > 0: + if num_prompt_tokens >= avg_prompt_tokens: + new_parts_a.append(avg_prompt_tokens) + num_prompt_tokens -= avg_prompt_tokens + else: + new_parts_a.append(num_prompt_tokens) + num_prompt_tokens = 0 + else: + new_parts_a.append(part) + for part in parts_b: + if part is None: + if num_prompt_tokens > 0: + if num_prompt_tokens >= avg_prompt_tokens: + new_parts_b.append(avg_prompt_tokens) + num_prompt_tokens -= avg_prompt_tokens + else: + new_parts_b.append(num_prompt_tokens) + num_prompt_tokens = 0 + else: + new_parts_b.append(part) + return new_parts_a, new_parts_b + + def encode(self, + example: InputExample, + priming: bool = False, + labeled: bool = False): + """ + Encode an input example using this pattern-verbalizer pair. + + :param example: the input example to encode + :param priming: whether to use this example for priming + :param labeled: if ``priming=True``, whether the label should be appended to this example + :return: A tuple, consisting of a list of input ids and a list of token type ids + """ + + if not priming: + assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true" + + tokenizer = self.tokenizer + raw_parts_a, raw_parts_b = self.get_parts(example) + + raw_parts_a = [ + x if isinstance(x, tuple) else (x, False) for x in raw_parts_a + ] + prompt_id = tokenizer.num_tokens + + def encode_input(raw_parts): + parts = [] + for x, s in raw_parts: + if isinstance(x, str): + x = tokenizer.EncodeAsIds(x) + elif isinstance(x, int): + x = [prompt_id] * x + else: + pass + parts.append((x, s)) + return parts + + parts_a = encode_input(raw_parts_a) + if self.prefix_prompt > 0: + parts_a = [([prompt_id] * self.prefix_prompt, False)] + parts_a + + parts_b = None + if raw_parts_b: + raw_parts_b = [ + x if isinstance(x, tuple) else (x, False) for x in raw_parts_b + ] + parts_b = encode_input(raw_parts_b) + + if self.is_multi_token: + answers = self.get_answers(example) + if example.label is not None: + label = self.label_list.index(example.label) + else: + label = 0 + + if not self.fast_decode: + ids_list, positions_list, sep_list, mask_list, target_list, prompt_list = [], [], [], [], [], [] + segment_id_list = [] + if priming: + answer = answers[label] + answer_ids = get_verbalization_ids( + answer, tokenizer, force_single_token=False) + self.num_truncated += self.truncate( + parts_a, + parts_b, + answer_ids, + max_length=self.max_seq_length) + tokens_a = [ + token_id for part, _ in parts_a for token_id in part + ] + tokens_b = [ + token_id for part, _ in parts_b for token_id in part + ] if parts_b else None + input_ids = tokens_a + if tokens_b: + input_ids += tokens_b + if labeled: + mask_idx = input_ids.index(self.mask_id) + input_ids = input_ids[: + mask_idx] + answer_ids + input_ids[ + mask_idx + 1:] + return input_ids + else: + for idx, answer in enumerate(answers): + this_parts_a, this_parts_b = copy.deepcopy( + parts_a), copy.deepcopy(parts_b) + answer_ids = get_verbalization_ids( + answer, tokenizer, force_single_token=False) + answer_ids = answer_ids + [ + tokenizer.get_command('eop').Id + ] + self.num_truncated += self.truncate( + this_parts_a, + this_parts_b, + answer_ids, + max_length=self.max_seq_length) + tokens_a = [ + token_id for part, _ in this_parts_a + for token_id in part + ] + tokens_b = [ + token_id for part, _ in this_parts_b + for token_id in part + ] if parts_b else None + if self.max_segment_length > 0: + num_segments = (len(answer_ids) + - 1) // self.max_segment_length + 1 + segments = [ + answer_ids[index + * self.max_segment_length:(index + + 1) + * self.max_segment_length] + for index in range(num_segments) + ] + segment_id_list += [idx] * len(segments) + else: + segments = [answer_ids] + for segment in segments: + data = build_input_from_ids( + tokens_a, + tokens_b, + segment, + self.max_seq_length, + self.tokenizer, + args=self.args, + add_cls=True, + add_sep=False, + add_piece=True, + mask_id=self.mask_id) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + prompt_pos = [ + idx for idx, token in enumerate(ids) + if token == prompt_id + ] + ids = [ + idx if idx != prompt_id else 0 for idx in ids + ] + prompt_list.append(prompt_pos) + ids_list.append(ids) + positions_list.append(position_ids) + sep_list.append(sep) + target_list.append(target_ids) + mask_list.append(loss_masks) + if self.mask in tokens_a: + mask_pos = tokens_a.index(self.mask) + tokens_a = tokens_a[: + mask_pos] + segment + tokens_a[ + mask_pos:] + else: + mask_pos = tokens_b.index(self.mask) + tokens_b = tokens_b[: + mask_pos] + segment + tokens_b[ + mask_pos:] + segment_id_list = segment_id_list if segment_id_list else None + sample = build_sample( + ids_list, + positions=positions_list, + masks=sep_list, + label=label, + logit_mask=mask_list, + target=target_list, + unique_id=example.guid, + segment_ids=segment_id_list, + prompt_ids=prompt_list) + return sample + else: + this_parts_a, this_parts_b = copy.deepcopy( + parts_a), copy.deepcopy(parts_b) + self.num_truncated += self.truncate( + this_parts_a, + this_parts_b, + None, + max_length=self.max_seq_length) + tokens_a = [ + token_id for part, _ in this_parts_a for token_id in part + ] + tokens_b = [ + token_id for part, _ in this_parts_b for token_id in part + ] if parts_b else None + data = build_input_from_ids( + tokens_a, + tokens_b, + None, + self.max_seq_length, + self.tokenizer, + args=self.args, + add_cls=True, + add_sep=False, + add_piece=False) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + sample = build_sample( + ids, + positions=position_ids, + masks=sep, + label=label, + unique_id=example.guid) + + ids_list, positions_list, mask_list, target_list, logit_mask_list = [], [], [], [], [] + for answer in answers: + answer_ids = get_verbalization_ids( + answer, tokenizer, force_single_token=False) + answer_ids = answer_ids + [tokenizer.get_command('eop').Id] + answer_ids = answer_ids[:self.max_dec_seq_length] + data = build_decoder_input(ids, answer_ids, + self.max_seq_length, + self.max_dec_seq_length, + tokenizer) + dec_ids, _, _, dec_position_ids, _, dec_target_ids, dec_loss_masks = data + ids_list.append(dec_ids) + positions_list.append(dec_position_ids) + mask_list.append(sep) + target_list.append(dec_target_ids) + logit_mask_list.append(dec_loss_masks) + + sample = build_decoder_sample(sample, ids_list, positions_list, + mask_list, target_list, + logit_mask_list) + return sample + + else: + self.num_truncated += self.truncate( + parts_a, parts_b, [], max_length=self.max_seq_length) + + tokens_a = [token_id for part, _ in parts_a for token_id in part] + tokens_b = [token_id for part, _ in parts_b + for token_id in part] if parts_b else None + if priming: + input_ids = tokens_a + if tokens_b: + input_ids += tokens_b + if labeled: + mask_idx = input_ids.index(self.mask_id) + verbalizer = self.verbalize(example.label) + assert len( + verbalizer + ) == 1, 'priming only supports one verbalization per label' + verbalizer = verbalizer[0] + verbalizer_id = get_verbalization_ids( + verbalizer, self.tokenizer, force_single_token=True) + input_ids[mask_idx] = verbalizer_id + return input_ids + data = build_input_from_ids( + tokens_a, + tokens_b, + None, + self.max_seq_length, + self.tokenizer, + args=self.args, + add_cls=True, + add_sep=False, + add_piece=True) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + prompt_pos = [ + idx for idx, token in enumerate(ids) if token == prompt_id + ] + ids = [token if token != prompt_id else 0 for token in ids] + target_ids = self.get_verbalizer_ids() + if example.label is not None: + label = self.label_list.index(example.label) + else: + label = 0 + sample = build_sample( + ids=ids, + positions=position_ids, + target=target_ids, + masks=sep, + logit_mask=loss_masks, + label=label, + unique_id=example.guid, + prompt_ids=prompt_pos) + return sample + + @staticmethod + def _seq_length(parts: List[Tuple[List[int], bool]], + only_shortenable: bool = False): + return sum([ + len(x) for x, shortenable in parts + if not only_shortenable or shortenable + ]) if parts else 0 + + @staticmethod + def _remove_last(parts: List[Tuple[List[int], bool]]): + last_idx = max(idx for idx, (seq, shortenable) in enumerate(parts) + if shortenable and seq) + parts[last_idx] = (parts[last_idx][0][:-1], parts[last_idx][1]) + + def truncate(self, parts_a: List[Tuple[List[int], bool]], + parts_b: List[Tuple[List[int], bool]], answer: List[int], + max_length: int): + """Truncate two sequences of text to a predefined total maximum length""" + total_len = self._seq_length(parts_a) + self._seq_length(parts_b) + if answer: + total_len += len(answer) + total_len += num_special_tokens_to_add( + parts_a, + parts_b, + answer, + add_cls=True, + add_sep=False, + add_piece=True) + num_tokens_to_remove = total_len - max_length + + if num_tokens_to_remove <= 0: + return False + + for _ in range(num_tokens_to_remove): + if self._seq_length( + parts_a, only_shortenable=True) > self._seq_length( + parts_b, only_shortenable=True): + self._remove_last(parts_a) + else: + self._remove_last(parts_b) + return True + + @abstractmethod + def get_parts(self, example: InputExample) -> FilledPattern: + """ + Given an input example, apply a pattern to obtain two text sequences (text_a and text_b) containing exactly one + mask token (or one consecutive sequence of mask tokens for PET with multiple masks). If a task requires only a + single sequence of text, the second sequence should be an empty list. + + :param example: the input example to process + :return: Two sequences of text. All text segments can optionally be marked as being shortenable. + """ + pass + + def get_answers(self, example: InputExample): + return [self.verbalize(label)[0] for label in self.label_list] + + def get_verbalizer_ids(self): + target_ids = [] + for label in self.label_list: + verbalizer = self.verbalize(label)[0] + verbalizer_id = get_verbalization_ids( + verbalizer, self.tokenizer, force_single_token=True) + target_ids.append(verbalizer_id) + return target_ids + + @abstractmethod + def verbalize(self, label) -> List[str]: + """ + Return all verbalizations for a given label. + + :param label: the label + :return: the list of verbalizations + """ + pass + + def get_mask_positions(self, input_ids: List[int]) -> List[int]: + label_idx = input_ids.index(self.mask_id) + labels = [-1] * len(input_ids) + labels[label_idx] = 1 + return labels + + @staticmethod + def _load_verbalizer_from_file(path: str, pattern_id: int): + + verbalizers = defaultdict( + dict) # type: Dict[int, Dict[str, List[str]]] + current_pattern_id = None + + with open(path, 'r') as fh: + for line in fh.read().splitlines(): + if line.isdigit(): + current_pattern_id = int(line) + elif line: + label, *realizations = line.split() + verbalizers[current_pattern_id][label] = realizations + + print_rank_0( + 'Automatically loaded the following verbalizer: \n {}'.format( + verbalizers[pattern_id])) + + def verbalize(label) -> List[str]: + return verbalizers[pattern_id][label] + + return verbalize + + +class CopaPVP(PVP): + + @staticmethod + def available_patterns(): + return [0, 1] + + @property + def is_multi_token(self): + return True + + @property + def spell_length(self): + return self.num_prompt_tokens + self.prefix_prompt + + @property + def mask(self) -> str: + """Return the underlying LM's mask token""" + mask_token = 'MASK' + return self.tokenizer.get_command(mask_token).Id + + @property + def mask_id(self) -> int: + """Return the underlying LM's mask id""" + mask_token = 'MASK' + return self.tokenizer.get_command(mask_token).Id + + def get_answers(self, example: InputExample): + choice1 = ' ' + self.remove_final_punc( + self.lowercase_first(example.meta['choice1'])) + choice2 = ' ' + self.remove_final_punc( + self.lowercase_first(example.meta['choice2'])) + return [choice1, choice2] + + def get_parts(self, example: InputExample) -> FilledPattern: + assert self.pattern_id in [0, 1, 2, 3] + premise = self.remove_final_punc( + self.shortenable(' ' + example.text_a)) + choice1 = self.remove_final_punc( + self.lowercase_first(example.meta['choice1'])) + choice2 = self.remove_final_punc( + self.lowercase_first(example.meta['choice2'])) + + question = example.meta['question'] + assert question in ['cause', 'effect'] + if question == 'cause': + joiner = ' because' + else: + joiner = ', so' + if self.pattern_id == 0: + parts_a, parts_b = [ + None, '"', choice1, '" or "', choice2, '"?', None, premise, + joiner, None, [self.mask], '.' + ], [] + elif self.pattern_id == 1: + parts_a, parts_b = [ + None, choice1, ' or', ' ' + choice2, '?', None, premise, + joiner, None, [self.mask], '.' + ], [] + elif self.pattern_id == 2: + parts_a, parts_b = [ + None, '"', choice1, '" or "', choice2, '"', None, premise, + joiner, [self.mask], '.', None + ], [] + else: + raise NotImplementedError(self.pattern_id) + parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b) + return parts_a, parts_b + + def verbalize(self, label) -> List[str]: + return [] + + def encode(self, + example: InputExample, + priming: bool = False, + labeled: bool = False): + """ + Encode an input example using this pattern-verbalizer pair. + + :param example: the input example to encode + :param priming: whether to use this example for priming + :param labeled: if ``priming=True``, whether the label should be appended to this example + :return: A tuple, consisting of a list of input ids and a list of token type ids + """ + if self.continuous_prompt or self.pattern_id < 2: + return super().encode(example, priming=priming, labeled=labeled) + if not priming: + assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true" + + tokenizer = self.tokenizer + premise = self.remove_final_punc(self.shortenable(example.text_a)) + choice1 = ' ' + self.remove_final_punc( + self.lowercase_first(example.meta['choice1'])) + choice2 = ' ' + self.remove_final_punc( + self.lowercase_first(example.meta['choice2'])) + question = example.meta['question'] + assert question in ['cause', 'effect'] + answer = ' because' if question == 'cause' else ' so' + answer_ids = [ + get_verbalization_ids(answer, tokenizer, force_single_token=True) + ] + if self.is_multi_token: + answer_ids.append(tokenizer.get_command('eop').Id) + + ids_list, positions_list, sep_list, mask_list, target_list = [], [], [], [], [] + + for choice in [choice1, choice2]: + parts = [ + '"', choice1[1:], '" or "', choice2[1:], '"?', premise, + [self.mask], choice + ] + parts = [x if isinstance(x, tuple) else (x, False) for x in parts] + parts = [(tokenizer.EncodeAsIds(x).tokenization if isinstance( + x, str) else x, s) for x, s in parts if x] + self.num_truncated += self.truncate( + parts, None, answer_ids, max_length=self.max_seq_length) + tokens_a = [token_id for part, _ in parts for token_id in part] + data = build_input_from_ids( + tokens_a, + None, + answer_ids, + self.max_seq_length, + self.tokenizer, + args=self.args, + add_cls=True, + add_sep=False, + add_piece=True) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + ids_list.append(ids) + positions_list.append(position_ids) + sep_list.append(sep) + target_list.append(target_ids) + mask_list.append(loss_masks) + if example.label is not None: + label = self.label_list.index(example.label) + else: + label = 0 + sample = build_sample( + ids_list, + positions=positions_list, + masks=sep_list, + label=label, + logit_mask=mask_list, + target=target_list, + unique_id=example.guid) + return sample + + +class WscPVP(PVP): + + @staticmethod + def available_patterns(): + return [0, 1, 2] + + @property + def is_multi_token(self): + return True + + @property + def spell_length(self): + return self.num_prompt_tokens + self.prefix_prompt + + def get_answers(self, example: InputExample): + target = ' ' + example.meta['span1_text'] + answers = [target] + if 'candidates' in example.meta: + candidates = example.meta['candidates'] + # if len(candidates) > 10: + # random.shuffle(candidates) + # candidates = candidates[:10] + answers += [' ' + cand for cand in candidates] + return answers + + def get_parts(self, example: InputExample) -> FilledPattern: + pronoun = example.meta['span2_text'] + pronoun_idx = example.meta['span2_index'] + + words_a = example.text_a.split() + words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*' + text_a = ' '.join(words_a) + text_a = self.shortenable(text_a) + + if self.pattern_id == 0: + parts_a, parts_b = [ + None, text_a, + None, " The pronoun '*" + pronoun + "*' refers to", None, + [self.mask], '.' + ], [] + elif self.pattern_id == 1: + parts_a, parts_b = [ + None, text_a, None, " In the previous sentence, the pronoun '*" + + pronoun + "*' refers to", None, [self.mask], '.' + ], [] + elif self.pattern_id == 2: + parts_a, parts_b = [ + None, text_a, None, + " Question: In the passage above, what does the pronoun '*" + + pronoun + "*' refer to?", None, ' Answer:', [self.mask], '.' + ], [] + else: + raise NotImplementedError(self.pattern_id) + parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b) + return parts_a, parts_b + + def encode(self, + example: InputExample, + priming: bool = False, + labeled: bool = False): + """ + Encode an input example using this pattern-verbalizer pair. + + :param example: the input example to encode + :param priming: whether to use this example for priming + :param labeled: if ``priming=True``, whether the label should be appended to this example + :return: A tuple, consisting of a list of input ids and a list of token type ids + """ + if self.args.loss_func in ['generative', 'mix']: + sample = super().encode(example, priming=priming, labeled=labeled) + if self.split == 'train': + sample['label'] = 0 + return sample + + if not priming: + assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true" + + tokenizer = self.tokenizer + prompt_id = tokenizer.num_tokens + raw_parts_a, raw_parts_b = self.get_parts(example) + + raw_parts_a = [ + x if isinstance(x, tuple) else (x, False) for x in raw_parts_a + ] + + def encode_input(raw_parts): + parts = [] + for x, s in raw_parts: + if isinstance(x, str): + x = tokenizer.EncodeAsIds(x) + elif isinstance(x, int): + x = [prompt_id] * x + else: + pass + parts.append((x, s)) + return parts + + parts_a = encode_input(raw_parts_a) + if self.prefix_prompt > 0: + parts_a = [([prompt_id] * self.prefix_prompt, False)] + parts_a + parts_b = None + if raw_parts_b: + raw_parts_b = [ + x if isinstance(x, tuple) else (x, False) for x in raw_parts_b + ] + parts_b = encode_input(raw_parts_b) + answer = self.get_answers(example)[0] + answer_ids = get_verbalization_ids( + answer, tokenizer, force_single_token=False) + answer_ids = answer_ids + [tokenizer.get_command('eop').Id] + self.num_truncated += self.truncate( + parts_a, parts_b, answer_ids, max_length=self.max_seq_length) + tokens_a = [token_id for part, _ in parts_a for token_id in part] + tokens_b = [token_id for part, _ in parts_b + for token_id in part] if parts_b else None + data = build_input_from_ids( + tokens_a, + tokens_b, + answer_ids, + self.max_seq_length, + self.tokenizer, + args=self.args, + add_cls=True, + add_sep=False, + add_piece=True) + ids, types, paddings, position_ids, sep, target_ids, loss_masks = data + prompt_pos = [ + idx for idx, token in enumerate(ids) if token == prompt_id + ] + ids = [token if token != prompt_id else 0 for token in ids] + if example.label is not None: + label = self.label_list.index(example.label) + else: + label = 0 + return { + 'text': np.array(ids, dtype=np.int64), + 'target': np.array(target_ids, dtype=np.int64), + 'attention_mask': np.array(sep, dtype=np.int64), + 'loss_mask': np.array(loss_masks, dtype=np.int64), + 'position_id': np.array(position_ids, dtype=np.int64), + 'prompt_pos': np.array(prompt_pos, dtype=np.int64), + 'label': label, + 'uid': example.guid + } + + def verbalize(self, label) -> List[str]: + return [] + + +class RecordPVP(PVP): + + @property + def is_multi_token(self): + return True + + def get_answers(self, example: InputExample): + choices = example.meta['candidates'] + choices = [' ' + choice for choice in choices] + return choices + + def get_parts(self, example: InputExample) -> FilledPattern: + premise = self.shortenable(example.text_a) + + assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token' + question_a, question_b = example.text_b.split('@placeholder') + return [premise, ' ' + question_a.rstrip(), [self.mask], + question_b], [] + + def verbalize(self, label) -> List[str]: + return [] + + +class RacePVP(PVP): + + @property + def is_multi_token(self): + return True + + @staticmethod + def available_patterns(): + return [0, 1] + + def get_answers(self, example: InputExample): + choices = example.meta['choices'] + choices = [' ' + choice for choice in choices] + return choices + + def get_parts(self, example: InputExample) -> FilledPattern: + context = self.shortenable(example.text_a) + question = ' ' + example.text_b + + if '_' in question: + left, right = question.split('_', maxsplit=1) + if self.pattern_id == 0: + return [context], [ + self.shortenable(left.rstrip()), [self.mask], + self.shortenable(right) + ] + else: + left = left.rstrip() + if left: + left = self.lowercase_first(left) + return [context], [ + ' Based on the previous passage,', + self.shortenable(left), [self.mask], + self.shortenable(right) + ] + else: + if self.pattern_id == 0: + return [context], [ + ' Question:', + self.shortenable(question), ' Answer:', [self.mask] + ] + else: + return [context], [ + ' Based on the previous passage,', + self.shortenable(question), [self.mask] + ] + + def verbalize(self, label) -> List[str]: + return [] + + +class RtePVP(PVP): + VERBALIZER = {'not_entailment': [' No'], 'entailment': [' Yes']} + + @staticmethod + def available_patterns(): + return [0, 1, 2, 3, 4] + + @property + def spell_length(self): + return self.num_prompt_tokens + self.prefix_prompt + + def get_parts(self, example: InputExample) -> FilledPattern: + # switch text_a and text_b to get the correct order + text_a = example.text_a + text_b = example.text_b.rstrip(string.punctuation) + if self.pattern_id == 0: + parts_a, parts_b = [None, '"', + self.shortenable(text_b), '" ?'], [ + None, [self.mask], ',', None, ' "', + self.shortenable(text_a), '"' + ] # noqa + elif self.pattern_id == 1: + parts_a, parts_b = [None, self.shortenable(text_b), '?'], [ + None, [self.mask], ',', None, + self.shortenable(' ' + text_a) + ] + elif self.pattern_id == 2: + parts_a, parts_b = [None, '"', + self.shortenable(text_b), '" ?'], [ + None, [self.mask], '. "', None, + self.shortenable(text_a), '"' + ] # noqa + elif self.pattern_id == 3: + parts_a, parts_b = [None, self.shortenable(text_b), '?'], [ + None, [self.mask], '.', None, + self.shortenable(' ' + text_a) + ] + elif self.pattern_id == 4: + parts_a, parts_b = [ + None, + self.shortenable(text_a), None, ' question:', + self.shortenable(' ' + text_b), ' True or False?', None, + ' answer:', [self.mask] + ], [] + else: + raise NotImplementedError(self.pattern_id) + parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b) + return parts_a, parts_b + + def verbalize(self, label) -> List[str]: + if self.pattern_id == 4: + return [' true'] if label == 'entailment' else [' false'] + return RtePVP.VERBALIZER[label] + + +class CbPVP(RtePVP): + VERBALIZER = { + 'contradiction': [' No'], + 'entailment': [' Yes'], + 'neutral': [' Maybe'] + } + + @staticmethod + def available_patterns(): + return [0, 1, 2, 3, 4] + + def get_parts(self, example: InputExample) -> FilledPattern: + if self.pattern_id == 4: + text_a = self.shortenable(example.text_a) + text_b = self.shortenable(' ' + example.text_b) + parts_a, parts_b = [ + None, text_a, None, ' question:', text_b, + ' true, false or neither?', None, ' answer:', [self.mask] + ], [] + parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b) + return parts_a, parts_b + return super().get_parts(example) + + def verbalize(self, label) -> List[str]: + if self.pattern_id == 4: + return [' true'] if label == 'entailment' else [ + ' false' + ] if label == 'contradiction' else [' neither'] + return CbPVP.VERBALIZER[label] + + +class BoolQPVP(PVP): + VERBALIZER_A = {'false': [' No'], 'true': [' Yes']} + + VERBALIZER_B = {'false': [' false'], 'true': [' true']} + + @staticmethod + def available_patterns(): + return [0, 1, 2, 3, 4, 5] + + @property + def spell_length(self): + return self.num_prompt_tokens + self.prefix_prompt + + def get_parts(self, example: InputExample) -> FilledPattern: + passage = example.text_a + question = example.text_b + + if self.pattern_id < 2: + parts_a, parts_b = [ + None, + self.shortenable(passage), None, ' Question:', + self.shortenable(' ' + question), '? Answer:', None, + [self.mask], '.' + ], [] + elif self.pattern_id < 4: + parts_a, parts_b = [ + None, + self.shortenable(passage), ' Based on the previous passage,', + None, + self.shortenable(' ' + question), '?', None, [self.mask], '.' + ], [] + elif self.pattern_id < 6: + parts_a, parts_b = [ + 'Based on the following passage', None, + self.shortenable(' ' + question), '?', None, [self.mask], '.', + None, + self.shortenable(' ' + passage) + ], [] + else: + raise NotImplementedError(self.pattern_id) + parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b) + return parts_a, parts_b + + def verbalize(self, label) -> List[str]: + if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4: + return BoolQPVP.VERBALIZER_A[label] + else: + return BoolQPVP.VERBALIZER_B[label] + + +class MultiRcPVP(PVP): + VERBALIZER = {0: [' No'], 1: [' Yes']} + + @staticmethod + def available_patterns(): + return [0, 1, 2, 3, 4] + + @property + def spell_length(self): + return self.num_prompt_tokens + self.prefix_prompt + + def get_parts(self, example: InputExample) -> FilledPattern: + passage = self.remove_final_punc( + self.shortenable(example.text_a.rstrip())) + question = self.remove_final_punc(example.text_b.rstrip()) + answer = example.meta['answer'] + if self.pattern_id == 0: + parts_a, parts_b = [ + passage, '.', None, ' Question:', ' ' + question + '?', None, + ' Is it', ' ' + answer, '?', None, [self.mask], '.' + ], [] + elif self.pattern_id == 1: + parts_a, parts_b = [ + passage, '.', None, ' Question:', ' ' + question, '?', + None, ' Is the correct answer "', answer, '"?', None, + [self.mask], '.' + ], [] + elif self.pattern_id == 2: + parts_a, parts_b = [ + passage, '. Based on the previous passage,', None, + ' ' + question, '?', None, ' Is "', answer, + '" a correct answer?', None, [self.mask], '.' + ], [] + elif self.pattern_id == 3: + parts_a, parts_b = [ + None, passage, None, ' ' + question, '- [', [self.mask], ']', + None, answer + ], [] + elif self.pattern_id == 4: + parts_a, parts_b = [ + passage, '.', None, ' Question:', ' ' + question, '?', None, + ' ' + answer, '?', None, [self.mask], '.' + ], [] + else: + raise NotImplementedError(self.pattern_id) + parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b) + return parts_a, parts_b + + def verbalize(self, label) -> List[str]: + if self.pattern_id == 3: + return [' False'] if label == 0 else [' True'] + return MultiRcPVP.VERBALIZER[label] + + +class WicPVP(PVP): + VERBALIZER_A = {'false': [' No'], 'true': [' Yes']} + VERBALIZER_B = {'false': ['2'], 'true': ['b']} + + @staticmethod + def available_patterns(): + return [0, 1, 2] + + @property + def spell_length(self): + return self.num_prompt_tokens + self.prefix_prompt + + def get_parts(self, example: InputExample) -> FilledPattern: + text_a = example.text_a + text_b = example.text_b + word = example.meta['word'] + + if self.pattern_id == 0: + parts_a, parts_b = [ + None, + self.shortenable('"' + text_a + '" / "' + text_b + '"'), None, + ' Similar sense of "' + word + '"?', None, [self.mask], '.' + ], [] + elif self.pattern_id == 1: + parts_a, parts_b = [ + self.shortenable(text_a), None, + self.shortenable(' ' + text_b), None, + ' Does ' + word + ' have the same meaning in both sentences?', + None, [self.mask] + ], [] + elif self.pattern_id == 2: + parts_a, parts_b = [ + None, word, ' .', None, ' Sense (1) (a) "', + self.shortenable(text_a), '"', None, ' (', [self.mask], ') "', + text_b, '"' + ], [] + else: + raise NotImplementedError(self.pattern_id) + parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b) + return parts_a, parts_b + + def verbalize(self, label) -> List[str]: + if self.pattern_id == 2: + return WicPVP.VERBALIZER_B[label] + return WicPVP.VERBALIZER_A[label] + + +class AgnewsPVP(PVP): + VERBALIZER = { + '1': [' World'], + '2': [' Sports'], + '3': [' Business'], + '4': [' Tech'] + } + + @staticmethod + def available_patterns(): + return [0, 1, 2, 3, 4, 5] + + def get_parts(self, example: InputExample) -> FilledPattern: + + text_a = self.shortenable(example.text_a) + text_b = self.shortenable(example.text_b) + + if self.pattern_id == 0: + return [[self.mask], ':', text_a, text_b], [] + elif self.pattern_id == 1: + return [[self.mask], ' News:', text_a, text_b], [] + elif self.pattern_id == 2: + return [text_a, '(', [self.mask], ')', text_b], [] + elif self.pattern_id == 3: + return [text_a, text_b, '(', [self.mask], ')'], [] + elif self.pattern_id == 4: + return ['[ Category:', [self.mask], ']', text_a, text_b], [] + elif self.pattern_id == 5: + return [[self.mask], '-', text_a, text_b], [] + else: + raise ValueError('No pattern implemented for id {}'.format( + self.pattern_id)) + + def verbalize(self, label) -> List[str]: + return AgnewsPVP.VERBALIZER[label] + + +class YahooPVP(PVP): + VERBALIZER = { + '1': [' Society'], + '2': [' Science'], + '3': [' Health'], + '4': [' Education'], + '5': [' Computer'], + '6': [' Sports'], + '7': [' Business'], + '8': [' Entertainment'], + '9': [' Relationship'], + '10': [' Politics'], + } + + @staticmethod + def available_patterns(): + return [0, 1, 2, 3, 4, 5] + + def get_parts(self, example: InputExample) -> FilledPattern: + + text_a = self.shortenable(example.text_a) + text_b = self.shortenable(example.text_b) + + if self.pattern_id == 0: + return [[self.mask], ':', text_a, text_b], [] + elif self.pattern_id == 1: + return [[self.mask], ' Question:', text_a, text_b], [] + elif self.pattern_id == 2: + return [text_a, '(', [self.mask], ')', text_b], [] + elif self.pattern_id == 3: + return [text_a, text_b, '(', [self.mask], ')'], [] + elif self.pattern_id == 4: + return ['[ Category:', [self.mask], ']', text_a, text_b], [] + elif self.pattern_id == 5: + return [[self.mask], '-', text_a, text_b], [] + else: + raise ValueError('No pattern implemented for id {}'.format( + self.pattern_id)) + + def verbalize(self, label) -> List[str]: + return YahooPVP.VERBALIZER[label] + + +class MnliPVP(PVP): + VERBALIZER_A = { + 'contradiction': [' Wrong'], + 'entailment': [' Right'], + 'neutral': [' Maybe'] + } + VERBALIZER_B = { + 'contradiction': [' No'], + 'entailment': [' Yes'], + 'neutral': [' Maybe'] + } + + @staticmethod + def available_patterns(): + return [0, 1, 2, 3] + + def get_parts(self, example: InputExample) -> FilledPattern: + text_a = self.shortenable(self.remove_final_punc(example.text_a)) + text_b = self.shortenable(example.text_b) + + if self.pattern_id == 0 or self.pattern_id == 2: + return ['"', text_a, '" ?'], [[self.mask], ', "', text_b, '"'] + elif self.pattern_id == 1 or self.pattern_id == 3: + return [text_a, '?'], [[self.mask], ',', text_b] + + def verbalize(self, label) -> List[str]: + if self.pattern_id == 0 or self.pattern_id == 1: + return MnliPVP.VERBALIZER_A[label] + return MnliPVP.VERBALIZER_B[label] + + +class YelpPolarityPVP(PVP): + VERBALIZER = {'1': [' bad'], '2': [' good']} + + @staticmethod + def available_patterns(): + return [0, 1, 2, 3] + + def get_parts(self, example: InputExample) -> FilledPattern: + text = self.shortenable(example.text_a) + + if self.pattern_id == 0: + return ['It was', [self.mask], '.', text], [] + elif self.pattern_id == 1: + return [text, '. All in all, it was', [self.mask], '.'], [] + elif self.pattern_id == 2: + return ['Just', [self.mask], '!'], [text] + elif self.pattern_id == 3: + return [text], [' In summary, the restaurant is', [self.mask], '.'] + else: + raise ValueError('No pattern implemented for id {}'.format( + self.pattern_id)) + + def verbalize(self, label) -> List[str]: + return YelpPolarityPVP.VERBALIZER[label] + + +class YelpFullPVP(YelpPolarityPVP): + VERBALIZER = { + '1': [' terrible'], + '2': [' bad'], + '3': [' okay'], + '4': [' good'], + '5': [' great'] + } + + def verbalize(self, label) -> List[str]: + return YelpFullPVP.VERBALIZER[label] + + +class XStancePVP(PVP): + VERBALIZERS = { + 'en': { + 'FAVOR': ['Yes'], + 'AGAINST': ['No'] + }, + 'de': { + 'FAVOR': ['Ja'], + 'AGAINST': ['Nein'] + }, + 'fr': { + 'FAVOR': ['Oui'], + 'AGAINST': ['Non'] + } + } + + @staticmethod + def available_patterns(): + return [0, 1, 2, 3, 4, 5] + + def get_parts(self, example: InputExample) -> FilledPattern: + + text_a = self.shortenable(example.text_a) + text_b = self.shortenable(example.text_b) + + if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4: + return ['"', text_a, '"'], [[self.mask], '. "', text_b, '"'] + elif self.pattern_id == 1 or self.pattern_id == 3 or self.pattern_id == 5: + return [text_a], [[self.mask], '.', text_b] + + def verbalize(self, label) -> List[str]: + lang = 'de' if self.pattern_id < 2 else 'en' if self.pattern_id < 4 else 'fr' + return XStancePVP.VERBALIZERS[lang][label] + + +class Sst2PVP(PVP): + VERBALIZER_A = {'0': [' terrible'], '1': [' great']} + + VERBALIZER_B = {'0': [' bad'], '1': [' good']} + + @staticmethod + def available_patterns(): + return [0, 1] + + def get_parts(self, example: InputExample) -> FilledPattern: + text = self.shortenable(example.text_a) + if self.pattern_id == 0 or self.pattern_id == 1: + return [text, ' It was', [self.mask], '.'], [] + else: + raise ValueError('No pattern implemented for id {}'.format( + self.pattern_id)) + + def verbalize(self, label) -> List[str]: + if self.pattern_id == 0: + return Sst2PVP.VERBALIZER_A[label] + else: + return Sst2PVP.VERBALIZER_B[label] + + +class ColaPVP(PVP): + VERBALIZER = {'0': [' incorrect'], '1': [' correct']} + + def get_parts(self, example: InputExample) -> FilledPattern: + text = self.shortenable(example.text_a) + if self.pattern_id == 0: + return ['"', text, '"', ' This is', [self.mask], '.'], [] + else: + raise ValueError('No pattern implemented for id {}'.format( + self.pattern_id)) + + def verbalize(self, label) -> List[str]: + return ColaPVP.VERBALIZER[label] + + +class MrpcPVP(PVP): + VERBALIZER = {'0': [' No'], '1': [' Yes']} + + @staticmethod + def available_patterns(): + return [0, 1] + + def get_parts(self, example: InputExample) -> FilledPattern: + text_a = self.shortenable(example.text_a) + if self.pattern_id == 0: + text_b = self.shortenable(self.lowercase_first(example.text_b)) + return [text_a], [[self.mask], ', ', text_b] + elif self.pattern_id == 1: + text_b = self.shortenable( + self.remove_final_punc(self.lowercase_first(example.text_b))) + return [text_a], [' Does it mean that', text_b, '?', [self.mask]] + else: + raise ValueError('No pattern implemented for id {}'.format( + self.pattern_id)) + + def verbalize(self, label) -> List[str]: + return MrpcPVP.VERBALIZER[label] + + +class QqpPVP(PVP): + VERBALIZER = {'0': [' No'], '1': [' Yes']} + + @staticmethod + def available_patterns(): + return [0, 1] + + def get_parts(self, example: InputExample) -> FilledPattern: + text_a = self.shortenable(example.text_a) + text_b = self.shortenable(self.lowercase_first(example.text_b)) + if self.pattern_id == 0: + return [text_a], [' Do you mean ', text_b, [self.mask], '.'] + elif self.pattern_id == 1: + return [text_a], [[self.mask], ', ', text_b] + else: + raise ValueError('No pattern implemented for id {}'.format( + self.pattern_id)) + + def verbalize(self, label) -> List[str]: + return QqpPVP.VERBALIZER[label] + + +class QnliPVP(PVP): + VERBALIZER = {'not_entailment': [' No'], 'entailment': [' Yes']} + + @staticmethod + def available_patterns(): + return [0, 1, 2] + + def get_parts(self, example: InputExample) -> FilledPattern: + question = self.remove_final_punc(example.text_a) + passage = example.text_b + if self.pattern_id == 0: + return [ + self.shortenable(passage), ' Question:', + self.shortenable(' ' + question), '? Do you know the answer?', + [self.mask], '.' + ], [] + elif self.pattern_id == 1: + return [ + self.shortenable(passage), + ' Based on the previous passage, do you know the answer', + self.shortenable(' ' + question), '?', [self.mask], '.' + ], [] + elif self.pattern_id == 2: + return [ + 'Based on the following passage, do you know the answer', + self.shortenable(' ' + question), '?', [self.mask], '.', + self.shortenable(' ' + passage) + ], [] + else: + raise ValueError('No pattern implemented for id {}'.format( + self.pattern_id)) + + def verbalize(self, label) -> List[str]: + return QnliPVP.VERBALIZER[label] + + +class SquadPVP(PVP): + + @property + def is_multi_token(self): + return True + + def get_answers(self, example: InputExample): + target = ' ' + example.meta['answer']['text'] + answers = [target] + return answers + + def get_parts(self, example: InputExample) -> FilledPattern: + context = self.shortenable(example.text_a) + question = example.text_b + return [context, ' ' + question, [self.mask], '.'], [] + + def verbalize(self, label) -> List[str]: + return [] + + +def get_verbalization_ids(word: str, tokenizer, + force_single_token: bool) -> Union[int, List[int]]: + """ + Get the token ids corresponding to a verbalization + + :param word: the verbalization + :param tokenizer: the tokenizer to use + :param force_single_token: whether it should be enforced that the verbalization corresponds to a single token. + If set to true, this method returns a single int instead of a list and throws an error if the word + corresponds to multiple tokens. + :return: either the list of token ids or the single token id corresponding to this word + """ + ids = tokenizer.EncodeAsIds(word).tokenization + if not force_single_token: + return ids + assert len(ids) == 1, \ + f'Verbalization "{word}" does not correspond to a single token, got {tokenizer.DecodeIds(ids)}' + verbalization_id = ids[0] + assert verbalization_id not in tokenizer.command_id_map, \ + f'Verbalization {word} is mapped to a special token {tokenizer.IdToToken(verbalization_id)}' + return verbalization_id + + +PVPS = { + 'agnews': AgnewsPVP, + 'mnli': MnliPVP, + 'yelp-polarity': YelpPolarityPVP, + 'yelp-full': YelpFullPVP, + 'yahoo': YahooPVP, + 'xstance': XStancePVP, + 'xstance-de': XStancePVP, + 'xstance-fr': XStancePVP, + 'rte': RtePVP, + 'wic': WicPVP, + 'cb': CbPVP, + 'wsc': WscPVP, + 'boolq': BoolQPVP, + 'copa': CopaPVP, + 'multirc': MultiRcPVP, + 'record': RecordPVP, + 'ax-b': RtePVP, + 'ax-g': RtePVP, + 'sst2': Sst2PVP, + 'cola': ColaPVP, + 'mrpc': MrpcPVP, + 'qqp': QqpPVP, + 'qnli': QnliPVP, + 'squad': SquadPVP, + 'race': RacePVP, +} diff --git a/modelscope/models/nlp/mglm/test/__init__.py b/modelscope/models/nlp/mglm/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/nlp/mglm/test/test_block.py b/modelscope/models/nlp/mglm/test/test_block.py new file mode 100644 index 00000000..ed4225da --- /dev/null +++ b/modelscope/models/nlp/mglm/test/test_block.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 Zhipu.AI + +import random +from argparse import Namespace + +import numpy as np +from blocklm_utils import ConstructBlockStrategy + + +# rng = random.Random() +# span_lengths = [2, 3, 4, 2, 3, 4] +# length = 100 +# +# counts = np.array([0] * length) +# for _ in range(10000): +# rng.shuffle(span_lengths) +# spans = ConstructBlockStrategy.sample_spans(span_lengths, length, rng) +# for start, end in spans: +# counts[start: end] += 1 +# print(counts) +def main(): + args = Namespace() + args.seq_length = 10 + args.eod_token = 0 + + strategy = ConstructBlockStrategy( + args, None, bert_ratio=0.4, max_seq_length=128) + counts = np.array([0] * 10) + for _ in range(10000): + spans = strategy.sample_span_in_document( + np.array([1, 2, 3, 0, 4, 5, 6, 7, 9, 0], dtype=np.long), [1, 1], + random.Random()) + for start, end in spans: + counts[start:end] += 1 + + print(counts) diff --git a/modelscope/models/nlp/mglm/test/test_rel_shift.py b/modelscope/models/nlp/mglm/test/test_rel_shift.py new file mode 100644 index 00000000..00cbb9fe --- /dev/null +++ b/modelscope/models/nlp/mglm/test/test_rel_shift.py @@ -0,0 +1,27 @@ +# Copyright (c) 2022 Zhipu.AI + +import matplotlib.pyplot as plt +import numpy as np +from learning_rates import AnnealingLR +from torch.nn.modules import Linear +from torch.optim import Adam + + +def main(): + model = Linear(10, 10) + optimizer = Adam(model.parameters()) + lr_scheduler = AnnealingLR( + optimizer, + start_lr=0.00015, + warmup_iter=3000, + num_iters=300000, + decay_style='cosine', + decay_ratio=0.1) + steps = np.arange(0, 400000, 10, dtype=np.long) + rates = [] + for step in steps: + lr_scheduler.num_iters = step + rates.append(lr_scheduler.get_lr()) + print(rates) + plt.plot(steps, rates) + plt.savefig('lr.pdf', format='pdf') diff --git a/modelscope/models/nlp/mglm/train_utils.py b/modelscope/models/nlp/mglm/train_utils.py new file mode 100644 index 00000000..c9c0de8e --- /dev/null +++ b/modelscope/models/nlp/mglm/train_utils.py @@ -0,0 +1,472 @@ +# Copyright (c) 2022 Zhipu.AI + +import deepspeed +import torch +from apex.optimizers import FusedAdam as Adam +from torch import distributed as dist + +from . import mpu +from .fp16 import DynamicLossScaler, FP16_Module, FP16_Optimizer +from .model import DistributedDataParallel as LocalDDP +from .model import (GLMForMultiTokenCloze, GLMForMultiTokenClozeFast, + GLMForSequenceClassification, GLMForSingleTokenCloze, + GLMModel) +from .model import PyTorchDistributedDataParallel as TorchDDP +from .model import glm_get_params_for_weight_decay_optimization +from .utils import get_checkpoint_iteration, get_checkpoint_name, print_rank_0 + + +def load_pretrained(model, checkpoint_path, args, task_tokens=None): + load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path) + checkpoint_name = get_checkpoint_name(load_dir, tag, release) + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading pretrained model {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + # Load the checkpoint. + sd = torch.load(checkpoint_name, map_location='cpu') + if args.deepspeed: + model = model.module + if isinstance(model, TorchDDP): + model = model.module + if isinstance(model, FP16_Module): + model = model.module + if hasattr(model, 'model'): + model = model.model + + # Model. + def extend_embedding_weights(state_weights, model_weights): + original_length = state_weights.shape[0] + assert original_length <= args.max_position_embeddings + 1 + new_weights = model_weights.clone() + new_weights[:original_length] = state_weights + return new_weights + + if args.block_lm: + if 'transformer.block_position_embeddings.weight' in sd['module']: + position_weights = sd['module'][ + 'transformer.position_embeddings.weight'] + if args.max_position_embeddings + 1 > position_weights.shape[0]: + sd['module'][ + 'transformer.position_embeddings.weight'] = extend_embedding_weights( + position_weights, + model.state_dict() + ['transformer.position_embeddings.weight'].data) + print_rank_0( + f'Extend position embedding to {args.max_position_embeddings + 1}' + ) + if 'transformer.block_position_embeddings.weight' in sd['module']: + block_position_weights = sd['module'][ + 'transformer.block_position_embeddings.weight'] + if args.max_position_embeddings + 1 > block_position_weights.shape[ + 0]: + sd['module'][ + 'transformer.block_position_embeddings.weight'] = extend_embedding_weights( + block_position_weights, + model.state_dict() + ['transformer.block_position_embeddings.weight'].data) + print_rank_0( + f'Extend block position embedding to {args.max_position_embeddings + 1}' + ) + for key in list(model.state_dict().keys()): + print(key) + model.state_dict()[key.replace( + 'mixins.block_position_embedding.block_position_embeddings.weight', + 'transformer.block_position_embeddings.weight').replace( + 'transformer.word_embeddings.weight', + 'word_embeddings.weight')] = model.state_dict().pop(key) + + missing_keys, unexpected_keys = model.load_state_dict( + sd['module'], strict=False) + if missing_keys or unexpected_keys: + print_rank_0( + f'Missing keys {missing_keys}, unexpected keys {unexpected_keys}') + if args.continuous_prompt and args.prompt_init: + model.prompt_spell.init_embedding(model.word_embeddings.weight.data, + task_tokens) + + +def get_model(args, + model_type=None, + multi_token=True, + num_labels=None, + spell_length=None): + """Build the model.""" + print_rank_0('building GPT2 model ...') + if args.pretrained_bert: + if model_type == 'multiple_choice': + model = BertForMultipleChoice.from_pretrained( + args.tokenizer_model_type, + cache_dir=args.cache_dir, + fp32_layernorm=args.fp32_layernorm, + fp32_embedding=args.fp32_embedding, + layernorm_epsilon=args.layernorm_epsilon) + elif model_type == 'classification': + model = BertForSequenceClassification.from_pretrained( + args.tokenizer_model_type, + cache_dir=args.cache_dir, + fp32_layernorm=args.fp32_layernorm, + fp32_embedding=args.fp32_embedding, + layernorm_epsilon=args.layernorm_epsilon, + num_labels=num_labels) + else: + raise NotImplementedError + else: + output_predict, paralle_output = True, True + if (model_type == 'multiple_choice' + or model_type == 'classification') and not args.cloze_eval: + output_predict = False + if model_type is not None: + paralle_output = False + if spell_length is not None: + print_rank_0(f'Continuous spell length {spell_length}') + model = GLMModel( + num_layers=args.num_layers, + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + embedding_dropout_prob=args.hidden_dropout, + attention_dropout_prob=args.attention_dropout, + output_dropout_prob=args.hidden_dropout, + max_sequence_length=args.max_position_embeddings, + max_memory_length=args.mem_length, + checkpoint_activations=args.checkpoint_activations, + checkpoint_num_layers=args.checkpoint_num_layers, + parallel_output=paralle_output, + relative_encoding=args.transformer_xl, + block_position_encoding=args.block_lm and not args.masked_lm, + output_predict=output_predict, + spell_length=spell_length, + spell_func=args.prompt_func, + attention_scale=args.attention_scale) + if args.freeze_transformer: + model.freeze_transformer( + tune_prefix_layers=args.tune_prefix_layers) + if model_type is not None: + if model_type == 'multiple_choice': + if args.cloze_eval: + if multi_token: + if args.fast_decode: + model = GLMForMultiTokenClozeFast( + model, length_penalty=args.length_penalty) + else: + model = GLMForMultiTokenCloze( + model, length_penalty=args.length_penalty) + else: + model = GLMForSingleTokenCloze( + model, take_softmax=args.adapet) + else: + model = GLMForSequenceClassification( + model, + args.hidden_size, + args.output_dropout, + args.pool_token, + num_class=num_labels) + elif model_type == 'classification': + model = GLMForSequenceClassification( + model, + args.hidden_size, + args.output_dropout, + args.pool_token, + num_class=num_labels) + elif model_type == 'generation': + pass + else: + raise NotImplementedError(model_type) + + if mpu.get_data_parallel_rank() == 0: + print( + ' > number of parameters on model parallel rank {}: {}'.format( + mpu.get_model_parallel_rank(), + sum([p.nelement() for p in model.parameters()])), + flush=True) + + # To prevent OOM for model sizes that cannot fit in GPU memory in full precision + if args.fp16: + model.half() + + # GPU allocation. + model.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if args.fp16: + model = FP16_Module(model) + + # Wrap model for distributed training. + if not args.deepspeed and (args.train_iters or args.epochs): + if args.DDP_impl == 'torch': + i = torch.cuda.current_device() + model = TorchDDP( + model, + device_ids=[i], + output_device=i, + process_group=mpu.get_data_parallel_group()) + elif args.DDP_impl == 'local': + model = LocalDDP(model) + else: + print_rank_0('Skip DDP model') + return model + + +def get_optimizer_param_groups(model): + # Build parameter groups (weight decay and non-decay). + while isinstance(model, (LocalDDP, TorchDDP, FP16_Module)): + model = model.module + param_groups = glm_get_params_for_weight_decay_optimization(model) + + # Add model parallel attribute if it is not set. + for param_group in param_groups: + # print('## param_group', len(param_group['params'])) + for param in param_group['params']: + if not hasattr(param, 'model_parallel'): + param.model_parallel = False + + return param_groups + + +def get_optimizer(param_groups, args): + """Set up the optimizer.""" + if args.cpu_optimizer: + # Apex FusedAdam uses decoupled weight decay so use the same here + if args.cpu_torch_adam: + cpu_adam_optimizer = torch.optim.AdamW + else: + from deepspeed.ops.adam import DeepSpeedCPUAdam + cpu_adam_optimizer = DeepSpeedCPUAdam + optimizer = cpu_adam_optimizer( + param_groups, lr=args.lr, weight_decay=args.weight_decay) + else: + # Use FusedAdam. + if args.optimizer == 'adam': + optimizer = Adam( + param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps) + elif args.optimizer == 'adafactor': + from transformers import Adafactor + optimizer = Adafactor( + param_groups, + lr=args.lr, + relative_step=False, + warmup_init=False) + else: + raise NotImplementedError + + print(f'Optimizer = {optimizer.__class__.__name__}') + if hasattr(args, 'deepspeed') and args.deepspeed: + raise NotImplementedError + # fp16 wrapper is not required for DeepSpeed. + # return optimizer + + # Wrap into fp16 optimizer. + if args.fp16: + optimizer = FP16_Optimizer( + optimizer, + static_loss_scale=args.loss_scale, + dynamic_loss_scale=args.dynamic_loss_scale, + dynamic_loss_args={ + 'scale_window': args.loss_scale_window, + 'min_scale': args.min_scale, + 'delayed_shift': args.hysteresis + }) + + return optimizer + + +def get_learning_rate_scheduler(optimizer, args): + """Build the learning rate scheduler.""" + + # Add linear learning rate scheduler. + if args.lr_decay_iters is not None: + num_iters = args.lr_decay_iters + else: + num_iters = args.train_iters + if args.finetune: + num_iters = num_iters // args.gradient_accumulation_steps + num_iters = max(1, num_iters) + init_step = -1 + warmup_iter = args.warmup * num_iters + lr_scheduler = AnnealingLR( + optimizer, + start_lr=args.lr, + warmup_iter=warmup_iter, + num_iters=num_iters - warmup_iter, + decay_style=args.lr_decay_style, + last_iter=init_step, + decay_ratio=args.lr_decay_ratio) + + return lr_scheduler + + +def setup_model_and_optimizer(args, + model_type=None, + multi_token=True, + num_labels=None, + spell_length=None): + """Setup model and optimizer.""" + + model = get_model( + args, + model_type=model_type, + multi_token=multi_token, + num_labels=num_labels, + spell_length=spell_length) + param_groups = get_optimizer_param_groups(model) + + if args.train_data is not None or args.data_dir is not None and ( + args.epochs > 0 or args.train_iters > 0): + if args.deepspeed: + print_rank_0('DeepSpeed is enabled.') + + model, optimizer, _, _ = deepspeed.initialize( + model=model, + model_parameters=param_groups, + args=args, + mpu=mpu, + dist_init_required=False) + else: + optimizer = get_optimizer(param_groups, args) + lr_scheduler = get_learning_rate_scheduler(optimizer, args) + else: + optimizer, lr_scheduler = None, None + + return model, optimizer, lr_scheduler + + +def backward_step(optimizer, model, lm_loss, args, timers): + """Backward step.""" + + # Total loss. + loss = lm_loss + + # Backward pass. + if args.deepspeed: + model.backward(loss) + else: + # optimizer.zero_grad() + if args.fp16: + optimizer.backward(loss, update_master_grads=False) + else: + loss.backward() + + if args.deepspeed or args.DDP_impl == 'torch': + # DeepSpeed backward propagation already addressed all reduce communication. + # Reset the timer to avoid breaking timer logs below. + timers('allreduce').reset() + else: + timers('allreduce').start() + model.allreduce_params( + reduce_after=False, fp32_allreduce=args.fp32_allreduce) + timers('allreduce').stop() + + # Update master gradients. + if not args.deepspeed: + if args.fp16: + optimizer.update_master_grads() + + # Clipping gradients helps prevent the exploding gradient. + if args.clip_grad > 0: + if not args.fp16: + mpu.clip_grad_norm(model.parameters(), args.clip_grad) + else: + optimizer.clip_master_grads(args.clip_grad) + + return lm_loss + + +def see_memory_usage(message, force=False): + if not force: + return + dist.barrier() + if dist.get_rank() == 0: + print(message) + print('Memory Allocated ', + torch.cuda.memory_allocated() / (1024 * 1024 * 1024), + 'GigaBytes') + print('Max Memory Allocated ', + torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), + 'GigaBytes') + print('Cache Allocated ', + torch.cuda.memory_cached() / (1024 * 1024 * 1024), 'GigaBytes') + print('Max cache Allocated ', + torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), + 'GigaBytes') + print(' ') + # input("Press Any Key To Continue ..") + + +def train_step(data_iterator, + model, + optimizer, + lr_scheduler, + args, + timers, + forward_step_func, + mems=None, + single_step=False): + """Single training step.""" + lm_loss_total, count = 0.0, 0 + mems = [] if mems is None else mems + if not args.deepspeed: + optimizer.zero_grad() + while True: + skipped_iter, complete = 0, False + # Forward model for one step. + timers('forward').start() + lm_loss, mems, _ = forward_step_func(data_iterator, model, args, + timers, mems) + timers('forward').stop() + # print_rank_0("Forward step") + if not args.deepspeed: + lm_loss /= args.gradient_accumulation_steps + + reduced_loss = lm_loss.detach().clone().view(1) + torch.distributed.all_reduce( + reduced_loss.data, group=mpu.get_data_parallel_group()) + reduced_loss.data = reduced_loss.data / ( + args.world_size / args.model_parallel_size) + + if not DynamicLossScaler._has_inf_or_nan(reduced_loss): + lm_loss_total += reduced_loss + count += 1 + + # Calculate gradients, reduce across processes, and clip. + timers('backward').start() + backward_step(optimizer, model, lm_loss, args, timers) + timers('backward').stop() + # print_rank_0("Backward step") + # Update parameters. + timers('optimizer').start() + if args.deepspeed: + if model.is_gradient_accumulation_boundary(): + model.step() + complete = True + if not (args.fp16 and optimizer.overflow): + lr_scheduler.step() + else: + skipped_iter = 1 + else: + model.step() + else: + if count == args.gradient_accumulation_steps: + optimizer.step() + complete = True + # Update learning rate. + if not (args.fp16 and optimizer.overflow): + lr_scheduler.step() + else: + skipped_iter = 1 + # print_rank_0("Optimizer step") + timers('optimizer').stop() + if complete: + break + else: + print_rank_0('Found NaN loss, skip backward') + del lm_loss, reduced_loss + mems = [] + if single_step: + break + if args.deepspeed: + lm_loss_total = lm_loss_total / count + return lm_loss_total, skipped_iter, mems diff --git a/modelscope/models/nlp/mglm/utils.py b/modelscope/models/nlp/mglm/utils.py new file mode 100644 index 00000000..2bfcf8c0 --- /dev/null +++ b/modelscope/models/nlp/mglm/utils.py @@ -0,0 +1,529 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for logging and serialization""" + +import os +import random +import subprocess +import time + +import json +import numpy as np +import torch + +from . import mpu +from .fp16 import FP16_Optimizer + +SUMMARY_WRITER_DIR_NAME = 'runs' + + +def get_log_dir(name, base): + return os.path.join(base, SUMMARY_WRITER_DIR_NAME, name) + + +def print_rank_0(message): + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + + +def get_hostname(): + hostname_cmd = ['hostname -I'] + result = subprocess.check_output(hostname_cmd, shell=True) + master_addr = result.decode('utf-8').split()[0] + return master_addr + + +def get_spare_port(args): + if torch.distributed.get_rank() == 0: + port = subprocess.check_output(['shuf -n 1 -i 10000-65535'], + shell=True) + port = int(port.strip()) + if port == args.master_port: + port = subprocess.check_output(['shuf -n 1 -i 10000-65535'], + shell=True) + port = int(port.strip()) + port = torch.cuda.LongTensor([port]) + else: + port = torch.cuda.LongTensor([0]) + torch.distributed.broadcast(port, 0) + port = port.item() + return port + + +def print_and_save_args(args, verbose=True, log_dir=None): + """Print arguments.""" + if verbose: + print('arguments:', flush=True) + for arg in vars(args): + dots = '.' * (29 - len(arg)) + print( + ' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True) + if log_dir is not None: + json_file = os.path.join(log_dir, 'config.json') + with open(json_file, 'w') as output: + json.dump(vars(args), output, sort_keys=True) + if args.deepspeed and args.deepspeed_config is not None: + with open(args.deepspeed_config) as file: + deepspeed_config = json.load(file) + deepspeed_json_file = os.path.join(log_dir, + 'config_gpt_large.json') + with open(deepspeed_json_file, 'w') as output: + json.dump(deepspeed_config, output) + + +def print_params_min_max_norm(optimizer, iteration): + """Print min, max, and norm of all parameters.""" + index = 0 + rank = torch.distributed.get_rank() + string = 'iteration, rank, index, model-parallel,min, max, norm\n' + optimizer_ = optimizer + if isinstance(optimizer, FP16_Optimizer): + optimizer_ = optimizer.optimizer + for param_group in optimizer_.param_groups: + for param in param_group['params']: + index += 1 + min_ = param.data.min() + max_ = param.data.max() + norm = param.data.norm() + string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( + iteration, rank, index, int(param.model_parallel)) + string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) + print(string, flush=True) + + +class Timers: + """Group of timers.""" + + class Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + assert not self.started_, 'timer has already been started' + torch.cuda.synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, 'timer is not started' + torch.cuda.synchronize() + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = self.Timer(name) + return self.timers[name] + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + string = 'time (ms)' + for name in names: + elapsed_time = self.timers[name].elapsed( + reset=reset) * 1000.0 / normalizer + string += ' | {}: {:.2f}'.format(name, elapsed_time) + print_rank_0(string) + + +def report_memory(name): + """Simple GPU memory report.""" + + mega_bytes = 1024.0 * 1024.0 + string = name + ' memory (MB)' + string += ' | allocated: {}'.format(torch.cuda.memory_allocated() + / mega_bytes) + string += ' | max allocated: {}'.format(torch.cuda.max_memory_allocated() + / mega_bytes) + string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) + string += ' | max cached: {}'.format(torch.cuda.memory_reserved() + / mega_bytes) + print_rank_0(string) + + +def get_checkpoint_name(checkpoints_path, + iteration, + release=False, + zero=False): + if release: + d = 'release' + else: + d = '{}'.format(iteration) + if zero: + dp_rank = mpu.get_data_parallel_rank() + d += '_zero_dp_rank_{}'.format(dp_rank) + return os.path.join( + checkpoints_path, d, + 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank())) + + +def ensure_directory_exists(filename): + dirname = os.path.dirname(filename) + if not os.path.exists(dirname): + os.makedirs(dirname, exist_ok=True) + + +def get_checkpoint_tracker_filename(checkpoints_path): + return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') + + +def save_zero_checkpoint(args, iteration, optimizer): + zero_sd = { + 'iteration': iteration, + 'optimizer_state_dict': optimizer.state_dict() + } + zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True) + ensure_directory_exists(zero_checkpoint_name) + torch.save(zero_sd, zero_checkpoint_name) + print(' successfully saved {}'.format(zero_checkpoint_name)) + + +def save_checkpoint(iteration, + model, + optimizer, + lr_scheduler, + args, + tag=None, + barrier=True, + only_changed_parameters=False, + no_deepspeed=False, + no_save_optim=False): + """Save a model checkpoint.""" + if tag is None: + tag = str(iteration) + if args.deepspeed and not no_deepspeed: + save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag) + else: + # Only rank zer0 of the data parallel writes to the disk. + + if mpu.get_data_parallel_rank() == 0: + checkpoint_name = get_checkpoint_name(args.save, tag) + print( + 'global rank {} is saving checkpoint at iteration {:7d} to {}'. + format(torch.distributed.get_rank(), iteration, + checkpoint_name)) + sd = {'iteration': iteration} + if args.deepspeed: + model = model.module + state_dict = model.state_dict() + if only_changed_parameters: + requires_grad_dict = {} + for name, parameter in model.named_parameters(): + requires_grad_dict[name] = parameter.requires_grad + state_dict = { + key: value + for key, value in state_dict.items() + if requires_grad_dict[key] + } + sd['module'] = state_dict + + # Optimizer stuff. + if not args.no_save_optim and not no_save_optim: + if optimizer is not None: + sd['optimizer'] = optimizer.state_dict() + if lr_scheduler is not None: + sd['lr_scheduler'] = lr_scheduler.state_dict() + + # rng states. + if not args.no_save_rng: + sd['random_rng_state'] = random.getstate() + sd['np_rng_state'] = np.random.get_state() + sd['torch_rng_state'] = torch.get_rng_state() + sd['cuda_rng_state'] = torch.cuda.get_rng_state() + sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker( + ).get_states() + + ensure_directory_exists(checkpoint_name) + torch.save(sd, checkpoint_name) + print(' successfully saved {}'.format(checkpoint_name)) + + # Wait so everyone is done (necessary) + if barrier: + torch.distributed.barrier() + # And update the latest iteration + if torch.distributed.get_rank() == 0: + tracker_filename = get_checkpoint_tracker_filename(args.save) + with open(tracker_filename, 'w') as f: + f.write(tag) + + +def save_ds_checkpoint(iteration, model, lr_scheduler, args, tag): + """Save a model checkpoint.""" + + sd = {} + sd['iteration'] = iteration + if lr_scheduler is not None: + sd['client_lr_scheduler'] = lr_scheduler.state_dict() + # rng states. + if not args.no_save_rng: + sd['random_rng_state'] = random.getstate() + sd['np_rng_state'] = np.random.get_state() + sd['torch_rng_state'] = torch.get_rng_state() + sd['cuda_rng_state'] = torch.cuda.get_rng_state() + sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() + model.save_checkpoint(args.save, tag, client_state=sd) + + +def get_checkpoint_iteration(load_path): + # Read the tracker file and set the iteration. + tracker_filename = get_checkpoint_tracker_filename(load_path) + if not os.path.isfile(tracker_filename): + print_rank_0('WARNING: could not find the metadata file {} '.format( + tracker_filename)) + if os.path.isdir(load_path): + path = os.path.normpath(load_path) + load_dir, tag = os.path.split(path) + print_rank_0( + 'Try to directly load the checkpoint from the directory') + return load_dir, tag, False, True + print_rank_0(' will not load any checkpoints and will start from ' + 'random') + return load_path, 0, False, False + with open(tracker_filename, 'r') as f: + metastring = f.read().strip() + release = metastring == 'release' + # try: + # iteration = int(metastring) + # except ValueError: + # release = metastring == 'release' + # if not release: + # print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( + # tracker_filename)) + # exit() + + # assert iteration > 0 or release, 'error parsing metadata file {}'.format( + # tracker_filename) + + return load_path, metastring, release, True + + +def load_checkpoint(model, + optimizer, + lr_scheduler, + args, + no_deepspeed=False, + no_load_optim=False): + """Load a model checkpoint.""" + + load_dir, tag, release, success = get_checkpoint_iteration(args.load) + + if not success: + return 0 + + if args.deepspeed and not no_deepspeed: + + checkpoint_name, sd = model.load_checkpoint( + load_dir, + tag, + load_optimizer_states=not args.no_load_optim and not no_load_optim, + load_lr_scheduler_states=not args.no_load_lr_scheduler) + if not args.no_load_lr_scheduler and 'client_lr_scheduler' in sd: + lr_scheduler.load_state_dict(sd['client_lr_scheduler']) + print_rank_0('Load lr scheduler state') + if checkpoint_name is None: + if mpu.get_data_parallel_rank() == 0: + print('Unable to load checkpoint.') + return tag + + else: + + # Checkpoint. + checkpoint_name = get_checkpoint_name(load_dir, tag, release) + + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + # Load the checkpoint. + sd = torch.load(checkpoint_name, map_location='cpu') + + # Model. + if args.deepspeed: + model = model.module + missing_keys, unexpected_keys = model.load_state_dict( + sd['module'], strict=False) + if missing_keys or unexpected_keys: + print_rank_0( + f'Missing keys {missing_keys}, unexpected keys {unexpected_keys}' + ) + + # Optimizer. + if not release and not args.finetune and not args.no_load_optim and not no_load_optim: + try: + if optimizer is not None: + optimizer.load_state_dict(sd['optimizer']) + if lr_scheduler is not None: + lr_scheduler.load_state_dict(sd['lr_scheduler']) + except KeyError: + print_rank_0( + 'Unable to load optimizer from checkpoint {}, exiting. ' + 'Specify --no-load-optim or --finetune to prevent ' + 'attempting to load the optimizer ' + 'state.'.format(checkpoint_name)) + + # Iterations. + if args.finetune or release: + iteration = 0 + else: + try: + iteration = sd['iteration'] + except KeyError: + try: # Backward compatible with older checkpoints + iteration = sd['total_iters'] + except KeyError: + print_rank_0( + 'A metadata file exists but Unable to load iteration ' + ' from checkpoint {}, starting from 0 iteration'.format( + checkpoint_name)) + iteration = 0 + + # rng states. + if not release and not args.finetune and not args.no_load_rng: + try: + random.setstate(sd['random_rng_state']) + np.random.set_state(sd['np_rng_state']) + torch.set_rng_state(sd['torch_rng_state']) + torch.cuda.set_rng_state(sd['cuda_rng_state']) + mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) + except KeyError: + print_rank_0( + 'Unable to load random state from checkpoint {}, exiting. ' + 'Specify --no-load-rng or --finetune to prevent ' + 'attempting to load the random ' + 'state.'.format(checkpoint_name)) + + if mpu.get_data_parallel_rank() == 0: + print(' successfully loaded {}'.format(checkpoint_name)) + + return iteration + + +def load_weights(src, dst, dst2src=False): + """ + Loads weights from src to dst via in place copy. + src is a huggingface gpt2model, while dst is one of our models. + dst2src=True loads parameters from our models into huggingface's. + ^dst2src is still untested + """ + conv_layer = 'Conv1D' in str(type(src)) + for n, p in src.named_parameters(): + if dst2src: + data = dst._parameters[n].data + load = p.data + else: + data = p.data + load = dst._parameters[n].data + if conv_layer and 'weight' in n: + data = data.t().contiguous() + load.copy_(data) + + +# dst._parameters[n].data.copy_(data) + + +def load_mlp(our, oai, dst2src=False): + load_weights(oai.c_fc, our.dense_h_to_4h, dst2src) + load_weights(oai.c_proj, our.dense_4h_to_h, dst2src) + + +def load_attention(our, oai, dst2src=False): + load_weights(oai.c_attn, our.query_key_value, dst2src) + load_weights(oai.c_proj, our.dense, dst2src) + + +def load_transformer_layer(our, oai, dst2src=False): + load_weights(oai.ln_1, our.input_layernorm, dst2src) + load_weights(oai.ln_2, our.post_attention_layernorm, dst2src) + load_mlp(our.mlp, oai.mlp, dst2src) + load_attention(our.attention, oai.attn, dst2src) + + +def move_weights(our, oai, dst2src=False): + """ + Loads weights from `oai` to `our` via in place copy. + `oai` is a huggingface gpt2model, while `our` is one of our models. + dst2src=True loads parameters from our models into huggingface's. + ^dst2src=True is still untested + """ + # while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)): + # our=our.module + transformer_model = oai.transformer + load_weights(transformer_model.ln_f, our.transformer.final_layernorm, + dst2src) + load_weights(transformer_model.wte, our.word_embeddings, dst2src) + load_weights(transformer_model.wpe, our.position_embeddings, dst2src) + + for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h): + load_transformer_layer(our_layer, oai_layer, dst2src) + + +def debug_finetune_data(local_vars, batch_id, tokenizer): + tokens, target_ids = local_vars['tokens'], local_vars['target_ids'] + attention_mask, logit_mask, position_ids = local_vars[ + 'attention_mask'], local_vars['logit_mask'], local_vars['position_ids'] + output_tokens = [] + sep = attention_mask[batch_id].item() + for i, token in enumerate(tokens[batch_id][:sep].tolist()): + token = tokenizer.IdToToken(token) + if token == '[MASK]': + token = f'[{position_ids[batch_id][0, i].item()}]' + output_tokens.append(token) + print(' '.join(output_tokens)) + target_positions = [] + for i in range(sep, tokens.size(-1)): + if logit_mask[batch_id][i]: + target_positions.append(i) + print(target_positions) + print(tokenizer.DecodeIds(tokens[batch_id][target_positions].tolist())) + if len(target_ids.shape) > 2: + print( + tokenizer.DecodeIds( + target_ids[batch_id][target_positions].tolist())) + else: + print(tokenizer.DecodeIds(target_ids[batch_id].tolist())) + print(position_ids[batch_id][:, target_positions]) diff --git a/modelscope/models/nlp/palm_v2/__init__.py b/modelscope/models/nlp/palm_v2/__init__.py new file mode 100644 index 00000000..45ab6621 --- /dev/null +++ b/modelscope/models/nlp/palm_v2/__init__.py @@ -0,0 +1,43 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import PalmConfig + from .backbone import ( + AbsSummarizer, + PalmForConditionalGeneration, + Translator, + ) + from .text_generation import PalmForTextGeneration +else: + _import_structure = { + 'configuration': ['PalmConfig'], + 'backbone': + ['AbsSummarizer', 'PalmForConditionalGeneration', 'Translator'], + 'text_generation': ['PalmForTextGeneration'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/palm_v2/backbone.py b/modelscope/models/nlp/palm_v2/backbone.py new file mode 100644 index 00000000..afee2e3f --- /dev/null +++ b/modelscope/models/nlp/palm_v2/backbone.py @@ -0,0 +1,1327 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import codecs +import copy +import math +import os +import subprocess +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import addict +import json +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.init import xavier_uniform_ +from transformers import (BertConfig, BertModel, BertTokenizer, RobertaConfig, + RobertaModel, RobertaTokenizer) +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel + +from modelscope.utils import logger as logging +from .configuration import PalmConfig +from .dureader_eval import compute_bleu_rouge, normalize + +CONFIG_NAME = 'config.json' +WEIGHTS_NAME = 'pytorch_model.bin' + + +class MultiHeadedAttention(nn.Module): # SelfAttention + """ + Multi-Head Attention module from + "Attention is All You Need" + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. + + Similar to standard `dot` attention but uses + multiple attention distributions simulataneously + to select relevant items. + + .. mermaid:: + + graph BT + A[key] + B[value] + C[query] + O[output] + subgraph Attn + D[Attn 1] + E[Attn 2] + F[Attn N] + end + A --> D + C --> D + A --> E + C --> E + A --> F + C --> F + D --> O + E --> O + F --> O + B --> O + + Also includes several additional tricks. + + Args: + head_count (int): number of parallel heads + model_dim (int): the dimension of keys/values/queries, + must be divisible by head_count + dropout (float): dropout parameter + """ + + def __init__(self, + head_count, + model_dim, + dropout=0.1, + use_final_linear=True): + assert model_dim % head_count == 0 + self.dim_per_head = model_dim // head_count + self.model_dim = model_dim + + super().__init__() + self.head_count = head_count + + self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) + self.linear_values = nn.Linear(model_dim, + head_count * self.dim_per_head) + self.linear_query = nn.Linear(model_dim, + head_count * self.dim_per_head) + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.use_final_linear = use_final_linear + if (self.use_final_linear): + self.final_linear = nn.Linear(model_dim, model_dim) + + def forward(self, + key, + value, + query, + mask=None, + layer_cache=None, + type=None, + predefined_graph_1=None, + return_attn=False): + """ + Compute the context vector and the attention vectors. + + Args: + key (`FloatTensor`): set of `key_len` + key vectors `[batch, key_len, dim]` + value (`FloatTensor`): set of `key_len` + value vectors `[batch, key_len, dim]` + query (`FloatTensor`): set of `query_len` + query vectors `[batch, query_len, dim]` + mask: binary mask indicating which keys have + non-zero attention `[batch, query_len, key_len]` + Returns: + (`FloatTensor`, `FloatTensor`) : + + * output context vectors `[batch, query_len, dim]` + * one of the attention vectors `[batch, query_len, key_len]` + """ + + batch_size = key.size(0) + dim_per_head = self.dim_per_head + head_count = self.head_count + + def shape(x): + """ projection """ + return x.view(batch_size, -1, head_count, dim_per_head) \ + .transpose(1, 2) + + def unshape(x): + """ compute context """ + return x.transpose(1, 2).contiguous() \ + .view(batch_size, -1, head_count * dim_per_head) + + # 1) Project key, value, and query. + if layer_cache is not None: + if type == 'self': + query, key, value = self.linear_query(query), self.linear_keys( + query), self.linear_values(query) + + key = shape(key) + value = shape(value) + + device = key.device + if layer_cache['self_keys'] is not None: + key = torch.cat((layer_cache['self_keys'].to(device), key), + dim=2) + if layer_cache['self_values'] is not None: + value = torch.cat( + (layer_cache['self_values'].to(device), value), dim=2) + layer_cache['self_keys'] = key + layer_cache['self_values'] = value + elif type == 'context': + query = self.linear_query(query) + if layer_cache['memory_keys'] is None: + key, value = self.linear_keys(key), self.linear_values( + value) + key = shape(key) + value = shape(value) + else: + key, value = layer_cache['memory_keys'], layer_cache[ + 'memory_values'] + layer_cache['memory_keys'] = key + layer_cache['memory_values'] = value + else: + key = self.linear_keys(key) + value = self.linear_values(value) + query = self.linear_query(query) + key = shape(key) + value = shape(value) + + query = shape(query) + + # 2) Calculate and scale scores. + query = query / math.sqrt(dim_per_head) + scores = torch.matmul(query, key.transpose(2, 3)) + + if mask is not None: + mask = mask.unsqueeze(1).expand_as(scores) + scores = scores.masked_fill(mask, -1e18) + + # 3) Apply attention dropout and compute context vectors. + + attn = self.softmax(scores) + + if predefined_graph_1 is not None: + attn_masked = attn[:, -1] * predefined_graph_1 + attn_masked = attn_masked / ( + torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) + + attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) + + drop_attn = self.dropout(attn) + if self.use_final_linear: + context = unshape(torch.matmul(drop_attn, value)) + output = self.final_linear(context) + if return_attn: + return output, attn + else: + return output + else: + context = torch.matmul(drop_attn, value) + if return_attn: + return context, attn + else: + return context + + +class PositionwiseFeedForward(nn.Module): # Output + """ A two-layer Feed-Forward-Network with residual layer norm. + + Args: + d_model (int): the size of input for the first-layer of the FFN. + d_ff (int): the hidden layer size of the second-layer + of the FNN. + dropout (float): dropout probability in :math:`[0, 1)`. + """ + + def __init__(self, d_model, d_ff, dropout=0.1): + super().__init__() + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.w_1 = nn.Linear(d_model, d_ff) + self.actv = ACT2FN['gelu_new'] + self.dropout_1 = nn.Dropout(dropout) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, x): + inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) + output = self.dropout_2(self.w_2(inter)) + return output + x + + +class TransformerDecoderLayer(nn.Module): # Layer + """ + Args: + d_model (int): the dimension of keys/values/queries in + MultiHeadedAttention, also the input size of + the first-layer of the PositionwiseFeedForward. + heads (int): the number of heads for MultiHeadedAttention. + d_ff (int): the second-layer of the PositionwiseFeedForward. + dropout (float): dropout probability(0-1.0). + self_attn_type (string): type of self-attention scaled-dot, average + """ + MAX_SIZE = 5000 + + def __init__(self, d_model, heads, d_ff, dropout): + super().__init__() + + self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) + + self.context_attn = MultiHeadedAttention( + heads, d_model, dropout=dropout) + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) + self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) + self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) + self.drop = nn.Dropout(dropout) + mask = self._get_attn_subsequent_mask(self.MAX_SIZE) + # Register self.mask as a buffer in TransformerDecoderLayer, so + # it gets TransformerDecoderLayer's cuda behavior automatically. + self.register_buffer('mask', mask) + + def forward(self, + inputs, + memory_bank, + src_pad_mask, + tgt_pad_mask, + previous_input=None, + layer_cache=None, + step=None): + """ + Args: + inputs (`FloatTensor`): `[batch_size x 1 x model_dim]` + memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]` + src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]` + tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]` + + Returns: + (`FloatTensor`, `FloatTensor`, `FloatTensor`): + + * output `[batch_size x 1 x model_dim]` + * attn `[batch_size x 1 x src_len]` + * all_input `[batch_size x current_step x model_dim]` + + """ + dec_mask = torch.gt( + tgt_pad_mask.type(torch.uint8) + + self.mask[:, :tgt_pad_mask.size(1), :tgt_pad_mask.size(1)].type( + torch.uint8), 0) + input_norm = self.layer_norm_1(inputs) + all_input = input_norm + if previous_input is not None: + all_input = torch.cat((previous_input, input_norm), dim=1) + dec_mask = None + + query = self.self_attn( + all_input, + all_input, + input_norm, + mask=dec_mask, + layer_cache=layer_cache, + type='self') + + query = self.drop(query) + inputs + + query_norm = self.layer_norm_2(query) + mid, attn = self.context_attn( + memory_bank, + memory_bank, + query_norm, + mask=src_pad_mask, + layer_cache=layer_cache, + type='context', + return_attn=True) + output = self.feed_forward(self.drop(mid) + query) + + return output, attn, all_input + + def _get_attn_subsequent_mask(self, size): + """ + Get an attention mask to avoid using the subsequent info. + + Args: + size: int + + Returns: + (`LongTensor`): + + * subsequent_mask `[1 x size x size]` + """ + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') + subsequent_mask = torch.from_numpy(subsequent_mask) + return subsequent_mask + + +class PositionalEncoding(nn.Module): + + def __init__(self, dropout, dim, max_len=5000): + super().__init__() + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) + * -(math.log(10000.0) / dim))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + self.dropout = nn.Dropout(dropout) + self.dim = dim + + def forward(self, emb, step=None): + emb = emb * math.sqrt(self.dim) + if (step): + emb = emb + self.pe[:, step][:, None, :] + + else: + emb = emb + self.pe[:, :emb.size(1)] + emb = self.dropout(emb) + return emb + + def get_emb(self, emb): + return self.pe[:, :emb.size(1)] + + +class TransformerDecoderState: + + def __init__(self, src: Tensor, cache_num_layers: int = -1): + self.src: Tensor = src + self.previous_input: Tensor = None + self.previous_layer_inputs: Tensor = None + self.cache: Optional[Dict[str, Any]] = None + if cache_num_layers != -1: + self._init_cache(cache_num_layers) + + def update_state(self, new_input, previous_layer_inputs): + self.previous_input = new_input + self.previous_layer_inputs = previous_layer_inputs + self.cache = None + + def _init_cache(self, num_layers): + self.cache = {} + for num in range(num_layers): + layer_cache = {'memory_keys': None, 'memory_values': None} + layer_cache['self_keys'] = None + layer_cache['self_values'] = None + self.cache['layer_{}'.format(num)] = layer_cache + + def map_batch_fn(self, fn): + + def _recursive_map(struct, batch_dim=0): + for k, v in struct.items(): + if v is not None: + if isinstance(v, dict): + _recursive_map(v) + else: + struct[k] = fn(v, batch_dim) + + self.src = fn(self.src, 0) + if self.cache is not None: + _recursive_map(self.cache) + + +class TransformerDecoder(nn.Module): # Decoder + """ + The Transformer decoder from "Attention is All You Need". + + + .. mermaid:: + + graph BT + A[input] + B[multi-head self-attn] + BB[multi-head src-attn] + C[feed forward] + O[output] + A --> B + B --> BB + BB --> C + C --> O + + + Args: + num_layers (int): number of encoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + dropout (float): dropout parameters + embeddings (:obj:`onmt.modules.Embeddings`): + embeddings to use, should have positional encodings + attn_type (str): if using a seperate copy attention + """ + decoder_type = 'transformer' + + def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings): + super().__init__() + + # Basic attributes. + self.num_layers = num_layers + self.embeddings = embeddings + self.pos_emb = PositionalEncoding(dropout, + self.embeddings.embedding_dim) + + # Build TransformerDecoder. + self.transformer_layers = nn.ModuleList([ + TransformerDecoderLayer(d_model, heads, d_ff, dropout) + for _ in range(num_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.state = None + + def forward(self, + state: TransformerDecoderState, + tgt: Tensor, + memory_bank: Tensor, + step: int = None, + memory_masks: Tensor = None): + src_words = state.src + tgt_words = tgt + src_batch, src_len = src_words.size() + tgt_batch, tgt_len = tgt_words.size() + + # Run the forward pass of the TransformerDecoder. + # emb = self.embeddings(tgt, step=step) + emb = self.embeddings(tgt) + assert emb.dim() == 3 # len x batch x embedding_dim + output = self.pos_emb(emb, step) + + src_memory_bank = memory_bank + padding_idx = self.embeddings.padding_idx + tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \ + .expand(tgt_batch, tgt_len, tgt_len) + + if memory_masks is not None: + src_len = memory_masks.size(-1) + src_pad_mask = memory_masks.expand(src_batch, tgt_len, src_len) + else: + src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ + .expand(src_batch, tgt_len, src_len) + + if state.cache is None: + saved_inputs = [] + attns = [] + for i in range(self.num_layers): + prev_layer_input = None + if state.cache is None: + if state.previous_input is not None: + prev_layer_input = state.previous_layer_inputs[i] + output, attn, all_input \ + = self.transformer_layers[i]( + output, src_memory_bank, + src_pad_mask, tgt_pad_mask, + previous_input=prev_layer_input, + layer_cache=state.cache['layer_{}'.format(i)] + if state.cache is not None else None, + step=step) + if state.cache is None: + saved_inputs.append(all_input) + attns.append(attn) + + if state.cache is None: + saved_inputs = torch.stack(saved_inputs) + + output = self.layer_norm(output) + + # Process the result and update the attentions. + if state.cache is None: + state.update_state(tgt, saved_inputs) + + return output, attns, state + + +class PalmPointerGenerator(nn.Module): + + def __init__(self, hidden_size, vocab_size): + super().__init__() + self.dense = nn.Linear(hidden_size, vocab_size) + self.gen_func = nn.LogSoftmax(-1) + + def forward(self, x): + x = self.dense(x) + x = self.gen_func(x) + return x + + +class PalmPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PalmConfig + base_model_prefix = 'palm' + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, + os.PathLike]], + **kwargs): + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + config = PalmConfig.from_json_file(config_file) if os.path.isfile( + config_file) else PalmConfig() + config.encoder_pth = os.path.join(pretrained_model_name_or_path, + config.encoder_pth) + checkpoint_file = os.path.join(pretrained_model_name_or_path, + WEIGHTS_NAME) + checkpoint = torch.load(checkpoint_file) if os.path.isfile( + checkpoint_file) else None + return cls(config, checkpoint, **kwargs) + + +class AbsSummarizer(PalmPreTrainedModel): # Model + + def __init__(self, config, checkpoint=None): + super().__init__(config) + self.config = config + if config.encoder == 'bert' or config.encoder == 'zh_bert': + self.bert = BertModel( + BertConfig.from_pretrained(config.encoder_pth)) + elif config.encoder == 'roberta': + self.bert = RobertaModel( + RobertaConfig.from_pretrained(config.encoder_pth)) + + if (config.max_pos > 512): + my_pos_embeddings = nn.Embedding( + config.max_pos, self.bert.model.config.hidden_size) + my_pos_embeddings.weight.data[: + 512] = self.bert.embeddings.position_embeddings.weight.data + my_pos_embeddings.weight.data[ + 512:] = self.bert.embeddings.position_embeddings.weight.data[ + -1][None, :].repeat(config.max_pos - 512, 1) + self.bert.model.embeddings.position_embeddings = my_pos_embeddings + self.vocab_size = self.bert.config.vocab_size + tgt_embeddings = nn.Embedding( + self.vocab_size, + self.bert.config.hidden_size, + padding_idx=1 if config.encoder == 'roberta' else 0) + + if config.share_emb: + tgt_embeddings.weight = copy.deepcopy( + self.bert.model.embeddings.word_embeddings.weight) + self.decoder = TransformerDecoder( + config.dec_layers, + config.dec_hidden_size, + heads=config.dec_heads, + d_ff=config.dec_ff_size, + dropout=config.dec_dropout, + embeddings=tgt_embeddings) + self.generator = PalmPointerGenerator(config.dec_hidden_size, + self.vocab_size) + self.generator.dense.weight = self.decoder.embeddings.weight + + if checkpoint is not None: + if 'model' in checkpoint: + checkpoint = checkpoint['model'] + for key in list(checkpoint.keys()): + checkpoint[key.replace('model.palm.', '')] = checkpoint[key] + self.load_state_dict(checkpoint, strict=False) + else: + for module in self.decoder.modules(): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + for p in self.generator.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + else: + p.data.zero_() + if config.use_bert_emb: + if config.encoder == 'roberta': + tgt_embeddings = nn.Embedding( + self.vocab_size, + self.bert.config.hidden_size, + padding_idx=1) + else: + tgt_embeddings = nn.Embedding( + self.vocab_size, + self.bert.config.hidden_size, + padding_idx=0) + tgt_embeddings.weight = copy.deepcopy( + self.bert.embeddings.word_embeddings.weight) + self.decoder.embeddings = tgt_embeddings + self.generator.dense.weight = self.decoder.embeddings.weight + + def forward(self, src, tgt, mask_src): + top_vec, _ = self.bert(src, mask_src, return_dict=False) + state = TransformerDecoderState(src) + decoder_outputs, attns, _ = self.decoder(state, tgt[:, :-1], top_vec) + return decoder_outputs, attns[-1], top_vec + + +class LabelSmoothingLoss(nn.Module): + """ + With label smoothing, + KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + """ + + def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): + assert 0.0 < label_smoothing <= 1.0 + self.padding_idx = ignore_index + super(LabelSmoothingLoss, self).__init__() + + smoothing_value = label_smoothing / (tgt_vocab_size - 2) + one_hot = torch.full((tgt_vocab_size, ), smoothing_value) + one_hot[self.padding_idx] = 0 + self.register_buffer('one_hot', one_hot.unsqueeze(0)) + self.confidence = 1.0 - label_smoothing + + def forward(self, output, target): + """ + output (FloatTensor): batch_size x n_classes + target (LongTensor): batch_size + """ + model_prob = self.one_hot.repeat(target.size(0), 1) + model_prob.scatter_(1, target.unsqueeze(1), self.confidence) + model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0) + + return F.kl_div(output, model_prob, reduction='sum') + + +class NMTLossCompute(nn.Module): + """ + Standard NMT Loss Computation. + """ + + def __init__(self, generator, symbols, vocab_size, label_smoothing=0.0): + super().__init__() + self.generator = generator + self.padding_idx = symbols['PAD'] + if label_smoothing > 0: + self.criterion = LabelSmoothingLoss( + label_smoothing, vocab_size, ignore_index=self.padding_idx) + else: + self.criterion = nn.NLLLoss( + ignore_index=self.padding_idx, reduction='sum') + + def _bottle(self, _v): + return _v.view(-1, _v.size(2)) + + def _unbottle(self, _v, batch_size): + return _v.view(-1, batch_size, _v.size(1)) + + def forward(self, tgt, output): + target = tgt[:, 1:] + normalization = target.ne(self.padding_idx).sum() + bottled_output = self._bottle(output) + scores = self.generator(bottled_output) + gtruth = target.contiguous().view(-1) + loss = self.criterion(scores, gtruth) + loss.div(float(normalization)) + return loss + + +class PalmForConditionalGeneration(PalmPreTrainedModel): + + def __init__(self, config, checkpoint=None): + super().__init__(config) + self.config = config + if config.encoder == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained( + config.encoder_pth, do_lower_case=False) + symbols = { + 'BOS': tokenizer.cls_token_id, + 'EOS': tokenizer.sep_token_id, + 'PAD': tokenizer.pad_token_id, + 'EOQ': tokenizer.unk_token_id + } + elif config.encoder == 'bert' or config.encoder == 'zh_bert': + tokenizer = BertTokenizer.from_pretrained( + config.encoder_pth, do_lower_case=True) + symbols = { + 'BOS': tokenizer.vocab['[CLS]'], + 'EOS': tokenizer.vocab['[SEP]'], + 'PAD': tokenizer.vocab['[PAD]'], + 'EOQ': tokenizer.vocab['[unused2]'] + } + self.tokenizer = tokenizer + self.symbols = symbols + self.palm = AbsSummarizer(config, checkpoint) + self.loss = NMTLossCompute(self.palm.generator, symbols, + self.palm.vocab_size, + config.label_smoothing) + + def forward(self, input_ids, attention_mask, labels): + output = self.palm( + src=input_ids, tgt=labels, mask_src=attention_mask)[0] + loss = self.loss(labels, output) + return addict.Dict(loss=loss) + + +class Translator(object): + """ + Uses a model to translate a batch of sentences. + """ + + @dataclass + class Batch: + batch_size: int + src: torch.Tensor + tgt: torch.Tensor + mask_src: torch.Tensor + query_id: List[None] = None + src_str: List[List[str]] = None + tgt_str: List[str] = None + + def __init__(self, + model: PalmForConditionalGeneration, + dataset: str = 'cnn'): + super().__init__() + self.logger = logging.get_logger(__name__) + self.args = model.config + self.args.dataset = dataset + self.model = model.palm + self.generator = self.model.generator + self.vocab = model.tokenizer + self.symbols = model.symbols + self.start_token = self.symbols['BOS'] + self.end_token = self.symbols['EOS'] + self.alpha = self.args.alpha + self.beam_size = self.args.beam_size + self.min_length = self.args.min_length + self.max_length = self.args.max_length + + def from_batch(self, translation_batch): + batch = translation_batch['batch'] + assert (len(translation_batch['gold_score']) == len( + translation_batch['predictions'])) + batch_size = batch.batch_size + + preds, pred_score, tgt_str, src, src_str = translation_batch[ + 'predictions'], translation_batch[ + 'scores'], batch.tgt_str, batch.src, batch.src_str + query_id = batch.query_id + ''' + try: + query_id = batch.query_id + except: + query_id = None + ''' + translations = [] + for b in range(batch_size): + if self.args.dataset == 'qg_ranking_test': + if self.args.encoder == 'bert' or self.args.encoder == 'zh_bert': + pred_sents = [ + ' '.join( + self.vocab.convert_ids_to_tokens( + [int(n) for n in each])).replace(' ##', '') + for each in preds[b] + ] + elif self.args.encoder == 'roberta': + pred_sents = [ + self.vocab.decode([int(n) for n in each + ]).replace('', + '').replace('', '') + for each in preds[b] + ] + elif self.args.encoder == 'roberta': + pred_sents = self.vocab.decode([int(n) + for n in preds[b][0]]).replace( + '', + '').replace('', '') + elif self.args.encoder == 'bert': + pred_sents = self.vocab.convert_ids_to_tokens( + [int(n) for n in preds[b][0]]) + pred_sents = ' '.join(pred_sents).replace(' ##', '') + elif self.args.encoder == 'zh_bert' and self.args.dataset == 'paraphrase': + pred_sents = [ + self.vocab.convert_ids_to_tokens([int(n) for n in pred]) + for pred in preds[b] + ] + pred_sents = [ + ''.join(pred).replace(' ##', '') for pred in pred_sents + ] + elif self.args.encoder == 'zh_bert': + pred_sents = self.vocab.convert_ids_to_tokens( + [int(n) for n in preds[b][0]]) + pred_sents = ''.join(pred_sents).replace('##', '') + gold_sent = tgt_str[b] + + if self.args.encoder == 'roberta': + raw_src = self.vocab.decode([int(t) for t in src[b]]) + raw_src = ' '.join(src_str[b]) + else: + raw_src = [self.vocab.ids_to_tokens[int(t)] + for t in src[b]][:500] + raw_src = ' '.join(raw_src) + if self.args.dataset == 'faq': + translation = (pred_sents, gold_sent, src_str[b], query_id[b], + pred_score[b]) + else: + translation = (pred_sents, gold_sent, raw_src, query_id[b], + pred_score[b]) + # translation = (pred_sents[0], gold_sent) + translations.append(translation) + + return translations + + def translate(self, data_iter, step): + gold_path = self.args.result_path + '.%d.gold' % step + can_path = self.args.result_path + '.%d.candidate' % step + self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') + self.can_out_file = codecs.open(can_path, 'w', 'utf-8') + self.pred_json_score_out_file = codecs.open(can_path + '.sample', 'w', + 'utf-8') + if self.args.dataset == 'paraphrase' and self.args.encoder == 'roberta': + out = '\t'.join([ + 'query_id', 'source_query', 'target_query', 'predict_query' + ]) + '\n' + self.pred_json_score_out_file.write(out) + + raw_src_path = self.args.result_path + '.%d.raw_src' % step + self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8') + + pred_results, gold_results = [], [] + cnt = 0 + pred_dict, ref_dict = {}, {} + for i, batch in enumerate(data_iter): + self.logger.info(f'data: {i + 1} / {len(data_iter)}') + batch_data = self.translate_batch(batch) + translations = self.from_batch(batch_data) + + for trans in translations: + pred, gold, src, query_id, pred_score = trans + src = src.replace('', '').replace('##', '').strip() + if self.args.dataset == 'qg_ranking_test': + pred_str = '\t'.join([ + each.replace('[unused0]', '').replace( + '[PAD]', '').replace('[unused1]', '').replace( + r' +', ' ').replace('[SEP]', '').replace( + '[unused2]', + '').replace(r' +', ' ').replace( + '', + '').replace('', '').replace( + '', + '').replace('', '').replace( + '', ' ').strip() + for each in pred + ]) + else: + pred_str = pred.replace('[unused0]', '').replace( + '[PAD]', '').replace('[unused1]', '').replace( + r' +', ' ').replace('[SEP]', '').replace( + '[unused2]', '').replace('[CLS]', '').replace( + '[SEP]', '').replace('[UNK]', '').strip() + pred_str = pred_str.replace(r' +', ' ').replace( + '', + '').replace('', '').replace('', '').replace( + '', '').replace('', ' ').strip() + gold_str = gold.replace('', '').strip().replace( + '[UNK]', '').replace('[unused1]', '').replace( + '[unused2]', + '').replace('##', '').replace('[CLS]', '').replace( + '[SEP]', '').strip().replace('', '').replace( + '', '').replace('', ' ').strip() + if (self.args.recall_eval): + _pred_str = '' + for sent in pred_str.split(''): + can_pred_str = _pred_str + '' + sent.strip() + if len(can_pred_str.split()) >= len( + gold_str.split()) + 10: + pred_str = _pred_str + break + else: + _pred_str = can_pred_str + + if self.args.dataset == 'marco' or self.args.dataset == 'squad' or self.args.dataset == 'qg_ranking': + pred_str = pred_str.replace('', ' ') + if query_id is not None: + pred_json = { + 'query_id': query_id, + 'answers': [pred_str] + } + gold_json = { + 'query_id': query_id, + 'answers': [gold_str] + } + pred_json_score = { + 'query_id': query_id, + 'answers': [pred_str], + 'scores': pred_score[0].cpu().numpy().tolist() + } + else: + pred_json = {'query_id': cnt, 'answers': [pred_str]} + gold_json = {'query_id': cnt, 'answers': [gold_str]} + pred_json_score = { + 'query_id': cnt, + 'answers': [pred_str], + 'scores': pred_score[0].cpu().numpy().tolist() + } + json.dump(pred_json, self.can_out_file) + self.can_out_file.write('\n') + json.dump(gold_json, self.gold_out_file) + self.gold_out_file.write('\n') + json.dump(pred_json_score, self.pred_json_score_out_file) + self.pred_json_score_out_file.write('\n') + self.src_out_file.write(src.strip() + '\n') + elif self.args.dataset == 'cnn': + self.can_out_file.write(pred_str + '\n') + self.gold_out_file.write(gold_str + '\n') + self.src_out_file.write(src.strip() + '\n') + elif self.args.dataset == 'dureader': + if query_id is None: + query_id = str(cnt) + pred_results.extend(normalize([pred_str])) + gold_results.extend(normalize([gold_str])) + self.can_out_file.write(pred_str + '\n') + self.gold_out_file.write('\t'.join([src[0], gold_str]) + + '\n') + + elif self.args.dataset == 'paraphrase': + if query_id is None: + query_id = str(cnt) + if self.args.encoder == 'roberta': + pred_str = [pred_str] + pred_dict[query_id] = normalize([pred_str[0]]) + ref_dict[query_id] = normalize([gold_str]) + self.pred_json_score_out_file.write( + '\t'.join([str(query_id), src, gold_str, pred_str[0]]) + + '\n') + elif self.args.dataset == 'faq': + if pred_score[0].cpu().numpy().tolist() < -3.5: + continue + self.can_out_file.write( + '\t'.join([str(query_id), src, pred_str]) + '\n') + self.gold_out_file.write( + '\t'.join([str(query_id), src, gold_str]) + '\n') + # passage, answer, question, score + self.pred_json_score_out_file.write('\t'.join([ + str(query_id), gold_str, src, pred_str, + str(pred_score[0].cpu().numpy().tolist()) + ]) + '\n') + elif self.args.dataset == 'qg_ranking_test': + self.can_out_file.write( + str(query_id) + '\t' + pred_str + '\n') + + cnt += 1 + self.can_out_file.flush() + self.gold_out_file.flush() + self.src_out_file.flush() + self.logger.info('cnt: %s' % cnt) + self.can_out_file.close() + self.gold_out_file.close() + self.src_out_file.close() + + if (step != -1): + if self.args.dataset == 'marco' or self.args.dataset == 'squad' or self.args.dataset == 'qg_ranking': + cnn_results = subprocess.getoutput( + './run.sh %s %s' % (gold_path, can_path)) # run.sh ... + self.logger.info(cnn_results) + elif self.args.dataset == 'cnn': + self.logger.info('Calculating Rouge') + from rouge import Rouge + candidates = [ + line.strip() for line in open(can_path, encoding='utf-8') + ] + references = [ + line.strip() for line in open(gold_path, encoding='utf-8') + ] + rouge_score = Rouge().get_scores( + candidates, references, avg=True) + # self.logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) + print(rouge_score) + elif self.args.dataset == 'dureader' or self.args.dataset == 'paraphrase': + + def postprocess_text(preds, labels): + preds = [pred.strip().replace('.', '') for pred in preds] + labels = [label.strip() for label in labels] + while '' in preds: + idx = preds.index('') + preds[idx] = '。' + return preds, labels + + pred_results, gold_results = postprocess_text( + pred_results, gold_results) + pred_dict = {str(i): tmp for i, tmp in enumerate(pred_results)} + gold_dict = {str(i): tmp for i, tmp in enumerate(gold_results)} + bleu_rouge = compute_bleu_rouge(pred_dict, gold_dict) + print(bleu_rouge) + # unreachable + elif self.args.dataset == 'dureader' or self.args.dataset == 'paraphrase': + pred_results, gold_results = postprocess_text( + pred_results, gold_results) + bleu_score = cal_bleu(pred_results, gold_results) + from rouge import Rouge + rouge = Rouge() + rouge_score = rouge.get_scores( + pred_results, gold_results, avg=True) + print("'Dev eval result: Bleu-4={}, {}".format( + bleu_score, rouge_score)) + + def translate_batch(self, batch: 'Batch', fast: bool = False): + """ + Translate a batch of sentences. + + Mostly a wrapper around :obj:`Beam`. + + Args: + batch (:obj:`Batch`): a batch from a dataset object + data (:obj:`Dataset`): the dataset object + fast (bool): enables fast beam search (may not support all features) + + Todo: + Shouldn't need the original dataset. + """ + self.model.eval() + with torch.no_grad(): + return self._fast_translate_batch( + batch, self.max_length, min_length=self.min_length) + + def _tile(self, x, count, dim=0): + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = x.view(batch, -1) \ + .transpose(0, 1) \ + .repeat(count, 1) \ + .transpose(0, 1) \ + .contiguous() \ + .view(*out_size) + if dim != 0: + x = x.permute(perm).contiguous() + return x + + def _top_k_top_p_filtering(self, + logits, + top_k=10, + top_p=1.0, + filter_value=-float('Inf'), + min_tokens_to_keep=1): + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), + logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, + None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + def _fast_translate_batch(self, + batch: 'Batch', + max_length: int, + min_length: int = 0): + # TODO: faster code path for beam_size == 1. + # TODO: support these blacklisted features. + + beam_size = self.beam_size + batch_size = batch.batch_size + src = batch.src + mask_src = batch.mask_src + + src_features, _ = self.model.bert(src, mask_src, return_dict=False) + state = TransformerDecoderState(src, self.model.decoder.num_layers) + device = src_features.device + + # Tile states and memory beam_size times. + state.map_batch_fn( + lambda state, dim: self._tile(state, beam_size, dim=dim)) + src_features = self._tile(src_features, beam_size, dim=0) + batch_offset = torch.arange( + batch_size, dtype=torch.long, device=device) + beam_offset = torch.arange( + 0, + batch_size * beam_size, + step=beam_size, + dtype=torch.long, + device=device) + alive_seq = torch.full([batch_size * beam_size, 1], + self.start_token, + dtype=torch.long, + device=device) + + # Give full probability to the first beam on the first step. + topk_log_probs = ( + torch.tensor( + [0.0] + [float('-inf')] * (beam_size - 1), + device=device).repeat(batch_size)) + + # Structure that holds finished hypotheses. + hypotheses = [[] for _ in range(batch_size)] # noqa: F812 + + results = {} + results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812 + results['scores'] = [[] for _ in range(batch_size)] # noqa: F812 + results['gold_score'] = [0] * batch_size + results['batch'] = batch + + for step in range(max_length): + decoder_input = alive_seq[:, -1].view(1, -1) + + # Decoder forward. + decoder_input = decoder_input.transpose(0, 1) + dec_out, attns, state = self.model.decoder( + state, decoder_input, src_features, step=step) + + # Generator forward. + log_probs = self.generator.forward( + dec_out.transpose(0, 1).squeeze(0)) + vocab_size = log_probs.size(-1) + + if step < min_length: + log_probs[:, self.end_token] = -1e20 + + # Multiply probs by the beam probability. + + length_penalty = ((5.0 + (step + 1)) / 6.0)**self.alpha + if self.args.sample_topk: + temperature = self.args.temperature + _scores = log_probs / temperature + _scores = self._top_k_top_p_filtering( + _scores, + top_k=self.args.top_k, + top_p=self.args.top_p, + min_tokens_to_keep=1 + ) # (batch_size * num_beams, vocab_size) + # Sample 2 next words for each beam (so we have some spare tokens + # and match output of greedy beam search) + topk_ids = torch.multinomial( + F.softmax(_scores, dim=-1), + num_samples=1) # (batch_size * num_beams, 2) + # Compute next scores + _scores = F.log_softmax( + _scores, dim=1) # (batch_size * num_beams, vocab_size) + + _scores += topk_log_probs.view(-1).unsqueeze(1) + _scores = _scores / length_penalty + topk_scores = torch.gather( + _scores, -1, topk_ids) # (batch_size * num_beams, 2) + # Match shape of greedy beam search + topk_ids = topk_ids.view( + -1, beam_size) # (batch_size, 2 * num_beams) + topk_scores = topk_scores.view( + -1, beam_size) # (batch_size, 2 * num_beams) + else: + log_probs += topk_log_probs.view(-1).unsqueeze(1) + curr_scores = log_probs / length_penalty + + curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) + topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) + if self.args.block_trigram: + cur_len = alive_seq.size(1) + if cur_len > 3: + for i in range(alive_seq.size(0)): + fail = False + words = [int(w) for w in alive_seq[i]] + if self.args.encoder == 'roberta': + words = self.vocab.decode(words).strip().split() + else: + words = [ + self.vocab.ids_to_tokens[w] for w in words + ] + words = ' '.join(words).replace(' ##', '').split() + if len(words) <= 3: + continue + trigrams = [(words[i - 1], words[i], words[i + 1]) + for i in range(1, + len(words) - 1)] + trigram = tuple(trigrams[-1]) + if trigram in trigrams[:-1]: + fail = True + if fail: + curr_scores[i] = -10e20 + # Recover log probs. + topk_log_probs = topk_scores * length_penalty + + # Resolve beam origin and true word ids. + topk_beam_index = topk_ids // vocab_size + topk_ids = topk_ids.fmod(vocab_size) + + # Map beam_index to batch_index in the flat representation. + batch_index = ( + topk_beam_index + + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) + select_indices = batch_index.view(-1) + + # Append last prediction. + alive_seq = torch.cat([ + alive_seq.index_select(0, select_indices), + topk_ids.view(-1, 1) + ], -1) + + is_finished = topk_ids.eq(self.end_token) + if step + 1 == max_length: + is_finished.fill_(self.end_token) + # End condition is top beam is finished. + end_condition = is_finished[:, 0].eq(1) + # Save finished hypotheses. + if is_finished.any(): + predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) + for i in range(is_finished.size(0)): + b = batch_offset[i] + if end_condition[i]: + is_finished[i].fill_(self.end_token) + finished_hyp = is_finished[i].nonzero().view(-1) + # Store finished hypotheses for this batch. + for j in finished_hyp: + hypotheses[b].append( + (topk_scores[i, j], predictions[i, j, 1:])) + # If the batch reached the end, save the n_best hypotheses. + if end_condition[i]: + best_hyp = sorted( + hypotheses[b], key=lambda x: x[0], reverse=True) + if self.args.dataset == 'qg_ranking_test' or ( + self.args.dataset == 'paraphrase' + and not self.args.sample_topk): + for each in best_hyp[:beam_size]: + score, pred = each + results['scores'][b].append(score) + results['predictions'][b].append(pred) + else: + score, pred = best_hyp[0] + results['scores'][b].append(score) + results['predictions'][b].append(pred) + non_finished = end_condition.eq(0).nonzero().view(-1) + # If all sentences are translated, no need to go further. + if len(non_finished) == 0: + break + # Remove finished batches for the next step. + topk_log_probs = topk_log_probs.index_select(0, non_finished) + batch_index = batch_index.index_select(0, non_finished) + batch_offset = batch_offset.index_select(0, non_finished) + alive_seq = predictions.index_select(0, non_finished) \ + .view(-1, alive_seq.size(-1)) + # Reorder states. + select_indices = batch_index.view(-1) + src_features = src_features.index_select(0, select_indices) + state.map_batch_fn( + lambda state, dim: state.index_select(dim, select_indices)) + + return results + + def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, + **kwargs) -> Dict[str, torch.Tensor]: + batch = self.Batch( + batch_size=input_ids.size()[0], + src=input_ids, + tgt=None, + mask_src=attention_mask) + translation_batch = self.translate_batch(batch) + + preds = translation_batch['predictions'] + return {'predictions': preds} diff --git a/modelscope/models/nlp/palm_v2/configuration.py b/modelscope/models/nlp/palm_v2/configuration.py new file mode 100644 index 00000000..3b9e51fb --- /dev/null +++ b/modelscope/models/nlp/palm_v2/configuration.py @@ -0,0 +1,116 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PALM model configuration """ + +from transformers.configuration_utils import PretrainedConfig + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class PalmConfig(PretrainedConfig): + r""" + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or + :class:`~transformers.TFBertModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or + :class:`~transformers.TFBertModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layernorm_epsilon (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + dec_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer decoder. + attn_separate (:obj:`bool`, `optional`, defaults to false): + Whether or not to separate the q, k, v of attention. + + Examples:: + + >>> from modelscope.models.nlp.palm_v2 import PalmForConditionalGeneration, PalmConfig + >>> configuration = PalmConfig() + + >>> # Initializing a model from the configuration + >>> model = PalmForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = 'palm' + + def __init__(self, + encoder='roberta', + encoder_pth='roberta-base', + max_pos=512, + share_emb=False, + dec_layers=12, + dec_hidden_size=768, + dec_heads=8, + dec_ff_size=3072, + dec_dropout=0.2, + use_bert_emb=True, + label_smoothing=0.1, + alpha=0.95, + beam_size=5, + min_length=40, + max_length=130, + sample_topk=False, + block_trigram=False, + **kwargs): + super().__init__(**kwargs) + self.encoder = encoder + self.encoder_pth = encoder_pth + self.max_pos = max_pos + self.share_emb = share_emb + self.dec_layers = dec_layers + self.dec_hidden_size = dec_hidden_size + self.dec_heads = dec_heads + self.dec_ff_size = dec_ff_size + self.dec_dropout = dec_dropout + self.use_bert_emb = use_bert_emb + self.label_smoothing = label_smoothing + # Translator + self.alpha = alpha + self.beam_size = beam_size + self.min_length = min_length + self.max_length = max_length + self.sample_topk = sample_topk + self.block_trigram = block_trigram diff --git a/modelscope/models/nlp/palm_v2/dureader_eval.py b/modelscope/models/nlp/palm_v2/dureader_eval.py new file mode 100644 index 00000000..db54f21d --- /dev/null +++ b/modelscope/models/nlp/palm_v2/dureader_eval.py @@ -0,0 +1,872 @@ +# ============================================================================== +# Copyright 2017 Baidu.com, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This module computes evaluation metrics for DuReader dataset. +""" + +import argparse +import copy +import math +import re +import sys +import zipfile +from collections import Counter, defaultdict + +import json +import numpy as np +from rouge import Rouge + +EMPTY = '' +YESNO_LABELS = set(['Yes', 'No', 'Depends']) + + +def my_lcs(string, sub): + """ + Calculates longest common subsequence for a pair of tokenized strings + :param string : list of str : tokens from a string split using whitespace + :param sub : list of str : shorter string, also split using whitespace + :returns: length (list of int): length of the longest common subsequence between the two strings + + Note: my_lcs only gives length of the longest common subsequence, not the actual LCS + """ + if (len(string) < len(sub)): + sub, string = string, sub + + lengths = [[0 for i in range(0, + len(sub) + 1)] + for j in range(0, + len(string) + 1)] + + for j in range(1, len(sub) + 1): + for i in range(1, len(string) + 1): + if (string[i - 1] == sub[j - 1]): + lengths[i][j] = lengths[i - 1][j - 1] + 1 + else: + lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) + + return lengths[len(string)][len(sub)] + + +class Bleu: + + def __init__(self, n=4): + # default compute Blue score up to 4 + self._n = n + self._hypo_for_image = {} + self.ref_for_image = {} + + def compute_score(self, gts, res): + assert (list(gts.keys()) == list(res.keys())) + imgIds = list(gts.keys()) + + bleu_scorer = BleuScorer(n=self._n) + for id in imgIds: + hypo = res[id] + ref = gts[id] + + # Sanity check. + assert (type(hypo) is list) + assert (len(hypo) == 1) + assert (type(ref) is list) + assert (len(ref) >= 1) + + bleu_scorer += (hypo[0], ref) + + score, scores = bleu_scorer.compute_score(option='closest', verbose=1) + return score, scores + + def method(self): + return 'Bleu' + + +def precook(s, n=4, out=False): + """Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well.""" + words = s.split() + counts = defaultdict(int) + for k in range(1, n + 1): + for i in range(len(words) - k + 1): + ngram = tuple(words[i:i + k]) + counts[ngram] += 1 + return (len(words), counts) + + +def cook_refs(refs, eff=None, n=4): # lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them.''' + + reflen = [] + maxcounts = {} + for ref in refs: + rl, counts = precook(ref, n) + reflen.append(rl) + for (ngram, count) in counts.items(): + maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) + + # Calculate effective reference sentence length. + if eff == 'shortest': + reflen = min(reflen) + elif eff == 'average': + reflen = float(sum(reflen)) / len(reflen) + + # lhuang: N.B.: leave reflen computaiton to the very end!! + + # lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) + + return reflen, maxcounts + + +def cook_test(test, xxx_todo_changeme, eff=None, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it.''' + (reflen, refmaxcounts) = xxx_todo_changeme + testlen, counts = precook(test, n, True) + + result = {} + + # Calculate effective reference sentence length. + + if eff == 'closest': + result['reflen'] = min((abs(ref - testlen), ref) for ref in reflen)[1] + else: # i.e., "average" or "shortest" or None + result['reflen'] = reflen + + result['testlen'] = testlen + + result['guess'] = [max(0, testlen - k + 1) for k in range(1, n + 1)] + + result['correct'] = [0] * n + for (ngram, count) in counts.items(): + result['correct'][len(ngram) - 1] += min( + refmaxcounts.get(ngram, 0), count) + + return result + + +class BleuScorer(object): + """Bleu scorer. + """ + + __slots__ = 'n', 'crefs', 'ctest', '_score', '_ratio', '_testlen', '_reflen', 'special_reflen' + + # special_reflen is used in oracle (proportional effective ref len for a node). + + def copy(self): + ''' copy the refs.''' + new = BleuScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + new._score = None + return new + + def __init__(self, test=None, refs=None, n=4, special_reflen=None): + ''' singular instance ''' + + self.n = n + self.crefs = [] + self.ctest = [] + self.cook_append(test, refs) + self.special_reflen = special_reflen + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + cooked_test = cook_test(test, self.crefs[-1]) + self.ctest.append(cooked_test) # N.B.: -1 + else: + self.ctest.append( + None) # lens of crefs and ctest have to match + + self._score = None # need to recompute + + def ratio(self, option=None): + self.compute_score(option=option) + return self._ratio + + def score_ratio(self, option=None): + '''return (bleu, len_ratio) pair''' + return (self.fscore(option=option), self.ratio(option=option)) + + def score_ratio_str(self, option=None): + return '%.4f (%.2f)' % self.score_ratio(option) + + def reflen(self, option=None): + self.compute_score(option=option) + return self._reflen + + def testlen(self, option=None): + self.compute_score(option=option) + return self._testlen + + def retest(self, new_test): + if type(new_test) is str: + new_test = [new_test] + assert len(new_test) == len(self.crefs), new_test + self.ctest = [] + for t, rs in zip(new_test, self.crefs): + self.ctest.append(cook_test(t, rs)) + self._score = None + + return self + + def rescore(self, new_test): + ''' replace test(s) with new test(s), and returns the new score.''' + + return self.retest(new_test).compute_score() + + def size(self): + assert len(self.crefs) == len( + self.ctest), 'refs/test mismatch! %d<>%d' % (len( + self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + # avoid creating new BleuScorer instances + self.cook_append(other[0], other[1]) + else: + assert self.compatible(other), 'incompatible BLEUs.' + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + self._score = None # need to recompute + + return self + + def compatible(self, other): + return isinstance(other, BleuScorer) and self.n == other.n + + def single_reflen(self, option='average'): + return self._single_reflen(self.crefs[0][0], option) + + def _single_reflen(self, reflens, option=None, testlen=None): + + if option == 'shortest': + reflen = min(reflens) + elif option == 'average': + reflen = float(sum(reflens)) / len(reflens) + elif option == 'closest': + reflen = min((abs(ref - testlen), ref) for ref in reflens)[1] + else: + assert False, 'unsupported reflen option %s' % option + + return reflen + + def recompute_score(self, option=None, verbose=0): + self._score = None + return self.compute_score(option, verbose) + + def compute_score(self, option=None, verbose=0): + n = self.n + small = 1e-9 + tiny = 1e-15 # so that if guess is 0 still return 0 + bleu_list = [[] for _ in range(n)] + + if self._score is not None: + return self._score + + if option is None: + option = 'average' if len(self.crefs) == 1 else 'closest' + + self._testlen = 0 + self._reflen = 0 + totalcomps = { + 'testlen': 0, + 'reflen': 0, + 'guess': [0] * n, + 'correct': [0] * n + } + + # for each sentence + for comps in self.ctest: + testlen = comps['testlen'] + self._testlen += testlen + + if self.special_reflen is None: # need computation + reflen = self._single_reflen(comps['reflen'], option, testlen) + else: + reflen = self.special_reflen + + self._reflen += reflen + + for key in ['guess', 'correct']: + for k in range(n): + totalcomps[key][k] += comps[key][k] + + # append per image bleu score + bleu = 1. + for k in range(n): + bleu *= (float(comps['correct'][k]) + tiny) / ( + float(comps['guess'][k]) + small) + bleu_list[k].append(bleu**(1. / (k + 1))) + ratio = (testlen + tiny) / (reflen + small + ) # N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleu_list[k][-1] *= math.exp(1 - 1 / ratio) + + if verbose > 1: + print(comps, reflen) + + totalcomps['reflen'] = self._reflen + totalcomps['testlen'] = self._testlen + + bleus = [] + bleu = 1. + for k in range(n): + bleu *= float(totalcomps['correct'][k] + tiny) / ( + totalcomps['guess'][k] + small) + bleus.append(bleu**(1. / (k + 1))) + ratio = (self._testlen + tiny) / (self._reflen + small + ) # N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleus[k] *= math.exp(1 - 1 / ratio) + + if verbose > 0: + print(totalcomps) + print('ratio:', ratio) + + self._score = bleus + return self._score, bleu_list + + +def normalize(s): + """ + Normalize strings to space joined chars. + + Args: + s: a list of strings. + + Returns: + A list of normalized strings. + """ + if not s: + return s + normalized = [] + for ss in s: + tokens = [c for c in list(ss) if len(c.strip()) != 0] + normalized.append(' '.join(tokens)) + return normalized + + +def data_check(obj, task): + """ + Check data. + + Raises: + Raises AssertionError when data is not legal. + """ + assert 'question_id' in obj, "Missing 'question_id' field." + assert 'question_type' in obj, \ + "Missing 'question_type' field. question_id: {}".format(obj['question_type']) + + assert 'yesno_answers' in obj, \ + "Missing 'yesno_answers' field. question_id: {}".format(obj['question_id']) + assert isinstance(obj['yesno_answers'], list), \ + r"""'yesno_answers' field must be a list, if the 'question_type' is not + 'YES_NO', then this field should be an empty list. + question_id: {}""".format(obj['question_id']) + + assert 'entity_answers' in obj, \ + "Missing 'entity_answers' field. question_id: {}".format(obj['question_id']) + assert isinstance( + obj['entity_answers'], + list) and len(obj['entity_answers']) > 0, r"""'entity_answers' field + must be a list, and has at least one element, which can be a empty list. + question_id: {}""".format(obj['question_id']) + + +def read_file(file_name, task, is_ref=False): + """ + Read predict answers or reference answers from file. + + Args: + file_name: the name of the file containing predict result or reference + result. + + Returns: + A dictionary mapping question_id to the result information. The result + information itself is also a dictionary with has four keys: + - question_type: type of the query. + - yesno_answers: A list of yesno answers corresponding to 'answers'. + - answers: A list of predicted answers. + - entity_answers: A list, each element is also a list containing the entities + tagged out from the corresponding answer string. + """ + + def _open(file_name, mode, zip_obj=None): + if zip_obj is not None: + return zip_obj.open(file_name, mode) + return open(file_name, mode) + + results = {} + keys = ['answers', 'yesno_answers', 'entity_answers', 'question_type'] + if is_ref: + keys += ['source'] + + zf = zipfile.ZipFile(file_name, + 'r') if file_name.endswith('.zip') else None + file_list = [file_name] if zf is None else zf.namelist() + + for fn in file_list: + for line in _open(fn, 'r', zip_obj=zf): + try: + obj = json.loads(line.strip()) + except ValueError: + raise ValueError('Every line of data should be legal json') + data_check(obj, task) + qid = obj['question_id'] + assert qid not in results, 'Duplicate question_id: {}'.format(qid) + results[qid] = {} + for k in keys: + results[qid][k] = obj[k] + return results + + +def compute_bleu_rouge(pred_dict, ref_dict, bleu_order=4): + """ + Compute bleu and rouge scores. + """ + assert set(pred_dict.keys()) == set(ref_dict.keys()), \ + 'missing keys: {}'.format(set(ref_dict.keys()) - set(pred_dict.keys())) + scores = {} + bleu_scores, _ = Bleu(bleu_order).compute_score(ref_dict, pred_dict) + for i, bleu_score in enumerate(bleu_scores): + scores['Bleu-%d' % (i + 1)] = bleu_score + # rouge_score, _ = Rouge().compute_score(ref_dict, pred_dict) + rouge_score = Rouge().get_scores( + list(map(lambda x: x[0], pred_dict.values())), + list(map(lambda x: x[0], ref_dict.values()))) + rouge_score = sum([d['rouge-l']['f'] + for d in rouge_score]) / len(rouge_score) + scores['Rouge-L'] = rouge_score + return scores + + +def local_prf(pred_list, ref_list): + """ + Compute local precision recall and f1-score, + given only one prediction list and one reference list + """ + common = Counter(pred_list) & Counter(ref_list) + num_same = sum(common.values()) + if num_same == 0: + return 0, 0, 0 + p = 1.0 * num_same / len(pred_list) + r = 1.0 * num_same / len(ref_list) + f1 = (2 * p * r) / (p + r) + return p, r, f1 + + +def compute_prf(pred_dict, ref_dict): + """ + Compute precision recall and f1-score. + """ + # pred_question_ids = set(pred_dict.keys()) + ref_question_ids = set(ref_dict.keys()) + correct_preds, total_correct, total_preds = 0, 0, 0 + for question_id in ref_question_ids: + pred_entity_list = pred_dict.get(question_id, [[]]) + assert len(pred_entity_list) == 1, \ + 'the number of entity list for question_id {} is not 1.'.format(question_id) + pred_entity_list = pred_entity_list[0] + all_ref_entity_lists = ref_dict[question_id] + best_local_f1 = 0 + best_ref_entity_list = None + for ref_entity_list in all_ref_entity_lists: + local_f1 = local_prf(pred_entity_list, ref_entity_list)[2] + if local_f1 > best_local_f1: + best_ref_entity_list = ref_entity_list + best_local_f1 = local_f1 + if best_ref_entity_list is None: + if len(all_ref_entity_lists) > 0: + best_ref_entity_list = sorted( + all_ref_entity_lists, key=lambda x: len(x))[0] + else: + best_ref_entity_list = [] + gold_entities = set(best_ref_entity_list) + pred_entities = set(pred_entity_list) + correct_preds += len(gold_entities & pred_entities) + total_preds += len(pred_entities) + total_correct += len(gold_entities) + p = float(correct_preds) / total_preds if correct_preds > 0 else 0 + r = float(correct_preds) / total_correct if correct_preds > 0 else 0 + f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0 + return {'Precision': p, 'Recall': r, 'F1': f1} + + +def prepare_prf(pred_dict, ref_dict): + """ + Prepares data for calculation of prf scores. + """ + preds = {k: v['entity_answers'] for k, v in pred_dict.items()} + refs = {k: v['entity_answers'] for k, v in ref_dict.items()} + return preds, refs + + +def filter_dict(result_dict, key_tag): + """ + Filter a subset of the result_dict, where keys ends with 'key_tag'. + """ + filtered = {} + for k, v in result_dict.items(): + if k.endswith(key_tag): + filtered[k] = v + return filtered + + +def get_metrics(pred_result, ref_result, task, source): + """ + Computes metrics. + """ + metrics = {} + + ref_result_filtered = {} + pred_result_filtered = {} + if source == 'both': + ref_result_filtered = ref_result + pred_result_filtered = pred_result + else: + for question_id, info in ref_result.items(): + if info['source'] == source: + ref_result_filtered[question_id] = info + if question_id in pred_result: + pred_result_filtered[question_id] = pred_result[ + question_id] + + if task == 'main' or task == 'all' \ + or task == 'description': + pred_dict, ref_dict = prepare_bleu(pred_result_filtered, + ref_result_filtered, task) + metrics = compute_bleu_rouge(pred_dict, ref_dict) + elif task == 'yesno': + pred_dict, ref_dict = prepare_bleu(pred_result_filtered, + ref_result_filtered, task) + keys = ['Yes', 'No', 'Depends'] + preds = [filter_dict(pred_dict, k) for k in keys] + refs = [filter_dict(ref_dict, k) for k in keys] + + metrics = compute_bleu_rouge(pred_dict, ref_dict) + + for k, pred, ref in zip(keys, preds, refs): + m = compute_bleu_rouge(pred, ref) + k_metric = [(k + '|' + key, v) for key, v in m.items()] + metrics.update(k_metric) + + elif task == 'entity': + pred_dict, ref_dict = prepare_prf(pred_result_filtered, + ref_result_filtered) + pred_dict_bleu, ref_dict_bleu = prepare_bleu(pred_result_filtered, + ref_result_filtered, task) + metrics = compute_prf(pred_dict, ref_dict) + metrics.update(compute_bleu_rouge(pred_dict_bleu, ref_dict_bleu)) + else: + raise ValueError('Illegal task name: {}'.format(task)) + + return metrics + + +def prepare_bleu(pred_result, ref_result, task): + """ + Prepares data for calculation of bleu and rouge scores. + """ + pred_list, ref_list = [], [] + qids = ref_result.keys() + for qid in qids: + if task == 'main': + pred, ref = get_main_result(qid, pred_result, ref_result) + elif task == 'yesno': + pred, ref = get_yesno_result(qid, pred_result, ref_result) + elif task == 'all': + pred, ref = get_all_result(qid, pred_result, ref_result) + elif task == 'entity': + pred, ref = get_entity_result(qid, pred_result, ref_result) + elif task == 'description': + pred, ref = get_desc_result(qid, pred_result, ref_result) + else: + raise ValueError('Illegal task name: {}'.format(task)) + if pred and ref: + pred_list += pred + ref_list += ref + pred_dict = dict(pred_list) + ref_dict = dict(ref_list) + for qid, ans in ref_dict.items(): + ref_dict[qid] = normalize(ref_dict[qid]) + pred_dict[qid] = normalize(pred_dict.get(qid, [EMPTY])) + if not ans or ans == [EMPTY]: + del ref_dict[qid] + del pred_dict[qid] + + for k, v in pred_dict.items(): + assert len(v) == 1, \ + 'There should be only one predict answer. question_id: {}'.format(k) + return pred_dict, ref_dict + + +def get_main_result(qid, pred_result, ref_result): + """ + Prepare answers for task 'main'. + + Args: + qid: question_id. + pred_result: A dict include all question_id's result information read + from args.pred_file. + ref_result: A dict incluce all question_id's result information read + from args.ref_file. + Returns: + Two lists, the first one contains predict result, the second + one contains reference result of the same question_id. Each list has + elements of tuple (question_id, answers), 'answers' is a list of strings. + """ + ref_ans = ref_result[qid]['answers'] + if not ref_ans: + ref_ans = [EMPTY] + pred_ans = pred_result.get(qid, {}).get('answers', [])[:1] + if not pred_ans: + pred_ans = [EMPTY] + + return [(qid, pred_ans)], [(qid, ref_ans)] + + +def get_entity_result(qid, pred_result, ref_result): + """ + Prepare answers for task 'entity'. + + Args: + qid: question_id. + pred_result: A dict include all question_id's result information read + from args.pred_file. + ref_result: A dict incluce all question_id's result information read + from args.ref_file. + Returns: + Two lists, the first one contains predict result, the second + one contains reference result of the same question_id. Each list has + elements of tuple (question_id, answers), 'answers' is a list of strings. + """ + if ref_result[qid]['question_type'] != 'ENTITY': + return None, None + return get_main_result(qid, pred_result, ref_result) + + +def get_desc_result(qid, pred_result, ref_result): + """ + Prepare answers for task 'description'. + + Args: + qid: question_id. + pred_result: A dict include all question_id's result information read + from args.pred_file. + ref_result: A dict incluce all question_id's result information read + from args.ref_file. + Returns: + Two lists, the first one contains predict result, the second + one contains reference result of the same question_id. Each list has + elements of tuple (question_id, answers), 'answers' is a list of strings. + """ + if ref_result[qid]['question_type'] != 'DESCRIPTION': + return None, None + return get_main_result(qid, pred_result, ref_result) + + +def get_yesno_result(qid, pred_result, ref_result): + """ + Prepare answers for task 'yesno'. + + Args: + qid: question_id. + pred_result: A dict include all question_id's result information read + from args.pred_file. + ref_result: A dict incluce all question_id's result information read + from args.ref_file. + Returns: + Two lists, the first one contains predict result, the second + one contains reference result of the same question_id. Each list has + elements of tuple (question_id, answers), 'answers' is a list of strings. + """ + + def _uniq(li, is_ref): + uniq_li = [] + left = [] + keys = set() + for k, v in li: + if k not in keys: + uniq_li.append((k, v)) + keys.add(k) + else: + left.append((k, v)) + + if is_ref: + dict_li = dict(uniq_li) + for k, v in left: + dict_li[k] += v + uniq_li = [(k, v) for k, v in dict_li.items()] + return uniq_li + + def _expand_result(uniq_li): + expanded = uniq_li[:] + keys = set([x[0] for x in uniq_li]) + for k in YESNO_LABELS - keys: + expanded.append((k, [EMPTY])) + return expanded + + def _get_yesno_ans(qid, result_dict, is_ref=False): + if qid not in result_dict: + return [(str(qid) + '_' + k, v) for k, v in _expand_result([])] + yesno_answers = result_dict[qid]['yesno_answers'] + answers = result_dict[qid]['answers'] + lbl_ans = _uniq([(k, [v]) for k, v in zip(yesno_answers, answers)], + is_ref) + ret = [(str(qid) + '_' + k, v) for k, v in _expand_result(lbl_ans)] + return ret + + if ref_result[qid]['question_type'] != 'YES_NO': + return None, None + + ref_ans = _get_yesno_ans(qid, ref_result, is_ref=True) + pred_ans = _get_yesno_ans(qid, pred_result) + return pred_ans, ref_ans + + +def get_all_result(qid, pred_result, ref_result): + """ + Prepare answers for task 'all'. + + Args: + qid: question_id. + pred_result: A dict include all question_id's result information read + from args.pred_file. + ref_result: A dict incluce all question_id's result information read + from args.ref_file. + Returns: + Two lists, the first one contains predict result, the second + one contains reference result of the same question_id. Each list has + elements of tuple (question_id, answers), 'answers' is a list of strings. + """ + if ref_result[qid]['question_type'] == 'YES_NO': + return get_yesno_result(qid, pred_result, ref_result) + return get_main_result(qid, pred_result, ref_result) + + +def format_metrics(metrics, task, err_msg): + """ + Format metrics. 'err' field returns any error occured during evaluation. + + Args: + metrics: A dict object contains metrics for different tasks. + task: Task name. + err_msg: Exception raised during evaluation. + Returns: + Formatted result. + """ + result = {} + sources = ['both', 'search', 'zhidao'] + if err_msg is not None: + return {'errorMsg': str(err_msg), 'errorCode': 1, 'data': []} + data = [] + if task != 'all' and task != 'main': + sources = ['both'] + + if task == 'entity': + metric_names = ['Bleu-4', 'Rouge-L'] + metric_names_prf = ['F1', 'Precision', 'Recall'] + for name in metric_names + metric_names_prf: + for src in sources: + obj = { + 'name': name, + 'value': round(metrics[src].get(name, 0) * 100, 2), + 'type': src, + } + data.append(obj) + elif task == 'yesno': + metric_names = ['Bleu-4', 'Rouge-L'] + details = ['Yes', 'No', 'Depends'] + src = sources[0] + for name in metric_names: + obj = { + 'name': name, + 'value': round(metrics[src].get(name, 0) * 100, 2), + 'type': 'All', + } + data.append(obj) + for d in details: + obj = { + 'name': name, + 'value': round(metrics[src].get(d + '|' + name, 0) * 100, + 2), + 'type': d + } + data.append(obj) + else: + metric_names = ['Bleu-4', 'Rouge-L'] + for name in metric_names: + for src in sources: + obj = { + 'name': name, + 'value': round(metrics[src].get(name, 0) * 100, 2), + 'type': src + } + data.append(obj) + + result['data'] = data + result['errorCode'] = 0 + result['errorMsg'] = 'success' + + return result + + +def main(args): + """ + Do evaluation. + """ + err = None + metrics = {} + try: + pred_result = read_file(args.pred_file, args.task) + ref_result = read_file(args.ref_file, args.task, is_ref=True) + sources = ['both', 'search', 'zhidao'] + if args.task not in set(['main', 'all']): + sources = sources[:1] + for source in sources: + metrics[source] = get_metrics(pred_result, ref_result, args.task, + source) + except ValueError as ve: + err = ve + except AssertionError as ae: + err = ae + + print( + json.dumps( + format_metrics(metrics, args.task, err), + ensure_ascii=False).encode('utf8')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('pred_file', help='predict file') + parser.add_argument('ref_file', help='reference file') + parser.add_argument( + 'task', help='task name: Main|Yes_No|All|Entity|Description') + + args = parser.parse_args() + args.task = args.task.lower().replace('_', '') + main(args) diff --git a/modelscope/models/nlp/palm_v2/text_generation.py b/modelscope/models/nlp/palm_v2/text_generation.py new file mode 100644 index 00000000..d83860db --- /dev/null +++ b/modelscope/models/nlp/palm_v2/text_generation.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict, List + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + +__all__ = ['PalmForTextGeneration'] + + +@MODELS.register_module(Tasks.text_generation, module_name=Models.palm) +class PalmForTextGeneration(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + super().__init__(model_dir, *args, **kwargs) + + from modelscope.models.nlp.palm_v2 import ( + PalmForConditionalGeneration, Translator) + self.model = PalmForConditionalGeneration.from_pretrained(model_dir) + self.tokenizer = self.model.tokenizer + self.generator = Translator(self.model) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'loss': Tensor([12.34]), # loss for backward + } + """ + return self.model(**input) + + def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + outputs = self.generator(**input) + preds = outputs['predictions'] + return {'sequences': [pred[0] for pred in preds]} diff --git a/modelscope/models/nlp/plug/__init__.py b/modelscope/models/nlp/plug/__init__.py new file mode 100644 index 00000000..589a636a --- /dev/null +++ b/modelscope/models/nlp/plug/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import PlugNLGConfig + from .backbone import PlugModel + from .distributed_plug import DistributedPlug +else: + _import_structure = { + 'configuration': ['PlugNLGConfig'], + 'backbone': ['PlugModel'], + 'distributed_plug': ['DistributedPlug'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/plug/backbone.py b/modelscope/models/nlp/plug/backbone.py new file mode 100644 index 00000000..7f3f12de --- /dev/null +++ b/modelscope/models/nlp/plug/backbone.py @@ -0,0 +1,1017 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import logging +import math +import os + +import torch +import torch.nn.functional as F +from deepspeed.utils.timer import SynchronizedWallClockTimer +from megatron import mpu +from torch import nn + +from modelscope.utils.nlp.distributed import (normal_init_method, + scaled_init_method) +from .configuration import PlugNLGConfig, PlugNLUConfig + +logger = logging.getLogger(__name__) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish} + + +class BertLayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = mpu.VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + init_method=normal_init_method( + mean=0.0, std=config.initializer_range)) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.fp32_layernorm = config.fp32_layernorm + self.fp32_embedding = config.fp32_embedding + self.fp32_tokentypes = config.fp32_tokentypes + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + if not self.fp32_tokentypes: + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + if self.fp32_embedding and not self.fp32_layernorm: + embeddings = embeddings.half() + previous_type = embeddings.type() + if self.fp32_layernorm: + embeddings = embeddings.float() + embeddings = self.LayerNorm(embeddings) + if self.fp32_layernorm: + if self.fp32_embedding: + embeddings = embeddings.half() + else: + embeddings = embeddings.type(previous_type) + else: + embeddings = words_embeddings.float() + position_embeddings.float( + ) + token_type_embeddings.float() + if self.fp32_tokentypes and not self.fp32_layernorm: + embeddings = embeddings.half() + previous_type = embeddings.type() + if self.fp32_layernorm: + embeddings = embeddings.float() + embeddings = self.LayerNorm(embeddings) + if self.fp32_layernorm: + if self.fp32_tokentypes: + embeddings = embeddings.half() + else: + embeddings = embeddings.type(previous_type) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super(BertSelfOutput, self).__init__() + if hasattr(config, 'deep_init') and config.deep_init: + init_method = scaled_init_method( + mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + else: + init_method = normal_init_method( + mean=0.0, std=config.initializer_range) + self.dense = mpu.RowParallelLinear( + input_size=config.hidden_size, + output_size=config.hidden_size, + bias=True, + input_is_parallel=True, + stride=1, + init_method=init_method) + self.fp32_layernorm = config.fp32_layernorm + if not config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states, + input_tensor, + ): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + ln_input = hidden_states + input_tensor + if self.LayerNorm is not None: + previous_type = ln_input.type() + if self.fp32_layernorm: + ln_input = ln_input.float() + hidden_states = self.LayerNorm(ln_input) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + else: + hidden_states = ln_input + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config): + super(BertAttention, self).__init__() + self.fp32_layernorm = config.fp32_layernorm + if config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + self.self = mpu.BertParallelSelfAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + dropout_prob=config.attention_probs_dropout_prob, + output_parallel=True, + init_method=normal_init_method( + mean=0.0, std=config.initializer_range), + separate=config.attn_separate) + self.output = BertSelfOutput(config) + + def forward( + self, + input_tensor, + attention_mask, + ): + if self.LayerNorm is not None: + ln_input = input_tensor + previous_type = input_tensor.type() + if self.fp32_layernorm: + ln_input = input_tensor.float() + ln_output = self.LayerNorm(ln_input) + if self.fp32_layernorm: + ln_output = ln_output.type(previous_type) + self_output = self.self( + ln_output, + attention_mask, + ) + else: + self_output = self.self( + input_tensor, + attention_mask, + ) + + attention_output = self.output( + self_output, + input_tensor, + ) + return attention_output + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = mpu.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.intermediate_size, + bias=True, + gather_output=False, + stride=1, + init_method=normal_init_method( + mean=0.0, std=config.initializer_range)) + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + + def forward( + self, + hidden_states, + ): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super(BertOutput, self).__init__() + if hasattr(config, 'deep_init') and config.deep_init: + init_method = scaled_init_method( + mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + else: + init_method = normal_init_method( + mean=0.0, std=config.initializer_range) + self.dense = mpu.RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=True, + input_is_parallel=True, + stride=1, + init_method=init_method) + self.fp32_layernorm = config.fp32_layernorm + if not config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states, + input_tensor, + ): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + ln_input = hidden_states + input_tensor + if self.LayerNorm is not None: + previous_type = ln_input.type() + if self.fp32_layernorm: + ln_input = ln_input.float() + hidden_states = self.LayerNorm(ln_input) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + else: + hidden_states = ln_input + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + self.fp32_layernorm = config.fp32_layernorm + if config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + if self.LayerNorm is not None: + ln_input = attention_output + previous_type = attention_output.type() + if self.fp32_layernorm: + ln_input = attention_output.float() + ln_output = self.LayerNorm(ln_input) + if self.fp32_layernorm: + ln_output = ln_output.type(previous_type) + intermediate_output = self.intermediate(ln_output) + else: + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super(BertEncoder, self).__init__() + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.fp32_layernorm = config.fp32_layernorm + if config.pre_ln: + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + else: + self.LayerNorm = None + + def forward( + self, + hidden_states, + attention_mask, + output_all_encoded_layers=True, + checkpoint_activations=False, + detach_index=-1, + ): + all_encoder_layers = [] + + def custom(start, end): + + def custom_forward(*inputs): + layers = self.layer[start:end] + x_ = inputs[0] + for layer in layers: + x_ = layer(x_, inputs[1]) + return x_ + + return custom_forward + + if checkpoint_activations: + layer_idx = 0 + num_layers = len(self.layer) + chunk_length = 1 + while layer_idx < num_layers: + hidden_states = mpu.checkpoint( + custom(layer_idx, layer_idx + chunk_length), hidden_states, + attention_mask * 1) + if detach_index == layer_idx: + hidden_states.detach_() + layer_idx += chunk_length + # decoder layers + else: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attention_mask) + if detach_index == i: + hidden_states.detach_() + if i == len(self.layer) - 1 and self.LayerNorm is not None: + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.LayerNorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + + if not output_all_encoded_layers or checkpoint_activations: + if self.LayerNorm is not None: + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.LayerNorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + self.LayerNorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.fp32_layernorm = config.fp32_layernorm + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.LayerNorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder_weight = bert_model_embedding_weights + self.bias = nn.Parameter( + torch.zeros(bert_model_embedding_weights.size(0))) + self.bias.model_parallel = True + self.fp32_embedding = config.fp32_embedding + self.fp32_layernorm = config.fp32_layernorm + + def convert_to_type(tensor): + if self.fp32_embedding: + return tensor.half() + else: + return tensor + + self.type_converter = convert_to_type + self.converted = False + self.timers = SynchronizedWallClockTimer() + + def forward(self, hidden_states): + if not self.converted: + self.converted = True + if self.fp32_embedding: + self.transform.half() + if self.fp32_layernorm: + self.transform.LayerNorm.float() + hidden_states = self.transform(self.type_converter(hidden_states)) + self.timers('final linear gather').start() + hidden_states = mpu.copy_to_model_parallel_region(hidden_states) + self.timers('final linear gather').stop() + hidden_states = F.linear( + self.type_converter(hidden_states), + self.type_converter(self.decoder_weight), + self.type_converter(self.bias)) + return hidden_states + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 3) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + for p in self.seq_relationship.parameters(): + if p is None: + continue + pooled_output = pooled_output.type_as(p) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class PreTrainedBertModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + def __init__(self, config, *inputs, **kwargs): + super(PreTrainedBertModel, self).__init__() + if not isinstance(config, PlugNLUConfig) and not isinstance( + config, PlugNLGConfig): + raise ValueError( + 'Parameter config in `{}(config)` should be an instance of class `BertConfig`. ' + 'To create a model from a Google pretrained model use ' + '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( + self.__class__.__name__, self.__class__.__name__)) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(PreTrainedBertModel): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Params: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as + described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLF`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + output_all_encoded_layers=True, + checkpoint_activations=False, + detach_index=-1, + ): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder( + embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers, + checkpoint_activations=checkpoint_activations, + detach_index=detach_index) + sequence_output = encoded_layers[-1] + for p in self.pooler.parameters(): + if p is None: + continue + sequence_output = sequence_output.type_as(p) + break + + pooled_output = sequence_output[:, 0] + if not output_all_encoded_layers or checkpoint_activations: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class DecodeLayer(nn.Module): + + def __init__(self, config): + super(DecodeLayer, self).__init__() + init_method = normal_init_method( + mean=0.0, std=config.initializer_range) + output_layer_init_method = scaled_init_method( + mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + + self.attention = mpu.GPT2ParallelSelfAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + attention_dropout_prob=config.attention_probs_dropout_prob, + output_dropout_prob=config.hidden_dropout_prob, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + ) + + self.cross_attention = mpu.PalmParallelCrossAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + attention_dropout_prob=config.attention_probs_dropout_prob, + output_dropout_prob=config.hidden_dropout_prob, + init_method=init_method, + attn_separate=False, + output_layer_init_method=output_layer_init_method, + ) + + self.input_layernorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.post_attention_layernorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.post_cross_attention_layernorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + self.intermediate = mpu.ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + gather_output=False, + init_method=init_method, + ) + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + self.output = mpu.RowParallelLinear( + config.intermediate_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + ) + + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.fp32_layernorm = config.fp32_layernorm + + def convert_to_type(tensor): + if self.fp32_layernorm: + return tensor.float() + else: + return tensor + + self.type_converter = convert_to_type + + # def forward(self, hidden_states, enc_attn_mask, dec_attn_mask): + def forward(self, + hidden_states, + enc_hidden_states, + enc_attn_mask, + dec_attn_mask, + is_infer=False): + residual = hidden_states + previous_type = hidden_states.type() + hidden_states = self.input_layernorm( + self.type_converter(hidden_states)) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + hidden_states = self.attention( + hidden_states, dec_attn_mask, is_infer=is_infer) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm( + self.type_converter(hidden_states)) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + hidden_states = self.cross_attention(hidden_states, enc_hidden_states, + enc_attn_mask) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_cross_attention_layernorm( + self.type_converter(hidden_states)) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + hidden_states = self.intermediate(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + hidden_states = self.output(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class BertDecoder(nn.Module): + + def __init__(self, config): + super(BertDecoder, self).__init__() + self.layer = nn.ModuleList( + [DecodeLayer(config) for _ in range(config.dec_hidden_layers)]) + + self.final_layernorm = BertLayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + self.fp32_layernorm = config.fp32_layernorm + + def forward(self, + hidden_states, + enc_hidden_states, + enc_attn_mask, + dec_attn_mask, + checkpoint_activations=False, + output_all_encoded_layers=False, + is_infer=False): + + def custom(start, end): + + def custom_forward(*inputs): + layers = self.layer[start:end] + x_ = inputs[0] + for layer in layers: + x_ = layer( + x_, + inputs[1], + inputs[2], + dec_attn_mask * 1, + is_infer=is_infer) + return x_ + + return custom_forward + + pre_enc_hidden = enc_hidden_states.data + if checkpoint_activations: + layer_idx = 0 + num_layers = len(self.layer) + chunk_length = 1 + while layer_idx < num_layers: + hidden_states = mpu.checkpoint( + custom(layer_idx, layer_idx + chunk_length), hidden_states, + enc_hidden_states, enc_attn_mask * 1) + enc_hidden_states.data = pre_enc_hidden + layer_idx += chunk_length + else: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module( + hidden_states, + enc_hidden_states, + enc_attn_mask, + dec_attn_mask, + is_infer=is_infer) + + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.final_layernorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + + return [hidden_states] + + +class DecodeModel(PreTrainedBertModel): + + def __init__(self, config): + super(DecodeModel, self).__init__(config) + self.decoder = BertDecoder(config) + self.apply(self.init_bert_weights) + + def forward(self, + embeddings, + sequence_output, + decode_input_ids, + position_ids=None, + enc_attn_mask=None, + dec_attn_mask=None, + checkpoint_activations=False, + is_infer=False): + extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = embeddings(decode_input_ids) + sequence_output = self.decoder( + embedding_output, + sequence_output, + extended_attention_mask, + dec_attn_mask, + checkpoint_activations=False, + is_infer=is_infer) + return sequence_output[-1] + + +class PalmForPreTraining(PreTrainedBertModel): + + def __init__(self, config): + super(PalmForPreTraining, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads( + config, self.bert.embeddings.word_embeddings.weight) + self.decoder = DecodeModel(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + decode_input_ids=None, + position_ids=None, + decode_attention_mask=None, + lm_labels=None, + checkpoint_activations=False, + is_infer=False, + sequence_output=None, + parallel_output=True): + if sequence_output is None: + sequence_output, pooled_output = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + checkpoint_activations=checkpoint_activations) + prediction_scores, seq_relationship_score = self.cls( + sequence_output, pooled_output) + else: + prediction_scores = None + sequence_output = sequence_output.to( + dtype=next(self.decoder.parameters()).dtype) + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + decode_output = self.decoder( + self.bert.embeddings, + sequence_output, + decode_input_ids, + position_ids, + attention_mask, + decode_attention_mask, + checkpoint_activations=checkpoint_activations, + is_infer=is_infer) + + transformer_output_parallel = mpu.copy_to_model_parallel_region( + decode_output) + + logits_parallel = F.linear(transformer_output_parallel, + self.bert.embeddings.word_embeddings.weight) + + if parallel_output: + return prediction_scores, logits_parallel + if is_infer: + return prediction_scores, mpu.gather_from_model_parallel_region( + logits_parallel), sequence_output + return prediction_scores, mpu.gather_from_model_parallel_region( + logits_parallel) + + +class PlugModel(torch.nn.Module): + """ + The bare Plug Model transformer outputting raw hidden-states without any specific head on top. + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`PlugNLGConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~DistributedPlug.initialize_model`] method to load the model weights. + Example: + + ```python + >>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given + >>> # here only initializes a slice of the model on a single GPU. + >>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model. + >>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel + + >>> # Initializing a Plug configuration + >>> configuration = PlugNLGConfig() + + >>> # Initializing a model from the configuration + >>> model = PlugModel(configuration) + """ + + def __init__(self, config): + super(PlugModel, self).__init__() + self.config = config + self.model = PalmForPreTraining(self.config) + + def forward(self, + input_tokens, + token_type_ids=None, + attention_mask=None, + target_tokens=None, + position_ids=None, + decode_attention_mask=None, + checkpoint_activations=False, + is_infer=False, + sequence_output=None, + parallel_output=True): + """ + Parameters: + input_tokens (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`): + `input_tokens_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. + Indices can be obtained using transformers [`BertTokenizer`]. See + [`TextGenerationPreprocessor.__call__`] for details. + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`, *optional*, defaults to + None): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + target_tokens (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): + Target token ids(labels) for language modeling. Note that the labels **are shifted** inside the model, + i.e. you can set `target_tokens = input_tokens` Indices are selected in + `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only + computed for labels in `[0, ..., config.vocab_size]` + + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.max_position_embeddings - 1]`. + + decode_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults + to None): + Mask to avoid performing attention on padding token indices of target tokens. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + checkpoint_activations (`boolean`, *optional*, defaults to `False`): + Whether gradient checkpointing is activated for this model or not. + is_infer (`boolean`, *optional*, defaults to `False`): + Whether or not to perform single inference. + sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, + defaults to None): + Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the + model. A single forward() call can produce one single token. To generate the current token, the + sequence_output generated by the `forward()` of the previous token is required. + parallel_output (`boolean`, *optional*, defaults to `True`): + To parallel return output, or gather it before return. + + + """ + return self.model( + input_tokens, + token_type_ids, + attention_mask, + target_tokens, + position_ids, + decode_attention_mask, + checkpoint_activations=checkpoint_activations, + is_infer=is_infer, + sequence_output=sequence_output, + parallel_output=parallel_output) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.model.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars) + + def load_state_dict(self, state_dict, strict=True): + return self.model.load_state_dict(state_dict, strict=strict) diff --git a/modelscope/models/nlp/plug/configuration.py b/modelscope/models/nlp/plug/configuration.py new file mode 100644 index 00000000..c3a526a9 --- /dev/null +++ b/modelscope/models/nlp/plug/configuration.py @@ -0,0 +1,255 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import json +from transformers import PretrainedConfig + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class PlugNLUConfig(PretrainedConfig): + model_type = 'plugNLU' + + def __init__(self, + vocab_size=21504, + original_vocab_size=21128, + hidden_size=8192, + num_hidden_layers=24, + num_attention_heads=128, + intermediate_size=32768, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=2048, + type_vocab_size=3, + initializer_range=0.00707, + lr_decay_style='linear', + weight_decay=1e-2, + clip_grad=1.0, + warmup=0.0333, + pre_ln=True, + fp16=True, + fp32_layernorm=True, + fp32_embedding=False, + fp32_tokentypes=False, + layernorm_epsilon=1e-5, + dec_hidden_layers=6, + attn_separate=False, + **kwargs): + super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) + + self.vocab_size = vocab_size + self.original_vocab_size = original_vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.lr_decay_style = lr_decay_style + self.weight_decay = weight_decay + self.clip_grad = clip_grad + self.warmup = warmup + self.pre_ln = pre_ln + self.fp16 = fp16 + self.fp32_layernorm = fp32_layernorm + self.fp32_embedding = fp32_embedding + self.layernorm_epsilon = layernorm_epsilon + self.fp32_tokentypes = fp32_tokentypes + self.dec_hidden_layers = dec_hidden_layers + self.attn_separate = attn_separate + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = PlugNLUConfig() + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, 'r', encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def merge_args(self, args): + """merge values a `BertConfig` from a json file of parameters.""" + local_keys = self.__dict__.keys() + for key, value in args.__dict__.items(): + if key in local_keys: + continue + self.__dict__[key] = value + return self + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' + + +class PlugNLGConfig(PlugNLUConfig): + """ + This is the configuration class to store the configuration of a [`PlugModel`]. It is used to instantiate a + PLUG understanding model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the PLUG + [PLUG](https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary) architecture. + + Configuration objects inherit from [`PlugNLUConfig`] and can be used to control the model outputs. Read the + documentation from [`PlugNLUConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 21504): + Padded vocabulary size of the PLUG model for vocab tensor parallel. Defines the number of different tokens + that can be represented by the `inputs_ids` passed when calling [`PlugModel`]. + original_vocab_size (`int`, *optional*, defaults to 21128): + True vocabulary size of the PLUG model. Defines the number of different tokens that can be represented. + hidden_size (`int`, *optional*, defaults to 8192): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + dec_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 128): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 32768): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the Transformer Attention. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 3): + The vocabulary size of the `token_type_ids` passed when calling [`PlugModel`]. + initializer_range (`float`, *optional*, defaults to 0.00707): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + lr_decay_style (`str`, *optional*, defaults to 'linear'): + The decay style of learning rate during fine-tunining. If string, `"linear"`, `"cosine"`, `"exponential"`, + `"constant"`, `"None"` are supported. + weight_decay (`float`, *optional*, defaults to 1e-2): + Decoupled weight decay to apply. + clip_grad (`float`, *optional*, defaults to 1.0): + Maximum gradient norm for gradient clipping. + warmup (`float`, *optional*, defaults to 0.01): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + pre_ln (`boolean`, *optional*, defaults to `True`): + Whether or not to apply LayerNorm to the input instead of the output in the blocks. + fp16 (`boolean`, *optional*, defaults to `True`): + Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training. + fp32_layernorm (`boolean`, *optional*, defaults to `True`): + Whether to use fp32 32-bit precision LayerNorm training while the argument `fp16` set to `True`. + fp32_embedding (`boolean`, *optional*, defaults to `False`): + Whether to use fp32 32-bit precision Embedding training while the argument `fp16` set to `True`. + fp32_tokentypes (`boolean`, *optional*, defaults to `False`): + Whether to use fp32 32-bit precision token types training while the argument `fp16` set to `True`. + layernorm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + attn_separate (`boolean`, *optional*, defaults to `False`): + Whether or not to separate query-key-value to query, key, value in the Attention. + + Example: + + ```python + >>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given + >>> # here only initializes a slice of the model on a single GPU. + >>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model. + >>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel + + >>> # Initializing a Plug configuration + >>> configuration = PlugNLGConfig() + + >>> # Initializing a model from the configuration + >>> model = PlugModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = 'plugNLG' + + def __init__(self, + vocab_size=21504, + original_vocab_size=21128, + hidden_size=8192, + num_hidden_layers=24, + dec_hidden_layers=6, + num_attention_heads=128, + intermediate_size=32768, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=2048, + type_vocab_size=3, + initializer_range=0.00707, + lr_decay_style='linear', + weight_decay=1e-2, + clip_grad=1.0, + warmup=0.01, + pre_ln=True, + fp16=True, + fp32_layernorm=True, + fp32_embedding=False, + fp32_tokentypes=False, + layernorm_epsilon=1e-5, + attn_separate=False, + **kwargs): + super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.lr_decay_style = lr_decay_style + self.weight_decay = weight_decay + self.clip_grad = clip_grad + self.warmup = warmup + self.pre_ln = pre_ln + self.fp16 = fp16 + self.fp32_layernorm = fp32_layernorm + self.fp32_embedding = fp32_embedding + self.layernorm_epsilon = layernorm_epsilon + self.fp32_tokentypes = fp32_tokentypes + self.dec_hidden_layers = dec_hidden_layers + self.attn_separate = attn_separate diff --git a/modelscope/models/nlp/plug/distributed_plug.py b/modelscope/models/nlp/plug/distributed_plug.py new file mode 100644 index 00000000..c72e92ba --- /dev/null +++ b/modelscope/models/nlp/plug/distributed_plug.py @@ -0,0 +1,234 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Dict + +import torch +import torch.nn.functional as F +from megatron import mpu +from megatron.fp16 import FP16_Module +from megatron.utils import print_rank_0 + +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.utils.logger import get_logger +from modelscope.utils.nlp.distributed import initialize_distributed +from modelscope.utils.nlp.load_checkpoint import pre_load +from modelscope.utils.torch_utils import set_random_seed_mpu +from . import PlugModel +from .configuration import PlugNLGConfig + +logger = get_logger(__name__) + + +class DistributedPlug(TorchModel): + """ + The wapper class of PLUG Model to initialize parallel environment, load model weights, generate sentences. + Parameters: + model_dir (`str`, *required*): + Path to model damo/nlp_plug_text-generation_27B. + The model structure in model_dir should be like this: + model_dir + |_ config.json + |_ configuration.json + |_ ds_zero-offload_10B_config.json + |_ vocab.txt + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + rank (`int`, *required*): + Used to identify different GPUs in a tensor parallel environment. eg. The rank of GPU #0 is 0, and the + model file `mp_rank_00_model_states.pt` will be loaded on this GPU. + world_size (`int`, *required*, defaults to 8): + The parallel size in total. + model_parallel_size (`int`, *required*, defaults to 8): + The parallel size of model(tensor parallel). + master_ip (`str`, *required*): + The master IP, can usually be set to `"127.0.0.1"`, used as part of + [`~torch.distributed.init_process_group`] method parameter `init_method`. + `init_method` = `"tcp://{master_ip}:{master_port}"` + master_port (`str`, *required*): + The master port, can usually be set to `"29500"`, used as part of + [`~torch.distributed.init_process_group`] method parameter `init_method`. + `init_method` = `"tcp://{master_ip}:{master_port}"` + seed (`int`, *optional*, defaults to 42): + Random seed to control sampling. + """ + + def __init__(self, model_dir, rank, **kwargs): + super().__init__(model_dir, **kwargs) + self.rank = rank + self.model_cfg = kwargs + self.config = PlugNLGConfig.from_pretrained(model_dir) + initialize_distributed(rank, mpu, kwargs['world_size'], + kwargs['model_parallel_size'], + kwargs['master_ip'], kwargs['master_port']) + seed = 42 if 'seed' not in kwargs else kwargs['seed'] + set_random_seed_mpu(seed) + self.iteration = 0 + self.dist_model = self.initialize_model(path_load_tag='model') + + def initialize_model(self, path_load_tag='model'): + """Build the model.""" + print_rank_0('Building Plug model. It will take a few minutes ...') + model = PlugModel(self.config) + + if mpu.get_data_parallel_rank() == 0: + logger.info( + ' > number of parameters on model parallel rank {}: {}'.format( + mpu.get_model_parallel_rank(), + sum([p.nelement() for p in model.parameters()]))) + + if self.config.deepspeed and self.config.fp16: + model.half() + + # GPU allocation. + model.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if self.config.fp16: + model = FP16_Module(model) + if self.config.fp32_embedding: + model.module.model.bert.embeddings.word_embeddings.float() + model.module.model.bert.embeddings.position_embeddings.float() + model.module.model.bert.embeddings.token_type_embeddings.float( + ) + if self.config.fp32_tokentypes: + model.module.model.bert.embeddings.token_type_embeddings.float( + ) + if self.config.fp32_layernorm: + for name, _module in model.named_modules(): + if 'LayerNorm' in name: + _module.float() + + load_model = pre_load(mpu, self.model_dir, tag=path_load_tag) + model_dict = model.module.model.state_dict() + for key in load_model: + if key not in model_dict.keys(): + print_rank_0('Skip key: ' + key) + else: + print_rank_0('Loading key: ' + key) + model.module.model.load_state_dict(load_model, strict=False) + return model + + @staticmethod + def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art- + # conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, + None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + return logits + + def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs): + device = torch.cuda.current_device() + batch_size = input['input_ids'].shape[0] + tokens = input['input_ids'].view(1, -1).contiguous().to(device) + dec_input_ids = input['dec_input_ids'].to(device) + attention_mask = input['attention_mask'].to(device) + self.dist_model.eval() + with torch.no_grad(): + # Only supports batch_size=1 + all_generate_tokens = [] + generate_tokens = [] + counter = 0 + sequence_output = None + vocab_size = self.config.original_vocab_size + sep_token_idx = 102 # index of [SEP] token in BertTokenizer + while counter < out_length: + if counter % 128 == 0 and counter != 0: + # Sliding window + generate_tokens.append(sep_token_idx) + start = (tokens == sep_token_idx).nonzero( + as_tuple=True)[-1] + if start + len(generate_tokens) >= 512: + tokens = torch.cat([ + tokens[:start], + torch.cuda.LongTensor(generate_tokens) + ], -1)[-512:] + else: + tokens[0][start:start + len(generate_tokens + )] = torch.cuda.LongTensor( + generate_tokens) + + attention_mask = (tokens != 0) + dec_input_ids = input['dec_input_ids'].to(device) + generate_tokens = [] + sequence_output = None + + position_ids = torch.full([batch_size, 1], + len(generate_tokens), + dtype=torch.long, + device=device) + _, logits, sequence_output = self.dist_model( + tokens, + None, + attention_mask, + dec_input_ids, + attention_mask, + position_ids, + is_infer=True, + sequence_output=sequence_output, + parallel_output=False) + logits = logits[:, -1, :] + logits = logits / self.model_cfg['temperature'] + logits = self.top_k_logits( + logits, + top_k=self.model_cfg['top_k'], + top_p=self.model_cfg['top_p']) + log_probs = F.softmax(logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1) + prev_token = prev[0].item() + if prev_token >= vocab_size: + prev_token = 100 + prev[0] = 100 + if prev_token == 102 and len(all_generate_tokens) > int( + max(1, out_length) * 0.8): + break + if prev_token == 102: + counter += 1 + continue + dec_input_ids = torch.cat([dec_input_ids, prev], dim=1) + generate_tokens.append(prev_token) + all_generate_tokens.append(prev_token) + counter += 1 + + generate_context = [] + for token in all_generate_tokens: + if generate_context and generate_context[ + -1] == 100 and token == 100: + continue + else: + generate_context.append(token) + return {'generate_context': generate_context} diff --git a/modelscope/models/nlp/ponet/__init__.py b/modelscope/models/nlp/ponet/__init__.py new file mode 100644 index 00000000..df996167 --- /dev/null +++ b/modelscope/models/nlp/ponet/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import PoNetConfig + from .backbone import (PoNetModel, PoNetPreTrainedModel) + from .tokenization import PoNetTokenizer + from .fill_mask import PoNetForMaskedLM +else: + _import_structure = { + 'configuration': ['PoNetConfig'], + 'backbone': ['PoNetModel', 'PoNetPreTrainedModel'], + 'fill_mask': ['PoNetForMaskedLM'], + 'tokenization': ['PoNetTokenizer'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/ponet/backbone.py b/modelscope/models/nlp/ponet/backbone.py new file mode 100644 index 00000000..fcc62fa2 --- /dev/null +++ b/modelscope/models/nlp/ponet/backbone.py @@ -0,0 +1,900 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PoNet model. """ + +import math +from distutils.version import LooseVersion + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import \ + BaseModelOutputWithPastAndCrossAttentions +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionBackboneModelOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from .configuration import PoNetConfig + +logger = get_logger(__name__) + +is_pytorch_12plus = LooseVersion(torch.__version__) >= LooseVersion('1.12.0') + +CLS_ID = 101 +EOS_ID = 102 + + +def segment_max(src, index, dim=1): + if is_pytorch_12plus: + out = torch.zeros_like(src).scatter_reduce( + dim, + index[:, :, None].expand_as(src), + src, + reduce='amax', + include_self=False) + else: + dummy_scatter_index = index[:, :, None].expand_as(src) + min_value = src.min() - 1 + dummpy_scatter_shape = (*src.shape[:-1], index.max() + 1, + src.shape[-1]) + dummy_scatter_index_expand = dummy_scatter_index.unsqueeze(-2).expand( + *dummpy_scatter_shape) + index_reconstruct_expand = torch.arange( + index.max() + 1, + device=src.device)[None, None, :, + None].expand(*dummpy_scatter_shape) + src_expand = src.unsqueeze(-2).expand(*dummpy_scatter_shape) + out, _ = src_expand.masked_scatter( + dummy_scatter_index_expand != index_reconstruct_expand, + torch.full_like(src_expand, min_value.item())).max(dim=1) + + dummy = index.unsqueeze(-1).expand(*index.shape[:2], out.size(-1)) + return torch.gather(out, dim, dummy).to(dtype=src.dtype) + + +def get_segment_index(input_ids, cls_id=CLS_ID, eos_id=EOS_ID): + mask = (input_ids == cls_id).to( + dtype=torch.long) + (input_ids == eos_id).to(dtype=torch.long) + mask = mask + torch.cat([torch.zeros_like(mask[:, 0:1]), mask[:, :-1]], + dim=1) + return mask.cumsum(dim=1) - 1 + + +def get_token_type_mask(input_ids, cls_id=CLS_ID, eos_id=EOS_ID): + mask = (input_ids == cls_id) | (input_ids == eos_id) + return mask + + +def get_win_max(hidden_states, kernel_size=3): + m = nn.MaxPool1d(kernel_size, stride=1, padding=kernel_size // 2) + out = m(hidden_states.permute(0, 2, 1)).permute(0, 2, 1) + return out + + +class PoNetEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse('1.6.0'): + self.register_buffer( + 'token_type_ids', + torch.zeros( + self.position_ids.size(), + dtype=torch.long, + device=self.position_ids.device), + persistent=False, + ) + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, + past_key_values_length:seq_length + + past_key_values_length] + + if token_type_ids is None: + if hasattr(self, 'token_type_ids'): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, + dtype=torch.long, + device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class PoNetSelfAttention(nn.Module): + + def __init__(self, config): + super().__init__() + + self.dense_local = nn.Linear(config.hidden_size, config.hidden_size) + self.dense_segment = nn.Linear(config.hidden_size, config.hidden_size) + + self.num_attention_heads = config.num_attention_heads + self.clsgsepg = getattr(config, 'clsgsepg', True) + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.dense_q = nn.Linear(config.hidden_size, self.all_head_size) + self.dense_k = nn.Linear(config.hidden_size, self.all_head_size) + self.dense_o = nn.Linear(config.hidden_size, self.all_head_size) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) # bz, head, len, head_size + + def forward( + self, + hidden_states, + segment_index, + token_type_mask, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + context_layer_q = self.transpose_for_scores( + self.dense_q(hidden_states)) + context_layer_k = self.transpose_for_scores( + self.dense_k(hidden_states)) + context_layer_v = context_layer_k + context_layer_o = self.transpose_for_scores( + self.dense_o(hidden_states)) + + if attention_mask is not None: + _attention_mask = (attention_mask.squeeze(1).unsqueeze(-1) < -1) + + if attention_mask is not None: + context_layer_q.masked_fill_(_attention_mask, 0.0) + q = context_layer_q.sum(dim=-2) / torch.ones_like( + _attention_mask).to(dtype=context_layer_q.dtype).masked_fill( + _attention_mask, 0.0).sum(dim=-2) + else: + q = context_layer_q.mean(dim=-2) + att = torch.einsum('bdh,bdlh -> bdl', q, context_layer_k) / math.sqrt( + context_layer_q.shape[-1]) + if attention_mask is not None: + att = att + attention_mask.squeeze(1) + att_prob = att.softmax(dim=-1) + v = torch.einsum('bdlh,bdl->bdh', context_layer_v, att_prob) + + context_layer_segment = self.dense_segment(hidden_states) + context_layer_local = self.dense_local(hidden_states) + if attention_mask is not None: + context_layer_local.masked_fill_( + _attention_mask.squeeze(1), -10000) + context_layer_segment.masked_fill_( + _attention_mask.squeeze(1), -10000) + + if self.clsgsepg: + # XXX: a trick to make sure the segment and local information will not leak + context_layer_local = get_win_max( + context_layer_local.masked_fill( + token_type_mask.unsqueeze(dim=-1), -10000)) + context_layer_segment = segment_max( + context_layer_segment, index=segment_index) + + context_layer_segment.masked_fill_( + token_type_mask.unsqueeze(dim=-1), 0.0) + context_layer_local.masked_fill_( + token_type_mask.unsqueeze(dim=-1), 0.0) + else: + context_layer_local = get_win_max(context_layer_local) + context_layer_segment = segment_max( + context_layer_segment, index=segment_index) + + context_layer_local = self.transpose_for_scores(context_layer_local) + context_layer_segment = self.transpose_for_scores( + context_layer_segment) + + context_layer = (v.unsqueeze(dim=-2) + context_layer_segment + ) * context_layer_o + context_layer_local + context_layer = context_layer.permute(0, 2, 1, 3).reshape( + *hidden_states.shape[:2], -1) + + if attention_mask is not None: + context_layer.masked_fill_(_attention_mask.squeeze(1), 0.0) + + outputs = (context_layer, + att_prob) if output_attentions else (context_layer, ) + return outputs + + +class PoNetSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class PoNetIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class PoNetOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class PoNetAttention(nn.Module): + + def __init__(self, config): + super().__init__() + self.self = PoNetSelfAttention(config) + self.output = PoNetSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + segment_index, + token_type_mask, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + segment_index, + token_type_mask, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class PoNetLayer(nn.Module): + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = PoNetAttention(config) + + config.is_decoder = False # XXX: Decoder is not yet impletemented. + self.is_decoder = config.is_decoder + + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + assert self.is_decoder, f'{self} should be used as a decoder model if cross attention is added' + self.crossattention = PoNetAttention(config) + self.intermediate = PoNetIntermediate(config) + self.output = PoNetOutput(config) + + def forward( + self, + hidden_states, + segment_index, + token_type_mask, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + segment_index, + token_type_mask, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[ + 1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + assert hasattr( + self, 'crossattention' + ), f'If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`' # noqa * + + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class PoNetEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [PoNetLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + segment_index, + token_type_mask, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if getattr(self.config, 'gradient_checkpointing', + False) and self.training: + + if use_cache: + logger.warning( + '`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting ' + '`use_cache=False`...') + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + segment_index, + token_type_mask, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + segment_index, + token_type_mask, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class PoNetPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class PoNetPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PoNetConfig + base_model_prefix = 'ponet' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + ponet_config = PoNetConfig(**kwargs) + model = cls(ponet_config) + else: + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + return model + + +@MODELS.register_module(Tasks.backbone, module_name=Models.ponet) +class PoNetModel(PoNetPreTrainedModel): + """The bare PoNet Model transformer outputting raw hidden-states without any specific head on top. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~modelscope.models.nlp.ponet.PoNetConfig`): + Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration + set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True, **kwargs): + super().__init__(config, **kwargs) + self.config = config + + self.embeddings = PoNetEmbeddings(config) + self.encoder = PoNetEncoder(config) + + self.pooler = PoNetPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + segment_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + encoder_hidden_states + (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` + with each tuple having 4 tensors of shape :obj: + `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + Returns: + Returns `modelscope.outputs.AttentionBackboneModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base', task='backbone') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base') + >>> print(model(**preprocessor('这是个测试'))) + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + segment_index = get_segment_index( + input_ids) if segment_ids is None else segment_ids + token_type_mask = get_token_type_mask(input_ids) + encoder_outputs = self.encoder( + embedding_output, + segment_index, + token_type_mask, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return AttentionBackboneModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/modelscope/models/nlp/ponet/configuration.py b/modelscope/models/nlp/ponet/configuration.py new file mode 100644 index 00000000..7dfaba48 --- /dev/null +++ b/modelscope/models/nlp/ponet/configuration.py @@ -0,0 +1,115 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PoNet model configuration, mainly copied from :class:`~transformers.BertConfig` """ +from transformers import PretrainedConfig + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class PoNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration + of a :class:`~modelscope.models.nlp.ponet.PoNetModel`. + It is used to instantiate a PoNet model according to the specified arguments. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): + Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, + :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on + :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.) + `__. For more information on :obj:`"relative_key_query"`, please refer to + `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) + `__. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``config.is_decoder=True``. + classifier_dropout (:obj:`float`, `optional`): + The dropout ratio for the classification head. + clsgsepg (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not use a trick to make sure the segment and local information will not leak. + """ + model_type = 'ponet' + + def __init__(self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type='absolute', + use_cache=True, + classifier_dropout=None, + clsgsepg=True, + **kwargs): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + self.clsgsepg = clsgsepg diff --git a/modelscope/models/nlp/ponet/fill_mask.py b/modelscope/models/nlp/ponet/fill_mask.py new file mode 100644 index 00000000..fb09efc0 --- /dev/null +++ b/modelscope/models/nlp/ponet/fill_mask.py @@ -0,0 +1,252 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from .backbone import PoNetModel, PoNetPreTrainedModel + +logger = get_logger(__name__) + + +class PoNetPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class PoNetLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = PoNetPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class PoNetOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = PoNetLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.ponet) +class PoNetForMaskedLM(PoNetPreTrainedModel): + r"""PoNet Model with a `language modeling` head on top. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of PoNet, the preprocessor of this model + is `modelscope.preprocessors.FillMaskPoNetPreprocessor`. + + Parameters: + config (:class:`~modelscope.models.nlp.ponet.PoNetConfig`): + Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config, **kwargs): + super().__init__(config) + + if config.is_decoder: + logger.warning( + 'If you want to use `PoNetForMaskedLM` make sure `config.is_decoder=False` for ' + 'bi-directional self-attention.') + + self.ponet = PoNetModel(config, add_pooling_layer=False) + self.cls = PoNetOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + segment_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`('batch_size, sequence_length', hidden_size)`, + `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。'))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。')) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ponet( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + segment_ids=segment_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return AttentionFillMaskModelOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + input_ids=input_ids, + ) diff --git a/modelscope/models/nlp/ponet/tokenization.py b/modelscope/models/nlp/ponet/tokenization.py new file mode 100644 index 00000000..2da91545 --- /dev/null +++ b/modelscope/models/nlp/ponet/tokenization.py @@ -0,0 +1,156 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for PoNet """ + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from transformers.file_utils import PaddingStrategy +from transformers.models.bert.tokenization_bert import BertTokenizer +from transformers.tokenization_utils import BatchEncoding, EncodedInput + +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': ModelFile.VOCAB_FILE} + +PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'nlp_ponet_fill-mask_chinese-base': 512, + 'nlp_ponet_fill-mask_english-base': 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + 'nlp_ponet_fill-mask_chinese-base': { + 'do_lower_case': True + }, + 'nlp_ponet_fill-mask_english-base': { + 'do_lower_case': True + }, +} + + +class PoNetTokenizer(BertTokenizer): + r""" + Construct an PoNet tokenizer. Based on BertTokenizer. + + This tokenizer inherits from :class:`~transformers.BertTokenizer` which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + + Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning + parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or + batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: (optional) Set to False to avoid returning + attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = 'attention_mask' in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and ( + max_length % pad_to_multiple_of != 0): + max_length = ( + (max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len( + required_input) != max_length + + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == 'right': + if return_attention_mask: + encoded_inputs['attention_mask'] = [1] * len( + required_input) + [0] * difference + if 'token_type_ids' in encoded_inputs: + encoded_inputs['token_type_ids'] = ( + encoded_inputs['token_type_ids'] + + [self.pad_token_type_id] * difference) + if 'special_tokens_mask' in encoded_inputs: + encoded_inputs['special_tokens_mask'] = encoded_inputs[ + 'special_tokens_mask'] + [1] * difference + if 'segment_ids' in encoded_inputs: + encoded_inputs[ + 'segment_ids'] = encoded_inputs['segment_ids'] + [ + encoded_inputs['segment_ids'][-1] + 1 + ] * difference # noqa * + encoded_inputs[self.model_input_names[ + 0]] = required_input + [self.pad_token_id] * difference + elif self.padding_side == 'left': + if return_attention_mask: + encoded_inputs['attention_mask'] = [0] * difference + [ + 1 + ] * len(required_input) + if 'token_type_ids' in encoded_inputs: + encoded_inputs['token_type_ids'] = [ + self.pad_token_type_id + ] * difference + encoded_inputs['token_type_ids'] + if 'segment_ids' in encoded_inputs: + encoded_inputs['segment_ids'] = [encoded_inputs['segment_ids'][-1] + 1] * difference + \ + encoded_inputs['segment_ids'] # noqa * + if 'special_tokens_mask' in encoded_inputs: + encoded_inputs['special_tokens_mask'] = [ + 1 + ] * difference + encoded_inputs['special_tokens_mask'] + encoded_inputs[self.model_input_names[ + 0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError('Invalid padding strategy:' + + str(self.padding_side)) + elif return_attention_mask and 'attention_mask' not in encoded_inputs: + encoded_inputs['attention_mask'] = [1] * len(required_input) + + return encoded_inputs diff --git a/modelscope/models/nlp/space/__init__.py b/modelscope/models/nlp/space/__init__.py new file mode 100644 index 00000000..32713c34 --- /dev/null +++ b/modelscope/models/nlp/space/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model import SpaceGenerator + from .model import SpaceModelBase, SpaceTokenizer + from .dialog_intent_prediction import SpaceForDialogIntent + from .dialog_modeling import SpaceForDialogModeling + from .dialog_state_tracking import SpaceForDST + from .configuration import SpaceConfig +else: + _import_structure = { + 'model': ['SpaceGenerator', 'SpaceModelBase', 'SpaceTokenizer'], + 'dialog_intent_prediction': ['SpaceForDialogIntent'], + 'dialog_modeling': ['SpaceForDialogModeling'], + 'dialog_state_tracking': ['SpaceForDST'], + 'configuration': ['SpaceConfig'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/space/configuration.py b/modelscope/models/nlp/space/configuration.py new file mode 100644 index 00000000..0da2d629 --- /dev/null +++ b/modelscope/models/nlp/space/configuration.py @@ -0,0 +1,32 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors. +# Copyright 2020 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Space configuration, mainly copied from :class:`~transformers.configuration_xlm_roberta` """ + +from modelscope.models.nlp.structbert import SbertConfig +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class SpaceConfig(SbertConfig): + """ + This class overrides [`SbertConfig`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + model_type = 'space' diff --git a/modelscope/models/nlp/space/dialog_intent_prediction.py b/modelscope/models/nlp/space/dialog_intent_prediction.py new file mode 100644 index 00000000..79ff01cd --- /dev/null +++ b/modelscope/models/nlp/space/dialog_intent_prediction.py @@ -0,0 +1,99 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Dict + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.models.nlp.space import SpaceGenerator, SpaceModelBase +from modelscope.preprocessors.nlp import IntentBPETextField +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['SpaceForDialogIntent'] + + +@MODELS.register_module( + Tasks.task_oriented_conversation, module_name=Models.space_intent) +class SpaceForDialogIntent(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the test generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + text_field (`BPETextField`, *optional*, defaults to `IntentBPETextField`): + The text field. + config (`Config`, *optional*, defaults to config in model hub): + The config. + """ + + super().__init__(model_dir, *args, **kwargs) + from modelscope.trainers.nlp.space.trainer.intent_trainer import \ + IntentTrainer + self.model_dir = model_dir + self.config = kwargs.pop( + 'config', + Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION))) + self.text_field = kwargs.pop( + 'text_field', + IntentBPETextField(self.model_dir, config=self.config)) + + self.generator = SpaceGenerator.create( + self.config, reader=self.text_field) + self.model = SpaceModelBase.create( + model_dir=model_dir, + config=self.config, + reader=self.text_field, + generator=self.generator) + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda() if self.config.use_gpu else array + + self.trainer = IntentTrainer( + model=self.model, + to_tensor=to_tensor, + config=self.config, + reader=self.text_field) + self.trainer.load() + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05 + 1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04 + 6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01 + 2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32), + } + Example: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.models.nlp import SpaceForDialogIntent + >>> from modelscope.preprocessors import DialogIntentPredictionPreprocessor + >>> cache_path = snapshot_download('damo/nlp_space_dialog-intent-prediction') + >>> preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) + >>> model = SpaceForDialogIntent( + model_dir=cache_path, + text_field=preprocessor.text_field, + config=preprocessor.config) + >>> print(model(preprocessor("What do I need to do for the card activation?"))) + """ + import numpy as np + pred = self.trainer.forward(input) + pred = np.squeeze(pred[0], 0) + + return {'pred': pred} diff --git a/modelscope/models/nlp/space/dialog_modeling.py b/modelscope/models/nlp/space/dialog_modeling.py new file mode 100644 index 00000000..16e9dc53 --- /dev/null +++ b/modelscope/models/nlp/space/dialog_modeling.py @@ -0,0 +1,118 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Dict + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.models.nlp.space import SpaceGenerator, SpaceModelBase +from modelscope.preprocessors.nlp import MultiWOZBPETextField +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['SpaceForDialogModeling'] + + +@MODELS.register_module( + Tasks.task_oriented_conversation, module_name=Models.space_modeling) +class SpaceForDialogModeling(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the test generation model from the `model_dir` path. + + Args: + model_dir (`str`): + The model path. + text_field (`BPETextField`, *optional*, defaults to `MultiWOZBPETextField`): + The text field. + config (`Config`, *optional*, defaults to config in model hub): + The config. + """ + + super().__init__(model_dir, *args, **kwargs) + from modelscope.trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer + self.model_dir = model_dir + self.config = kwargs.pop( + 'config', + Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION))) + + import torch + self.config.use_gpu = True if ( + 'device' not in kwargs or kwargs['device'] + == 'gpu') and torch.cuda.is_available() else False + + self.text_field = kwargs.pop( + 'text_field', + MultiWOZBPETextField(config=self.config, model_dir=self.model_dir)) + self.generator = SpaceGenerator.create( + self.config, reader=self.text_field) + self.model = SpaceModelBase.create( + model_dir=model_dir, + config=self.config, + reader=self.text_field, + generator=self.generator) + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda() if self.config.use_gpu else array + + self.trainer = MultiWOZTrainer( + model=self.model, + to_tensor=to_tensor, + config=self.config, + reader=self.text_field, + evaluator=None) + self.trainer.load() + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'labels': array([1,192,321,12]), # lable + 'resp': array([293,1023,123,1123]), #vocab label for response + 'bspn': array([123,321,2,24,1 ]), + 'aspn': array([47,8345,32,29,1983]), + 'db': array([19, 24, 20]), + } + Examples: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.models.nlp import SpaceForDialogModeling + >>> from modelscope.preprocessors import DialogModelingPreprocessor + >>> cache_path = snapshot_download('damo/nlp_space_dialog-modeling') + >>> preprocessor = DialogModelingPreprocessor(model_dir=cache_path) + >>> model = SpaceForDialogModeling(model_dir=cache_path, + text_field=preprocessor.text_field, + config=preprocessor.config) + >>> print(model(preprocessor({ + 'user_input': 'i would like a taxi from saint john \'s college to pizza hut fen ditton .', + 'history': {} + }))) + """ + + first_turn = input['first_turn'] + batch = input['batch'] + prompt_id = input['prompt_id'] + labels = input['labels'] + old_pv_turn = input['history'] + + pv_turn = self.trainer.forward( + first_turn=first_turn, + batch=batch, + prompt_id=prompt_id, + labels=labels, + old_pv_turn=old_pv_turn) + + return pv_turn diff --git a/modelscope/models/nlp/space/dialog_state_tracking.py b/modelscope/models/nlp/space/dialog_state_tracking.py new file mode 100644 index 00000000..9a713a59 --- /dev/null +++ b/modelscope/models/nlp/space/dialog_state_tracking.py @@ -0,0 +1,392 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Space model. mainly copied from :module:`~transformers.modeling_xlm_roberta`""" + +from typing import Dict + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.file_utils import add_start_docstrings +from transformers.modeling_utils import PreTrainedModel + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.models.nlp.structbert import (SbertForMaskedLM, SbertModel, + SbertPreTrainedModel) +from modelscope.utils.constant import Tasks +from .configuration import SpaceConfig + +SPACE_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config ([`SpaceConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. +""" + + +@add_start_docstrings( + 'The bare Space Model transformer outputting raw hidden-states without any specific head on top. ' + 'It is identical with the Bert Model from Transformers', + SPACE_START_DOCSTRING, +) +class SpaceModel(SbertModel): + """ + This class overrides [`SbertModel`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = SpaceConfig + + +class SpacePreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SpaceConfig + base_model_prefix = 'bert' + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + @param kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels is not input. + label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). + + @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + config = SpaceConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + return model + + +@add_start_docstrings( + """ + Space Model transformer with Dialog state tracking heads on top (a inform projection + layer with a dialog state layer and a set of slots including history infromation from + previous dialog) e.g. for multiwoz2.2 tasks. + """, + SPACE_START_DOCSTRING, +) +@MODELS.register_module( + Tasks.task_oriented_conversation, module_name=Models.space_dst) +class SpaceForDST(SpacePreTrainedModel): + + def __init__(self, config): + super(SpaceForDST, self).__init__(config) + self.slot_list = config.dst_slot_list + self.class_types = config.dst_class_types + self.class_labels = config.dst_class_labels + self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable + self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable + self.class_aux_feats_inform = config.dst_class_aux_feats_inform + self.class_aux_feats_ds = config.dst_class_aux_feats_ds + self.class_loss_ratio = config.dst_class_loss_ratio + + # Only use refer loss if refer class is present in dataset. + if 'refer' in self.class_types: + self.refer_index = self.class_types.index('refer') + else: + self.refer_index = -1 + + self.bert = SpaceModel(config) + self.dropout = nn.Dropout(config.dst_dropout_rate) + self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) + + if self.class_aux_feats_inform: + self.add_module( + 'inform_projection', + nn.Linear(len(self.slot_list), len(self.slot_list))) + if self.class_aux_feats_ds: + self.add_module( + 'ds_projection', + nn.Linear(len(self.slot_list), len(self.slot_list))) + + aux_dims = len(self.slot_list) * ( + self.class_aux_feats_inform + self.class_aux_feats_ds + ) # second term is 0, 1 or 2 + + for slot in self.slot_list: + self.add_module( + 'class_' + slot, + nn.Linear(config.hidden_size + aux_dims, self.class_labels)) + self.add_module('token_' + slot, nn.Linear(config.hidden_size, 2)) + self.add_module( + 'refer_' + slot, + nn.Linear(config.hidden_size + aux_dims, + len(self.slot_list) + 1)) + + self.init_weights() + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'inputs': dict(input_ids, input_masks,start_pos), # tracking states + 'outputs': dict(slots_logits), + 'unique_ids': str(test-example.json-0), # default value + 'input_ids_unmasked': array([101, 7632, 1010,0,0,0]) + 'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), + 'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), + 'prefix': str('final'), #default value + 'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]) + } + + Example: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.models.nlp import SpaceForDST + >>> from modelscope.preprocessors import DialogStateTrackingPreprocessor + >>> cache_path = snapshot_download('damo/nlp_space_dialog-state-tracking') + >>> model = SpaceForDST.from_pretrained(cache_path) + >>> preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) + >>> print(model(preprocessor({ + 'utter': { + 'User-1': "Hi, I'm looking for a train that is going" + "to cambridge and arriving there by 20:45, is there anything like that?" + }, + 'history_states': [{}] + }))) + """ + import numpy as np + import torch + + # self.model.eval() ???? + batch = input['batch'] + + features = input['features'] + diag_state = input['diag_state'] + turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] + reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] + for slot in self.config.dst_slot_list: + for i in reset_diag_state: + diag_state[slot][i] = 0 + + with torch.no_grad(): + inputs = { + 'input_ids': batch[0], + 'input_mask': batch[1], + 'segment_ids': batch[2], + 'start_pos': batch[3], + 'end_pos': batch[4], + 'inform_slot_id': batch[5], + 'refer_id': batch[6], + 'diag_state': diag_state, + 'class_label_id': batch[8] + } + unique_ids = [features[i.item()].guid for i in batch[9]] + values = [features[i.item()].values for i in batch[9]] + input_ids_unmasked = [ + features[i.item()].input_ids_unmasked for i in batch[9] + ] + inform = [features[i.item()].inform for i in batch[9]] + outputs = self._forward(**inputs) + + # Update dialog state for next turn. + for slot in self.config.dst_slot_list: + updates = outputs[2][slot].max(1)[1] + for i, u in enumerate(updates): + if u != 0: + diag_state[slot][i] = u + + return { + 'inputs': inputs, + 'outputs': outputs, + 'unique_ids': unique_ids, + 'input_ids_unmasked': input_ids_unmasked, + 'values': values, + 'inform': inform, + 'prefix': 'final', + 'ds': input['ds'] + } + + def _forward(self, + input_ids, + input_mask=None, + segment_ids=None, + position_ids=None, + head_mask=None, + start_pos=None, + end_pos=None, + inform_slot_id=None, + refer_id=None, + class_label_id=None, + diag_state=None): + outputs = self.bert( + input_ids, + attention_mask=input_mask, + token_type_ids=segment_ids, + position_ids=position_ids, + head_mask=head_mask) + + sequence_output = outputs.last_hidden_state + pooled_output = outputs.pooler_output + + sequence_output = self.dropout(sequence_output) + pooled_output = self.dropout(pooled_output) + + # TODO: establish proper format in labels already? + if inform_slot_id is not None: + inform_labels = torch.stack(list(inform_slot_id.values()), + 1).float() + if diag_state is not None: + diag_state_labels = torch.clamp( + torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) + + total_loss = 0 + per_slot_per_example_loss = {} + per_slot_class_logits = {} + per_slot_start_logits = {} + per_slot_end_logits = {} + per_slot_refer_logits = {} + for slot in self.slot_list: + if self.class_aux_feats_inform and self.class_aux_feats_ds: + pooled_output_aux = torch.cat( + (pooled_output, self.inform_projection(inform_labels), + self.ds_projection(diag_state_labels)), 1) + elif self.class_aux_feats_inform: + pooled_output_aux = torch.cat( + (pooled_output, self.inform_projection(inform_labels)), 1) + elif self.class_aux_feats_ds: + pooled_output_aux = torch.cat( + (pooled_output, self.ds_projection(diag_state_labels)), 1) + else: + pooled_output_aux = pooled_output + class_logits = self.dropout_heads( + getattr(self, 'class_' + slot)(pooled_output_aux)) + + token_logits = self.dropout_heads( + getattr(self, 'token_' + slot)(sequence_output)) + start_logits, end_logits = token_logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + refer_logits = self.dropout_heads( + getattr(self, 'refer_' + slot)(pooled_output_aux)) + + per_slot_class_logits[slot] = class_logits + per_slot_start_logits[slot] = start_logits + per_slot_end_logits[slot] = end_logits + per_slot_refer_logits[slot] = refer_logits + + # If there are no labels, don't compute loss + if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: + # If we are on multi-GPU, split add a dimension + if len(start_pos[slot].size()) > 1: + start_pos[slot] = start_pos[slot].squeeze(-1) + if len(end_pos[slot].size()) > 1: + end_pos[slot] = end_pos[slot].squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) # This is a single index + start_pos[slot].clamp_(0, ignored_index) + end_pos[slot].clamp_(0, ignored_index) + + class_loss_fct = CrossEntropyLoss(reduction='none') + token_loss_fct = CrossEntropyLoss( + reduction='none', ignore_index=ignored_index) + refer_loss_fct = CrossEntropyLoss(reduction='none') + + start_loss = token_loss_fct(start_logits, start_pos[slot]) + end_loss = token_loss_fct(end_logits, end_pos[slot]) + token_loss = (start_loss + end_loss) / 2.0 + + token_is_pointable = (start_pos[slot] > 0).float() + if not self.token_loss_for_nonpointable: + token_loss *= token_is_pointable + + refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) + token_is_referrable = torch.eq(class_label_id[slot], + self.refer_index).float() + if not self.refer_loss_for_nonpointable: + refer_loss *= token_is_referrable + + class_loss = class_loss_fct(class_logits, class_label_id[slot]) + + if self.refer_index > -1: + per_example_loss = (self.class_loss_ratio) * class_loss + ( + (1 - self.class_loss_ratio) / 2) * token_loss + ( + (1 - self.class_loss_ratio) / 2) * refer_loss + else: + per_example_loss = self.class_loss_ratio * class_loss + ( + 1 - self.class_loss_ratio) * token_loss + + total_loss += per_example_loss.sum() + per_slot_per_example_loss[slot] = per_example_loss + + # add hidden states and attention if they are here + outputs = (total_loss, ) + ( + per_slot_per_example_loss, + per_slot_class_logits, + per_slot_start_logits, + per_slot_end_logits, + per_slot_refer_logits, + ) + (outputs.embedding_output, ) + + return outputs diff --git a/modelscope/models/nlp/space/model/__init__.py b/modelscope/models/nlp/space/model/__init__.py new file mode 100644 index 00000000..cfff335d --- /dev/null +++ b/modelscope/models/nlp/space/model/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .gen_unified_transformer import GenUnifiedTransformer +from .generator import SpaceGenerator +from .intent_unified_transformer import IntentUnifiedTransformer +from .model_base import SpaceModelBase +from .tokenization_space import (BasicTokenizer, SpaceTokenizer, + WordpieceTokenizer) +from .unified_transformer import UnifiedTransformer diff --git a/modelscope/models/nlp/space/model/gen_unified_transformer.py b/modelscope/models/nlp/space/model/gen_unified_transformer.py new file mode 100644 index 00000000..c5d50cd9 --- /dev/null +++ b/modelscope/models/nlp/space/model/gen_unified_transformer.py @@ -0,0 +1,283 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch + +from .unified_transformer import UnifiedTransformer + + +class GenUnifiedTransformer(UnifiedTransformer): + """ + Implement generation unified transformer. + """ + + def __init__(self, model_dir, config, reader, generator): + super(GenUnifiedTransformer, self).__init__(model_dir, config, reader, + generator) + self.understand = config.BPETextField.understand + + if self.use_gpu: + self.cuda() + return + + def _forward(self, inputs, is_training, with_label): + """ Real forward process of model in different mode(train/test). """ + + def cat(x, y, dim=1): + return torch.cat([x, y], dim=dim) + + outputs = {} + + if self.understand or self.policy: + if self.understand: + prompt_token = inputs['understand_token'] + prompt_mask = inputs['understand_mask'] + if self.policy: + prompt_token = cat(prompt_token, inputs['policy_token']) + prompt_mask = cat(prompt_mask, inputs['policy_mask']) + else: + prompt_token = inputs['policy_token'] + prompt_mask = inputs['policy_mask'] + + enc_embed, dec_embed, prompt_embed = self._encoder_prompt_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'][:, :-1], + tgt_mask=inputs['tgt_mask'][:, :-1], + prompt_token=prompt_token, + prompt_mask=prompt_mask, + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn'], + tgt_pos=inputs['tgt_pos'][:, :-1], + tgt_type=inputs['tgt_type'][:, :-1], + tgt_turn=inputs['tgt_turn'][:, :-1]) + else: + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'][:, :-1], + tgt_mask=inputs['tgt_mask'][:, :-1], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn'], + tgt_pos=inputs['tgt_pos'][:, :-1], + tgt_type=inputs['tgt_type'][:, :-1], + tgt_turn=inputs['tgt_turn'][:, :-1]) + + outputs['dec_probs'] = self._dec_head(dec_embed=dec_embed) + return outputs + + def _collect_metrics(self, inputs, outputs, with_label, data_file): + + metrics = {} + loss = 0. + + label = inputs['tgt_token'][:, 1:] + token_num = torch.sum(torch.sum(inputs['tgt_mask'], dim=1) - 1) + nll = self.nll_loss( + torch.log(outputs['dec_probs'] + 1e-12).permute(0, 2, 1), label) + nll = torch.sum(nll, dim=1) + token_nll = torch.sum(nll) / token_num + nll = torch.mean(nll) + metrics['nll'] = nll + metrics['token_nll'] = token_nll + metrics['token_num'] = token_num + loss = loss + (token_nll if self.token_loss else nll) + + metrics['loss'] = loss + if self.gpu > 1: + return nll, token_nll, token_num + else: + return metrics + + def _optimize(self, loss, do_update=False, optimizer=None): + """ Optimize loss function and update model. """ + assert optimizer is not None + + if self.gradient_accumulation_steps > 1: + loss = loss / self.gradient_accumulation_steps + + loss.backward() + + if self.grad_clip is not None and self.grad_clip > 0: + torch.nn.utils.clip_grad_norm_( + parameters=self.parameters(), max_norm=self.grad_clip) + + if do_update: + optimizer.step() + optimizer.zero_grad() + + return + + def _init_state(self, + src_token, + src_mask, + src_pos=None, + src_type=None, + src_turn=None): + """ Initialize decode state. """ + state = {} + batch_size = src_token.shape[0] + + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + src_embed = self.embed_layer_norm(src_embed) + + mask = self._create_mask(src_mask, append_head=False) + + enc_out = src_embed + + cache = {} + for _l, layer in enumerate(self.layers): + cache[f'layer_{_l}'] = {} + enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) + + state['cache'] = cache + state['mask'] = mask[:, :1] + state['batch_size'] = batch_size + shape = [batch_size, 1, 1] + state['pred_mask'] = torch.ones(shape, dtype=torch.float32) + state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) + state['pred_type'] = torch.zeros(shape, dtype=torch.int64) + state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) + if self.use_gpu: + state['pred_mask'] = state['pred_mask'].cuda() + state['pred_pos'] = state['pred_pos'].cuda() + state['pred_type'] = state['pred_type'].cuda() + state['pred_turn'] = state['pred_turn'].cuda() + + return state + + def _init_prompt_state(self, + src_token, + src_mask, + prompt_token, + prompt_mask, + src_pos=None, + src_type=None, + src_turn=None, + prompt_pos=None, + prompt_type=None, + prompt_turn=None): + """ Initialize decode state. """ + state = {} + batch_size = src_token.shape[0] + + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, + prompt_turn) + embed = torch.cat([src_embed, prompt_embed], dim=1) + embed = self.embed_layer_norm(embed) + enc_out = embed + + enc_mask = self._create_mask(src_mask, auto_regressive=False) + dec_mask = self._create_mask(prompt_mask, auto_regressive=True) + mask = self._join_mask(enc_mask, dec_mask) + + cache = {} + for _l, layer in enumerate(self.layers): + cache[f'layer_{_l}'] = {} + enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) + + state['cache'] = cache + state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] + state['batch_size'] = batch_size + shape = [batch_size, 1, 1] + state['pred_mask'] = torch.ones(shape, dtype=torch.float32) + state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) + state['pred_type'] = torch.zeros(shape, dtype=torch.int64) + state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) + if self.use_gpu: + state['pred_mask'] = state['pred_mask'].cuda() + state['pred_pos'] = state['pred_pos'].cuda() + state['pred_type'] = state['pred_type'].cuda() + state['pred_turn'] = state['pred_turn'].cuda() + + return state + + def _decode(self, state): + """ Decoding one time stamp. """ + + # shape: [batch_size, 1, seq_len] + mask = state['mask'] + + # shape: [batch_size, 1, 1] + pred_token = state['pred_token'] + pred_mask = state['pred_mask'] + pred_pos = state['pred_pos'] + pred_type = state['pred_type'] + pred_turn = state['pred_turn'] + + # list of shape(len: num_layers): [batch_size, seq_len, hidden_dim] + cache = state['cache'] + + pred_embed = self.embedder(pred_token, pred_pos, pred_type, + pred_turn).squeeze(-2) + pred_embed = self.embed_layer_norm(pred_embed) + + # shape: [batch_size, 1, seq_len + 1] + mask = torch.cat([mask, 1 - pred_mask], dim=2) + + # shape: [batch_size, 1, hidden_dim] + for _l, layer in enumerate(self.layers): + pred_embed = layer(pred_embed, mask, cache[f'layer_{_l}']) + + # shape: [batch_size, vocab_size] + pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) + pred_logits = torch.log(pred_probs) + + state['mask'] = mask + return pred_logits, state + + def _infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ Real inference process of model. """ + + def cat(x, y, dim=1): + return torch.cat([x, y], dim=dim) + + # Initial decode state. + if self.understand or self.policy: + if self.understand: + prompt_token = inputs['understand_token'] + prompt_mask = inputs['understand_mask'] + if self.policy: + prompt_token = cat(prompt_token, inputs['policy_token']) + prompt_mask = cat(prompt_mask, inputs['policy_mask']) + else: + prompt_token = inputs['policy_token'] + prompt_mask = inputs['policy_mask'] + + state = self._init_prompt_state( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + prompt_token=prompt_token, + prompt_mask=prompt_mask, + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + else: + state = self._init_state( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + + # Generation process. + gen_results = self.generator( + step_fn=self._decode, + state=state, + start_id=start_id, + eos_id=eos_id, + max_gen_len=max_gen_len, + prev_input=prev_input) + + outputs = gen_results['preds'] + return outputs + + +GenUnifiedTransformer.register('GenUnifiedTransformer') diff --git a/modelscope/models/nlp/space/model/generator.py b/modelscope/models/nlp/space/model/generator.py new file mode 100644 index 00000000..2e05b545 --- /dev/null +++ b/modelscope/models/nlp/space/model/generator.py @@ -0,0 +1,282 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import numpy as np +import torch + + +def repeat(var, times): + if isinstance(var, list): + return [repeat(x, times) for x in var] + elif isinstance(var, dict): + return {k: repeat(v, times) for k, v in var.items()} + elif isinstance(var, torch.Tensor): + var = var.unsqueeze(1) + expand_times = [1] * len(var.shape) + expand_times[1] = times + dtype = var.dtype + var = var.float() + var = var.repeat(*expand_times) + shape = [var.shape[0] * var.shape[1]] + list(var.shape[2:]) + var = var.reshape(*shape) + var = torch.tensor(var, dtype=dtype) + return var + else: + return var + + +def gather(var, idx): + if isinstance(var, list): + return [gather(x, idx) for x in var] + elif isinstance(var, dict): + return {k: gather(v, idx) for k, v in var.items()} + elif isinstance(var, torch.Tensor): + out = var.index_select(dim=0, index=idx) + return out + else: + return var + + +class SpaceGenerator(object): + """ Genrator class. """ + + _registry = dict() + + @classmethod + def register(cls, name): + SpaceGenerator._registry[name] = cls + return + + @staticmethod + def by_name(name): + return SpaceGenerator._registry[name] + + @staticmethod + def create(config, *args, **kwargs): + """ Create generator. """ + generator_cls = SpaceGenerator.by_name(config.Generator.generator) + return generator_cls(config, *args, **kwargs) + + def __init__(self, config, reader): + self.vocab_size = reader.vocab_size + self.bos_id = reader.bos_id + self.eos_id = reader.eos_id + self.unk_id = reader.unk_id + self.pad_id = reader.pad_id + self.min_gen_len = config.Generator.min_gen_len + self.max_gen_len = config.Generator.max_gen_len + self.use_gpu = config.use_gpu + assert 1 <= self.min_gen_len <= self.max_gen_len + return + + def __call__(self, step_fn, state): + """Running generation. + + Args: + step_fn (`function`) : decoding one step + state(`dict`) : initial state + """ + raise NotImplementedError + + +class BeamSearch(SpaceGenerator): + """ BeamSearch generator. """ + + def __init__(self, config, reader): + super().__init__(config, reader) + self.beam_size = config.Generator.beam_size + self.length_average = config.Generator.length_average + self.length_penalty = config.Generator.length_penalty + self.ignore_unk = config.Generator.ignore_unk + return + + def __call__(self, + step_fn, + state, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ + Running beam search. + + Args: + step_fn(`function`) : decoding one step + state(`dict`) : initial state + """ + if prev_input is not None: + + if isinstance(prev_input, list): + length = max(list(map(lambda x: len(x), prev_input))) + prev_input_numpy = np.full((len(prev_input), length), + self.pad_id) + for i, x in enumerate(prev_input): + prev_input_numpy[i, :len(x)] = x + prev_input_tensor = torch.from_numpy(prev_input_numpy) + if self.use_gpu: + prev_input_tensor = prev_input_tensor.cuda() + + for i in range(length): + state['pred_token'] = prev_input_tensor[:, i].unsqueeze( + -1).unsqueeze(-1) + if i != 0: + state['pred_mask'] = torch.not_equal( + state['pred_token'], self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + state[ + 'pred_mask'].int() + _, state = step_fn(state) + else: + assert isinstance(prev_input, torch.Tensor) + for i, input in enumerate(prev_input): + state['pred_token'] = input.expand(1, 1, 1) + if i != 0: + state['pred_mask'] = torch.not_equal( + state['pred_token'], self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + 1 + _, state = step_fn(state) + + batch_size = state['batch_size'] + beam_size = self.beam_size + + # shape: [batch_size, 1] + pos_index = torch.arange( + 0, batch_size, 1, dtype=torch.int64) * beam_size + pos_index = pos_index.unsqueeze(1) + + # shape: [batch_size, beam_size, 1] + if start_id is None: + start_id = self.bos_id + if eos_id is None: + eos_id = self.eos_id + predictions = torch.ones([batch_size, beam_size, 1], + dtype=torch.int64) * start_id + + if self.use_gpu: + pos_index = pos_index.cuda() + predictions = predictions.cuda() + + # initial input (start_id) + state['pred_token'] = predictions[:, :1] + if prev_input is not None: + state['pred_mask'] = torch.not_equal(state['pred_token'], + self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + 1 + + # shape: [batch_size, vocab_size] + scores, state = step_fn(state) + + unk_penalty = np.zeros(self.vocab_size, dtype='float32') + unk_penalty[self.unk_id] = -1e10 + unk_penalty = torch.from_numpy(unk_penalty) + + eos_penalty = np.zeros(self.vocab_size, dtype='float32') + eos_penalty[eos_id] = -1e10 + eos_penalty = torch.from_numpy(eos_penalty) + + scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') + scores_after_end[ + self. + pad_id] = 0 # we want is generated after ,so maximum log(p()) is (0) + scores_after_end = torch.from_numpy(scores_after_end) + + if self.use_gpu: + unk_penalty = unk_penalty.cuda() + eos_penalty = eos_penalty.cuda() + scores_after_end = scores_after_end.cuda() + + if self.ignore_unk: + scores = scores + unk_penalty + scores = scores + eos_penalty + + # shape: [batch_size, beam_size] + sequence_scores, preds = torch.topk(scores, self.beam_size) + + predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) + state = repeat(state, beam_size) + + if max_gen_len is None: + max_gen_len = self.max_gen_len + for step in range(2, max_gen_len + 1): + pre_ids = predictions[:, :, -1:] + state['pred_token'] = pre_ids.reshape(batch_size * beam_size, 1, 1) + state['pred_mask'] = torch.not_equal(state['pred_token'], + self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + 1 + scores, state = step_fn(state) + + # Generate next + # scores shape: [batch_size * beam_size, vocab_size] + if self.ignore_unk: + scores = scores + unk_penalty + + if step <= self.min_gen_len: + scores = scores + eos_penalty + + # scores shape: [batch_size, beam_size, vocab_size] + scores = scores.reshape(batch_size, beam_size, self.vocab_size) + + # previous token is [PAD] or [EOS] + pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ + (1 - torch.not_equal(pre_ids, self.pad_id).float()) + + scores = scores * (1 - pre_eos_mask) + pre_eos_mask.repeat( + 1, 1, self.vocab_size) * scores_after_end + if self.length_average: + scaled_value = \ + pre_eos_mask + (1 - pre_eos_mask) * (1 - 1 / step) + sequence_scores = sequence_scores.unsqueeze(2) * scaled_value + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) + scores = scores * scaled_value + elif self.length_penalty >= 0.0: + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ + (math.pow((4 + step) / (5 + step), self.length_penalty)) + sequence_scores = scaled_value * sequence_scores + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ + (math.pow(1 / (5 + step), self.length_penalty)) + scores = scores * scaled_value + scores = scores + sequence_scores.unsqueeze(-1) + scores = scores.reshape(batch_size, beam_size * self.vocab_size) + + topk_scores, topk_indices = torch.topk(scores, beam_size) + # topk_indices: [batch_size, beam_size * self.vocab_size] (already reshaped) + parent_idx = topk_indices.floor_divide(self.vocab_size) + preds = topk_indices % self.vocab_size + + # Gather state / sequence_scores + parent_idx = parent_idx + pos_index + parent_idx = parent_idx.reshape(batch_size * beam_size) + state = gather(state, parent_idx) + sequence_scores = topk_scores + + predictions = predictions.reshape(batch_size * beam_size, step) + predictions = gather(predictions, parent_idx) + predictions = predictions.reshape(batch_size, beam_size, step) + predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) + + # The last token should be or + pre_ids = predictions[:, :, -1] + pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ + (1 - torch.not_equal(pre_ids, self.pad_id).float()) + sequence_scores = sequence_scores * pre_eos_mask + ( + 1 - pre_eos_mask) * (-1e10) + + # first get ascending ordered index,then sort "predictions" and "sequence_scores" + indices = torch.argsort(sequence_scores, dim=1) + indices = indices + pos_index + indices = indices.reshape(-1) + sequence_scores = sequence_scores.reshape(batch_size * beam_size) + predictions = predictions.reshape(batch_size * beam_size, -1) + sequence_scores = gather(sequence_scores, indices) + predictions = gather(predictions, indices) + sequence_scores = sequence_scores.reshape(batch_size, beam_size) + predictions = predictions.reshape(batch_size, beam_size, -1) + + results = { + 'preds': predictions[:, -1], + 'scores': sequence_scores[:, -1] + } + return results + + +BeamSearch.register('BeamSearch') diff --git a/modelscope/models/nlp/space/model/intent_unified_transformer.py b/modelscope/models/nlp/space/model/intent_unified_transformer.py new file mode 100644 index 00000000..11385a6f --- /dev/null +++ b/modelscope/models/nlp/space/model/intent_unified_transformer.py @@ -0,0 +1,197 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.utils.nlp.space.criterions import compute_kl_loss +from .unified_transformer import UnifiedTransformer + + +class IntentUnifiedTransformer(UnifiedTransformer): + """ + Implement intent unified transformer. + """ + + def __init__(self, model_dir, config, reader, generator): + super(IntentUnifiedTransformer, self).__init__(model_dir, config, + reader, generator) + self.example = config.Model.example + self.num_intent = config.Model.num_intent + self.with_rdrop = config.Model.with_rdrop + self.kl_ratio = config.Model.kl_ratio + self.loss_fct = nn.CrossEntropyLoss() + if self.example: + self.loss_fct = nn.NLLLoss() + else: + self.intent_classifier = nn.Linear(self.hidden_dim, + self.num_intent) + self.loss_fct = nn.CrossEntropyLoss() + + if self.use_gpu: + self.cuda() + return + + def _forward(self, inputs, is_training, with_label): + """ Real forward process of model in different mode(train/test). """ + + def aug(v): + assert isinstance(v, torch.Tensor) + return torch.cat([v, v], dim=0) + + outputs = {} + + if self.with_mlm: + mlm_embed = self._encoder_network( + input_token=inputs['mlm_token'], + input_mask=inputs['src_mask'], + input_pos=inputs['src_pos'], + input_type=inputs['src_type'], + input_turn=inputs['src_turn']) + outputs['mlm_probs'] = self._mlm_head(mlm_embed=mlm_embed) + + if self.with_rdrop or self.with_contrastive: + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=aug(inputs['src_token']), + src_mask=aug(inputs['src_mask']), + tgt_token=aug(inputs['tgt_token']), + tgt_mask=aug(inputs['tgt_mask']), + src_pos=aug(inputs['src_pos']), + src_type=aug(inputs['src_type']), + src_turn=aug(inputs['src_turn'])) + else: + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'], + tgt_mask=inputs['tgt_mask'], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + features = dec_embed[:, -1] + features = self.pooler(features) if self.with_pool else features + + if self.example: + assert not self.with_rdrop + ex_enc_embed, ex_dec_embed = self._encoder_decoder_network( + src_token=inputs['example_src_token'], + src_mask=inputs['example_src_mask'], + tgt_token=inputs['example_tgt_token'], + tgt_mask=inputs['example_tgt_mask'], + src_pos=inputs['example_src_pos'], + src_type=inputs['example_src_type'], + src_turn=inputs['example_src_turn']) + ex_features = ex_dec_embed[:, -1] + ex_features = self.pooler( + ex_features) if self.with_pool else ex_features + + probs = self.softmax(features.mm(ex_features.t())) + example_intent = inputs['example_intent'].unsqueeze(0) + intent_probs = torch.zeros(probs.size(0), self.num_intent) + intent_probs = intent_probs.cuda( + ) if self.use_gpu else intent_probs + intent_probs = intent_probs.scatter_add( + -1, example_intent.repeat(probs.size(0), 1), probs) + outputs['intent_probs'] = intent_probs + else: + intent_logits = self.intent_classifier(features) + outputs['intent_logits'] = intent_logits + + if self.with_contrastive: + features = features if self.with_pool else self.pooler(features) + batch_size = features.size(0) // 2 + features = \ + torch.cat( + [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], + dim=1 + ) + features = F.normalize(features, dim=-1, p=2) + outputs['features'] = features + + return outputs + + def _collect_metrics(self, inputs, outputs, with_label, data_file): + + metrics = {} + batch_size = inputs['src_token'].size(0) + + intent_label = torch.cat([inputs['intent_label'], inputs['intent_label']], dim=0) \ + if self.with_rdrop or self.with_contrastive else inputs['intent_label'] + + if self.example: + intent_loss = self.loss_fct( + torch.log(outputs['intent_probs'] + 1e-12).view( + -1, self.num_intent), intent_label.type(torch.long)) + else: + intent_loss = self.loss_fct( + outputs['intent_logits'].view(-1, self.num_intent), + intent_label.type(torch.long)) + metrics['intent_loss'] = intent_loss + loss = intent_loss + + if self.with_mlm: + mlm_num = torch.sum(torch.sum(inputs['mlm_mask'], dim=1)) + mlm = self.nll_loss( + torch.log(outputs['mlm_probs'] + 1e-12).permute(0, 2, 1), + inputs['mlm_label']) + mlm = torch.sum(mlm, dim=1) + token_mlm = torch.sum(mlm) / mlm_num + mlm = torch.mean(mlm) + metrics['mlm'] = mlm + metrics['token_mlm'] = token_mlm + metrics['mlm_num'] = mlm_num + loss = loss + (token_mlm + if self.token_loss else mlm) * self.mlm_ratio + else: + mlm, token_mlm, mlm_num = None, None, None + + if self.with_rdrop: + kl = compute_kl_loss( + p=outputs['intent_logits'][:batch_size], + q=outputs['intent_logits'][batch_size:]) + metrics['kl'] = kl + loss = loss + kl * self.kl_ratio + else: + kl = None + + if self.with_contrastive: + pass + con = None + else: + con = None + + metrics['loss'] = loss + + if self.gpu > 1: + return intent_loss, mlm, token_mlm, mlm_num, kl, con + else: + return metrics + + def _infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ Real inference process of model. """ + results = {} + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'], + tgt_mask=inputs['tgt_mask'], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + features = dec_embed[:, -1] + features = self.pooler(features) if self.with_pool else features + if self.example: + results['features'] = features + else: + intent_logits = self.intent_classifier(features) + intent_probs = self.softmax(intent_logits) + results['intent_probs'] = intent_probs + return results + + +IntentUnifiedTransformer.register('IntentUnifiedTransformer') diff --git a/modelscope/models/nlp/space/model/model_base.py b/modelscope/models/nlp/space/model/model_base.py new file mode 100644 index 00000000..b7812182 --- /dev/null +++ b/modelscope/models/nlp/space/model/model_base.py @@ -0,0 +1,100 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +import torch.nn as nn + +from modelscope.utils.constant import ModelFile + + +class SpaceModelBase(nn.Module): + """ + Basic model wrapper for static graph and dygrpah. + """ + _registry = dict() + + @classmethod + def register(cls, name): + SpaceModelBase._registry[name] = cls + return + + @staticmethod + def by_name(name): + return SpaceModelBase._registry[name] + + @staticmethod + def create(model_dir, config, *args, **kwargs): + model_cls = SpaceModelBase.by_name(config.Model.model) + return model_cls(model_dir, config, *args, **kwargs) + + def __init__(self, model_dir, config): + super(SpaceModelBase, self).__init__() + self.init_checkpoint = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + self.abandon_label = config.Dataset.abandon_label + self.use_gpu = config.use_gpu + self.gpu = config.Trainer.gpu + return + + def _create_parameters(self): + """ Create model's paramters. """ + raise NotImplementedError + + def _forward(self, inputs, is_training, with_label): + """ NO LABEL: Real forward process of model in different mode(train/test). """ + raise NotImplementedError + + def _collect_metrics(self, inputs, outputs, with_label, data_file): + """ NO LABEL: Calculate loss function by using inputs and outputs. """ + raise NotImplementedError + + def _optimize(self, loss, optimizer, lr_scheduler): + """ Optimize loss function and update model. """ + raise NotImplementedError + + def _infer(self, inputs, start_id, eos_id, max_gen_len, prev_input): + """ Real inference process of model. """ + raise NotImplementedError + + def forward(self, + inputs, + is_training=False, + with_label=False, + data_file=None): + """ + Forward process, include real forward, collect metrices and optimize(optional) + + Args: + inputs(`dict` of numpy.ndarray/int/float/...) : input data + """ + if is_training: + self.train() + else: + self.eval() + + with_label = False if self.abandon_label else with_label + outputs = self._forward(inputs, is_training, with_label=with_label) + metrics = self._collect_metrics( + inputs, outputs, with_label=with_label, data_file=data_file) + + return metrics + + def infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """Inference process. + + Args: + inputs(`dict` of numpy.ndarray/int/float/...) : input data + """ + self.eval() + results = self._infer( + inputs, + start_id=start_id, + eos_id=eos_id, + max_gen_len=max_gen_len, + prev_input=prev_input) + return results diff --git a/modelscope/models/nlp/space/model/tokenization_space.py b/modelscope/models/nlp/space/model/tokenization_space.py new file mode 100644 index 00000000..e3b358d4 --- /dev/null +++ b/modelscope/models/nlp/space/model/tokenization_space.py @@ -0,0 +1,29 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Tokenization classes for Space. mainly copied from :module:`~transformers.tokenization_xlm_roberta`""" + +from modelscope.models.nlp.structbert import (BasicTokenizer, SbertTokenizer, + WordpieceTokenizer) +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class SpaceTokenizer(SbertTokenizer): + """ + This class overrides [`SpaceTokenizer`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ diff --git a/modelscope/models/nlp/space/model/unified_transformer.py b/modelscope/models/nlp/space/model/unified_transformer.py new file mode 100644 index 00000000..19069971 --- /dev/null +++ b/modelscope/models/nlp/space/model/unified_transformer.py @@ -0,0 +1,308 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.nlp.space.model.model_base import SpaceModelBase +from modelscope.models.nlp.space.modules.embedder import Embedder +from modelscope.models.nlp.space.modules.transformer_block import \ + TransformerBlock + + +class UnifiedTransformer(SpaceModelBase): + """ + Implement unified transformer. + """ + + def __init__(self, model_dir, config, reader, generator, dtype='float32'): + super(UnifiedTransformer, self).__init__(model_dir, config) + self.reader = reader + self.generator = generator + self.policy = config.BPETextField.policy + self.generation = config.BPETextField.generation + self.num_token_embeddings = config.Model.num_token_embeddings + self.num_pos_embeddings = config.Model.num_pos_embeddings + self.num_type_embeddings = config.Model.num_type_embeddings + self.num_turn_embeddings = config.Model.num_turn_embeddings + self.temperature = config.Model.temperature + self.hidden_dim = config.Model.hidden_dim + self.num_heads = config.Model.num_heads + self.num_layers = config.Model.num_layers + self.padding_idx = config.Model.padding_idx + self.dropout = config.Model.dropout + self.embed_dropout = config.Model.embed_dropout + self.attn_dropout = config.Model.attn_dropout + self.ff_dropout = config.Model.ff_dropout + self.mlm_ratio = config.Model.mlm_ratio + self.mmd_ratio = config.Model.mmd_ratio + self.pos_trainable = config.Model.pos_trainable + self.label_smooth = config.Model.label_smooth + self.initializer_range = config.Model.initializer_range + self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps + self.token_loss = config.Trainer.token_loss + self.learning_method = config.Dataset.learning_method + self.with_contrastive = config.Dataset.with_contrastive + self.with_query_bow = config.BPETextField.with_query_bow + self.with_resp_bow = config.BPETextField.with_resp_bow + self.with_pool = config.Model.with_pool + self.with_mlm = config.Dataset.with_mlm + self._dtype = dtype + + self.embedder = Embedder( + self.hidden_dim, + self.num_token_embeddings, + self.num_pos_embeddings, + self.num_type_embeddings, + self.num_turn_embeddings, + padding_idx=self.padding_idx, + dropout=self.embed_dropout, + pos_trainable=self.pos_trainable) + self.embed_layer_norm = nn.LayerNorm( + normalized_shape=self.hidden_dim, + eps=1e-12, + elementwise_affine=True) + + self.layers = nn.ModuleList([ + TransformerBlock(self.hidden_dim, self.num_heads, self.dropout, + self.attn_dropout, self.ff_dropout) + for _ in range(config.Model.num_layers) + ]) + + if self.with_mlm: + self.mlm_transform = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), + nn.LayerNorm( + normalized_shape=self.hidden_dim, + eps=1e-12, + elementwise_affine=True)) + self.mlm_bias = nn.Parameter( + torch.zeros(self.num_token_embeddings)) + + self.pooler = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), nn.Tanh()) + + if self.with_query_bow or self.with_resp_bow: + self.bow_predictor = nn.Linear( + self.hidden_dim, self.num_token_embeddings, bias=False) + + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(dim=-1) + self.bce_loss = nn.BCELoss(reduction='none') + self.nll_loss = nn.NLLLoss( + ignore_index=self.padding_idx, reduction='none') + self._create_parameters() + + self.max_grad_norm = config.Model.max_grad_norm + if self.max_grad_norm is not None: + self.grad_clip = self.max_grad_norm + else: + self.grad_clip = None + self.weight_decay = config.Model.weight_decay + + if self.use_gpu: + self.cuda() + + return + + def _create_parameters(self): + """ Create model's paramters. """ + sequence_mask = np.tri( + self.num_pos_embeddings, + self.num_pos_embeddings, + dtype=self._dtype) + self.sequence_mask = torch.tensor(sequence_mask) + return + + def _create_mask(self, + input_mask, + append_head=False, + auto_regressive=False): + """Create attention mask. + from sequence to matrix:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] + + Args: + input_mask (Variable(shape: [batch_size, max_seq_len])) + auto_regressive(bool) + """ + seq_len = input_mask.shape[1] + + input_mask = input_mask.float() + mask1 = input_mask.unsqueeze(-1).repeat(1, 1, seq_len) + mask2 = mask1.permute(0, 2, 1) + mask = mask1 * mask2 + + if append_head: + mask = torch.cat([mask[:, :1, :], mask], dim=1) + mask = torch.cat([mask[:, :, :1], mask], dim=2) + seq_len += 1 + + if auto_regressive: + seq_mask = self.sequence_mask[:seq_len, :seq_len] + seq_mask = seq_mask.to(mask.device) + mask = mask * seq_mask + + mask = 1 - mask + return mask + + def _join_mask(self, mask1, mask2): + """Merge source attention mask and target attention mask. + There are four parts:left upper (lu) / right upper (ru) / left below (lb) / right below (rb) + + Args: + mask1(Variable(shape: [batch_size, max_src_len, max_src_len])) : source attention mask + mask2(Variable(shape: [batch_size, max_tgt_len, max_tgt_len])) : target attention mask + """ + batch_size = mask1.shape[0] + seq_len1 = mask1.shape[1] + seq_len2 = mask2.shape[1] + # seq_len = seq_len1 + seq_len2 + + mask_lu = mask1 + mask_ru = torch.ones(batch_size, seq_len1, seq_len2).to(mask_lu.device) + if self.use_gpu: + mask_ru = mask_ru.cuda() + mask3 = mask2[:, :, :1].repeat(1, 1, seq_len1) + mask4 = mask1[:, :1].repeat(1, seq_len2, 1) + mask_lb = mask3 + mask4 - mask3 * mask4 + mask_rb = mask2 + mask_u = torch.cat([mask_lu, mask_ru], dim=2) + mask_b = torch.cat([mask_lb, mask_rb], dim=2) + mask = torch.cat([mask_u, mask_b], dim=1) + return mask + + def _mlm_head(self, mlm_embed): + mlm_embed = self.mlm_transform(mlm_embed) + mlm_logits = torch.matmul( + mlm_embed, self.embedder.token_embedding.weight.T) + self.mlm_bias + mlm_probs = self.softmax(mlm_logits) + return mlm_probs + + def _dec_head(self, dec_embed): + dec_logits = torch.matmul(dec_embed, + self.embedder.token_embedding.weight.T) + dec_probs = self.softmax(dec_logits) + return dec_probs + + def _refactor_feature(self, features): + features = self.pooler(features) if self.with_pool else features + batch_size = features.size(0) // 2 + features = \ + torch.cat( + [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], + dim=1 + ) + features = F.normalize(features, dim=-1, p=2) + return features + + def _encoder_network(self, + input_token, + input_mask, + input_pos=None, + input_type=None, + input_turn=None): + embed = self.embedder(input_token, input_pos, input_type, input_turn) + embed = self.embed_layer_norm(embed) + mask = self._create_mask(input_mask, auto_regressive=False) + + for layer in self.layers: + embed = layer(embed, mask, None) + + return embed + + def _encoder_decoder_network(self, + src_token, + src_mask, + tgt_token, + tgt_mask, + src_pos=None, + src_type=None, + src_turn=None, + tgt_pos=None, + tgt_type=None, + tgt_turn=None): + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) + embed = torch.cat([src_embed, tgt_embed], dim=1) + embed = self.embed_layer_norm(embed) + + enc_mask = self._create_mask(src_mask, auto_regressive=False) + dec_mask = self._create_mask(tgt_mask, auto_regressive=True) + mask = self._join_mask(enc_mask, dec_mask) + + for layer in self.layers: + embed = layer(embed, mask, None) + + tgt_len = tgt_token.shape[1] + enc_embed = embed[:, :-tgt_len] + dec_embed = embed[:, -tgt_len:] + + return enc_embed, dec_embed + + def _encoder_prompt_decoder_network(self, + src_token, + src_mask, + tgt_token, + tgt_mask, + prompt_token, + prompt_mask, + src_pos=None, + src_type=None, + src_turn=None, + tgt_pos=None, + tgt_type=None, + tgt_turn=None, + prompt_pos=None, + prompt_type=None, + prompt_turn=None): + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) + prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, + prompt_turn) + + embed = torch.cat([src_embed, prompt_embed, tgt_embed], dim=1) + embed = self.embed_layer_norm(embed) + + enc_mask = self._create_mask(src_mask, auto_regressive=False) + dec_mask = self._create_mask( + torch.cat([prompt_mask, tgt_mask], dim=1), auto_regressive=True) + mask = self._join_mask(enc_mask, dec_mask) + + for layer in self.layers: + embed = layer(embed, mask, None) + + src_len = src_token.shape[1] + tgt_len = tgt_token.shape[1] + enc_embed = embed[:, :src_len] + dec_embed = embed[:, -tgt_len:] + prompt_embed = embed[:, src_len:-tgt_len] + + return enc_embed, dec_embed, prompt_embed + + def _optimize(self, loss, optimizer=None, lr_scheduler=None): + """ Optimize loss function and update model. """ + assert optimizer is not None + optimizer.zero_grad() + loss.backward() + + if self.grad_clip is not None and self.grad_clip > 0: + torch.nn.utils.clip_grad_norm_( + parameters=self.parameters(), max_norm=self.grad_clip) + optimizer.step() + if lr_scheduler is not None: + lr_scheduler.step() + return + + def _infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ Real inference process of model. """ + results = {} + return results + + +UnifiedTransformer.register('UnifiedTransformer') diff --git a/modelscope/models/nlp/space/modules/__init__.py b/modelscope/models/nlp/space/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/nlp/space/modules/embedder.py b/modelscope/models/nlp/space/modules/embedder.py new file mode 100644 index 00000000..e68ac7d3 --- /dev/null +++ b/modelscope/models/nlp/space/modules/embedder.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn + + +class Embedder(nn.Module): + """ + Composite embedding layer. + """ + + def __init__(self, + hidden_dim, + num_token_embeddings, + num_pos_embeddings, + num_type_embeddings, + num_turn_embeddings, + padding_idx=None, + dropout=0.1, + pos_trainable=False): + super(Embedder, self).__init__() + + self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim) + self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim) + self.pos_embedding.weight.requires_grad = pos_trainable + self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim) + self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim) + self.dropout_layer = nn.Dropout(p=dropout) + + # follow the default xavier_uniform initializer in paddle version + # otherwise, there are bugs for dec_probs computation in weight typing setting + # default norm initializer in nn.Embedding in pytorch, which samples larger values + nn.init.xavier_uniform_(self.token_embedding.weight) + nn.init.xavier_uniform_(self.pos_embedding.weight) + nn.init.xavier_uniform_(self.type_embedding.weight) + nn.init.xavier_uniform_(self.turn_embedding.weight) + return + + def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None): + embed = self.token_embedding(token_inp) + if pos_inp is not None: + embed += self.pos_embedding(pos_inp) + if type_inp is not None: + embed += self.type_embedding(type_inp) + if turn_inp is not None: + embed += self.turn_embedding(turn_inp) + embed = self.dropout_layer(embed) + return embed + + +def main(): + import numpy as np + + model = Embedder(10, 20, 20, 20, 20) + token_inp = torch.tensor( + np.random.randint(0, 19, [10, 10]).astype('int64')) + pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) + type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) + turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) + out = model(token_inp, pos_inp, type_inp, turn_inp) + print(out) + + +if __name__ == '__main__': + main() diff --git a/modelscope/models/nlp/space/modules/feedforward.py b/modelscope/models/nlp/space/modules/feedforward.py new file mode 100644 index 00000000..43318eb6 --- /dev/null +++ b/modelscope/models/nlp/space/modules/feedforward.py @@ -0,0 +1,41 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn + + +class FeedForward(nn.Module): + """ + Positional feed forward layer. + """ + + def __init__(self, hidden_dim, inner_dim, dropout): + super(FeedForward, self).__init__() + + self.hidden_dim = hidden_dim + self.inner_dim = inner_dim + self.linear_hidden = nn.Sequential( + nn.Linear(hidden_dim, inner_dim), nn.GELU()) + self.linear_out = nn.Linear(inner_dim, hidden_dim) + self.dropout_layer = nn.Dropout(p=dropout) + return + + def forward(self, x): + out = self.linear_hidden(x) + out = self.dropout_layer(out) + out = self.linear_out(out) + return out + + +def main(): + import numpy as np + + model = FeedForward(10, 20, 0.5) + inp = np.random.rand(2, 3, 10).astype('float32') + inp = torch.tensor(inp) + out = model(inp) + print(out) + + +if __name__ == '__main__': + main() diff --git a/modelscope/models/nlp/space/modules/functions.py b/modelscope/models/nlp/space/modules/functions.py new file mode 100644 index 00000000..daa62bb4 --- /dev/null +++ b/modelscope/models/nlp/space/modules/functions.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn.functional as F + + +def unsqueeze(input, dims): + """ Implement multi-dimension unsqueeze function. """ + if isinstance(dims, (list, tuple)): + dims = [ + dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims + ] + dims = sorted(dims, reverse=True) + shape = list(input.shape) + for dim in dims: + shape.insert(dim, 1) + return torch.reshape(input, shape) + elif isinstance(dims, int): + return input.unsqueeze(dims) + else: + raise ValueError('Warning: type(dims) must in (list, tuple, int)!') + + +def gumbel_softmax(input, tau=1, eps=1e-10): + """ Basic implement of gumbel_softmax. """ + U = torch.tensor(np.random.rand(*input.shape)) + gumbel = 0.0 - torch.log(eps - torch.log(U + eps)) + y = input + gumbel + return F.softmax(y / tau) + + +def equal(x, y, dtype=None): + """ Implement equal in dygraph mode. (paddle) """ + if dtype is None: + dtype = 'float32' + if isinstance(x, torch.Tensor): + x = x.numpy() + if isinstance(y, torch.Tensor): + y = y.numpy() + out = np.equal(x, y).astype(dtype) + return torch.tensor(out) + + +def not_equal(x, y, dtype=None): + """ Implement not_equal in dygraph mode. (paddle) """ + return 1 - equal(x, y, dtype) + + +if __name__ == '__main__': + a = torch.tensor([[1, 1], [3, 4]]) + b = torch.tensor([[1, 1], [3, 4]]) + c = torch.equal(a, a) + c1 = equal(a, 3) + d = 1 - torch.not_equal(a, 3).float() + print(c) + print(c1) + print(d) + e = F.gumbel_softmax(a) + f = a.unsqueeze(a) + g = unsqueeze(a, dims=[0, 0, 1]) + print(g, g.shape) diff --git a/modelscope/models/nlp/space/modules/multihead_attention.py b/modelscope/models/nlp/space/modules/multihead_attention.py new file mode 100644 index 00000000..d075e9c5 --- /dev/null +++ b/modelscope/models/nlp/space/modules/multihead_attention.py @@ -0,0 +1,105 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn + + +class MultiheadAttention(nn.Module): + """ + Multi head attention layer. + """ + + def __init__(self, hidden_dim, num_heads, dropout): + assert hidden_dim % num_heads == 0 + super(MultiheadAttention, self).__init__() + + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.scale = self.head_dim**-0.5 + self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3) + self.linear_out = nn.Linear(hidden_dim, hidden_dim) + self.dropout_layer = nn.Dropout(p=dropout) + self.softmax = nn.Softmax(dim=-1) + return + + def _split_heads(self, x, is_key=False): + x = x.reshape(x.size(0), x.size(1), self.num_heads, self.head_dim) + x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3) + return x + + def _merge_heads(self, x): + x = x.permute(0, 2, 1, 3) + x = x.reshape(x.size(0), x.size(1), self.hidden_dim) + return x + + def _attn(self, query, key, value, mask): + # shape: [batch_size, num_head, seq_len, seq_len] + scores = torch.matmul(query, key) + scores = scores * self.scale + + if mask is not None: + mask = mask.unsqueeze(1) + mask = mask.repeat(1, self.num_heads, 1, 1) + scores.masked_fill_( + mask.bool(), + float('-inf')) # scores = (1 - mask) * scores + mask * (-1e10) + + attn = self.softmax(scores) + attn = self.dropout_layer(attn) + + if mask is not None: + ''' + mask: [batch size, num_heads, seq_len, seq_len] + + >>> F.softmax([-1e10, -100, -100]) + >>> [0.00, 0.50, 0.50] + >>> F.softmax([-1e10, -1e10, -1e10]) + >>> [0.33, 0.33, 0.33] + ==> [0.00, 0.00, 0.00] + ''' + attn.masked_fill_(mask.bool(), 0.) # attn = (1 - mask) * attn + + out = torch.matmul(attn, value) + return out + + def forward(self, inp, mask=None, cache=None): + """ Forward process of self attention. """ + # shape: [batch_size, seq_len, 3 * hidden_dim] + qkv = self.linear_qkv(inp) + query, key, value = torch.split(qkv, self.hidden_dim, dim=2) + + # shape: [batch_size, num_head, seq_len, head_dim] + query = self._split_heads(query) + # shape: [batch_size, num_head, head_dim, seq_len] + key = self._split_heads(key, is_key=True) + # shape: [batch_size, num_head, seq_len, head_dim] + value = self._split_heads(value) + + if cache is not None: + if 'key' in cache and 'value' in cache: + key = torch.cat([cache['key'], key], dim=3) + value = torch.cat([cache['value'], value], dim=2) + cache['key'] = key + cache['value'] = value + + out = self._attn(query, key, value, mask) + out = self._merge_heads(out) + out = self.linear_out(out) + return out + + +def main(): + import numpy as np + + model = MultiheadAttention(10, 2, 0.5) + inp = np.random.rand(2, 3, 10).astype('float32') + inp = torch.tensor(inp) + mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') + mask = torch.tensor(mask) + out = model(inp, mask=mask, cache=None) + print(out) + + +if __name__ == '__main__': + main() diff --git a/modelscope/models/nlp/space/modules/transformer_block.py b/modelscope/models/nlp/space/modules/transformer_block.py new file mode 100644 index 00000000..3044963a --- /dev/null +++ b/modelscope/models/nlp/space/modules/transformer_block.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn + +from .feedforward import FeedForward +from .multihead_attention import MultiheadAttention + + +class TransformerBlock(nn.Module): + """ + Transformer block module. + """ + + def __init__(self, hidden_dim, num_heads, dropout, attn_dropout, + ff_dropout): + super(TransformerBlock, self).__init__() + + self.attn = MultiheadAttention( + hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout) + self.attn_norm = nn.LayerNorm( + normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) + self.ff = FeedForward( + hidden_dim=hidden_dim, + inner_dim=4 * hidden_dim, + dropout=ff_dropout) + self.ff_norm = nn.LayerNorm( + normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) + self.dropout_layer = nn.Dropout(p=dropout) + return + + def forward(self, inp, mask=None, cache=None): + """Forward process on one transformer layer. + + Args: + x(Variable(shape: [batch_size, seq_len, hidden_size])) + memory(Variable(shape: [batch_size, seq_len, hidden_size])) + mask + cache + """ + attn_out = self.attn(inp, mask, cache) + attn_out = self.dropout_layer(attn_out) + attn_out = self.attn_norm(attn_out + inp) + + ff_out = self.ff(attn_out) + ff_out = self.dropout_layer(ff_out) + ff_out = self.ff_norm(ff_out + attn_out) + + return ff_out + + +def main(): + import numpy as np + + model = TransformerBlock(10, 2, 0.5, 0.5, 0.5) + inp = np.random.rand(2, 3, 10).astype('float32') + inp = torch.tensor(inp) + mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') + mask = torch.tensor(mask) + out = model(inp, mask=mask, cache=None) + print(out) + + +if __name__ == '__main__': + main() diff --git a/modelscope/models/nlp/space_T_cn/__init__.py b/modelscope/models/nlp/space_T_cn/__init__.py new file mode 100644 index 00000000..b9deb700 --- /dev/null +++ b/modelscope/models/nlp/space_T_cn/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .table_question_answering import TableQuestionAnswering +else: + _import_structure = { + 'table_question_answering': ['TableQuestionAnswering'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/space_T_cn/backbone.py b/modelscope/models/nlp/space_T_cn/backbone.py new file mode 100644 index 00000000..5afde06e --- /dev/null +++ b/modelscope/models/nlp/space_T_cn/backbone.py @@ -0,0 +1,1001 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import absolute_import, division, print_function +import copy +import math +import os +import shutil +import tarfile +import tempfile + +import numpy as np +import torch +from torch import nn + +from modelscope.models.nlp.space_T_cn.configuration import SpaceTCnConfig +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() + +CONFIG_NAME = ModelFile.CONFIGURATION +WEIGHTS_NAME = ModelFile.TORCH_MODEL_BIN_FILE + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish} + + +class BertLayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + self.match_type_embeddings = nn.Embedding(11, config.hidden_size) + self.type_embeddings = nn.Embedding(6, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, + input_ids, + header_ids, + token_type_ids=None, + match_type_ids=None, + l_hs=None, + header_len=None, + type_idx=None, + col_dict_list=None, + ids=None, + header_flatten_tokens=None, + header_flatten_index=None, + header_flatten_output=None, + token_column_id=None, + token_column_mask=None, + column_start_index=None, + headers_length=None): + seq_length = input_ids.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + words_embeddings = self.word_embeddings(input_ids) + header_embeddings = self.word_embeddings(header_ids) + + if col_dict_list is not None and l_hs is not None: + col_dict_list = np.array(col_dict_list)[ids.cpu().numpy()].tolist() + header_len = np.array( + header_len, dtype=object)[ids.cpu().numpy()].tolist() + for bi, col_dict in enumerate(col_dict_list): + for ki, vi in col_dict.items(): + length = header_len[bi][vi] + if length == 0: + continue + words_embeddings[bi, ki, :] = torch.mean( + header_embeddings[bi, vi, :length, :], dim=0) + + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = words_embeddings + position_embeddings + token_type_embeddings + + if match_type_ids is not None: + match_type_embeddings = self.match_type_embeddings(match_type_ids) + embeddings += match_type_embeddings + + if type_idx is not None: + type_embeddings = self.type_embeddings(type_idx) + embeddings += type_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, schema_link_matrix=None): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfAttentionWithRelationsRAT(nn.Module): + ''' + Adapted from https://github.com/microsoft/rat-sql/blob/master/ratsql/models/transformer.py + ''' + + def __init__(self, config): + super(BertSelfAttentionWithRelationsRAT, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.relation_k_emb = nn.Embedding( + 7, config.hidden_size // config.num_attention_heads) + self.relation_v_emb = nn.Embedding( + 7, config.hidden_size // config.num_attention_heads) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, relation): + ''' + relation is [batch, seq len, seq len] + ''' + mixed_query_layer = self.query( + hidden_states) # [batch, seq len, hidden dim] + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + relation_k = self.relation_k_emb( + relation) # [batch, seq len, seq len, head dim] + relation_v = self.relation_v_emb( + relation) # [batch, seq len, seq len, head dim] + + query_layer = self.transpose_for_scores( + mixed_query_layer) # [batch, num attn heads, seq len, head dim] + key_layer = self.transpose_for_scores( + mixed_key_layer) # [batch, num attn heads, seq len, head dim] + value_layer = self.transpose_for_scores( + mixed_value_layer) # [batch, num attn heads, seq len, head dim] + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose( + -1, -2)) # [batch, num attn heads, seq len, seq len] + + # relation_k_t is [batch, seq len, head dim, seq len] + relation_k_t = relation_k.transpose(-2, -1) + # query_layer_t is [batch, seq len, num attn heads, head dim] + query_layer_t = query_layer.permute(0, 2, 1, 3) + # relation_attention_scores is [batch, seq len, num attn heads, seq len] + relation_attention_scores = torch.matmul(query_layer_t, relation_k_t) + # relation_attention_scores_t is [batch, num attn heads, seq len, seq len] + relation_attention_scores_t = relation_attention_scores.permute( + 0, 2, 1, 3) + + merged_attention_scores = (attention_scores + + relation_attention_scores_t) / math.sqrt( + self.attention_head_size) + + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + merged_attention_scores = merged_attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(merged_attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + # attention_probs is [batch, num attn heads, seq len, seq len] + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + # attention_probs_t is [batch, seq len, num attn heads, seq len] + attention_probs_t = attention_probs.permute(0, 2, 1, 3) + + # [batch, seq len, num attn heads, seq len] + # * [batch, seq len, seq len, head dim] + # = [batch, seq len, num attn heads, head dim] + context_relation = torch.matmul(attention_probs_t, relation_v) + + # context_relation_t is [batch, num attn heads, seq len, head dim] + context_relation_t = context_relation.permute(0, 2, 1, 3) + + merged_context_layer = context_layer + context_relation_t + merged_context_layer = merged_context_layer.permute(0, 2, 1, + 3).contiguous() + new_context_layer_shape = merged_context_layer.size()[:-2] + ( + self.all_head_size, ) + merged_context_layer = merged_context_layer.view( + *new_context_layer_shape) + return merged_context_layer + + +class BertSelfAttentionWithRelationsTableformer(nn.Module): + + def __init__(self, config): + super(BertSelfAttentionWithRelationsTableformer, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.schema_link_embeddings = nn.Embedding(7, self.num_attention_heads) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, relation): + ''' + relation is [batch, seq len, seq len] + ''' + mixed_query_layer = self.query( + hidden_states) # [batch, seq len, hidden dim] + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + schema_link_embeddings = self.schema_link_embeddings( + relation) # [batch, seq len, seq len, 1] + schema_link_embeddings = schema_link_embeddings.permute(0, 3, 1, 2) + + query_layer = self.transpose_for_scores( + mixed_query_layer) # [batch, num attn heads, seq len, head dim] + key_layer = self.transpose_for_scores( + mixed_key_layer) # [batch, num attn heads, seq len, head dim] + value_layer = self.transpose_for_scores( + mixed_value_layer) # [batch, num attn heads, seq len, head dim] + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose( + -1, -2)) # [batch, num attn heads, seq len, seq len] + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + + merged_attention_scores = attention_scores + schema_link_embeddings + + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + merged_attention_scores = merged_attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(merged_attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + # attention_probs is [batch, num attn heads, seq len, seq len] + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, schema_link_module='none'): + super(BertAttention, self).__init__() + if schema_link_module == 'none': + self.self = BertSelfAttention(config) + if schema_link_module == 'rat': + self.self = BertSelfAttentionWithRelationsRAT(config) + if schema_link_module == 'add': + self.self = BertSelfAttentionWithRelationsTableformer(config) + self.output = BertSelfOutput(config) + + def forward(self, input_tensor, attention_mask, schema_link_matrix=None): + self_output = self.self(input_tensor, attention_mask, + schema_link_matrix) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, schema_link_module='none'): + super(BertLayer, self).__init__() + self.attention = BertAttention( + config, schema_link_module=schema_link_module) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask, schema_link_matrix=None): + attention_output = self.attention(hidden_states, attention_mask, + schema_link_matrix) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class SqlBertEncoder(nn.Module): + + def __init__(self, layers, config): + super(SqlBertEncoder, self).__init__() + layer = BertLayer(config) + self.layer = nn.ModuleList( + [copy.deepcopy(layer) for _ in range(layers)]) + + def forward(self, + hidden_states, + attention_mask, + output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertEncoder(nn.Module): + + def __init__(self, config, schema_link_module='none'): + super(BertEncoder, self).__init__() + layer = BertLayer(config, schema_link_module=schema_link_module) + self.layer = nn.ModuleList( + [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + + def forward(self, + hidden_states, + attention_mask, + all_schema_link_matrix=None, + all_schema_link_mask=None, + output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask, + all_schema_link_matrix) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter( + torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class PreTrainedBertModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + def __init__(self, config, *inputs, **kwargs): + super(PreTrainedBertModel, self).__init__() + if not isinstance(config, SpaceTCnConfig): + raise ValueError( + 'Parameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. ' + 'To create a model from a Google pretrained model use ' + '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( + self.__class__.__name__, self.__class__.__name__)) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, + pretrained_model_name, + state_dict=None, + cache_dir=None, + *inputs, + **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) + to use instead of Google pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + resolved_archive_file = pretrained_model_name + # redirect to the cache, if necessary + tempdir = None + if os.path.isdir(resolved_archive_file): + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info('extracting archive file {} to temp dir {}'.format( + resolved_archive_file, tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = SpaceTCnConfig.from_json_file(config_file) + logger.info('Model config {}'.format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, + True, missing_keys, unexpected_keys, + error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + if len(missing_keys) > 0: + logger.info( + 'Weights of {} not initialized from pretrained model: {}'. + format(model.__class__.__name__, missing_keys)) + print() + print('*' * 10, 'WARNING missing weights', '*' * 10) + print('Weights of {} not initialized from pretrained model: {}'. + format(model.__class__.__name__, missing_keys)) + print() + if len(unexpected_keys) > 0: + logger.info( + 'Weights from pretrained model not used in {}: {}'.format( + model.__class__.__name__, unexpected_keys)) + print() + print('*' * 10, 'WARNING unexpected weights', '*' * 10) + print('Weights from pretrained model not used in {}: {}'.format( + model.__class__.__name__, unexpected_keys)) + print() + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + return model + + +class SpaceTCnModel(PreTrainedBertModel): + """SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN"). + + Params: + config: a SpaceTCnConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output + as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLF`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.SpaceTCnModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config, schema_link_module='none'): + super(SpaceTCnModel, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder( + config, schema_link_module=schema_link_module) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + header_ids, + token_order_ids=None, + token_type_ids=None, + attention_mask=None, + match_type_ids=None, + l_hs=None, + header_len=None, + type_ids=None, + col_dict_list=None, + ids=None, + header_flatten_tokens=None, + header_flatten_index=None, + header_flatten_output=None, + token_column_id=None, + token_column_mask=None, + column_start_index=None, + headers_length=None, + all_schema_link_matrix=None, + all_schema_link_mask=None, + output_all_encoded_layers=True): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + # Bowen: comment out the following line for Pytorch >= 1.5 + # https://github.com/huggingface/transformers/issues/3936#issuecomment-793764416 + # extended_attention_mask = extended_attention_mask.to(self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings( + input_ids, header_ids, token_type_ids, match_type_ids, l_hs, + header_len, type_ids, col_dict_list, ids, header_flatten_tokens, + header_flatten_index, header_flatten_output, token_column_id, + token_column_mask, column_start_index, headers_length) + encoded_layers = self.encoder( + embedding_output, + extended_attention_mask, + all_schema_link_matrix=all_schema_link_matrix, + all_schema_link_mask=all_schema_link_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class Seq2SQL(nn.Module): + + def __init__(self, iS, hS, lS, dr, n_cond_ops, n_agg_ops, n_action_ops, + max_select_num, max_where_num, device): + super(Seq2SQL, self).__init__() + self.iS = iS + self.hS = hS + self.ls = lS + self.dr = dr + self.device = device + + self.n_agg_ops = n_agg_ops + self.n_cond_ops = n_cond_ops + self.n_action_ops = n_action_ops + self.max_select_num = max_select_num + self.max_where_num = max_where_num + + self.w_sss_model = nn.Linear(iS, max_where_num) + self.w_sse_model = nn.Linear(iS, max_where_num) + self.s_ht_model = nn.Linear(iS, max_select_num) + self.wc_ht_model = nn.Linear(iS, max_where_num) + + self.select_agg_model = nn.Linear(iS * max_select_num, + n_agg_ops * max_select_num) + self.w_op_model = nn.Linear(iS * max_where_num, + n_cond_ops * max_where_num) + + self.conn_model = nn.Linear(iS, 3) + self.action_model = nn.Linear(iS, n_action_ops + 1) + self.slen_model = nn.Linear(iS, max_select_num + 1) + self.wlen_model = nn.Linear(iS, max_where_num + 1) + + def forward(self, wemb_layer, l_n, l_hs, start_index, column_index, tokens, + ids): + # chunk input lists for multi-gpu + max_l_n = max(l_n) + max_l_hs = max(l_hs) + l_n = np.array(l_n)[ids.cpu().numpy()].tolist() + l_hs = np.array(l_hs)[ids.cpu().numpy()].tolist() + start_index = np.array(start_index)[ids.cpu().numpy()].tolist() + column_index = np.array(column_index)[ids.cpu().numpy()].tolist() + # tokens = np.array(tokens)[ids.cpu().numpy()].tolist() + + conn_index = [] + slen_index = [] + wlen_index = [] + action_index = [] + where_op_index = [] + select_agg_index = [] + header_pos_index = [] + query_index = [] + for ib, elem in enumerate(start_index): + # [SEP] conn [SEP] wlen [SEP] (wop [SEP])*wn slen [SEP] (agg [SEP])*sn + action_index.append(elem + 1) + conn_index.append(elem + 2) + wlen_index.append(elem + 3) + woi = [elem + 4 + i for i in range(self.max_where_num)] + + slen_index.append(elem + 4 + self.max_where_num) + sai = [ + elem + 5 + self.max_where_num + i + for i in range(self.max_select_num) + ] + where_op_index.append(woi) + select_agg_index.append(sai) + + qilist = [i for i in range(l_n[ib] + 2)] + [l_n[ib] + 1] * ( + max_l_n - l_n[ib]) + query_index.append(qilist) + + index = [column_index[ib] + i for i in range(0, l_hs[ib], 1)] + index += [index[0] for _ in range(max_l_hs - len(index))] + header_pos_index.append(index) + + # print("tokens: ", tokens) + # print("conn_index: ", conn_index, "start_index: ", start_index) + conn_index = torch.tensor(conn_index, dtype=torch.long).to(self.device) + slen_index = torch.tensor(slen_index, dtype=torch.long).to(self.device) + wlen_index = torch.tensor(wlen_index, dtype=torch.long).to(self.device) + action_index = torch.tensor( + action_index, dtype=torch.long).to(self.device) + where_op_index = torch.tensor( + where_op_index, dtype=torch.long).to(self.device) + select_agg_index = torch.tensor( + select_agg_index, dtype=torch.long).to(self.device) + query_index = torch.tensor( + query_index, dtype=torch.long).to(self.device) + header_index = torch.tensor( + header_pos_index, dtype=torch.long).to(self.device) + + bS = len(l_n) + conn_emb = torch.zeros([bS, self.iS]).to(self.device) + slen_emb = torch.zeros([bS, self.iS]).to(self.device) + wlen_emb = torch.zeros([bS, self.iS]).to(self.device) + action_emb = torch.zeros([bS, self.iS]).to(self.device) + wo_emb = torch.zeros([bS, self.max_where_num, self.iS]).to(self.device) + sa_emb = torch.zeros([bS, self.max_select_num, + self.iS]).to(self.device) + qv_emb = torch.zeros([bS, max_l_n + 2, self.iS]).to(self.device) + ht_emb = torch.zeros([bS, max_l_hs, self.iS]).to(self.device) + for i in range(bS): + conn_emb[i, :] = wemb_layer[i].index_select(0, conn_index[i]) + slen_emb[i, :] = wemb_layer[i].index_select(0, slen_index[i]) + wlen_emb[i, :] = wemb_layer[i].index_select(0, wlen_index[i]) + action_emb[i, :] = wemb_layer[i].index_select(0, action_index[i]) + + wo_emb[i, :, :] = wemb_layer[i].index_select( + 0, where_op_index[i, :]) + sa_emb[i, :, :] = wemb_layer[i].index_select( + 0, select_agg_index[i, :]) + qv_emb[i, :, :] = wemb_layer[i].index_select(0, query_index[i, :]) + ht_emb[i, :, :] = wemb_layer[i].index_select(0, header_index[i, :]) + + s_cco = self.conn_model(conn_emb.reshape(-1, self.iS)).reshape(bS, 3) + s_slen = self.slen_model(slen_emb.reshape(-1, self.iS)).reshape( + bS, self.max_select_num + 1) + s_wlen = self.wlen_model(wlen_emb.reshape(-1, self.iS)).reshape( + bS, self.max_where_num + 1) + s_action = self.action_model(action_emb.reshape(-1, self.iS)).reshape( + bS, self.n_action_ops + 1) + wo_output = self.w_op_model( + wo_emb.reshape(-1, self.iS * self.max_where_num)).reshape( + bS, -1, self.n_cond_ops) + + wc_output = self.wc_ht_model(ht_emb.reshape(-1, self.iS)).reshape( + bS, -1, self.max_where_num).transpose(1, 2) + + wv_ss = self.w_sss_model(qv_emb.reshape(-1, self.iS)).reshape( + bS, -1, self.max_where_num).transpose(1, 2) + wv_se = self.w_sse_model(qv_emb.reshape(-1, self.iS)).reshape( + bS, -1, self.max_where_num).transpose(1, 2) + + sc_output = self.s_ht_model(ht_emb.reshape(-1, self.iS)).reshape( + bS, -1, self.max_select_num).transpose(1, 2) + sa_output = self.select_agg_model( + sa_emb.reshape(-1, self.iS * self.max_select_num)).reshape( + bS, -1, self.n_agg_ops) + + return s_action, sc_output, sa_output, s_cco, wc_output, wo_output, ( + wv_ss, wv_se), (s_slen, s_wlen) diff --git a/modelscope/models/nlp/space_T_cn/configuration.py b/modelscope/models/nlp/space_T_cn/configuration.py new file mode 100644 index 00000000..e698b310 --- /dev/null +++ b/modelscope/models/nlp/space_T_cn/configuration.py @@ -0,0 +1,115 @@ +# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT configuration.""" + +from __future__ import absolute_import, division, print_function +import copy +import logging + +import json + +logger = logging.getLogger(__name__) + + +class SpaceTCnConfig(object): + """Configuration class to store the configuration of a `SpaceTCnModel`. + """ + + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02): + """Constructs SpaceTCnConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `SpaceTCnConfig`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into `SpaceTCnConfig`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, str): + with open( + vocab_size_or_config_json_file, 'r', + encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + else: + raise ValueError( + 'First argument must be either a vocabulary size (int)' + 'or the path to a pretrained model config file (str)') + + @classmethod + def from_dict(cls, json_object): + """Constructs a `SpaceTCnConfig` from a Python dictionary of parameters.""" + config = SpaceTCnConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `SpaceTCnConfig` from a json file of parameters.""" + with open(json_file, 'r', encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' diff --git a/modelscope/models/nlp/space_T_cn/table_question_answering.py b/modelscope/models/nlp/space_T_cn/table_question_answering.py new file mode 100644 index 00000000..a3f504b7 --- /dev/null +++ b/modelscope/models/nlp/space_T_cn/table_question_answering.py @@ -0,0 +1,776 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Dict + +import numpy +import torch +import torch.nn.functional as F +from transformers import BertTokenizer + +from modelscope.metainfo import Models +from modelscope.models.base import Model, Tensor +from modelscope.models.builder import MODELS +from modelscope.preprocessors.nlp.space_T_cn.fields.struct import Constant +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import verify_device +from .backbone import Seq2SQL, SpaceTCnModel +from .configuration import SpaceTCnConfig + +__all__ = ['TableQuestionAnswering'] + + +@MODELS.register_module( + Tasks.table_question_answering, module_name=Models.space_T_cn) +class TableQuestionAnswering(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the table-question-answering model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + self.tokenizer = BertTokenizer( + os.path.join(model_dir, ModelFile.VOCAB_FILE)) + device_name = kwargs.get('device', 'gpu') + verify_device(device_name) + self._device_name = device_name + + state_dict = torch.load( + os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location='cpu') + + self.backbone_config = SpaceTCnConfig.from_json_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + self.backbone_model = SpaceTCnModel( + config=self.backbone_config, schema_link_module='rat') + self.backbone_model.load_state_dict(state_dict['backbone_model']) + + constant = Constant() + self.agg_ops = constant.agg_ops + self.cond_ops = constant.cond_ops + self.cond_conn_ops = constant.cond_conn_ops + self.action_ops = constant.action_ops + self.max_select_num = constant.max_select_num + self.max_where_num = constant.max_where_num + self.col_type_dict = constant.col_type_dict + self.schema_link_dict = constant.schema_link_dict + n_cond_ops = len(self.cond_ops) + n_agg_ops = len(self.agg_ops) + n_action_ops = len(self.action_ops) + iS = self.backbone_config.hidden_size + self.head_model = Seq2SQL(iS, 100, 2, 0.0, n_cond_ops, n_agg_ops, + n_action_ops, self.max_select_num, + self.max_where_num, self._device_name) + self.head_model.load_state_dict(state_dict['head_model'], strict=False) + + self.backbone_model.to(self._device_name) + self.head_model.to(self._device_name) + + def convert_string(self, pr_wvi, nlu, nlu_tt): + convs = [] + for b, nlu1 in enumerate(nlu): + conv_dict = {} + nlu_tt1 = nlu_tt[b] + idx = 0 + convflag = True + for i, ntok in enumerate(nlu_tt1): + if idx >= len(nlu1): + convflag = False + break + + if ntok.startswith('##'): + ntok = ntok.replace('##', '') + tok = nlu1[idx:idx + 1].lower() + if ntok == tok: + conv_dict[i] = [idx, idx + 1] + idx += 1 + elif ntok == '#': + conv_dict[i] = [idx, idx] + elif ntok == '[UNK]': + conv_dict[i] = [idx, idx + 1] + j = i + 1 + idx += 1 + if idx < len(nlu1) and j < len( + nlu_tt1) and nlu_tt1[j] != '[UNK]': + while idx < len(nlu1): + val = nlu1[idx:idx + 1].lower() + if nlu_tt1[j].startswith(val): + break + idx += 1 + conv_dict[i][1] = idx + elif tok in ntok: + startid = idx + idx += 1 + while idx < len(nlu1): + tok += nlu1[idx:idx + 1].lower() + if ntok == tok: + conv_dict[i] = [startid, idx + 1] + break + idx += 1 + idx += 1 + else: + convflag = False + + conv = [] + if convflag: + for pr_wvi1 in pr_wvi[b]: + s1, e1 = conv_dict[pr_wvi1[0]] + s2, e2 = conv_dict[pr_wvi1[1]] + newidx = pr_wvi1[1] + while newidx + 1 < len( + nlu_tt1) and s2 == e2 and nlu_tt1[newidx] == '#': + newidx += 1 + s2, e2 = conv_dict[newidx] + if newidx + 1 < len(nlu_tt1) and nlu_tt1[ + newidx + 1].startswith('##'): + s2, e2 = conv_dict[newidx + 1] + phrase = nlu1[s1:e2] + conv.append(phrase) + else: + for pr_wvi1 in pr_wvi[b]: + phrase = ''.join(nlu_tt1[pr_wvi1[0]:pr_wvi1[1] + + 1]).replace('##', '') + conv.append(phrase) + convs.append(conv) + + return convs + + def get_fields_info(self, t1s, tables, train=True): + nlu, nlu_t, sql_i, q_know, t_know, action, hs_t, types, units, his_sql, schema_link = \ + [], [], [], [], [], [], [], [], [], [], [] + for t1 in t1s: + nlu.append(t1['question']) + nlu_t.append(t1['question_tok']) + hs_t.append(t1['header_tok']) + q_know.append(t1['bertindex_knowledge']) + t_know.append(t1['header_knowledge']) + types.append(t1['types']) + units.append(t1['units']) + his_sql.append(t1.get('history_sql', None)) + schema_link.append(t1.get('schema_link', [])) + if train: + action.append(t1.get('action', [0])) + sql_i.append(t1['sql']) + + return nlu, nlu_t, sql_i, q_know, t_know, action, hs_t, types, units, his_sql, schema_link + + def get_history_select_where(self, his_sql, header_len): + if his_sql is None: + return [0], [0] + + sel = [] + for seli in his_sql['sel']: + if seli + 1 < header_len and seli + 1 not in sel: + sel.append(seli + 1) + + whe = [] + for condi in his_sql['conds']: + if condi[0] + 1 < header_len and condi[0] + 1 not in whe: + whe.append(condi[0] + 1) + + if len(sel) == 0: + sel.append(0) + if len(whe) == 0: + whe.append(0) + + sel.sort() + whe.sort() + + return sel, whe + + def get_types_ids(self, col_type): + for key, type_ids in self.col_type_dict.items(): + if key in col_type.lower(): + return type_ids + return self.col_type_dict['null'] + + def generate_inputs(self, nlu1_tok, hs_t_1, type_t, unit_t, his_sql, + q_know, t_know, s_link): + tokens = [] + orders = [] + types = [] + segment_ids = [] + matchs = [] + col_dict = {} + schema_tok = [] + + tokens.append('[CLS]') + orders.append(0) + types.append(0) + i_st_nlu = len(tokens) + + matchs.append(0) + segment_ids.append(0) + for idx, token in enumerate(nlu1_tok): + if q_know[idx] == 100: + break + elif q_know[idx] >= 5: + matchs.append(1) + else: + matchs.append(q_know[idx] + 1) + tokens.append(token) + orders.append(0) + types.append(0) + segment_ids.append(0) + + i_ed_nlu = len(tokens) + + tokens.append('[SEP]') + orders.append(0) + types.append(0) + matchs.append(0) + segment_ids.append(0) + + sel, whe = self.get_history_select_where(his_sql, len(hs_t_1)) + + if len(sel) == 1 and sel[0] == 0 \ + and len(whe) == 1 and whe[0] == 0: + pass + else: + tokens.append('select') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + + for seli in sel: + tokens.append('[PAD]') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + col_dict[len(tokens) - 1] = seli + + tokens.append('where') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + + for whei in whe: + tokens.append('[PAD]') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + col_dict[len(tokens) - 1] = whei + + tokens.append('[SEP]') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + + column_start = len(tokens) + i_hds_f = [] + header_flatten_tokens, header_flatten_index = [], [] + for i, hds11 in enumerate(hs_t_1): + if len(unit_t[i]) == 1 and unit_t[i][0] == 'null': + temp_header_tokens = hds11 + else: + temp_header_tokens = hds11 + unit_t[i] + schema_tok.append(temp_header_tokens) + header_flatten_tokens.extend(temp_header_tokens) + header_flatten_index.extend([i + 1] * len(temp_header_tokens)) + i_st_hd_f = len(tokens) + tokens += ['[PAD]'] + orders.append(0) + types.append(self.get_types_ids(type_t[i])) + i_ed_hd_f = len(tokens) + col_dict[len(tokens) - 1] = i + i_hds_f.append((i_st_hd_f, i_ed_hd_f)) + if i == 0: + matchs.append(6) + else: + matchs.append(t_know[i - 1] + 6) + segment_ids.append(1) + + tokens.append('[SEP]') + orders.append(0) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + # position where + # [SEP] + start_ids = len(tokens) - 1 + + tokens.append('action') # action + orders.append(1) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + tokens.append('connect') # column + orders.append(1) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + tokens.append('allen') # select len + orders.append(1) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + for x in range(self.max_where_num): + tokens.append('act') # op + orders.append(2 + x) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + tokens.append('size') # where len + orders.append(1) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + for x in range(self.max_select_num): + tokens.append('focus') # agg + orders.append(2 + x) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + i_nlu = (i_st_nlu, i_ed_nlu) + + schema_link_matrix = numpy.zeros((len(tokens), len(tokens)), + dtype='int32') + schema_link_mask = numpy.zeros((len(tokens), len(tokens)), + dtype='float32') + for relation in s_link: + if relation['label'] in ['col', 'val']: + [q_st, q_ed] = relation['question_index'] + cid = max(0, relation['column_index']) + schema_link_matrix[ + i_st_nlu + q_st: i_st_nlu + q_ed + 1, + column_start + cid + 1: column_start + cid + 1 + 1] = \ + self.schema_link_dict[relation['label'] + '_middle'] + schema_link_matrix[ + i_st_nlu + q_st, + column_start + cid + 1: column_start + cid + 1 + 1] = \ + self.schema_link_dict[relation['label'] + '_start'] + schema_link_matrix[ + i_st_nlu + q_ed, + column_start + cid + 1: column_start + cid + 1 + 1] = \ + self.schema_link_dict[relation['label'] + '_end'] + schema_link_mask[i_st_nlu + q_st:i_st_nlu + q_ed + 1, + column_start + cid + 1:column_start + cid + 1 + + 1] = 1.0 + + return tokens, orders, types, segment_ids, matchs, \ + i_nlu, i_hds_f, start_ids, column_start, col_dict, schema_tok, \ + header_flatten_tokens, header_flatten_index, schema_link_matrix, schema_link_mask + + def gen_l_hpu(self, i_hds): + """ + Treat columns as if it is a batch of natural language utterance + with batch-size = # of columns * # of batch_size + i_hds = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)]) + """ + l_hpu = [] + for i_hds1 in i_hds: + for i_hds11 in i_hds1: + l_hpu.append(i_hds11[1] - i_hds11[0]) + + return l_hpu + + def get_bert_output(self, model_bert, tokenizer, nlu_t, hs_t, col_types, + units, his_sql, q_know, t_know, schema_link): + """ + Here, input is toknized further by WordPiece (WP) tokenizer and fed into BERT. + + INPUT + :param model_bert: + :param tokenizer: WordPiece toknizer + :param nlu: Question + :param nlu_t: CoreNLP tokenized nlu. + :param hds: Headers + :param hs_t: None or 1st-level tokenized headers + :param max_seq_length: max input token length + + OUTPUT + tokens: BERT input tokens + nlu_tt: WP-tokenized input natural language questions + orig_to_tok_index: map the index of 1st-level-token to the index of 2nd-level-token + tok_to_orig_index: inverse map. + + """ + + l_n = [] + l_hs = [] # The length of columns for each batch + + input_ids = [] + order_ids = [] + type_ids = [] + segment_ids = [] + match_ids = [] + input_mask = [] + + i_nlu = [ + ] # index to retreive the position of contextual vector later. + i_hds = [] + tokens = [] + orders = [] + types = [] + matchs = [] + segments = [] + schema_link_matrix_list = [] + schema_link_mask_list = [] + start_index = [] + column_index = [] + col_dict_list = [] + header_list = [] + header_flatten_token_list = [] + header_flatten_tokenid_list = [] + header_flatten_index_list = [] + + header_tok_max_len = 0 + cur_max_length = 0 + + for b, nlu_t1 in enumerate(nlu_t): + hs_t1 = [hs_t[b][-1]] + hs_t[b][:-1] + type_t1 = [col_types[b][-1]] + col_types[b][:-1] + unit_t1 = [units[b][-1]] + units[b][:-1] + l_hs.append(len(hs_t1)) + + # [CLS] nlu [SEP] col1 [SEP] col2 [SEP] ...col-n [SEP] + # 2. Generate BERT inputs & indices. + tokens1, orders1, types1, segment1, match1, i_nlu1, i_hds_1, \ + start_idx, column_start, col_dict, schema_tok, \ + header_flatten_tokens, header_flatten_index, schema_link_matrix, schema_link_mask = \ + self.generate_inputs( + nlu_t1, hs_t1, type_t1, unit_t1, his_sql[b], + q_know[b], t_know[b], schema_link[b]) + + l_n.append(i_nlu1[1] - i_nlu1[0]) + start_index.append(start_idx) + column_index.append(column_start) + col_dict_list.append(col_dict) + tokens.append(tokens1) + orders.append(orders1) + types.append(types1) + segments.append(segment1) + matchs.append(match1) + i_nlu.append(i_nlu1) + i_hds.append(i_hds_1) + schema_link_matrix_list.append(schema_link_matrix) + schema_link_mask_list.append(schema_link_mask) + header_flatten_token_list.append(header_flatten_tokens) + header_flatten_index_list.append(header_flatten_index) + header_list.append(schema_tok) + header_max = max([len(schema_tok1) for schema_tok1 in schema_tok]) + if header_max > header_tok_max_len: + header_tok_max_len = header_max + + if len(tokens1) > cur_max_length: + cur_max_length = len(tokens1) + + if len(tokens1) > 512: + print('input too long!!! total_num:%d\t question:%s' % + (len(tokens1), ''.join(nlu_t1))) + + assert cur_max_length <= 512 + + for i, tokens1 in enumerate(tokens): + segment_ids1 = segments[i] + order_ids1 = orders[i] + type_ids1 = types[i] + match_ids1 = matchs[i] + input_ids1 = tokenizer.convert_tokens_to_ids(tokens1) + input_mask1 = [1] * len(input_ids1) + + while len(input_ids1) < cur_max_length: + input_ids1.append(0) + input_mask1.append(0) + segment_ids1.append(0) + order_ids1.append(0) + type_ids1.append(0) + match_ids1.append(0) + + if len(input_ids1) != cur_max_length: + print('Error: ', nlu_t1, tokens1, len(input_ids1), + cur_max_length) + + assert len(input_ids1) == cur_max_length + assert len(input_mask1) == cur_max_length + assert len(order_ids1) == cur_max_length + assert len(segment_ids1) == cur_max_length + assert len(match_ids1) == cur_max_length + assert len(type_ids1) == cur_max_length + + input_ids.append(input_ids1) + order_ids.append(order_ids1) + type_ids.append(type_ids1) + segment_ids.append(segment_ids1) + input_mask.append(input_mask1) + match_ids.append(match_ids1) + + header_len = [] + header_ids = [] + header_max_len = max( + [len(header_list1) for header_list1 in header_list]) + for header1 in header_list: + header_len1 = [] + header_ids1 = [] + for header_tok in header1: + header_len1.append(len(header_tok)) + header_tok_ids1 = tokenizer.convert_tokens_to_ids(header_tok) + while len(header_tok_ids1) < header_tok_max_len: + header_tok_ids1.append(0) + header_ids1.append(header_tok_ids1) + while len(header_ids1) < header_max_len: + header_ids1.append([0] * header_tok_max_len) + header_len.append(header_len1) + header_ids.append(header_ids1) + + for i, header_flatten_token in enumerate(header_flatten_token_list): + header_flatten_tokenid = tokenizer.convert_tokens_to_ids( + header_flatten_token) + header_flatten_tokenid_list.append(header_flatten_tokenid) + + # Convert to tensor + all_input_ids = torch.tensor( + input_ids, dtype=torch.long).to(self._device_name) + all_order_ids = torch.tensor( + order_ids, dtype=torch.long).to(self._device_name) + all_type_ids = torch.tensor( + type_ids, dtype=torch.long).to(self._device_name) + all_input_mask = torch.tensor( + input_mask, dtype=torch.long).to(self._device_name) + all_segment_ids = torch.tensor( + segment_ids, dtype=torch.long).to(self._device_name) + all_match_ids = torch.tensor( + match_ids, dtype=torch.long).to(self._device_name) + all_header_ids = torch.tensor( + header_ids, dtype=torch.long).to(self._device_name) + all_ids = torch.arange( + all_input_ids.shape[0], dtype=torch.long).to(self._device_name) + + bS = len(header_flatten_tokenid_list) + max_header_flatten_token_length = max( + [len(x) for x in header_flatten_tokenid_list]) + all_header_flatten_tokens = numpy.zeros( + (bS, max_header_flatten_token_length), dtype='int32') + all_header_flatten_index = numpy.zeros( + (bS, max_header_flatten_token_length), dtype='int32') + for i, header_flatten_tokenid in enumerate( + header_flatten_tokenid_list): + for j, tokenid in enumerate(header_flatten_tokenid): + all_header_flatten_tokens[i, j] = tokenid + for j, hdindex in enumerate(header_flatten_index_list[i]): + all_header_flatten_index[i, j] = hdindex + all_header_flatten_output = numpy.zeros((bS, header_max_len + 1), + dtype='int32') + all_header_flatten_tokens = torch.tensor( + all_header_flatten_tokens, dtype=torch.long).to(self._device_name) + all_header_flatten_index = torch.tensor( + all_header_flatten_index, dtype=torch.long).to(self._device_name) + all_header_flatten_output = torch.tensor( + all_header_flatten_output, + dtype=torch.float32).to(self._device_name) + + all_token_column_id = numpy.zeros((bS, cur_max_length), dtype='int32') + all_token_column_mask = numpy.zeros((bS, cur_max_length), + dtype='float32') + for bi, col_dict in enumerate(col_dict_list): + for ki, vi in col_dict.items(): + all_token_column_id[bi, ki] = vi + 1 + all_token_column_mask[bi, ki] = 1.0 + all_token_column_id = torch.tensor( + all_token_column_id, dtype=torch.long).to(self._device_name) + all_token_column_mask = torch.tensor( + all_token_column_mask, dtype=torch.float32).to(self._device_name) + + all_schema_link_matrix = numpy.zeros( + (bS, cur_max_length, cur_max_length), dtype='int32') + all_schema_link_mask = numpy.zeros( + (bS, cur_max_length, cur_max_length), dtype='float32') + for i, schema_link_matrix in enumerate(schema_link_matrix_list): + temp_len = schema_link_matrix.shape[0] + all_schema_link_matrix[i, 0:temp_len, + 0:temp_len] = schema_link_matrix + all_schema_link_mask[i, 0:temp_len, + 0:temp_len] = schema_link_mask_list[i] + all_schema_link_matrix = torch.tensor( + all_schema_link_matrix, dtype=torch.long).to(self._device_name) + all_schema_link_mask = torch.tensor( + all_schema_link_mask, dtype=torch.long).to(self._device_name) + + # 5. generate l_hpu from i_hds + l_hpu = self.gen_l_hpu(i_hds) + + # 4. Generate BERT output. + all_encoder_layer, pooled_output = model_bert( + all_input_ids, + all_header_ids, + token_order_ids=all_order_ids, + token_type_ids=all_segment_ids, + attention_mask=all_input_mask, + match_type_ids=all_match_ids, + l_hs=l_hs, + header_len=header_len, + type_ids=all_type_ids, + col_dict_list=col_dict_list, + ids=all_ids, + header_flatten_tokens=all_header_flatten_tokens, + header_flatten_index=all_header_flatten_index, + header_flatten_output=all_header_flatten_output, + token_column_id=all_token_column_id, + token_column_mask=all_token_column_mask, + column_start_index=column_index, + headers_length=l_hs, + all_schema_link_matrix=all_schema_link_matrix, + all_schema_link_mask=all_schema_link_mask, + output_all_encoded_layers=False) + + return all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \ + l_n, l_hpu, l_hs, start_index, column_index, all_ids + + def predict(self, querys): + self.head_model.eval() + self.backbone_model.eval() + + nlu, nlu_t, sql_i, q_know, t_know, tb, hs_t, types, units, his_sql, schema_link = \ + self.get_fields_info(querys, None, train=False) + + with torch.no_grad(): + all_encoder_layer, _, tokens, i_nlu, i_hds, l_n, l_hpu, l_hs, start_index, column_index, ids = \ + self.get_bert_output( + self.backbone_model, self.tokenizer, + nlu_t, hs_t, types, units, his_sql, q_know, t_know, schema_link) + + s_action, s_sc, s_sa, s_cco, s_wc, s_wo, s_wvs, s_len = self.head_model( + all_encoder_layer, l_n, l_hs, start_index, column_index, + tokens, ids) + + action_batch = torch.argmax(F.softmax(s_action, -1), -1).cpu().tolist() + scco_batch = torch.argmax(F.softmax(s_cco, -1), -1).cpu().tolist() + sc_batch = torch.argmax(F.softmax(s_sc, -1), -1).cpu().tolist() + sa_batch = torch.argmax(F.softmax(s_sa, -1), -1).cpu().tolist() + wc_batch = torch.argmax(F.softmax(s_wc, -1), -1).cpu().tolist() + wo_batch = torch.argmax(F.softmax(s_wo, -1), -1).cpu().tolist() + s_wvs_s, s_wvs_e = s_wvs + wvss_batch = torch.argmax(F.softmax(s_wvs_s, -1), -1).cpu().tolist() + wvse_batch = torch.argmax(F.softmax(s_wvs_e, -1), -1).cpu().tolist() + s_slen, s_wlen = s_len + slen_batch = torch.argmax(F.softmax(s_slen, -1), -1).cpu().tolist() + wlen_batch = torch.argmax(F.softmax(s_wlen, -1), -1).cpu().tolist() + + pr_wvi = [] + for i in range(len(querys)): + wvi = [] + for j in range(wlen_batch[i]): + wvi.append([ + max(0, wvss_batch[i][j] - 1), + max(0, wvse_batch[i][j] - 1) + ]) + pr_wvi.append(wvi) + pr_wvi_str = self.convert_string(pr_wvi, nlu, nlu_t) + + pre_results = [] + for ib in range(len(querys)): + res_one = {} + sql = {} + sql['cond_conn_op'] = scco_batch[ib] + sl = slen_batch[ib] + sql['sel'] = list( + numpy.array(sc_batch[ib][:sl]).astype(numpy.int32) - 1) + sql['agg'] = list( + numpy.array(sa_batch[ib][:sl]).astype(numpy.int32)) + sels = [] + aggs = [] + for ia, sel in enumerate(sql['sel']): + if sel == -1: + if sql['agg'][ia] > 0: + sels.append(l_hs[ib] - 1) + aggs.append(sql['agg'][ia]) + continue + sels.append(int(sel)) + if sql['agg'][ia] == -1: + aggs.append(0) + else: + aggs.append(int(sql['agg'][ia])) + if len(sels) == 0: + sels.append(l_hs[ib] - 1) + aggs.append(0) + assert len(sels) == len(aggs) + sql['sel'] = sels + sql['agg'] = aggs + + conds = [] + wl = wlen_batch[ib] + wc_os = list( + numpy.array(wc_batch[ib][:wl]).astype(numpy.int32) - 1) + wo_os = list(numpy.array(wo_batch[ib][:wl]).astype(numpy.int32)) + res_one['question_tok'] = querys[ib]['question_tok'] + for i in range(wl): + if wc_os[i] == -1: + continue + conds.append([int(wc_os[i]), int(wo_os[i]), pr_wvi_str[ib][i]]) + if len(conds) == 0: + conds.append([l_hs[ib] - 1, 2, 'Nulll']) + sql['conds'] = conds + res_one['question'] = querys[ib]['question'] + res_one['table_id'] = querys[ib]['table_id'] + res_one['sql'] = sql + res_one['action'] = action_batch[ib] + res_one['model_out'] = [ + sc_batch[ib], sa_batch[ib], wc_batch[ib], wo_batch[ib], + wvss_batch[ib], wvse_batch[ib] + ] + pre_results.append(res_one) + + return pre_results + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + + Returns: + Dict[str, Tensor]: results + Example: + { + 'result': + { + 'question_tok': ['有', '哪', '些', '风', '险', '类', '型', '?'], + 'question': '有哪些风险类型?', + 'table_id': 'fund', + 'sql': { + 'cond_conn_op': 0, + 'sel': [5], + 'agg': [0], + 'conds': [[10, 2, 'Nulll']] + }, + 'action': 10, + 'model_out': [ + [6, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0] + ] + }, + 'history_sql': None + } + + Example: + >>> from modelscope.models.nlp import TableQuestionAnswering + >>> from modelscope.preprocessors import TableQuestionAnsweringPreprocessor + >>> model = TableQuestionAnswering.from_pretrained('damo/nlp_convai_text2sql_pretrain_cn') + >>> preprocessor = TableQuestionAnsweringPreprocessor(model_dir=model.model_dir) + >>> print(model(preprocessor({'question': '有哪些风险类型?'}))) + """ + result = self.predict(input['datas'])[0] + + return { + 'result': result, + 'history_sql': input['datas'][0]['history_sql'] + } diff --git a/modelscope/models/nlp/space_T_en/__init__.py b/modelscope/models/nlp/space_T_en/__init__.py new file mode 100644 index 00000000..46c8b38c --- /dev/null +++ b/modelscope/models/nlp/space_T_en/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .text_to_sql import StarForTextToSql +else: + _import_structure = { + 'text_to_sql': ['StarForTextToSql'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/space_T_en/text_to_sql.py b/modelscope/models/nlp/space_T_en/text_to_sql.py new file mode 100644 index 00000000..ca2d2596 --- /dev/null +++ b/modelscope/models/nlp/space_T_en/text_to_sql.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Dict, Optional + +import torch +from text2sql_lgesql.asdl.asdl import ASDLGrammar +from text2sql_lgesql.asdl.transition_system import TransitionSystem +from text2sql_lgesql.model.model_constructor import Text2SQL + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['StarForTextToSql'] + + +@MODELS.register_module( + Tasks.table_question_answering, module_name=Models.space_T_en) +class StarForTextToSql(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the star model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + self.beam_size = 5 + self.config = kwargs.pop( + 'config', + Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION))) + self.config.model.model_dir = model_dir + self.grammar = ASDLGrammar.from_filepath( + os.path.join(model_dir, 'sql_asdl_v2.txt')) + self.trans = TransitionSystem.get_class_by_lang('sql')(self.grammar) + self.arg = self.config.model + self.device = 'cuda' if \ + ('device' not in kwargs or kwargs['device'] == 'gpu') \ + and torch.cuda.is_available() else 'cpu' + self.model = Text2SQL(self.arg, self.trans) + check_point = torch.load( + open( + os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'rb'), + map_location=self.device) + self.model.load_state_dict(check_point['model']) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + + Example: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.models.nlp import StarForTextToSql + >>> from modelscope.preprocessors import ConversationalTextToSqlPreprocessor + >>> test_case = { + 'database_id': 'employee_hire_evaluation', + 'local_db_path': None, + 'utterance': [ + "I'd like to see Shop names.", 'Which of these are hiring?', + 'Which shop is hiring the highest number of employees?' + ' | do you want the name of the shop ? | Yes' + ] + } + >>> cache_path = snapshot_download('damo/nlp_star_conversational-text-to-sql') + >>> preprocessor = ConversationalTextToSqlPreprocessor( + model_dir=cache_path, + database_id=test_case['database_id'], + db_content=True) + >>> model = StarForTextToSql(cache_path, config=preprocessor.config) + >>> print(model(preprocessor({ + 'utterance': "I'd like to see Shop names.", + 'history': [], + 'last_sql': '', + 'database_id': 'employee_hire_evaluation', + 'local_db_path': None + }))) + """ + self.model.eval() + hyps = self.model.parse(input['batch'], self.beam_size) # + db = input['batch'].examples[0].db + + predict = {'predict': hyps, 'db': db} + return predict diff --git a/modelscope/models/nlp/structbert/__init__.py b/modelscope/models/nlp/structbert/__init__.py new file mode 100644 index 00000000..60d369e0 --- /dev/null +++ b/modelscope/models/nlp/structbert/__init__.py @@ -0,0 +1,51 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .backbone import (SbertModel, SbertPreTrainedModel) + from .configuration import SbertConfig + from .faq_question_answering import SbertForFaqQuestionAnswering + from .fill_mask import SbertForMaskedLM + from .text_classification import SbertForSequenceClassification + from .token_classification import SbertForTokenClassification + from .tokenization import (BasicTokenizer, SbertTokenizer, + WordpieceTokenizer) + from .tokenization_fast import SbertTokenizerFast +else: + _import_structure = { + 'backbone': ['SbertModel', 'SbertPreTrainedModel'], + 'configuration': ['SbertConfig'], + 'fill_mask': ['SbertForMaskedLM'], + 'faq_question_answering': ['SbertForFaqQuestionAnswering'], + 'text_classification': ['SbertForSequenceClassification'], + 'token_classification': ['SbertForTokenClassification'], + 'tokenization': + ['BasicTokenizer', 'SbertTokenizer', 'WordpieceTokenizer'], + 'tokenization_fast': ['SbertTokenizerFast'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/structbert/adv_utils.py b/modelscope/models/nlp/structbert/adv_utils.py new file mode 100644 index 00000000..91a4cb82 --- /dev/null +++ b/modelscope/models/nlp/structbert/adv_utils.py @@ -0,0 +1,168 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + + +def _symmetric_kl_div(logits1, logits2, attention_mask=None): + """ + Calclate two logits' the KL div value symmetrically. + :param logits1: The first logit. + :param logits2: The second logit. + :param attention_mask: An optional attention_mask which is used to mask some element out. + This is usually useful in token_classification tasks. + If the shape of logits is [N1, N2, ... Nn, D], the shape of attention_mask should be [N1, N2, ... Nn] + :return: The mean loss. + """ + labels_num = logits1.shape[-1] + KLDiv = nn.KLDivLoss(reduction='none') + loss = torch.sum( + KLDiv(nn.LogSoftmax(dim=-1)(logits1), + nn.Softmax(dim=-1)(logits2)), + dim=-1) + torch.sum( + KLDiv(nn.LogSoftmax(dim=-1)(logits2), + nn.Softmax(dim=-1)(logits1)), + dim=-1) + if attention_mask is not None: + loss = torch.sum( + loss * attention_mask) / torch.sum(attention_mask) / labels_num + else: + loss = torch.mean(loss) / labels_num + return loss + + +def compute_adv_loss(embedding, + model, + ori_logits, + ori_loss, + adv_grad_factor, + adv_bound=None, + sigma=5e-6, + **kwargs): + """ + Calculate the adv loss of the model. + :param embedding: Original sentense embedding + :param model: The model, or the forward function(including decoder/classifier), + accept kwargs as input, output logits + :param ori_logits: The original logits outputed from the model function + :param ori_loss: The original loss + :param adv_grad_factor: This factor will be multipled by the KL loss grad and then the result will be added to + the original embedding. + More details please check:https://arxiv.org/abs/1908.04577 + The range of this value always be 1e-3~1e-7 + :param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding. + If not proveded, 2 * sigma will be used as the adv_bound factor + :param sigma: The std factor used to produce a 0 mean normal distribution. + If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor + :param kwargs: the input param used in model function + :return: The original loss adds the adv loss + """ + adv_bound = adv_bound if adv_bound is not None else 2 * sigma + embedding_1 = embedding + embedding.data.new(embedding.size()).normal_( + 0, sigma) # 95% in +- 1e-5 + kwargs.pop('input_ids') + if 'inputs_embeds' in kwargs: + kwargs.pop('inputs_embeds') + with_attention_mask = False if 'with_attention_mask' not in kwargs else kwargs[ + 'with_attention_mask'] + attention_mask = kwargs['attention_mask'] + if not with_attention_mask: + attention_mask = None + if 'with_attention_mask' in kwargs: + kwargs.pop('with_attention_mask') + outputs = model(**kwargs, inputs_embeds=embedding_1) + v1_logits = outputs.logits + loss = _symmetric_kl_div(ori_logits, v1_logits, attention_mask) + emb_grad = torch.autograd.grad(loss, embedding_1)[0].data + emb_grad_norm = emb_grad.norm( + dim=2, keepdim=True, p=float('inf')).max( + 1, keepdim=True)[0] + is_nan = torch.any(torch.isnan(emb_grad_norm)) + if is_nan: + logger.warning('Nan occured when calculating adv loss.') + return ori_loss + emb_grad = emb_grad / (emb_grad_norm + 1e-6) + embedding_2 = embedding_1 + adv_grad_factor * emb_grad + embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) + embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) + outputs = model(**kwargs, inputs_embeds=embedding_2) + adv_logits = outputs.logits + adv_loss = _symmetric_kl_div(ori_logits, adv_logits, attention_mask) + return ori_loss + adv_loss + + +def compute_adv_loss_pair(embedding, + model, + start_logits, + end_logits, + ori_loss, + adv_grad_factor, + adv_bound=None, + sigma=5e-6, + **kwargs): + """ + Calculate the adv loss of the model. This function is used in the pair logits scenerio. + :param embedding: Original sentense embedding + :param model: The model, or the forward function(including decoder/classifier), + accept kwargs as input, output logits + :param start_logits: The original start logits outputed from the model function + :param end_logits: The original end logits outputed from the model function + :param ori_loss: The original loss + :param adv_grad_factor: This factor will be multipled by the KL loss grad and then the result will be added to + the original embedding. + More details please check:https://arxiv.org/abs/1908.04577 + The range of this value always be 1e-3~1e-7 + :param adv_bound: adv_bound is used to cut the top and the bottom bound of the produced embedding. + If not proveded, 2 * sigma will be used as the adv_bound factor + :param sigma: The std factor used to produce a 0 mean normal distribution. + If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor + :param kwargs: the input param used in model function + :return: The original loss adds the adv loss + """ + adv_bound = adv_bound if adv_bound is not None else 2 * sigma + embedding_1 = embedding + embedding.data.new(embedding.size()).normal_( + 0, sigma) # 95% in +- 1e-5 + kwargs.pop('input_ids') + if 'inputs_embeds' in kwargs: + kwargs.pop('inputs_embeds') + outputs = model(**kwargs, inputs_embeds=embedding_1) + v1_logits_start, v1_logits_end = outputs.logits + loss = _symmetric_kl_div(start_logits, + v1_logits_start) + _symmetric_kl_div( + end_logits, v1_logits_end) + loss = loss / 2 + emb_grad = torch.autograd.grad(loss, embedding_1)[0].data + emb_grad_norm = emb_grad.norm( + dim=2, keepdim=True, p=float('inf')).max( + 1, keepdim=True)[0] + is_nan = torch.any(torch.isnan(emb_grad_norm)) + if is_nan: + logger.warning('Nan occured when calculating pair adv loss.') + return ori_loss + emb_grad = emb_grad / emb_grad_norm + embedding_2 = embedding_1 + adv_grad_factor * emb_grad + embedding_2 = torch.max(embedding_1 - adv_bound, embedding_2) + embedding_2 = torch.min(embedding_1 + adv_bound, embedding_2) + outputs = model(**kwargs, inputs_embeds=embedding_2) + adv_logits_start, adv_logits_end = outputs.logits + adv_loss = _symmetric_kl_div(start_logits, + adv_logits_start) + _symmetric_kl_div( + end_logits, adv_logits_end) + return ori_loss + adv_loss diff --git a/modelscope/models/nlp/structbert/backbone.py b/modelscope/models/nlp/structbert/backbone.py new file mode 100755 index 00000000..039db3ce --- /dev/null +++ b/modelscope/models/nlp/structbert/backbone.py @@ -0,0 +1,932 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch StructBERT model. mainly copied from :module:`~transformers.modeling_bert`""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from packaging import version +from transformers.activations import ACT2FN +from transformers.modeling_outputs import \ + BaseModelOutputWithPastAndCrossAttentions +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionBackboneModelOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from modelscope.utils.logger import get_logger +from .configuration import SbertConfig + +logger = get_logger(__name__) + + +class SbertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse('1.6.0'): + self.register_buffer( + 'token_type_ids', + torch.zeros( + self.position_ids.size(), + dtype=torch.long, + device=self.position_ids.device), + persistent=False, + ) + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + return_inputs_embeds=False): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, + past_key_values_length:seq_length + + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users + # when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, 'token_type_ids'): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, + dtype=torch.long, + device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + if not return_inputs_embeds: + return embeddings + else: + return embeddings, inputs_embeds + + +class SbertSelfAttention(nn.Module): + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' + f'heads ({config.num_attention_heads})') + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query': + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in SbertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + if self.is_decoder: + outputs = outputs + (past_key_value, ) + return outputs + + +class SbertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class SbertAttention(nn.Module): + + def __init__(self, config): + super().__init__() + self.self = SbertSelfAttention(config) + self.output = SbertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class SbertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SbertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class SbertLayer(nn.Module): + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = SbertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError( + f'{self} should be used as a decoder model if cross attention is added' + ) + self.crossattention = SbertAttention(config) + self.intermediate = SbertIntermediate(config) + self.output = SbertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[ + 1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, 'crossattention'): + raise ValueError( + f'If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention ' + f'layers by setting `config.add_cross_attention=True`') + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class SbertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [SbertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class SbertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class SbertPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SbertConfig + base_model_prefix = 'bert' + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, SbertEncoder): + module.gradient_checkpointing = value + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels is not input. + label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + config = SbertConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) + id2label = kwargs.get( + 'id2label', None if label2id is None else + {id: label + for label, id in label2id.items()}) + if id2label is not None and label2id is None: + label2id = {label: id for id, label in id2label.items()} + + num_labels = kwargs.get( + 'num_labels', None if label2id is None else len(label2id)) + if num_labels is not None: + model_kwargs['num_labels'] = num_labels + if label2id is not None: + model_kwargs['label2id'] = label2id + if id2label is not None: + model_kwargs['id2label'] = id2label + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + return model + + +@dataclass +class AttentionBackboneModelOutputWithEmbedding(AttentionBackboneModelOutput): + embedding_output: torch.FloatTensor = None + logits: Optional[Union[tuple, torch.FloatTensor]] = None + kwargs: dict = None + + +@MODELS.register_module(Tasks.backbone, module_name=Models.structbert) +class SbertModel(SbertPreTrainedModel): + """The StructBERT Model transformer outputting raw hidden-states without any specific head on top. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration + set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config: SbertConfig, add_pooling_layer=True, **kwargs): + super().__init__(config) + self.config = config + + self.embeddings = SbertEmbeddings(config) + self.encoder = SbertEncoder(config) + + self.pooler = SbertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, + `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple + having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + Returns: + Returns `modelscope.outputs.AttentionBackboneModelOutputWithEmbedding` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='backbone') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_backbone_base_std') + >>> print(model(**preprocessor('这是个测试'))) + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[ + 2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, 'token_type_ids'): + buffered_token_type_ids = self.embeddings.token_type_ids[:, : + seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + embedding_output, orignal_embeds = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + return_inputs_embeds=True, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, + pooled_output) + encoder_outputs[1:] + (orignal_embeds, ) + + return AttentionBackboneModelOutputWithEmbedding( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + embedding_output=orignal_embeds) diff --git a/modelscope/models/nlp/structbert/configuration.py b/modelscope/models/nlp/structbert/configuration.py new file mode 100644 index 00000000..8f095f9d --- /dev/null +++ b/modelscope/models/nlp/structbert/configuration.py @@ -0,0 +1,134 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" StructBERT model configuration, mainly copied from :class:`~transformers.BertConfig` """ +from transformers import PretrainedConfig + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class SbertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration + of a :class:`~modelscope.models.nlp.structbert.SbertModel`. + It is used to instantiate a StructBERT model according to the specified arguments. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or + :class:`~transformers.TFBertModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or + :class:`~transformers.TFBertModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): + Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, + :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on + :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.) + `__. For more information on :obj:`"relative_key_query"`, please refer to + `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) + `__. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``config.is_decoder=True``. + classifier_dropout (:obj:`float`, `optional`): + The dropout ratio for the classification head. + adv_grad_factor (:obj:`float`, `optional`): This factor will be multiplied by the KL loss grad and then + the result will be added to the original embedding. + More details please check:https://arxiv.org/abs/1908.04577 + The range of this value should between 1e-3~1e-7 + adv_bound (:obj:`float`, `optional`): adv_bound is used to cut the top and the bottom bound of + the produced embedding. + If not provided, 2 * sigma will be used as the adv_bound factor + sigma (:obj:`float`, `optional`): The std factor used to produce a 0 mean normal distribution. + If adv_bound not provided, 2 * sigma will be used as the adv_bound factor + """ + + model_type = 'structbert' + + def __init__(self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type='absolute', + use_cache=True, + classifier_dropout=None, + **kwargs): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + # adv_grad_factor, used in adv loss. + # Users can check adv_utils.py for details. + # if adv_grad_factor set to None, no adv loss will not applied to the model. + self.adv_grad_factor = 5e-5 if 'adv_grad_factor' not in kwargs else kwargs[ + 'adv_grad_factor'] + # sigma value, used in adv loss. + self.sigma = 5e-6 if 'sigma' not in kwargs else kwargs['sigma'] + # adv_bound value, used in adv loss. + self.adv_bound = 2 * self.sigma if 'adv_bound' not in kwargs else kwargs[ + 'adv_bound'] diff --git a/modelscope/models/nlp/structbert/faq_question_answering.py b/modelscope/models/nlp/structbert/faq_question_answering.py new file mode 100644 index 00000000..c8dbf302 --- /dev/null +++ b/modelscope/models/nlp/structbert/faq_question_answering.py @@ -0,0 +1,293 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +from collections import namedtuple +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.nlp.structbert import SbertConfig, SbertModel +from modelscope.models.nlp.task_models.task_model import BaseTaskModel +from modelscope.utils.config import Config, ConfigFields +from modelscope.utils.constant import ModelFile, Tasks + +activations = { + 'relu': F.relu, + 'tanh': torch.tanh, + 'linear': lambda x: x, +} + +activation_coeffs = { + 'relu': math.sqrt(2), + 'tanh': 5 / 3, + 'linear': 1., +} + + +class LinearProjection(nn.Module): + + def __init__(self, + in_features, + out_features, + activation='linear', + bias=True): + super().__init__() + self.activation = activations[activation] + activation_coeff = activation_coeffs[activation] + linear = nn.Linear(in_features, out_features, bias=bias) + nn.init.normal_( + linear.weight, std=math.sqrt(1. / in_features) * activation_coeff) + if bias: + nn.init.zeros_(linear.bias) + self.model = nn.utils.weight_norm(linear) + + def forward(self, x): + return self.activation(self.model(x)) + + +class RelationModule(nn.Module): + + def __init__(self, args): + super(RelationModule, self).__init__() + input_size = args.proj_hidden_size * 4 + self.prediction = torch.nn.Sequential( + LinearProjection( + input_size, args.proj_hidden_size * 4, activation='relu'), + nn.Dropout(args.dropout), + LinearProjection(args.proj_hidden_size * 4, 1)) + + def forward(self, query, protos): + n_cls = protos.shape[0] + n_query = query.shape[0] + protos = protos.unsqueeze(0).repeat(n_query, 1, 1) + query = query.unsqueeze(1).repeat(1, n_cls, 1) + input_feat = torch.cat( + [query, protos, (protos - query).abs(), query * protos], dim=-1) + dists = self.prediction(input_feat) # [bsz,n_query,n_cls,1] + return dists.squeeze(-1) + + +class MetricsLayer(nn.Module): + + def __init__(self, args): + super(MetricsLayer, self).__init__() + self.args = args + assert args.metrics in ('relation', 'cosine') + if args.metrics == 'relation': + self.relation_net = RelationModule(args) + + @property + def name(self): + return self.args.metrics + + def forward(self, query, protos): + """ query : [bsz, n_query, dim] + support : [bsz, n_query, n_cls, dim] | [bsz, n_cls, dim] + """ + if self.args.metrics == 'cosine': + supervised_dists = self.cosine_similarity(query, protos) + if self.training: + supervised_dists *= 5 + elif self.args.metrics in ('relation', ): + supervised_dists = self.relation_net(query, protos) + else: + raise NotImplementedError + return supervised_dists + + def cosine_similarity(self, x, y): + # x=[bsz, n_query, dim] + # y=[bsz, n_cls, dim] + n_query = x.shape[0] + n_cls = y.shape[0] + dim = x.shape[-1] + x = x.unsqueeze(1).expand([n_query, n_cls, dim]) + y = y.unsqueeze(0).expand([n_query, n_cls, dim]) + return F.cosine_similarity(x, y, -1) + + +class AveragePooling(nn.Module): + + def forward(self, x, mask, dim=1): + return torch.sum( + x * mask.float(), dim=dim) / torch.sum( + mask.float(), dim=dim) + + +class AttnPooling(nn.Module): + + def __init__(self, input_size, hidden_size=None, output_size=None): + super().__init__() + self.input_proj = nn.Sequential( + LinearProjection(input_size, hidden_size), nn.Tanh(), + LinearProjection(hidden_size, 1, bias=False)) + self.output_proj = LinearProjection( + input_size, output_size) if output_size else lambda x: x + + def forward(self, x, mask): + score = self.input_proj(x) + score = score * mask.float() + -1e4 * (1. - mask.float()) + score = F.softmax(score, dim=1) + features = self.output_proj(x) + return torch.matmul(score.transpose(1, 2), features).squeeze(1) + + +class PoolingLayer(nn.Module): + + def __init__(self, args): + super(PoolingLayer, self).__init__() + if args.pooling == 'attn': + self.pooling = AttnPooling(args.proj_hidden_size, + args.proj_hidden_size, + args.proj_hidden_size) + elif args.pooling == 'avg': + self.pooling = AveragePooling() + else: + raise NotImplementedError(args.pooling) + + def forward(self, x, mask): + return self.pooling(x, mask) + + +@MODELS.register_module( + Tasks.faq_question_answering, module_name=Models.structbert) +class SbertForFaqQuestionAnswering(BaseTaskModel): + _backbone_prefix = '' + + @classmethod + def _instantiate(cls, **kwargs): + model = cls(kwargs.get('model_dir')) + model.load_checkpoint(kwargs.get('model_dir')) + return model + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + + backbone_cfg = SbertConfig.from_pretrained(model_dir) + self.bert = SbertModel(backbone_cfg) + + model_config = Config.from_file( + os.path.join(model_dir, + ModelFile.CONFIGURATION)).get(ConfigFields.model, {}) + + metric = model_config.get('metric', 'cosine') + pooling_method = model_config.get('pooling', 'avg') + + Arg = namedtuple('args', [ + 'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling' + ]) + args = Arg( + metrics=metric, + proj_hidden_size=self.bert.config.hidden_size, + hidden_size=self.bert.config.hidden_size, + dropout=0.0, + pooling=pooling_method) + + self.metrics_layer = MetricsLayer(args) + self.pooling = PoolingLayer(args) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Args: + input (Dict[str, Tensor]): the preprocessed data, it contains the following keys: + query(:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + The query to be predicted. + support(:obj:`torch.LongTensor` of shape :obj:`(support_size, sequence_length)`): + The support set. + support_label(:obj:`torch.LongTensor` of shape :obj:`(support_size, )`): + The labels of support set. + + Returns: + Dict[str, Tensor]: result, it contains the following key: + scores(:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_cls)`): + Predicted scores of all classes for each query. + Examples: + >>> from modelscope.hub.snapshot_download import snapshot_download + >>> from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor + >>> from modelscope.models.nlp import SbertForFaqQuestionAnswering + >>> cache_path = snapshot_download('damo/nlp_structbert_faq-question-answering_chinese-base') + >>> preprocessor = FaqQuestionAnsweringPreprocessor.from_pretrained(cache_path) + >>> model = SbertForFaqQuestionAnswering.from_pretrained(cache_path) + >>> param = { + >>> 'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'], + >>> 'support_set': [{ + >>> 'text': '卖品代金券怎么用', + >>> 'label': '6527856' + >>> }, { + >>> 'text': '怎么使用优惠券', + >>> 'label': '6527856' + >>> }, { + >>> 'text': '这个可以一起领吗', + >>> 'label': '1000012000' + >>> }, { + >>> 'text': '付款时送的优惠券哪里领', + >>> 'label': '1000012000' + >>> }, { + >>> 'text': '购物等级怎么长', + >>> 'label': '13421097' + >>> }, { + >>> 'text': '购物等级二心', + >>> 'label': '13421097' + >>> }] + >>> } + >>> result = model(preprocessor(param)) + """ + assert not self.training + query = input['query'] + support = input['support'] + if isinstance(query, list): + query = torch.stack(query) + if isinstance(support, list): + support = torch.stack(support) + n_query = query.shape[0] + n_support = support.shape[0] + query_mask = torch.ne(query, 0).view([n_query, -1]) + support_mask = torch.ne(support, 0).view([n_support, -1]) + + support_labels = input['support_labels'] + num_cls = torch.max(support_labels) + 1 + onehot_labels = self._get_onehot_labels(support_labels, n_support, + num_cls) + + input_ids = torch.cat([query, support]) + input_mask = torch.cat([query_mask, support_mask], dim=0) + pooled_representation = self.forward_sentence_embedding({ + 'input_ids': + input_ids, + 'attention_mask': + input_mask + }) + z_query = pooled_representation[:n_query] + z_support = pooled_representation[n_query:] + cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5 + protos = torch.matmul(onehot_labels.transpose(0, 1), + z_support) / cls_n_support.unsqueeze(-1) + scores = self.metrics_layer(z_query, protos).view([n_query, num_cls]) + if self.metrics_layer.name == 'relation': + scores = torch.sigmoid(scores) + return {'scores': scores} + + def _get_onehot_labels(self, labels, support_size, num_cls): + labels_ = labels.view(support_size, 1) + target_oh = torch.zeros(support_size, num_cls).to(labels) + target_oh.scatter_(dim=1, index=labels_, value=1) + return target_oh.view(support_size, num_cls).float() + + def forward_sentence_embedding(self, inputs: Dict[str, Tensor]): + input_ids = inputs['input_ids'] + input_mask = inputs['attention_mask'] + if not isinstance(input_ids, Tensor): + input_ids = torch.IntTensor(input_ids) + if not isinstance(input_mask, Tensor): + input_mask = torch.IntTensor(input_mask) + rst = self.bert(input_ids, input_mask) + last_hidden_states = rst.last_hidden_state + if len(input_mask.shape) == 2: + input_mask = input_mask.unsqueeze(-1) + pooled_representation = self.pooling(last_hidden_states, input_mask) + return pooled_representation diff --git a/modelscope/models/nlp/structbert/fill_mask.py b/modelscope/models/nlp/structbert/fill_mask.py new file mode 100644 index 00000000..e611aa88 --- /dev/null +++ b/modelscope/models/nlp/structbert/fill_mask.py @@ -0,0 +1,284 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import SbertModel, SbertPreTrainedModel +from .configuration import SbertConfig + +logger = logging.get_logger(__name__) + + +class SbertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class SbertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = SbertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class SbertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = SbertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class SbertPreTrainingHeads(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = SbertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) +class SbertForMaskedLM(SbertPreTrainedModel): + r"""StructBERT Model with a `language modeling` head on top. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of StructBERT, the preprocessor of this model + is `modelscope.preprocessors.NLPPreprocessor`. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config: SbertConfig, **kwargs): + super().__init__(config) + + if config.is_decoder: + logger.warning( + 'If you want to use `SbertForMaskedLM` make sure `config.is_decoder=False` for ' + 'bi-directional self-attention.') + + self.bert = SbertModel(config) + self.cls = SbertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor, NLPPreprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_fill-mask_chinese-large') + >>> preprocessor = NLPPreprocessor('damo/nlp_structbert_fill-mask_chinese-large') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。'))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。')) + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if not return_dict: + output = (prediction_scores, ) + outputs[2:-1] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return AttentionFillMaskModelOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + input_ids=input_ids, + ) + + def prepare_inputs_for_generation(self, + input_ids, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert self.config.pad_token_id is not None, 'The PAD token should be defined for generation' + attention_mask_zero = attention_mask.new_zeros( + (attention_mask.shape[0], 1)) + attention_mask = torch.cat([attention_mask, attention_mask_zero], + dim=-1) + dummy_token = torch.full((effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {'input_ids': input_ids, 'attention_mask': attention_mask} diff --git a/modelscope/models/nlp/structbert/text_classification.py b/modelscope/models/nlp/structbert/text_classification.py new file mode 100644 index 00000000..8797beb3 --- /dev/null +++ b/modelscope/models/nlp/structbert/text_classification.py @@ -0,0 +1,236 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTextClassificationModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .adv_utils import compute_adv_loss +from .backbone import SbertModel, SbertPreTrainedModel +from .configuration import SbertConfig + +logger = logging.get_logger(__name__) + + +@MODELS.register_module( + Tasks.text_classification, module_name=Models.structbert) +@MODELS.register_module(Tasks.nli, module_name=Models.structbert) +@MODELS.register_module( + Tasks.sentiment_classification, module_name=Models.structbert) +@MODELS.register_module( + Tasks.sentence_similarity, module_name=Models.structbert) +@MODELS.register_module( + Tasks.zero_shot_classification, module_name=Models.structbert) +class SbertForSequenceClassification(SbertPreTrainedModel): + r"""StructBERT Model transformer with a sequence classification/regression head on top + (a linear layer on top of the pooled output) e.g. for GLUE tasks. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the text classification model of StructBERT, the preprocessor of this model + is `modelscope.preprocessors.SequenceClassificationPreprocessor`. + + Trainer: + This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, + NlpEpochBasedTrainer, or trainers from other frameworks. + The preferred trainer in ModelScope is NlpEpochBasedTrainer. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + def __init__(self, config: SbertConfig, **kwargs): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + if self.config.adv_grad_factor is None: + logger.warning( + 'Adv parameters not set, skipping compute_adv_loss.') + + SbertForSequenceClassification.base_model_prefix = getattr( + config, 'base_model_prefix', + SbertForSequenceClassification.base_model_prefix) + setattr(self, self.base_model_prefix, SbertModel(config)) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None + else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.init_weights() + + def _forward_call(self, **kwargs): + outputs = self.base_model(**kwargs) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + outputs['logits'] = logits + outputs.kwargs = kwargs + return outputs + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + *args, + **kwargs): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + Returns `modelscope.outputs.AttentionTextClassificationModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor(('这是个测试', '这也是个测试')))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('text-classification', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins(('这是个测试', '这也是个测试'))) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if not return_dict: + logger.error('Return tuple in sbert is not supported now.') + outputs = self._forward_call( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + return self.compute_loss(outputs, labels, **outputs.kwargs) + + def compute_loss(self, outputs, labels, **kwargs): + logits = outputs.logits + embedding_output = outputs.embedding_output + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long + or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + if self.config.adv_grad_factor is not None and self.training: + loss = compute_adv_loss( + embedding=embedding_output, + model=self._forward_call, + ori_logits=logits, + ori_loss=loss, + adv_bound=self.config.adv_bound, + adv_grad_factor=self.config.adv_grad_factor, + sigma=self.config.sigma, + **kwargs) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return AttentionTextClassificationModelOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/modelscope/models/nlp/structbert/token_classification.py b/modelscope/models/nlp/structbert/token_classification.py new file mode 100644 index 00000000..a040ff3e --- /dev/null +++ b/modelscope/models/nlp/structbert/token_classification.py @@ -0,0 +1,229 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import TokenClassifierOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .adv_utils import compute_adv_loss +from .backbone import SbertModel, SbertPreTrainedModel +from .configuration import SbertConfig + +logger = logging.get_logger(__name__) + + +@MODELS.register_module( + Tasks.token_classification, module_name=Models.structbert) +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) +@MODELS.register_module(Tasks.part_of_speech, module_name=Models.structbert) +class SbertForTokenClassification(SbertPreTrainedModel): + r"""StructBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) + e.g. for Named-Entity-Recognition (NER) tasks. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the token-classification model of StructBERT, the preprocessor of this model + is `modelscope.preprocessors.TokenClassificationPreprocessor`. + + Trainer: + This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer, + NlpEpochBasedTrainer, or trainers from other frameworks. + The preferred trainer in modelscope is NlpEpochBasedTrainer. + + Parameters: + config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with + all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. + """ + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + + def __init__(self, config: SbertConfig, **kwargs): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + if self.config.adv_grad_factor is None: + logger.warning( + 'Adv parameters not set, skipping compute_adv_loss.') + setattr(self, self.base_model_prefix, + SbertModel(config, add_pooling_layer=False)) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None + else config.hidden_dropout_prob) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def _forward_call(self, **kwargs): + outputs = self.bert(**kwargs) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + outputs['logits'] = logits + outputs.kwargs = kwargs + return outputs + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + offset_mapping=None, + label_mask=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + label_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask + values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + Returns `modelscope.outputs.TokenClassifierOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if not return_dict: + logger.error('Return tuple in sbert is not supported now.') + + outputs = self._forward_call( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + logits = outputs.logits + embedding_output = outputs.embedding_output + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), + torch.tensor(loss_fct.ignore_index).type_as(labels)) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + if self.config.adv_grad_factor is not None and self.training: + loss = compute_adv_loss( + embedding=embedding_output, + model=self._forward_call, + ori_logits=logits, + ori_loss=loss, + adv_bound=self.config.adv_bound, + adv_grad_factor=self.config.adv_grad_factor, + sigma=self.config.sigma, + with_attention_mask=attention_mask is not None, + **outputs.kwargs) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + offset_mapping=offset_mapping, + ) diff --git a/modelscope/models/nlp/structbert/tokenization.py b/modelscope/models/nlp/structbert/tokenization.py new file mode 100644 index 00000000..3171e31d --- /dev/null +++ b/modelscope/models/nlp/structbert/tokenization.py @@ -0,0 +1,519 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Sbert. mainly copied from :module:`~transformers.tokenization_bert`""" + +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + +from transformers.tokenization_utils import (PreTrainedTokenizer, _is_control, + _is_punctuation, _is_whitespace) + +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': ModelFile.VOCAB_FILE} + +PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'nlp_structbert_backbone_large_std': 512, + 'nlp_structbert_backbone_base_std': 512, + 'nlp_structbert_backbone_lite_std': 512, + 'nlp_structbert_backbone_tiny_std': 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + 'english_sbert-large-std-512': { + 'do_lower_case': True + }, +} + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, 'r', encoding='utf-8') as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip('\n') + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class SbertTokenizer(PreTrainedTokenizer): + r""" + Construct a SBERT tokenizer. Based on WordPiece. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (:obj:`str`): + File containing the vocabulary. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to do basic tokenization before WordPiece. + never_split (:obj:`Iterable`, `optional`): + Collection of tokens which will never be split during tokenization. Only has an effect when + :obj:`do_basic_tokenize=True` + unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this `issue + `__). + strip_accents: (:obj:`bool`, `optional`): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for :obj:`lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__(self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token='[UNK]', + sep_token='[SEP]', + pad_token='[PAD]', + cls_token='[CLS]', + mask_token='[MASK]', + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs): + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " + 'model use `tokenizer = SbertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([ + (ids, tok) for tok, ids in self.vocab.items() + ]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token=self.unk_token) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens): + + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = ' '.join(tokens).replace(' ##', '').strip() + return out_string + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A SBERT sequence has the following format: + + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ( + [0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A SBERT sequence + pair mask has the following format: + + :: + + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + + If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, + (filename_prefix + '-' if filename_prefix else '') + + VOCAB_FILES_NAMES['vocab_file']) + else: + vocab_file = (filename_prefix + + '-' if filename_prefix else '') + save_directory + with open(vocab_file, 'w', encoding='utf-8') as writer: + for token, token_index in sorted( + self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f'Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive.' + ' Please check that the vocabulary is not corrupted!') + index = token_index + writer.write(token + '\n') + index += 1 + return (vocab_file, ) + + +class BasicTokenizer(object): + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to lowercase the input when tokenizing. + never_split (:obj:`Iterable`, `optional`): + Collection of tokens which will never be split during tokenization. Only has an effect when + :obj:`do_basic_tokenize=True` + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this `issue + `__). + strip_accents: (:obj:`bool`, `optional`): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for :obj:`lowercase` (as in the original BERT). + """ + + def __init__(self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see + WordPieceTokenizer. + + Args: + **never_split**: (`optional`) list of str + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + :func:`PreTrainedTokenizer.tokenize`) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union( + set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(' '.join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize('NFD', text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == 'Mn': + continue + output.append(char) + return ''.join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if never_split is not None and text in never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return [''.join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(' ') + output.append(char) + output.append(' ') + else: + output.append(char) + return ''.join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((0x4E00 <= cp <= 0x9FFF) or (0x3400 <= cp <= 0x4DBF) + or (0x20000 <= cp <= 0x2A6DF) or (0x2A700 <= cp <= 0x2B73F) + or (0x2B740 <= cp <= 0x2B81F) or (0x2B820 <= cp <= 0x2CEAF) + or (0xF900 <= cp <= 0xFAFF) or (0x2F800 <= cp <= 0x2FA1F)): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(' ') + else: + output.append(char) + return ''.join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = ''.join(chars[start:end]) + if start > 0: + substr = '##' + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens diff --git a/modelscope/models/nlp/structbert/tokenization_fast.py b/modelscope/models/nlp/structbert/tokenization_fast.py new file mode 100644 index 00000000..6f7b7ba7 --- /dev/null +++ b/modelscope/models/nlp/structbert/tokenization_fast.py @@ -0,0 +1,203 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for Sbert. mainly copied from :module:`~transformers.tokenization_bert_fast`""" + +from typing import List, Optional, Tuple + +import json +import transformers +from tokenizers import normalizers +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from .tokenization import SbertTokenizer + +logger = get_logger(__name__) + +VOCAB_FILES_NAMES = { + 'vocab_file': ModelFile.VOCAB_FILE, + 'tokenizer_file': 'tokenizer.json' +} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': {}, + 'tokenizer_file': {}, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'nlp_structbert_backbone_large_std': 512, + 'nlp_structbert_backbone_base_std': 512, + 'nlp_structbert_backbone_lite_std': 512, + 'nlp_structbert_backbone_tiny_std': 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + 'english_sbert-large-std-512': { + 'do_lower_case': True + }, +} + +transformers.SLOW_TO_FAST_CONVERTERS[ + 'SbertTokenizer'] = transformers.SLOW_TO_FAST_CONVERTERS['BertTokenizer'] + + +class SbertTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" SBERT tokenizer (backed by HuggingFace's `tokenizers` library). Based on WordPiece. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main + methods. Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (:obj:`str`): + File containing the vocabulary. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to lowercase the input when tokenizing. + unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see `this + issue `__). + strip_accents: (:obj:`bool`, `optional`): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for :obj:`lowercase` (as in the original BERT). + wordpieces_prefix: (:obj:`str`, `optional`, defaults to :obj:`"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = SbertTokenizer + + def __init__(self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token='[UNK]', + sep_token='[SEP]', + pad_token='[PAD]', + cls_token='[CLS]', + mask_token='[MASK]', + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + pre_tok_state = json.loads( + self.backend_tokenizer.normalizer.__getstate__()) + if (pre_tok_state.get('lowercase', do_lower_case) != do_lower_case + or pre_tok_state.get('strip_accents', + strip_accents) != strip_accents): + pre_tok_class = getattr(normalizers, pre_tok_state.pop('type')) + pre_tok_state['lowercase'] = do_lower_case + pre_tok_state['strip_accents'] = strip_accents + self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) + + self.do_lower_case = do_lower_case + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A SBERT sequence has the following format: + + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A SBERT sequence + pair mask has the following format: + + :: + + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + + If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save( + save_directory, name=filename_prefix) + return tuple(files) diff --git a/modelscope/models/nlp/task_models/__init__.py b/modelscope/models/nlp/task_models/__init__.py new file mode 100644 index 00000000..b8722a36 --- /dev/null +++ b/modelscope/models/nlp/task_models/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .information_extraction import InformationExtractionModel + from .feature_extraction import FeatureExtractionModel + from .fill_mask import FillMaskModel + from .nncrf_for_named_entity_recognition import ( + LSTMCRFForNamedEntityRecognition, + TransformerCRFForNamedEntityRecognition, + ) + from .nncrf_for_word_segmentation import ( + LSTMCRFForWordSegmentation, + TransformerCRFForWordSegmentation, + ) + from .sequence_classification import SequenceClassificationModel + from .task_model import SingleBackboneTaskModelBase + from .token_classification import TokenClassificationModel + from .text_generation import TaskModelForTextGeneration + +else: + _import_structure = { + 'information_extraction': ['InformationExtractionModel'], + 'feature_extraction': ['FeatureExtractionModel'], + 'fill_mask': ['FillMaskModel'], + 'nncrf_for_named_entity_recognition': [ + 'TransformerCRFForNamedEntityRecognition', + 'LSTMCRFForNamedEntityRecognition' + ], + 'nncrf_for_word_segmentation': + ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'], + 'sequence_classification': ['SequenceClassificationModel'], + 'task_model': ['SingleBackboneTaskModelBase'], + 'token_classification': ['TokenClassificationModel'], + 'text_generation': ['TaskModelForTextGeneration'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/task_models/feature_extraction.py b/modelscope/models/nlp/task_models/feature_extraction.py new file mode 100644 index 00000000..9360ec08 --- /dev/null +++ b/modelscope/models/nlp/task_models/feature_extraction.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import TaskModels +from modelscope.models.builder import MODELS +from modelscope.models.nlp.bert import BertConfig +from modelscope.models.nlp.task_models.task_model import \ + SingleBackboneTaskModelBase +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping + +__all__ = ['FeatureExtractionModel'] + + +@MODELS.register_module( + Tasks.feature_extraction, module_name=TaskModels.feature_extraction) +class FeatureExtractionModel(SingleBackboneTaskModelBase): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the fill mask model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + if 'base_model_prefix' in kwargs: + self._base_model_prefix = kwargs['base_model_prefix'] + + self.build_backbone(self.backbone_cfg) + + def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: + # backbone do not need labels, only head need for loss compute + input.pop(OutputKeys.LABELS, None) + outputs = super().forward(input) + sequence_output = outputs.last_hidden_state + return {OutputKeys.TEXT_EMBEDDING: sequence_output} diff --git a/modelscope/models/nlp/task_models/fill_mask.py b/modelscope/models/nlp/task_models/fill_mask.py new file mode 100644 index 00000000..0f7d3345 --- /dev/null +++ b/modelscope/models/nlp/task_models/fill_mask.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import TaskModels +from modelscope.models.builder import MODELS +from modelscope.models.nlp.bert import BertConfig +from modelscope.models.nlp.task_models.task_model import \ + SingleBackboneTaskModelBase +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping + +__all__ = ['FillMaskModel'] + + +@MODELS.register_module(Tasks.fill_mask, module_name=TaskModels.fill_mask) +class FillMaskModel(SingleBackboneTaskModelBase): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the fill mask model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + if 'base_model_prefix' in kwargs: + self._base_model_prefix = kwargs['base_model_prefix'] + + self.build_backbone(self.backbone_cfg) + self.build_head(self.head_cfg) + + def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: + + # backbone do not need labels, only head need for loss compute + labels = input.pop(OutputKeys.LABELS, None) + + outputs = super().forward(input) + sequence_output = outputs.last_hidden_state + outputs = self.head.forward(sequence_output) + + if labels is not None: + input[OutputKeys.LABELS] = labels + loss = self.compute_loss(outputs, labels) + outputs.update(loss) + outputs[OutputKeys.INPUT_IDS] = input[OutputKeys.INPUT_IDS] + return outputs diff --git a/modelscope/models/nlp/task_models/information_extraction.py b/modelscope/models/nlp/task_models/information_extraction.py new file mode 100644 index 00000000..ce0e21a3 --- /dev/null +++ b/modelscope/models/nlp/task_models/information_extraction.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import TaskModels +from modelscope.models.builder import MODELS +from modelscope.models.nlp.task_models.task_model import \ + SingleBackboneTaskModelBase +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + +__all__ = ['InformationExtractionModel'] + + +@MODELS.register_module( + Tasks.information_extraction, + module_name=TaskModels.information_extraction) +@MODELS.register_module( + Tasks.relation_extraction, module_name=TaskModels.information_extraction) +class InformationExtractionModel(SingleBackboneTaskModelBase): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the information extraction model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + self.build_backbone(self.backbone_cfg) + self.build_head(self.head_cfg) + + def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: + outputs = super().forward(input) + sequence_output = outputs.last_hidden_state + outputs = self.head.forward(sequence_output, input['text'], + input['offsets']) + return {OutputKeys.SPO_LIST: outputs} diff --git a/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py b/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py new file mode 100644 index 00000000..017e35e5 --- /dev/null +++ b/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py @@ -0,0 +1,727 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. +# The CRF implementation borrows mostly from AllenNLP CRF module (https://github.com/allenai/allennlp) +# and pytorch-crf (https://github.com/kmkurn/pytorch-crf) with some modifications. + +import os +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import TokenClassifierWithPredictionsOutput +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = [ + 'TransformerCRFForNamedEntityRecognition', + 'LSTMCRFForNamedEntityRecognition' +] + + +class SequenceLabelingForNamedEntityRecognition(TorchModel): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.model = self.init_model(model_dir, *args, **kwargs) + + model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + self.model.load_state_dict( + torch.load(model_ckpt, map_location=torch.device('cpu'))) + + def init_model(self, model_dir, *args, **kwargs): + raise NotImplementedError + + def train(self): + return self.model.train() + + def eval(self): + return self.model.eval() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + offset_mapping=None, + label_mask=None, + ) -> Dict[str, Any]: + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + label_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask + values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + Returns `modelscope.outputs.TokenClassifierOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') + >>> print(model(**preprocessor(('This is a test', 'This is also a test')))) + """ + input_tensor = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'label_mask': label_mask, + } + output = { + 'offset_mapping': offset_mapping, + **input_tensor, + **self.model(input_tensor) + } + return output + + def postprocess(self, input: Dict[str, Any], **kwargs): + predicts = self.model.decode(input) + offset_len = len(input['offset_mapping']) + predictions = torch.narrow( + predicts, 1, 0, + offset_len) # index_select only move loc, not resize + return TokenClassifierWithPredictionsOutput( + loss=None, + logits=None, + hidden_states=None, + attentions=None, + offset_mapping=input['offset_mapping'], + predictions=predictions, + ) + + +@MODELS.register_module( + Tasks.named_entity_recognition, module_name=Models.tcrf) +class TransformerCRFForNamedEntityRecognition( + SequenceLabelingForNamedEntityRecognition): + """This model wraps the TransformerCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + num_labels = self.config.num_labels + + model = TransformerCRF(model_dir, num_labels) + return model + + +@MODELS.register_module( + Tasks.named_entity_recognition, module_name=Models.lcrf) +class LSTMCRFForNamedEntityRecognition( + SequenceLabelingForNamedEntityRecognition): + """This model wraps the LSTMCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + vocab_size = self.config.vocab_size + embed_width = self.config.embed_width + num_labels = self.config.num_labels + lstm_hidden_size = self.config.lstm_hidden_size + + model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size) + return model + + +class TransformerCRF(nn.Module): + """A transformer based model to NER tasks. + + This model will use transformers' backbones as its backbone. + """ + + def __init__(self, model_dir, num_labels, **kwargs): + super(TransformerCRF, self).__init__() + + self.encoder = AutoModel.from_pretrained(model_dir) + self.linear = nn.Linear(self.encoder.config.hidden_size, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embed = self.encoder( + inputs['input_ids'], attention_mask=inputs['attention_mask'])[0] + logits = self.linear(embed) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + return predicts + + +class LSTMCRF(nn.Module): + """ + A standard bilstm-crf model for fast prediction. + """ + + def __init__(self, + vocab_size, + embed_width, + num_labels, + lstm_hidden_size=100, + **kwargs): + super(LSTMCRF, self).__init__() + self.embedding = Embedding(vocab_size, embed_width) + self.lstm = nn.LSTM( + embed_width, + lstm_hidden_size, + num_layers=1, + bidirectional=True, + batch_first=True) + self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embedding = self.embedding(inputs['input_ids']) + lstm_output, _ = self.lstm(embedding) + logits = self.ffn(lstm_output) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + return predicts + + +class CRF(nn.Module): + """Conditional random field. + This module implements a conditional random field [LMP01]_. The forward computation + of this class computes the log likelihood of the given sequence of tags and + emission score tensor. This class also has `~CRF.decode` method which finds + the best tag sequence given an emission score tensor using `Viterbi algorithm`_. + Args: + num_tags: Number of tags. + batch_first: Whether the first dimension corresponds to the size of a minibatch. + Attributes: + start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size + ``(num_tags,)``. + end_transitions (`~torch.nn.Parameter`): End transition score tensor of size + ``(num_tags,)``. + transitions (`~torch.nn.Parameter`): Transition score tensor of size + ``(num_tags, num_tags)``. + .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). + "Conditional random fields: Probabilistic models for segmenting and + labeling sequence data". *Proc. 18th International Conf. on Machine + Learning*. Morgan Kaufmann. pp. 282–289. + .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm + + """ + + def __init__(self, num_tags: int, batch_first: bool = False) -> None: + if num_tags <= 0: + raise ValueError(f'invalid number of tags: {num_tags}') + super().__init__() + self.num_tags = num_tags + self.batch_first = batch_first + self.start_transitions = nn.Parameter(torch.empty(num_tags)) + self.end_transitions = nn.Parameter(torch.empty(num_tags)) + self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize the transition parameters. + The parameters will be initialized randomly from a uniform distribution + between -0.1 and 0.1. + """ + nn.init.uniform_(self.start_transitions, -0.1, 0.1) + nn.init.uniform_(self.end_transitions, -0.1, 0.1) + nn.init.uniform_(self.transitions, -0.1, 0.1) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(num_tags={self.num_tags})' + + def forward(self, + emissions: torch.Tensor, + tags: torch.LongTensor, + mask: Optional[torch.ByteTensor] = None, + reduction: str = 'mean') -> torch.Tensor: + """Compute the conditional log likelihood of a sequence of tags given emission scores. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + tags (`~torch.LongTensor`): Sequence of tags tensor of size + ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + reduction: Specifies the reduction to apply to the output: + ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. + ``sum``: the output will be summed over batches. ``mean``: the output will be + averaged over batches. ``token_mean``: the output will be averaged over tokens. + Returns: + `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if + reduction is ``none``, ``()`` otherwise. + """ + if reduction not in ('none', 'sum', 'mean', 'token_mean'): + raise ValueError(f'invalid reduction: {reduction}') + if mask is None: + mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, tags=tags, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + mask = mask.transpose(0, 1) + + # shape: (batch_size,) + numerator = self._compute_score(emissions, tags, mask) + # shape: (batch_size,) + denominator = self._compute_normalizer(emissions, mask) + # shape: (batch_size,) + llh = numerator - denominator + + if reduction == 'none': + return llh + if reduction == 'sum': + return llh.sum() + if reduction == 'mean': + return llh.mean() + return llh.sum() / mask.float().sum() + + def decode(self, + emissions: torch.Tensor, + mask: Optional[torch.ByteTensor] = None, + nbest: Optional[int] = None, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + """Find the most likely tag sequence using Viterbi algorithm. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + nbest (`int`): Number of most probable paths for each sequence + pad_tag (`int`): Tag at padded positions. Often input varies in length and + the length will be padded to the maximum length in the batch. Tags at + the padded positions will be assigned with a padding tag, i.e. `pad_tag` + Returns: + A PyTorch tensor of the best tag sequence for each batch of shape + (nbest, batch_size, seq_length) + """ + if nbest is None: + nbest = 1 + if mask is None: + mask = torch.ones( + emissions.shape[:2], + dtype=torch.uint8, + device=emissions.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + mask = mask.transpose(0, 1) + + if nbest == 1: + return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0) + return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) + + def _validate(self, + emissions: torch.Tensor, + tags: Optional[torch.LongTensor] = None, + mask: Optional[torch.ByteTensor] = None) -> None: + if emissions.dim() != 3: + raise ValueError( + f'emissions must have dimension of 3, got {emissions.dim()}') + if emissions.size(2) != self.num_tags: + raise ValueError( + f'expected last dimension of emissions is {self.num_tags}, ' + f'got {emissions.size(2)}') + + if tags is not None: + if emissions.shape[:2] != tags.shape: + raise ValueError( + 'the first two dimensions of emissions and tags must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}' + ) + + if mask is not None: + if emissions.shape[:2] != mask.shape: + raise ValueError( + 'the first two dimensions of emissions and mask must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}' + ) + no_empty_seq = not self.batch_first and mask[0].all() + no_empty_seq_bf = self.batch_first and mask[:, 0].all() + if not no_empty_seq and not no_empty_seq_bf: + raise ValueError('mask of the first timestep must all be on') + + def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # tags: (seq_length, batch_size) + # mask: (seq_length, batch_size) + seq_length, batch_size = tags.shape + mask = mask.float() + + # Start transition score and first emission + # shape: (batch_size,) + score = self.start_transitions[tags[0]] + score += emissions[0, torch.arange(batch_size), tags[0]] + + for i in range(1, seq_length): + # Transition score to next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += self.transitions[tags[i - 1], tags[i]] * mask[i] + + # Emission score for next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] + + # End transition score + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + # shape: (batch_size,) + last_tags = tags[seq_ends, torch.arange(batch_size)] + # shape: (batch_size,) + score += self.end_transitions[last_tags] + + return score + + def _compute_normalizer(self, emissions: torch.Tensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + seq_length = emissions.size(0) + + # Start transition score and first emission; score has size of + # (batch_size, num_tags) where for each batch, the j-th column stores + # the score that the first timestep has tag j + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + + for i in range(1, seq_length): + # Broadcast score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emissions = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the sum of scores of all + # possible tag sequences so far that end with transitioning from tag i to tag j + # and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emissions + + # Sum over all possible current tags, but we're in score space, so a sum + # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of + # all possible tag sequences so far, that end in tag i + # shape: (batch_size, num_tags) + next_score = torch.logsumexp(next_score, dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(1), next_score, score) + + # End transition score + # shape: (batch_size, num_tags) + score += self.end_transitions + + # Sum (log-sum-exp) over all possible tags + # shape: (batch_size,) + return torch.logsumexp(score, dim=1) + + def _viterbi_decode(self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + pad_tag: Optional[int] = None) -> List[List[int]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros((seq_length, batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size), + pad_tag, + dtype=torch.long, + device=device) + + # - score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # - history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + # Broadcast viterbi score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emission = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the score of the best + # tag sequence so far that ends with transitioning from tag i to tag j and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + + # Find the maximum score over all possible current tag + # shape: (batch_size, num_tags) + next_score, indices = next_score.max(dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(-1), next_score, score) + indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx) + history_idx[i - 1] = indices + + # End transition score + # shape: (batch_size, num_tags) + end_score = score + self.end_transitions + _, end_tag = end_score.max(dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags), + end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size), + dtype=torch.long, + device=device) + best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx], 1, best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size) + + return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1) + + def _viterbi_decode_nbest( + self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + nbest: int, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (nbest, batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros( + (seq_length, batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size, nbest), + pad_tag, + dtype=torch.long, + device=device) + + # + score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # + history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + if i == 1: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1) + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + else: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) + # shape: (batch_size, num_tags, nbest, num_tags) + next_score = broadcast_score + self.transitions.unsqueeze( + 1) + broadcast_emission + + # Find the top `nbest` maximum score over all possible current tag + # shape: (batch_size, nbest, num_tags) + next_score, indices = next_score.view(batch_size, -1, + self.num_tags).topk( + nbest, dim=1) + + if i == 1: + score = score.unsqueeze(-1).expand(-1, -1, nbest) + indices = indices * nbest + + # convert to shape: (batch_size, num_tags, nbest) + next_score = next_score.transpose(2, 1) + indices = indices.transpose(2, 1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags, nbest) + score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), + next_score, score) + indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, + oor_idx) + history_idx[i - 1] = indices + + # End transition score shape: (batch_size, num_tags, nbest) + end_score = score + self.end_transitions.unsqueeze(-1) + _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), + end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size, nbest), + dtype=torch.long, + device=device) + best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ + .view(1, -1).expand(batch_size, -1) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, + best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest + + return torch.where(mask.unsqueeze(-1), best_tags_arr, + oor_tag).permute(2, 1, 0) + + +class Embedding(nn.Module): + + def __init__(self, vocab_size, embed_width): + super(Embedding, self).__init__() + + self.embedding = nn.Embedding(vocab_size, embed_width) + + def forward(self, input_ids): + return self.embedding(input_ids) diff --git a/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py b/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py new file mode 100644 index 00000000..2a3f6cf4 --- /dev/null +++ b/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py @@ -0,0 +1,639 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. +# The CRF implementation borrows mostly from AllenNLP CRF module (https://github.com/allenai/allennlp) +# and pytorch-crf (https://github.com/kmkurn/pytorch-crf) with some modifications. + +import os +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import TokenClassifierWithPredictionsOutput +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'] + + +class SequenceLabelingForWordSegmentation(TorchModel): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.model = self.init_model(model_dir, *args, **kwargs) + + model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + self.model.load_state_dict( + torch.load(model_ckpt, map_location=torch.device('cpu'))) + + def init_model(self, model_dir, *args, **kwargs): + raise NotImplementedError + + def train(self): + return self.model.train() + + def eval(self): + return self.model.eval() + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + input_tensor = { + 'input_ids': input['input_ids'], + 'attention_mask': input['attention_mask'], + 'label_mask': input['label_mask'], + } + output = { + 'offset_mapping': input['offset_mapping'], + **input_tensor, + **self.model(input_tensor) + } + return output + + def postprocess(self, input: Dict[str, Any], **kwargs): + predicts = self.model.decode(input) + offset_len = len(input['offset_mapping']) + predictions = torch.narrow( + predicts, 1, 0, + offset_len) # index_select only move loc, not resize + return TokenClassifierWithPredictionsOutput( + loss=None, + logits=None, + hidden_states=None, + attentions=None, + offset_mapping=input['offset_mapping'], + predictions=predictions, + ) + + +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.tcrf_wseg) +class TransformerCRFForWordSegmentation(SequenceLabelingForWordSegmentation): + """This model wraps the TransformerCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + num_labels = self.config.num_labels + + model = TransformerCRF(model_dir, num_labels) + return model + + +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.lcrf_wseg) +class LSTMCRFForWordSegmentation(SequenceLabelingForWordSegmentation): + """This model wraps the LSTMCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + vocab_size = self.config.vocab_size + embed_width = self.config.embed_width + num_labels = self.config.num_labels + lstm_hidden_size = self.config.lstm_hidden_size + + model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size) + return model + + +class TransformerCRF(nn.Module): + """A transformer based model to NER tasks. + + This model will use transformers' backbones as its backbone. + """ + + def __init__(self, model_dir, num_labels, **kwargs): + super(TransformerCRF, self).__init__() + + self.encoder = AutoModel.from_pretrained(model_dir) + self.linear = nn.Linear(self.encoder.config.hidden_size, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embed = self.encoder( + inputs['input_ids'], attention_mask=inputs['attention_mask'])[0] + logits = self.linear(embed) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + + return predicts + + +class LSTMCRF(nn.Module): + """ + A standard bilstm-crf model for fast prediction. + """ + + def __init__(self, + vocab_size, + embed_width, + num_labels, + lstm_hidden_size=100, + **kwargs): + super(LSTMCRF, self).__init__() + self.embedding = Embedding(vocab_size, embed_width) + self.lstm = nn.LSTM( + embed_width, + lstm_hidden_size, + num_layers=1, + bidirectional=True, + batch_first=True) + self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embedding = self.embedding(inputs['input_ids']) + lstm_output, _ = self.lstm(embedding) + logits = self.ffn(lstm_output) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + outputs = {'predicts': predicts} + return outputs + + +class CRF(nn.Module): + """Conditional random field. + This module implements a conditional random field [LMP01]_. The forward computation + of this class computes the log likelihood of the given sequence of tags and + emission score tensor. This class also has `~CRF.decode` method which finds + the best tag sequence given an emission score tensor using `Viterbi algorithm`_. + Args: + num_tags: Number of tags. + batch_first: Whether the first dimension corresponds to the size of a minibatch. + Attributes: + start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size + ``(num_tags,)``. + end_transitions (`~torch.nn.Parameter`): End transition score tensor of size + ``(num_tags,)``. + transitions (`~torch.nn.Parameter`): Transition score tensor of size + ``(num_tags, num_tags)``. + .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). + "Conditional random fields: Probabilistic models for segmenting and + labeling sequence data". *Proc. 18th International Conf. on Machine + Learning*. Morgan Kaufmann. pp. 282–289. + .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm + + """ + + def __init__(self, num_tags: int, batch_first: bool = False) -> None: + if num_tags <= 0: + raise ValueError(f'invalid number of tags: {num_tags}') + super().__init__() + self.num_tags = num_tags + self.batch_first = batch_first + self.start_transitions = nn.Parameter(torch.empty(num_tags)) + self.end_transitions = nn.Parameter(torch.empty(num_tags)) + self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize the transition parameters. + The parameters will be initialized randomly from a uniform distribution + between -0.1 and 0.1. + """ + nn.init.uniform_(self.start_transitions, -0.1, 0.1) + nn.init.uniform_(self.end_transitions, -0.1, 0.1) + nn.init.uniform_(self.transitions, -0.1, 0.1) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(num_tags={self.num_tags})' + + def forward(self, + emissions: torch.Tensor, + tags: torch.LongTensor, + mask: Optional[torch.ByteTensor] = None, + reduction: str = 'mean') -> torch.Tensor: + """Compute the conditional log likelihood of a sequence of tags given emission scores. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + tags (`~torch.LongTensor`): Sequence of tags tensor of size + ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + reduction: Specifies the reduction to apply to the output: + ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. + ``sum``: the output will be summed over batches. ``mean``: the output will be + averaged over batches. ``token_mean``: the output will be averaged over tokens. + Returns: + `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if + reduction is ``none``, ``()`` otherwise. + """ + if reduction not in ('none', 'sum', 'mean', 'token_mean'): + raise ValueError(f'invalid reduction: {reduction}') + if mask is None: + mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, tags=tags, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + mask = mask.transpose(0, 1) + + # shape: (batch_size,) + numerator = self._compute_score(emissions, tags, mask) + # shape: (batch_size,) + denominator = self._compute_normalizer(emissions, mask) + # shape: (batch_size,) + llh = numerator - denominator + + if reduction == 'none': + return llh + if reduction == 'sum': + return llh.sum() + if reduction == 'mean': + return llh.mean() + return llh.sum() / mask.float().sum() + + def decode(self, + emissions: torch.Tensor, + mask: Optional[torch.ByteTensor] = None, + nbest: Optional[int] = None, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + """Find the most likely tag sequence using Viterbi algorithm. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + nbest (`int`): Number of most probable paths for each sequence + pad_tag (`int`): Tag at padded positions. Often input varies in length and + the length will be padded to the maximum length in the batch. Tags at + the padded positions will be assigned with a padding tag, i.e. `pad_tag` + Returns: + A PyTorch tensor of the best tag sequence for each batch of shape + (nbest, batch_size, seq_length) + """ + if nbest is None: + nbest = 1 + if mask is None: + mask = torch.ones( + emissions.shape[:2], + dtype=torch.uint8, + device=emissions.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + mask = mask.transpose(0, 1) + + if nbest == 1: + return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0) + return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) + + def _validate(self, + emissions: torch.Tensor, + tags: Optional[torch.LongTensor] = None, + mask: Optional[torch.ByteTensor] = None) -> None: + if emissions.dim() != 3: + raise ValueError( + f'emissions must have dimension of 3, got {emissions.dim()}') + if emissions.size(2) != self.num_tags: + raise ValueError( + f'expected last dimension of emissions is {self.num_tags}, ' + f'got {emissions.size(2)}') + + if tags is not None: + if emissions.shape[:2] != tags.shape: + raise ValueError( + 'the first two dimensions of emissions and tags must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}' + ) + + if mask is not None: + if emissions.shape[:2] != mask.shape: + raise ValueError( + 'the first two dimensions of emissions and mask must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}' + ) + no_empty_seq = not self.batch_first and mask[0].all() + no_empty_seq_bf = self.batch_first and mask[:, 0].all() + if not no_empty_seq and not no_empty_seq_bf: + raise ValueError('mask of the first timestep must all be on') + + def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # tags: (seq_length, batch_size) + # mask: (seq_length, batch_size) + seq_length, batch_size = tags.shape + mask = mask.float() + + # Start transition score and first emission + # shape: (batch_size,) + score = self.start_transitions[tags[0]] + score += emissions[0, torch.arange(batch_size), tags[0]] + + for i in range(1, seq_length): + # Transition score to next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += self.transitions[tags[i - 1], tags[i]] * mask[i] + + # Emission score for next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] + + # End transition score + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + # shape: (batch_size,) + last_tags = tags[seq_ends, torch.arange(batch_size)] + # shape: (batch_size,) + score += self.end_transitions[last_tags] + + return score + + def _compute_normalizer(self, emissions: torch.Tensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + seq_length = emissions.size(0) + + # Start transition score and first emission; score has size of + # (batch_size, num_tags) where for each batch, the j-th column stores + # the score that the first timestep has tag j + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + + for i in range(1, seq_length): + # Broadcast score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emissions = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the sum of scores of all + # possible tag sequences so far that end with transitioning from tag i to tag j + # and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emissions + + # Sum over all possible current tags, but we're in score space, so a sum + # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of + # all possible tag sequences so far, that end in tag i + # shape: (batch_size, num_tags) + next_score = torch.logsumexp(next_score, dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(1), next_score, score) + + # End transition score + # shape: (batch_size, num_tags) + score += self.end_transitions + + # Sum (log-sum-exp) over all possible tags + # shape: (batch_size,) + return torch.logsumexp(score, dim=1) + + def _viterbi_decode(self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + pad_tag: Optional[int] = None) -> List[List[int]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros((seq_length, batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size), + pad_tag, + dtype=torch.long, + device=device) + + # - score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # - history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + # Broadcast viterbi score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emission = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the score of the best + # tag sequence so far that ends with transitioning from tag i to tag j and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + + # Find the maximum score over all possible current tag + # shape: (batch_size, num_tags) + next_score, indices = next_score.max(dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(-1), next_score, score) + indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx) + history_idx[i - 1] = indices + + # End transition score + # shape: (batch_size, num_tags) + end_score = score + self.end_transitions + _, end_tag = end_score.max(dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags), + end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size), + dtype=torch.long, + device=device) + best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx], 1, best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size) + + return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1) + + def _viterbi_decode_nbest( + self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + nbest: int, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (nbest, batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros( + (seq_length, batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size, nbest), + pad_tag, + dtype=torch.long, + device=device) + + # + score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # + history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + if i == 1: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1) + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + else: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) + # shape: (batch_size, num_tags, nbest, num_tags) + next_score = broadcast_score + self.transitions.unsqueeze( + 1) + broadcast_emission + + # Find the top `nbest` maximum score over all possible current tag + # shape: (batch_size, nbest, num_tags) + next_score, indices = next_score.view(batch_size, -1, + self.num_tags).topk( + nbest, dim=1) + + if i == 1: + score = score.unsqueeze(-1).expand(-1, -1, nbest) + indices = indices * nbest + + # convert to shape: (batch_size, num_tags, nbest) + next_score = next_score.transpose(2, 1) + indices = indices.transpose(2, 1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags, nbest) + score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), + next_score, score) + indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, + oor_idx) + history_idx[i - 1] = indices + + # End transition score shape: (batch_size, num_tags, nbest) + end_score = score + self.end_transitions.unsqueeze(-1) + _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), + end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size, nbest), + dtype=torch.long, + device=device) + best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ + .view(1, -1).expand(batch_size, -1) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, + best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest + + return torch.where(mask.unsqueeze(-1), best_tags_arr, + oor_tag).permute(2, 1, 0) + + +class Embedding(nn.Module): + + def __init__(self, vocab_size, embed_width): + super(Embedding, self).__init__() + + self.embedding = nn.Embedding(vocab_size, embed_width) + + def forward(self, input_ids): + return self.embedding(input_ids) diff --git a/modelscope/models/nlp/task_models/sequence_classification.py b/modelscope/models/nlp/task_models/sequence_classification.py new file mode 100644 index 00000000..6c0c09a2 --- /dev/null +++ b/modelscope/models/nlp/task_models/sequence_classification.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import TaskModels +from modelscope.models.builder import MODELS +from modelscope.models.nlp.task_models.task_model import \ + SingleBackboneTaskModelBase +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping + +__all__ = ['SequenceClassificationModel'] + + +@MODELS.register_module( + Tasks.text_classification, module_name=TaskModels.text_classification) +class SequenceClassificationModel(SingleBackboneTaskModelBase): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the sequence classification model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + if 'base_model_prefix' in kwargs: + self._base_model_prefix = kwargs['base_model_prefix'] + + # get the num_labels from label_mapping.json + self.id2label = {} + # get the num_labels + num_labels = kwargs.get('num_labels') + if num_labels is None: + label2id = parse_label_mapping(model_dir) + if label2id is not None and len(label2id) > 0: + num_labels = len(label2id) + self.id2label = {id: label for label, id in label2id.items()} + self.head_cfg['num_labels'] = num_labels + + self.build_backbone(self.backbone_cfg) + self.build_head(self.head_cfg) + + def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: + # backbone do not need labels, only head need for loss compute + labels = input.pop(OutputKeys.LABELS, None) + + outputs = super().forward(input) + pooled_output = outputs.pooler_output + outputs = self.head.forward(pooled_output) + if labels is not None: + input[OutputKeys.LABELS] = labels + loss = self.compute_loss(outputs, labels) + outputs.update(loss) + return outputs diff --git a/modelscope/models/nlp/task_models/task_model.py b/modelscope/models/nlp/task_models/task_model.py new file mode 100644 index 00000000..8c83517a --- /dev/null +++ b/modelscope/models/nlp/task_models/task_model.py @@ -0,0 +1,510 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path +import re +from abc import ABC +from collections import OrderedDict +from typing import Any, Dict + +import torch +from torch import nn + +from modelscope.models.base import TorchModel +from modelscope.models.builder import build_backbone, build_head +from modelscope.utils.config import ConfigDict +from modelscope.utils.constant import Fields, Tasks +from modelscope.utils.file_utils import func_receive_dict_inputs +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + +__all__ = ['EncoderDecoderTaskModelBase', 'SingleBackboneTaskModelBase'] + + +def _repr(modules, depth=1): + # model name log level control + if depth == 0: + return modules._get_name() + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = modules.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + + def _addindent(s_, numSpaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + for key, module in modules._modules.items(): + mod_str = _repr(module, depth - 1) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + lines = extra_lines + child_lines + + main_str = modules._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + return main_str + + +class BaseTaskModel(TorchModel, ABC): + """ Base task model interface for nlp + + """ + # keys to ignore when load missing + _keys_to_ignore_on_load_missing = None + # keys to ignore when load unexpected + _keys_to_ignore_on_load_unexpected = None + # backbone prefix, default None + _backbone_prefix = None + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.config = ConfigDict(kwargs) + + def __repr__(self): + # only log backbone and head name + depth = 1 + return _repr(self, depth) + + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.get('model_dir') + model = cls(**kwargs) + model.load_checkpoint(model_local_dir=model_dir, **kwargs) + return model + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pass + + def load_checkpoint(self, + model_local_dir, + default_dtype=None, + load_state_fn=None, + **kwargs): + """ + Load model checkpoint file and feed the parameters into the model. + Args: + model_local_dir: The actual checkpoint dir on local disk. + default_dtype: Set the default float type by 'torch.set_default_dtype' + load_state_fn: An optional load_state_fn used to load state_dict into the model. + + Returns: + + """ + # TODO Sharded ckpt + ckpt_file = os.path.join(model_local_dir, 'pytorch_model.bin') + state_dict = torch.load(ckpt_file, map_location='cpu') + if default_dtype is not None: + torch.set_default_dtype(default_dtype) + + missing_keys, unexpected_keys, mismatched_keys, error_msgs = self._load_checkpoint( + state_dict, + load_state_fn=load_state_fn, + ignore_mismatched_sizes=True, + _fast_init=True, + ) + + return { + 'missing_keys': missing_keys, + 'unexpected_keys': unexpected_keys, + 'mismatched_keys': mismatched_keys, + 'error_msgs': error_msgs, + } + + def _load_checkpoint( + self, + state_dict, + load_state_fn, + ignore_mismatched_sizes, + _fast_init, + ): + # Retrieve missing & unexpected_keys + model_state_dict = self.state_dict() + prefix = self._backbone_prefix + + # add head prefix + new_state_dict = OrderedDict() + for name, module in state_dict.items(): + if not name.startswith(prefix) and not name.startswith('head'): + new_state_dict['.'.join(['head', name])] = module + else: + new_state_dict[name] = module + state_dict = new_state_dict + + loaded_keys = [k for k in state_dict.keys()] + expected_keys = list(model_state_dict.keys()) + + def _fix_key(key): + if 'beta' in key: + return key.replace('beta', 'bias') + if 'gamma' in key: + return key.replace('gamma', 'weight') + return key + + original_loaded_keys = loaded_keys + loaded_keys = [_fix_key(key) for key in loaded_keys] + + if len(prefix) > 0: + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + expects_prefix_module = any( + s.startswith(prefix) for s in expected_keys) + else: + has_prefix_module = False + expects_prefix_module = False + + # key re-naming operations are never done on the keys + # that are loaded, but always on the keys of the newly initialized model + remove_prefix_from_model = not has_prefix_module and expects_prefix_module + add_prefix_to_model = has_prefix_module and not expects_prefix_module + + if remove_prefix_from_model: + expected_keys_not_prefixed = [ + s for s in expected_keys if not s.startswith(prefix) + ] + expected_keys = [ + '.'.join(s.split('.')[1:]) if s.startswith(prefix) else s + for s in expected_keys + ] + elif add_prefix_to_model: + expected_keys = ['.'.join([prefix, s]) for s in expected_keys] + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + if self._keys_to_ignore_on_load_missing is not None: + for pat in self._keys_to_ignore_on_load_missing: + missing_keys = [ + k for k in missing_keys if re.search(pat, k) is None + ] + + if self._keys_to_ignore_on_load_unexpected is not None: + for pat in self._keys_to_ignore_on_load_unexpected: + unexpected_keys = [ + k for k in unexpected_keys if re.search(pat, k) is None + ] + + if _fast_init: + # retrieve unintialized modules and initialize + uninitialized_modules = self.retrieve_modules_from_names( + missing_keys, + prefix=prefix, + add_prefix=add_prefix_to_model, + remove_prefix=remove_prefix_from_model) + for module in uninitialized_modules: + self._init_weights(module) + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = '' + model_to_load = self + if len(prefix) > 0 and not hasattr(self, prefix) and has_prefix_module: + start_prefix = prefix + '.' + if len(prefix) > 0 and hasattr(self, prefix) and not has_prefix_module: + model_to_load = getattr(self, prefix) + if any(key in expected_keys_not_prefixed for key in loaded_keys): + raise ValueError( + 'The state dictionary of the model you are trying to load is corrupted. Are you sure it was ' + 'properly saved?') + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + if remove_prefix_from_model: + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f'{prefix}.{checkpoint_key}' + elif add_prefix_to_model: + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = '.'.join(checkpoint_key.split('.')[1:]) + + if (model_key in model_state_dict): + model_shape = model_state_dict[model_key].shape + checkpoint_shape = state_dict[checkpoint_key].shape + if (checkpoint_shape != model_shape): + mismatched_keys.append( + (checkpoint_key, + state_dict[checkpoint_key].shape, + model_state_dict[model_key].shape)) + del state_dict[checkpoint_key] + return mismatched_keys + + def _load_state_dict_into_model(model_to_load, state_dict, + start_prefix): + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + if load_state_fn is not None: + load_state_fn( + model_to_load, + state_dict, + prefix=start_prefix, + local_metadata=None, + error_msgs=error_msgs) + else: + + def load(module: nn.Module, prefix=''): + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], + error_msgs) + module._load_from_state_dict(*args) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model_to_load, prefix=start_prefix) + + return error_msgs + + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict, + start_prefix) + + if len(error_msgs) > 0: + error_msg = '\n\t'.join(error_msgs) + raise RuntimeError( + f'Error(s) in loading state_dict for {self.__class__.__name__}:\n\t{error_msg}' + ) + + if len(unexpected_keys) > 0: + logger.warning( + f'Some weights of the model checkpoint were not used when' + f' initializing {self.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are' + f' initializing {self.__class__.__name__} from the checkpoint of a model trained on another task or' + ' with another architecture (e.g. initializing a BertForSequenceClassification model from a' + ' BertForPreTraining model).\n- This IS NOT expected if you are initializing' + f' {self.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical' + ' (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).' + ) + else: + logger.info( + f'All model checkpoint weights were used when initializing {self.__class__.__name__}.\n' + ) + if len(missing_keys) > 0: + logger.warning( + f'Some weights of {self.__class__.__name__} were not initialized from the model checkpoint' + f' and are newly initialized: {missing_keys}\nYou should probably' + ' TRAIN this model on a down-stream task to be able to use it for predictions and inference.' + ) + elif len(mismatched_keys) == 0: + logger.info( + f'All the weights of {self.__class__.__name__} were initialized from the model checkpoint ' + f'If your task is similar to the task the model of the checkpoint' + f' was trained on, you can already use {self.__class__.__name__} for predictions without further' + ' training.') + if len(mismatched_keys) > 0: + mismatched_warning = '\n'.join([ + f'- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated' + for key, shape1, shape2 in mismatched_keys + ]) + logger.warning( + f'Some weights of {self.__class__.__name__} were not initialized from the model checkpoint' + f' and are newly initialized because the shapes did not' + f' match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able' + ' to use it for predictions and inference.') + + return missing_keys, unexpected_keys, mismatched_keys, error_msgs + + def retrieve_modules_from_names(self, + names, + prefix=None, + add_prefix=False, + remove_prefix=False): + module_keys = set(['.'.join(key.split('.')[:-1]) for key in names]) + + # torch.nn.ParameterList is a special case where two parameter keywords + # are appended to the module name, *e.g.* bert.special_embeddings.0 + module_keys = module_keys.union( + set([ + '.'.join(key.split('.')[:-2]) for key in names + if key[-1].isdigit() + ])) + + retrieved_modules = [] + # retrieve all modules that has at least one missing weight name + for name, module in self.named_modules(): + if remove_prefix: + name = '.'.join( + name.split('.')[1:]) if name.startswith(prefix) else name + elif add_prefix: + name = '.'.join([prefix, name]) if len(name) > 0 else prefix + + if name in module_keys: + retrieved_modules.append(module) + + return retrieved_modules + + +class SingleBackboneTaskModelBase(BaseTaskModel): + """ + This is the base class of any single backbone nlp task classes. + """ + # The backbone prefix defaults to "bert" + _backbone_prefix = 'bert' + + # The head prefix defaults to "head" + _head_prefix = 'head' + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.backbone_cfg = self.config.get('backbone', None) + assert self.backbone_cfg is not None + self.head_cfg = self.config.get('head', None) + + def build_backbone(self, cfg): + if 'prefix' in cfg: + self._backbone_prefix = cfg['prefix'] + backbone = build_backbone(cfg) + setattr(self, cfg['prefix'], backbone) + + def build_head(self, cfg): + if cfg is None: + raise ValueError( + 'Head config is missing, check if this was a backbone-only model' + ) + if 'prefix' in cfg: + self._head_prefix = cfg['prefix'] + head = build_head(cfg, task_name=self.group_key) + setattr(self, self._head_prefix, head) + return head + + @property + def backbone(self): + if 'backbone' != self._backbone_prefix: + return getattr(self, self._backbone_prefix) + return super().__getattr__('backbone') + + @property + def head(self): + if 'head' != self._head_prefix: + return getattr(self, self._head_prefix) + return super().__getattr__('head') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """default forward method is the backbone-only forward""" + if func_receive_dict_inputs(self.backbone.forward): + outputs = self.backbone.forward(input) + else: + outputs = self.backbone.forward(**input) + return outputs + + def compute_loss(self, outputs, labels): + loss = self.head.compute_loss(outputs, labels) + return loss + + def extract_backbone_outputs(self, outputs): + sequence_output = None + pooled_output = None + if hasattr(self.backbone, 'extract_sequence_outputs'): + sequence_output = self.backbone.extract_sequence_outputs(outputs) + if hasattr(self.backbone, 'extract_pooled_outputs'): + pooled_output = self.backbone.extract_pooled_outputs(outputs) + return sequence_output, pooled_output + + +class EncoderDecoderTaskModelBase(BaseTaskModel): + """ + This is the base class of encoder-decoder nlp task classes. + """ + # The encoder backbone prefix, default to "encoder" + _encoder_prefix = 'encoder' + # The decoder backbone prefix, default to "decoder" + _decoder_prefix = 'decoder' + # The key in cfg specifing the encoder type + _encoder_key_in_cfg = 'encoder_type' + # The key in cfg specifing the decoder type + _decoder_key_in_cfg = 'decoder_type' + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + + def build_encoder(self): + encoder = build_backbone( + self.config, + type_name=self._encoder_key_in_cfg, + task_name=Tasks.backbone) + setattr(self, self._encoder_prefix, encoder) + return encoder + + def build_decoder(self): + decoder = build_backbone( + self.config, + type_name=self._decoder_key_in_cfg, + task_name=Tasks.backbone) + setattr(self, self._decoder_prefix, decoder) + return decoder + + @property + def encoder_(self): + return getattr(self, self._encoder_prefix) + + @property + def decoder_(self): + return getattr(self, self._decoder_prefix) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if func_receive_dict_inputs(self.encoder_.forward): + encoder_outputs = self.encoder_.forward(input) + else: + encoder_outputs = self.encoder_.forward(**input) + decoder_inputs = self.project_decoder_inputs_and_mediate( + input, encoder_outputs) + if func_receive_dict_inputs(self.decoder_.forward): + outputs = self.decoder_.forward(decoder_inputs) + else: + outputs = self.decoder_.forward(**decoder_inputs) + + return outputs + + def project_decoder_inputs_and_mediate(self, input, encoder_outputs): + return {**input, **encoder_outputs} diff --git a/modelscope/models/nlp/task_models/text_generation.py b/modelscope/models/nlp/task_models/text_generation.py new file mode 100644 index 00000000..b886f124 --- /dev/null +++ b/modelscope/models/nlp/task_models/text_generation.py @@ -0,0 +1,88 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import numpy as np +from transformers.modeling_utils import GenerationMixin + +from modelscope.metainfo import TaskModels +from modelscope.models.builder import MODELS +from modelscope.models.nlp.task_models.task_model import \ + SingleBackboneTaskModelBase +from modelscope.outputs import (OutputKeys, TextGenerationModelOutput, + TokenGeneratorOutput) +from modelscope.utils.constant import Tasks + +__all__ = ['TaskModelForTextGeneration'] + + +@MODELS.register_module( + Tasks.text_generation, module_name=TaskModels.text_generation) +class TaskModelForTextGeneration(SingleBackboneTaskModelBase, GenerationMixin): + main_input_name = 'input_ids' + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + if 'base_model_prefix' in kwargs: + self._base_model_prefix = kwargs['base_model_prefix'] + + self.build_backbone(self.backbone_cfg) + self.build_head(self.head_cfg) + if self.config.get('shared_embedding', False): + input_embeddings = self.backbone.get_input_embeddings() + output_embeddings = self.head.get_output_embeddings() + output_embeddings.weight = input_embeddings.weight + + def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: + # backbone do not need labels, only head need for loss compute + labels = input.pop(OutputKeys.LABELS, None) + + backbone_outputs = super().forward(input) + hidden_states = backbone_outputs[0] + + logits = self.head.forward(hidden_states) + loss = None + if labels is not None: + input[OutputKeys.LABELS] = labels + loss = self.compute_loss(logits, labels) + return TextGenerationModelOutput(logits=logits, loss=loss) + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get('attention_mask', None) + position_ids = kwargs.get('position_ids', None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + 'input_ids': input_ids, + 'past_key_values': past, + 'use_cache': kwargs.get('use_cache'), + 'position_ids': position_ids, + 'attention_mask': attention_mask, + } + + def generate(self, inputs, *args, **kwargs): + input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs + generate_output = super().generate(input_ids, *args, **kwargs) + if isinstance(generate_output, Dict): + return TokenGeneratorOutput( + sequences=generate_output.sequences, + scores=generate_output.scores, + attentions=generate_output.attentions, + hidden_states=generate_output.hidden_states) + else: + return TokenGeneratorOutput(sequences=generate_output) diff --git a/modelscope/models/nlp/task_models/token_classification.py b/modelscope/models/nlp/task_models/token_classification.py new file mode 100644 index 00000000..8b523baf --- /dev/null +++ b/modelscope/models/nlp/task_models/token_classification.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import numpy as np +import torch + +from modelscope.metainfo import TaskModels +from modelscope.models.builder import MODELS +from modelscope.models.nlp.task_models.task_model import \ + SingleBackboneTaskModelBase +from modelscope.outputs import OutputKeys, TokenClassifierOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) + +__all__ = ['TokenClassificationModel'] + + +@MODELS.register_module( + Tasks.token_classification, module_name=TaskModels.token_classification) +@MODELS.register_module( + Tasks.part_of_speech, module_name=TaskModels.token_classification) +class TokenClassificationModel(SingleBackboneTaskModelBase): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the token classification model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + if 'base_model_prefix' in kwargs: + self._base_model_prefix = kwargs['base_model_prefix'] + + # get the num_labels + num_labels = kwargs.get('num_labels') + if num_labels is None: + label2id = parse_label_mapping(model_dir) + if label2id is not None and len(label2id) > 0: + num_labels = len(label2id) + self.id2label = {id: label for label, id in label2id.items()} + self.head_cfg['num_labels'] = num_labels + + self.build_backbone(self.backbone_cfg) + self.build_head(self.head_cfg) + + def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: + labels = None + if OutputKeys.LABEL in input: + labels = input.pop(OutputKeys.LABEL) + elif OutputKeys.LABELS in input: + labels = input.pop(OutputKeys.LABELS) + + outputs = super().forward(input) + sequence_output = outputs[0] + logits = self.head.forward(sequence_output) + loss = None + if labels in input: + loss = self.compute_loss(outputs, labels) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + offset_mapping=input['offset_mapping'], + ) + + def extract_logits(self, outputs): + return outputs[OutputKeys.LOGITS].cpu().detach() diff --git a/modelscope/models/nlp/veco/__init__.py b/modelscope/models/nlp/veco/__init__.py new file mode 100644 index 00000000..0774e9b4 --- /dev/null +++ b/modelscope/models/nlp/veco/__init__.py @@ -0,0 +1,47 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import VecoConfig + from .backbone import VecoModel + from .text_classification import VecoForSequenceClassification + from .token_classification import VecoForTokenClassification + from .fill_mask import VecoForMaskedLM + from .tokenization import VecoTokenizer + from .tokenization_fast import VecoTokenizerFast +else: + _import_structure = { + 'configuration': ['VecoConfig'], + 'backbone': ['VecoModel'], + 'text_classification': ['VecoForSequenceClassification'], + 'fill_mask': ['VecoForMaskedLM'], + 'token_classification': ['VecoForTokenClassification'], + 'tokenization': ['VecoTokenizer'], + 'tokenization_fast': ['VecoTokenizerFast'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/veco/backbone.py b/modelscope/models/nlp/veco/backbone.py new file mode 100644 index 00000000..98d8c30a --- /dev/null +++ b/modelscope/models/nlp/veco/backbone.py @@ -0,0 +1,96 @@ +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Veco model. mainly copied from :module:`~transformers.modeling_xlm_roberta`""" + +from transformers import RobertaModel + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionBackboneModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .configuration import VecoConfig + +logger = logging.get_logger(__name__) + +VECO_PRETRAINED_MODEL_ARCHIVE_LIST = [] + + +@MODELS.register_module(Tasks.backbone, module_name=Models.veco) +class VecoModel(TorchModel, RobertaModel): + """The bare Veco Model transformer outputting raw hidden-states without any specific head on top. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config ([`VecoConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + This class overrides [`RobertaModel`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = VecoConfig + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def forward(self, *args, **kwargs): + """ + Returns: + Returns `modelscope.outputs.AttentionBackboneModelOutputWithEmbedding` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_veco_fill-mask-large', task='backbone') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_veco_fill-mask-large') + >>> print(model(**preprocessor('这是个测试'))) + + """ + kwargs['return_dict'] = True + outputs = super(Model, self).forward(*args, **kwargs) + return AttentionBackboneModelOutput( + last_hidden_state=outputs.last_hidden_state, + pooler_output=outputs.pooler_output, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + ponet_config = VecoConfig(**kwargs) + model = cls(ponet_config) + else: + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + return model diff --git a/modelscope/models/nlp/veco/configuration.py b/modelscope/models/nlp/veco/configuration.py new file mode 100644 index 00000000..396755dc --- /dev/null +++ b/modelscope/models/nlp/veco/configuration.py @@ -0,0 +1,33 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors. +# Copyright 2020 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Veco configuration, mainly copied from :class:`~transformers.configuration_xlm_roberta` """ + +from transformers import RobertaConfig + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class VecoConfig(RobertaConfig): + """ + This class overrides [`RobertaConfig`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + model_type = 'veco' diff --git a/modelscope/models/nlp/veco/fill_mask.py b/modelscope/models/nlp/veco/fill_mask.py new file mode 100644 index 00000000..de2cdb4a --- /dev/null +++ b/modelscope/models/nlp/veco/fill_mask.py @@ -0,0 +1,99 @@ +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import RobertaForMaskedLM + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionFillMaskModelOutput +from modelscope.utils.constant import Tasks +from .configuration import VecoConfig + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco) +class VecoForMaskedLM(TorchModel, RobertaForMaskedLM): + """Veco Model transformer with a masked language model head on top (a linear layer on top of the + pooled output). + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the fill_mask model of StructBERT, the preprocessor of this model + is `modelscope.preprocessors.NLPPreprocessor`. + + Parameters: + config ([`VecoConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + This class overrides [`RobertaForMaskedLM`]. Please check the superclass for the + appropriate documentation alongside usage examples. + """ + + config_class = VecoConfig + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def forward(self, *args, **kwargs): + """ + Returns: + Returns `modelscope.outputs.AttentionFillMaskModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_veco_fill-mask-large') + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_veco_fill-mask-large') + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('你师父差得动你,你师父可不动我。'))) + >>> # Call the pipeline + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('你师父差得动你,你师父可不动我。')) + """ + + kwargs['return_dict'] = True + outputs = super(Model, self).forward(*args, **kwargs) + return AttentionFillMaskModelOutput( + loss=outputs.loss, + logits=outputs.logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + input_ids=kwargs['input_ids'], + ) + + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + ponet_config = VecoConfig(**kwargs) + model = cls(ponet_config) + else: + model = super( + Model, + cls).from_pretrained(pretrained_model_name_or_path=model_dir) + return model diff --git a/modelscope/models/nlp/veco/text_classification.py b/modelscope/models/nlp/veco/text_classification.py new file mode 100644 index 00000000..e4e74d8f --- /dev/null +++ b/modelscope/models/nlp/veco/text_classification.py @@ -0,0 +1,150 @@ +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import RobertaForSequenceClassification + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTextClassificationModelOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from .configuration import VecoConfig + + +@MODELS.register_module(Tasks.nli, module_name=Models.veco) +@MODELS.register_module( + Tasks.sentiment_classification, module_name=Models.veco) +@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.veco) +@MODELS.register_module(Tasks.text_classification, module_name=Models.veco) +class VecoForSequenceClassification(TorchModel, + RobertaForSequenceClassification): + """Veco Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Preprocessor: + This is the text classification model of Veco, the preprocessor of this model + is `modelscope.preprocessors.SequenceClassificationPreprocessor`. + + Trainer: + This model should be trained by dataset which has mixed languages, + and evaluated by datasets of languages one by one. + For example, if the training dataset is xnli (which has sub datasets of multiple languages), then you + should mix the sub-datasets with the languages you want to train to one training dataset, and evaluate + the model one sub-dataset by one sub-dataset of different languages. + This procedure can be done by custom code. If you are using trainer of ModelScope, + the `VecoTrainer` is suggested to use to train this model. This trainer overrides the basic evaluation + loop, and will call the evaluation dataset one by one. Besides, this trainer will use the `VecoTaskDataset` + to mix the input datasets to one, you can check the API Doc for the details. + + To check the complete example please + view the unittest `test_veco_xnli` in `tests.trainers.test_finetune_sequence_classification.py` + + Parameters: + config ([`VecoConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + This class overrides [`RobertaForSequenceClassification`]. Please check the superclass for the + appropriate documentation alongside usage examples. + """ + + config_class = VecoConfig + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def forward(self, *args, **kwargs): + """ + Returns: + Returns `modelscope.outputs.AttentionTextClassificationModelOutput` + + Examples: + >>> from modelscope.models import Model + >>> from modelscope.preprocessors import Preprocessor + >>> model = Model.from_pretrained('damo/nlp_veco_fill-mask-large', + >>> task='text-classification', num_labels=2) + >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_veco_fill-mask-large', + >>> label2id={'0': 0, '1': 1}) + >>> # Call the model, return some tensors + >>> print(model(**preprocessor('这是个测试'))) + >>> # Call the pipeline, the result may be incorrect + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('text-classification', pipeline_name='text-classification', + >>> model=model, preprocessor=preprocessor) + >>> print(pipeline_ins('这是个测试')) + """ + + kwargs['return_dict'] = True + outputs = super(Model, self).forward(*args, **kwargs) + return AttentionTextClassificationModelOutput( + loss=outputs.loss, + logits=outputs.logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels is not input. + label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + config = VecoConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) + id2label = kwargs.get( + 'id2label', None if label2id is None else + {id: label + for label, id in label2id.items()}) + if id2label is not None and label2id is None: + label2id = {label: id for id, label in id2label.items()} + + num_labels = kwargs.get( + 'num_labels', None if label2id is None else len(label2id)) + if num_labels is not None: + model_kwargs['num_labels'] = num_labels + if label2id is not None: + model_kwargs['label2id'] = label2id + if id2label is not None: + model_kwargs['id2label'] = id2label + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + return model diff --git a/modelscope/models/nlp/veco/token_classification.py b/modelscope/models/nlp/veco/token_classification.py new file mode 100644 index 00000000..f6252209 --- /dev/null +++ b/modelscope/models/nlp/veco/token_classification.py @@ -0,0 +1,107 @@ +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import RobertaForTokenClassification + +from modelscope.metainfo import Models +from modelscope.models import Model, TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTokenClassificationModelOutput +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from .configuration import VecoConfig + + +@MODELS.register_module(Tasks.token_classification, module_name=Models.veco) +class VecoForTokenClassification(TorchModel, RobertaForTokenClassification): + """Veco Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config ([`VecoConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. + + This class overrides [`RobertaForTokenClassification`]. Please check the superclass for the + appropriate documentation alongside usage examples. + """ + + config_class = VecoConfig + + def __init__(self, config, **kwargs): + super().__init__(config.name_or_path, **kwargs) + super(Model, self).__init__(config) + + def forward(self, *args, **kwargs): + kwargs['return_dict'] = True + outputs = super(Model, self).forward(*args, **kwargs) + return AttentionTokenClassificationModelOutput( + loss=outputs.loss, + logits=outputs.logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels is not input. + label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + if model_dir is None: + config = VecoConfig(**kwargs) + model = cls(config) + else: + model_kwargs = {} + label2id = kwargs.get('label2id', parse_label_mapping(model_dir)) + id2label = kwargs.get( + 'id2label', None if label2id is None else + {id: label + for label, id in label2id.items()}) + if id2label is not None and label2id is None: + label2id = {label: id for id, label in id2label.items()} + + num_labels = kwargs.get( + 'num_labels', None if label2id is None else len(label2id)) + if num_labels is not None: + model_kwargs['num_labels'] = num_labels + if label2id is not None: + model_kwargs['label2id'] = label2id + if id2label is not None: + model_kwargs['id2label'] = id2label + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_kwargs) + return model diff --git a/modelscope/models/nlp/veco/tokenization.py b/modelscope/models/nlp/veco/tokenization.py new file mode 100644 index 00000000..21711456 --- /dev/null +++ b/modelscope/models/nlp/veco/tokenization.py @@ -0,0 +1,321 @@ +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Tokenization classes for Veco. mainly copied from :module:`~transformers.tokenization_xlm_roberta`""" + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = '▁' + +VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'} + +PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} + + +class VecoTokenizer(PreTrainedTokenizer): + """ + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of + sequence. The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + additional_special_tokens (`List[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. + The [Python wrapper for SentencePiece](https://github.com/google/sentencepiece/tree/master/python) + can be used, among other things, to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ['input_ids', 'attention_mask'] + + def __init__(self, + vocab_file, + bos_token='', + eos_token='', + sep_token='', + cls_token='', + unk_token='', + pad_token='', + mask_token='', + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs) -> None: + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken( + mask_token, lstrip=True, rstrip=False) if isinstance( + mask_token, str) else mask_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # Mimic fairseq token-to-id alignment for the first 4 token + self.fairseq_tokens_to_ids = { + '': 0, + '': 1, + '': 2, + '': 3 + } + + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.fairseq_tokens_to_ids[''] = len( + self.sp_model) + self.fairseq_offset + self.fairseq_ids_to_tokens = { + v: k + for k, v in self.fairseq_tokens_to_ids.items() + } + + def __getstate__(self): + state = self.__dict__.copy() + state['sp_model'] = None + state['sp_model_proto'] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, 'sp_model_kwargs'): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An Veco sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ( + [0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. Veco does + not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + @property + def vocab_size(self): + return len( + self.sp_model) + self.fairseq_offset + 1 # Add the token + + def get_vocab(self): + vocab = { + self.convert_ids_to_tokens(i): i + for i in range(self.vocab_size) + } + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() + return out_string + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error( + f'Vocabulary path ({save_directory}) should be a directory') + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + '-' if filename_prefix else '') + + VOCAB_FILES_NAMES['vocab_file']) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file, ) diff --git a/modelscope/models/nlp/veco/tokenization_fast.py b/modelscope/models/nlp/veco/tokenization_fast.py new file mode 100644 index 00000000..b41a5c3b --- /dev/null +++ b/modelscope/models/nlp/veco/tokenization_fast.py @@ -0,0 +1,213 @@ +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Fast Tokenization classes for Veco. mainly copied from :module:`~transformers.tokenization_xlm_roberta_fast`""" + +import os +from shutil import copyfile +from typing import List, Optional, Tuple + +import transformers +from transformers.file_utils import is_sentencepiece_available +from transformers.tokenization_utils import AddedToken +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from modelscope.utils import logger as logging + +if is_sentencepiece_available(): + from .tokenization import VecoTokenizer +else: + VecoTokenizer = None + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + 'vocab_file': 'sentencepiece.bpe.model', + 'tokenizer_file': 'tokenizer.json' +} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': {}, + 'tokenizer_file': {}, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} + +transformers.SLOW_TO_FAST_CONVERTERS[ + 'VecoTokenizer'] = transformers.SLOW_TO_FAST_CONVERTERS[ + 'XLMRobertaTokenizer'] + + +class VecoTokenizerFast(PreTrainedTokenizerFast): + """ + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. + Based on [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main + methods. Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of + sequence. The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + additional_special_tokens (`List[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ['input_ids', 'attention_mask'] + slow_tokenizer_class = VecoTokenizer + + def __init__(self, + vocab_file=None, + tokenizer_file=None, + bos_token='', + eos_token='', + sep_token='', + cls_token='', + unk_token='', + pad_token='', + mask_token='', + **kwargs): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken( + mask_token, lstrip=True, rstrip=False) if isinstance( + mask_token, str) else mask_token + + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + **kwargs, + ) + + self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An Veco sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. Veco does + not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' + 'tokenizer.') + + if not os.path.isdir(save_directory): + logger.error( + f'Vocabulary path ({save_directory}) should be a directory.') + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + '-' if filename_prefix else '') + + VOCAB_FILES_NAMES['vocab_file']) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file, ) diff --git a/modelscope/models/science/__init__.py b/modelscope/models/science/__init__.py new file mode 100644 index 00000000..50ab55d7 --- /dev/null +++ b/modelscope/models/science/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .unifold import UnifoldForProteinStructrue + +else: + _import_structure = {'unifold': ['UnifoldForProteinStructrue']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/science/unifold/__init__.py b/modelscope/models/science/unifold/__init__.py new file mode 100644 index 00000000..75435fed --- /dev/null +++ b/modelscope/models/science/unifold/__init__.py @@ -0,0 +1 @@ +from .model import UnifoldForProteinStructrue diff --git a/modelscope/models/science/unifold/config.py b/modelscope/models/science/unifold/config.py new file mode 100644 index 00000000..e760fbf9 --- /dev/null +++ b/modelscope/models/science/unifold/config.py @@ -0,0 +1,636 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import copy +from typing import Any + +import ml_collections as mlc + +N_RES = 'number of residues' +N_MSA = 'number of MSA sequences' +N_EXTRA_MSA = 'number of extra MSA sequences' +N_TPL = 'number of templates' + +d_pair = mlc.FieldReference(128, field_type=int) +d_msa = mlc.FieldReference(256, field_type=int) +d_template = mlc.FieldReference(64, field_type=int) +d_extra_msa = mlc.FieldReference(64, field_type=int) +d_single = mlc.FieldReference(384, field_type=int) +max_recycling_iters = mlc.FieldReference(3, field_type=int) +chunk_size = mlc.FieldReference(4, field_type=int) +aux_distogram_bins = mlc.FieldReference(64, field_type=int) +eps = mlc.FieldReference(1e-8, field_type=float) +inf = mlc.FieldReference(3e4, field_type=float) +use_templates = mlc.FieldReference(True, field_type=bool) +is_multimer = mlc.FieldReference(False, field_type=bool) + + +def base_config(): + return mlc.ConfigDict({ + 'data': { + 'common': { + 'features': { + 'aatype': [N_RES], + 'all_atom_mask': [N_RES, None], + 'all_atom_positions': [N_RES, None, None], + 'alt_chi_angles': [N_RES, None], + 'atom14_alt_gt_exists': [N_RES, None], + 'atom14_alt_gt_positions': [N_RES, None, None], + 'atom14_atom_exists': [N_RES, None], + 'atom14_atom_is_ambiguous': [N_RES, None], + 'atom14_gt_exists': [N_RES, None], + 'atom14_gt_positions': [N_RES, None, None], + 'atom37_atom_exists': [N_RES, None], + 'frame_mask': [N_RES], + 'true_frame_tensor': [N_RES, None, None], + 'bert_mask': [N_MSA, N_RES], + 'chi_angles_sin_cos': [N_RES, None, None], + 'chi_mask': [N_RES, None], + 'extra_msa_deletion_value': [N_EXTRA_MSA, N_RES], + 'extra_msa_has_deletion': [N_EXTRA_MSA, N_RES], + 'extra_msa': [N_EXTRA_MSA, N_RES], + 'extra_msa_mask': [N_EXTRA_MSA, N_RES], + 'extra_msa_row_mask': [N_EXTRA_MSA], + 'is_distillation': [], + 'msa_feat': [N_MSA, N_RES, None], + 'msa_mask': [N_MSA, N_RES], + 'msa_chains': [N_MSA, None], + 'msa_row_mask': [N_MSA], + 'num_recycling_iters': [], + 'pseudo_beta': [N_RES, None], + 'pseudo_beta_mask': [N_RES], + 'residue_index': [N_RES], + 'residx_atom14_to_atom37': [N_RES, None], + 'residx_atom37_to_atom14': [N_RES, None], + 'resolution': [], + 'rigidgroups_alt_gt_frames': [N_RES, None, None, None], + 'rigidgroups_group_exists': [N_RES, None], + 'rigidgroups_group_is_ambiguous': [N_RES, None], + 'rigidgroups_gt_exists': [N_RES, None], + 'rigidgroups_gt_frames': [N_RES, None, None, None], + 'seq_length': [], + 'seq_mask': [N_RES], + 'target_feat': [N_RES, None], + 'template_aatype': [N_TPL, N_RES], + 'template_all_atom_mask': [N_TPL, N_RES, None], + 'template_all_atom_positions': [N_TPL, N_RES, None, None], + 'template_alt_torsion_angles_sin_cos': [ + N_TPL, + N_RES, + None, + None, + ], + 'template_frame_mask': [N_TPL, N_RES], + 'template_frame_tensor': [N_TPL, N_RES, None, None], + 'template_mask': [N_TPL], + 'template_pseudo_beta': [N_TPL, N_RES, None], + 'template_pseudo_beta_mask': [N_TPL, N_RES], + 'template_sum_probs': [N_TPL, None], + 'template_torsion_angles_mask': [N_TPL, N_RES, None], + 'template_torsion_angles_sin_cos': + [N_TPL, N_RES, None, None], + 'true_msa': [N_MSA, N_RES], + 'use_clamped_fape': [], + 'assembly_num_chains': [1], + 'asym_id': [N_RES], + 'sym_id': [N_RES], + 'entity_id': [N_RES], + 'num_sym': [N_RES], + 'asym_len': [None], + 'cluster_bias_mask': [N_MSA], + }, + 'masked_msa': { + 'profile_prob': 0.1, + 'same_prob': 0.1, + 'uniform_prob': 0.1, + }, + 'block_delete_msa': { + 'msa_fraction_per_block': 0.3, + 'randomize_num_blocks': False, + 'num_blocks': 5, + 'min_num_msa': 16, + }, + 'random_delete_msa': { + 'max_msa_entry': 1 << 25, # := 33554432 + }, + 'v2_feature': + False, + 'gumbel_sample': + False, + 'max_extra_msa': + 1024, + 'msa_cluster_features': + True, + 'reduce_msa_clusters_by_max_templates': + True, + 'resample_msa_in_recycling': + True, + 'template_features': [ + 'template_all_atom_positions', + 'template_sum_probs', + 'template_aatype', + 'template_all_atom_mask', + ], + 'unsupervised_features': [ + 'aatype', + 'residue_index', + 'msa', + 'msa_chains', + 'num_alignments', + 'seq_length', + 'between_segment_residues', + 'deletion_matrix', + 'num_recycling_iters', + 'crop_and_fix_size_seed', + ], + 'recycling_features': [ + 'msa_chains', + 'msa_mask', + 'msa_row_mask', + 'bert_mask', + 'true_msa', + 'msa_feat', + 'extra_msa_deletion_value', + 'extra_msa_has_deletion', + 'extra_msa', + 'extra_msa_mask', + 'extra_msa_row_mask', + 'is_distillation', + ], + 'multimer_features': [ + 'assembly_num_chains', + 'asym_id', + 'sym_id', + 'num_sym', + 'entity_id', + 'asym_len', + 'cluster_bias_mask', + ], + 'use_templates': + use_templates, + 'is_multimer': + is_multimer, + 'use_template_torsion_angles': + use_templates, + 'max_recycling_iters': + max_recycling_iters, + }, + 'supervised': { + 'use_clamped_fape_prob': + 1.0, + 'supervised_features': [ + 'all_atom_mask', + 'all_atom_positions', + 'resolution', + 'use_clamped_fape', + 'is_distillation', + ], + }, + 'predict': { + 'fixed_size': True, + 'subsample_templates': False, + 'block_delete_msa': False, + 'random_delete_msa': True, + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 128, + 'max_templates': 4, + 'num_ensembles': 2, + 'crop': False, + 'crop_size': None, + 'supervised': False, + 'biased_msa_by_chain': False, + 'share_mask': False, + }, + 'eval': { + 'fixed_size': True, + 'subsample_templates': False, + 'block_delete_msa': False, + 'random_delete_msa': True, + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 128, + 'max_templates': 4, + 'num_ensembles': 1, + 'crop': False, + 'crop_size': None, + 'spatial_crop_prob': 0.5, + 'ca_ca_threshold': 10.0, + 'supervised': True, + 'biased_msa_by_chain': False, + 'share_mask': False, + }, + 'train': { + 'fixed_size': True, + 'subsample_templates': True, + 'block_delete_msa': True, + 'random_delete_msa': True, + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 128, + 'max_templates': 4, + 'num_ensembles': 1, + 'crop': True, + 'crop_size': 256, + 'spatial_crop_prob': 0.5, + 'ca_ca_threshold': 10.0, + 'supervised': True, + 'use_clamped_fape_prob': 1.0, + 'max_distillation_msa_clusters': 1000, + 'biased_msa_by_chain': True, + 'share_mask': True, + }, + }, + 'globals': { + 'chunk_size': chunk_size, + 'block_size': None, + 'd_pair': d_pair, + 'd_msa': d_msa, + 'd_template': d_template, + 'd_extra_msa': d_extra_msa, + 'd_single': d_single, + 'eps': eps, + 'inf': inf, + 'max_recycling_iters': max_recycling_iters, + 'alphafold_original_mode': False, + }, + 'model': { + 'is_multimer': is_multimer, + 'input_embedder': { + 'tf_dim': 22, + 'msa_dim': 49, + 'd_pair': d_pair, + 'd_msa': d_msa, + 'relpos_k': 32, + 'max_relative_chain': 2, + }, + 'recycling_embedder': { + 'd_pair': d_pair, + 'd_msa': d_msa, + 'min_bin': 3.25, + 'max_bin': 20.75, + 'num_bins': 15, + 'inf': 1e8, + }, + 'template': { + 'distogram': { + 'min_bin': 3.25, + 'max_bin': 50.75, + 'num_bins': 39, + }, + 'template_angle_embedder': { + 'd_in': 57, + 'd_out': d_msa, + }, + 'template_pair_embedder': { + 'd_in': 88, + 'v2_d_in': [39, 1, 22, 22, 1, 1, 1, 1], + 'd_pair': d_pair, + 'd_out': d_template, + 'v2_feature': False, + }, + 'template_pair_stack': { + 'd_template': d_template, + 'd_hid_tri_att': 16, + 'd_hid_tri_mul': 64, + 'num_blocks': 2, + 'num_heads': 4, + 'pair_transition_n': 2, + 'dropout_rate': 0.25, + 'inf': 1e9, + 'tri_attn_first': True, + }, + 'template_pointwise_attention': { + 'enabled': True, + 'd_template': d_template, + 'd_pair': d_pair, + 'd_hid': 16, + 'num_heads': 4, + 'inf': 1e5, + }, + 'inf': 1e5, + 'eps': 1e-6, + 'enabled': use_templates, + 'embed_angles': use_templates, + }, + 'extra_msa': { + 'extra_msa_embedder': { + 'd_in': 25, + 'd_out': d_extra_msa, + }, + 'extra_msa_stack': { + 'd_msa': d_extra_msa, + 'd_pair': d_pair, + 'd_hid_msa_att': 8, + 'd_hid_opm': 32, + 'd_hid_mul': 128, + 'd_hid_pair_att': 32, + 'num_heads_msa': 8, + 'num_heads_pair': 4, + 'num_blocks': 4, + 'transition_n': 4, + 'msa_dropout': 0.15, + 'pair_dropout': 0.25, + 'inf': 1e9, + 'eps': 1e-10, + 'outer_product_mean_first': False, + }, + 'enabled': True, + }, + 'evoformer_stack': { + 'd_msa': d_msa, + 'd_pair': d_pair, + 'd_hid_msa_att': 32, + 'd_hid_opm': 32, + 'd_hid_mul': 128, + 'd_hid_pair_att': 32, + 'd_single': d_single, + 'num_heads_msa': 8, + 'num_heads_pair': 4, + 'num_blocks': 48, + 'transition_n': 4, + 'msa_dropout': 0.15, + 'pair_dropout': 0.25, + 'inf': 1e9, + 'eps': 1e-10, + 'outer_product_mean_first': False, + }, + 'structure_module': { + 'd_single': d_single, + 'd_pair': d_pair, + 'd_ipa': 16, + 'd_angle': 128, + 'num_heads_ipa': 12, + 'num_qk_points': 4, + 'num_v_points': 8, + 'dropout_rate': 0.1, + 'num_blocks': 8, + 'no_transition_layers': 1, + 'num_resnet_blocks': 2, + 'num_angles': 7, + 'trans_scale_factor': 10, + 'epsilon': 1e-12, + 'inf': 1e5, + 'separate_kv': False, + 'ipa_bias': True, + }, + 'heads': { + 'plddt': { + 'num_bins': 50, + 'd_in': d_single, + 'd_hid': 128, + }, + 'distogram': { + 'd_pair': d_pair, + 'num_bins': aux_distogram_bins, + 'disable_enhance_head': False, + }, + 'pae': { + 'd_pair': d_pair, + 'num_bins': aux_distogram_bins, + 'enabled': False, + 'iptm_weight': 0.8, + 'disable_enhance_head': False, + }, + 'masked_msa': { + 'd_msa': d_msa, + 'd_out': 23, + 'disable_enhance_head': False, + }, + 'experimentally_resolved': { + 'd_single': d_single, + 'd_out': 37, + 'enabled': False, + 'disable_enhance_head': False, + }, + }, + }, + 'loss': { + 'distogram': { + 'min_bin': 2.3125, + 'max_bin': 21.6875, + 'num_bins': 64, + 'eps': 1e-6, + 'weight': 0.3, + }, + 'experimentally_resolved': { + 'eps': 1e-8, + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'weight': 0.0, + }, + 'fape': { + 'backbone': { + 'clamp_distance': 10.0, + 'clamp_distance_between_chains': 30.0, + 'loss_unit_distance': 10.0, + 'loss_unit_distance_between_chains': 20.0, + 'weight': 0.5, + 'eps': 1e-4, + }, + 'sidechain': { + 'clamp_distance': 10.0, + 'length_scale': 10.0, + 'weight': 0.5, + 'eps': 1e-4, + }, + 'weight': 1.0, + }, + 'plddt': { + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'cutoff': 15.0, + 'num_bins': 50, + 'eps': 1e-10, + 'weight': 0.01, + }, + 'masked_msa': { + 'eps': 1e-8, + 'weight': 2.0, + }, + 'supervised_chi': { + 'chi_weight': 0.5, + 'angle_norm_weight': 0.01, + 'eps': 1e-6, + 'weight': 1.0, + }, + 'violation': { + 'violation_tolerance_factor': 12.0, + 'clash_overlap_tolerance': 1.5, + 'bond_angle_loss_weight': 0.3, + 'eps': 1e-6, + 'weight': 0.0, + }, + 'pae': { + 'max_bin': 31, + 'num_bins': 64, + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'eps': 1e-8, + 'weight': 0.0, + }, + 'repr_norm': { + 'weight': 0.01, + 'tolerance': 1.0, + }, + 'chain_centre_mass': { + 'weight': 0.0, + 'eps': 1e-8, + }, + }, + }) + + +def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None): + with c.unlocked(): + for k, v in c.items(): + if ignore is not None and k == ignore: + continue + if isinstance(v, mlc.ConfigDict): + recursive_set(v, key, value) + elif k == key: + c[k] = value + + +def model_config(name, train=False): + c = copy.deepcopy(base_config()) + + def model_2_v2(c): + recursive_set(c, 'v2_feature', True) + recursive_set(c, 'gumbel_sample', True) + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.template.template_angle_embedder.d_in = 34 + return c + + def multimer(c): + recursive_set(c, 'is_multimer', True) + recursive_set(c, 'max_extra_msa', 1152) + recursive_set(c, 'max_msa_clusters', 128) + recursive_set(c, 'v2_feature', True) + recursive_set(c, 'gumbel_sample', True) + c.model.template.template_angle_embedder.d_in = 34 + c.model.template.template_pair_stack.tri_attn_first = False + c.model.template.template_pointwise_attention.enabled = False + c.model.heads.pae.enabled = True + # we forget to enable it in our training, so disable it here + c.model.heads.pae.disable_enhance_head = True + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.structure_module.trans_scale_factor = 20 + c.loss.pae.weight = 0.1 + c.model.input_embedder.tf_dim = 21 + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.chain_centre_mass.weight = 1.0 + return c + + if name == 'model_1': + pass + elif name == 'model_1_ft': + recursive_set(c, 'max_extra_msa', 5120) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == 'model_1_af2': + recursive_set(c, 'max_extra_msa', 5120) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + elif name == 'model_2': + pass + elif name == 'model_init': + pass + elif name == 'model_init_af2': + c.globals.alphafold_original_mode = True + pass + elif name == 'model_2_ft': + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == 'model_2_af2': + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + elif name == 'model_2_v2': + c = model_2_v2(c) + elif name == 'model_2_v2_ft': + c = model_2_v2(c) + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == 'model_3_af2' or name == 'model_4_af2': + recursive_set(c, 'max_extra_msa', 5120) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + c.model.template.enabled = False + c.model.template.embed_angles = False + recursive_set(c, 'use_templates', False) + recursive_set(c, 'use_template_torsion_angles', False) + elif name == 'model_5_af2': + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + c.model.template.enabled = False + c.model.template.embed_angles = False + recursive_set(c, 'use_templates', False) + recursive_set(c, 'use_template_torsion_angles', False) + elif name == 'multimer': + c = multimer(c) + elif name == 'multimer_ft': + c = multimer(c) + recursive_set(c, 'max_extra_msa', 1152) + recursive_set(c, 'max_msa_clusters', 256) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.5 + elif name == 'multimer_af2': + recursive_set(c, 'max_extra_msa', 1152) + recursive_set(c, 'max_msa_clusters', 256) + recursive_set(c, 'is_multimer', True) + recursive_set(c, 'v2_feature', True) + recursive_set(c, 'gumbel_sample', True) + c.model.template.template_angle_embedder.d_in = 34 + c.model.template.template_pair_stack.tri_attn_first = False + c.model.template.template_pointwise_attention.enabled = False + c.model.heads.pae.enabled = True + c.model.heads.experimentally_resolved.enabled = True + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.structure_module.trans_scale_factor = 20 + c.loss.pae.weight = 0.1 + c.loss.violation.weight = 0.5 + c.loss.experimentally_resolved.weight = 0.01 + c.model.input_embedder.tf_dim = 21 + c.globals.alphafold_original_mode = True + c.data.train.crop_size = 384 + c.loss.repr_norm.weight = 0 + c.loss.chain_centre_mass.weight = 1.0 + recursive_set(c, 'outer_product_mean_first', True) + else: + raise ValueError(f'invalid --model-name: {name}.') + if train: + c.globals.chunk_size = None + recursive_set(c, 'inf', 3e4) + recursive_set(c, 'eps', 1e-5, 'loss') + return c diff --git a/modelscope/models/science/unifold/data/__init__.py b/modelscope/models/science/unifold/data/__init__.py new file mode 100644 index 00000000..9821d212 --- /dev/null +++ b/modelscope/models/science/unifold/data/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data pipeline for model features.""" diff --git a/modelscope/models/science/unifold/data/data_ops.py b/modelscope/models/science/unifold/data/data_ops.py new file mode 100644 index 00000000..637aa0cd --- /dev/null +++ b/modelscope/models/science/unifold/data/data_ops.py @@ -0,0 +1,1397 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import itertools +from functools import reduce, wraps +from operator import add +from typing import List, MutableMapping, Optional + +import numpy as np +import torch +from unicore.data import data_utils +from unicore.utils import batched_gather, one_hot, tensor_tree_map, tree_map + +from modelscope.models.science.unifold.config import (N_EXTRA_MSA, N_MSA, + N_RES, N_TPL) +from modelscope.models.science.unifold.data import residue_constants as rc +from modelscope.models.science.unifold.modules.frame import Frame, Rotation + +NumpyDict = MutableMapping[str, np.ndarray] +TorchDict = MutableMapping[str, np.ndarray] + +protein: TorchDict + +MSA_FEATURE_NAMES = [ + 'msa', + 'deletion_matrix', + 'msa_mask', + 'msa_row_mask', + 'bert_mask', + 'true_msa', + 'msa_chains', +] + + +def cast_to_64bit_ints(protein): + # We keep all ints as int64 + for k, v in protein.items(): + if k.endswith('_mask'): + protein[k] = v.type(torch.float32) + elif v.dtype in (torch.int32, torch.uint8, torch.int8): + protein[k] = v.type(torch.int64) + + return protein + + +def make_seq_mask(protein): + protein['seq_mask'] = torch.ones( + protein['aatype'].shape, dtype=torch.float32) + return protein + + +def make_template_mask(protein): + protein['template_mask'] = torch.ones( + protein['template_aatype'].shape[0], dtype=torch.float32) + return protein + + +def curry1(f): + """Supply all arguments but the first.""" + + @wraps(f) + def fc(*args, **kwargs): + return lambda x: f(x, *args, **kwargs) + + return fc + + +def correct_msa_restypes(protein): + """Correct MSA restype to have the same order as rc.""" + protein['msa'] = protein['msa'].long() + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = ( + torch.tensor(new_order_list, dtype=torch.int8).unsqueeze(-1).expand( + -1, protein['msa'].shape[1])) + protein['msa'] = torch.gather(new_order, 0, protein['msa']).long() + + return protein + + +def squeeze_features(protein): + """Remove singleton and repeated dimensions in protein features.""" + if len(protein['aatype'].shape) == 2: + protein['aatype'] = torch.argmax(protein['aatype'], dim=-1) + if 'resolution' in protein and len(protein['resolution'].shape) == 1: + # use tensor for resolution + protein['resolution'] = protein['resolution'][0] + for k in [ + 'domain_name', + 'msa', + 'num_alignments', + 'seq_length', + 'sequence', + 'superfamily', + 'deletion_matrix', + 'between_segment_residues', + 'residue_index', + 'template_all_atom_mask', + ]: + if k in protein and len(protein[k].shape): + final_dim = protein[k].shape[-1] + if isinstance(final_dim, int) and final_dim == 1: + if torch.is_tensor(protein[k]): + protein[k] = torch.squeeze(protein[k], dim=-1) + else: + protein[k] = np.squeeze(protein[k], axis=-1) + + for k in ['seq_length', 'num_alignments']: + if k in protein and len(protein[k].shape): + protein[k] = protein[k][0] + + return protein + + +@curry1 +def randomly_replace_msa_with_unknown(protein, replace_proportion): + """Replace a portion of the MSA with 'X'.""" + if replace_proportion > 0.0: + msa_mask = np.random.rand(protein['msa'].shape) < replace_proportion + x_idx = 20 + gap_idx = 21 + msa_mask = torch.logical_and(msa_mask, protein['msa'] != gap_idx) + protein['msa'] = torch.where(msa_mask, + torch.ones_like(protein['msa']) * x_idx, + protein['msa']) + aatype_mask = np.random.rand( + protein['aatype'].shape) < replace_proportion + + protein['aatype'] = torch.where( + aatype_mask, + torch.ones_like(protein['aatype']) * x_idx, + protein['aatype'], + ) + return protein + + +def gumbel_noise(shape): + """Generate Gumbel Noise of given Shape. + This generates samples from Gumbel(0, 1). + Args: + shape: Shape of noise to return. + Returns: + Gumbel noise of given shape. + """ + epsilon = 1e-6 + uniform_noise = torch.from_numpy(np.random.uniform(0, 1, shape)) + gumbel = -torch.log(-torch.log(uniform_noise + epsilon) + epsilon) + return gumbel + + +def gumbel_max_sample(logits): + """Samples from a probability distribution given by 'logits'. + This uses Gumbel-max trick to implement the sampling in an efficient manner. + Args: + logits: Logarithm of probabilities to sample from, probabilities can be + unnormalized. + Returns: + Sample from logprobs in one-hot form. + """ + z = gumbel_noise(logits.shape) + return torch.argmax(logits + z, dim=-1) + + +def gumbel_argsort_sample_idx(logits): + """Samples with replacement from a distribution given by 'logits'. + This uses Gumbel trick to implement the sampling an efficient manner. For a + distribution over k items this samples k times without replacement, so this + is effectively sampling a random permutation with probabilities over the + permutations derived from the logprobs. + Args: + logits: Logarithm of probabilities to sample from, probabilities can be + unnormalized. + Returns: + Sample from logprobs in index + """ + z = gumbel_noise(logits.shape) + return torch.argsort(logits + z, dim=-1, descending=True) + + +def uniform_permutation(num_seq): + shuffled = torch.from_numpy(np.random.permutation(num_seq - 1) + 1) + return torch.cat((torch.tensor([0]), shuffled), dim=0) + + +def gumbel_permutation(msa_mask, msa_chains=None): + has_msa = torch.sum(msa_mask.long(), dim=-1) > 0 + # default logits is zero + logits = torch.zeros_like(has_msa, dtype=torch.float32) + logits[~has_msa] = -1e6 + # one sample only + assert len(logits.shape) == 1 + # skip first row + logits = logits[1:] + has_msa = has_msa[1:] + if logits.shape[0] == 0: + return torch.tensor([0]) + if msa_chains is not None: + # skip first row + msa_chains = msa_chains[1:].reshape(-1) + msa_chains[~has_msa] = 0 + keys, counts = np.unique(msa_chains, return_counts=True) + num_has_msa = has_msa.sum() + num_pair = (msa_chains == 1).sum() + num_unpair = num_has_msa - num_pair + num_chains = (keys > 1).sum() + logits[has_msa] = 1.0 / (num_has_msa + 1e-6) + logits[~has_msa] = 0 + for k in keys: + if k > 1: + cur_mask = msa_chains == k + cur_cnt = cur_mask.sum() + if cur_cnt > 0: + logits[cur_mask] *= num_unpair / (num_chains * cur_cnt) + logits = torch.log(logits + 1e-6) + shuffled = gumbel_argsort_sample_idx(logits) + 1 + return torch.cat((torch.tensor([0]), shuffled), dim=0) + + +@curry1 +def sample_msa(protein, + max_seq, + keep_extra, + gumbel_sample=False, + biased_msa_by_chain=False): + """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" + num_seq = protein['msa'].shape[0] + num_sel = min(max_seq, num_seq) + if not gumbel_sample: + index_order = uniform_permutation(num_seq) + else: + msa_chains = ( + protein['msa_chains'] if + (biased_msa_by_chain and 'msa_chains' in protein) else None) + index_order = gumbel_permutation(protein['msa_mask'], msa_chains) + num_sel = min(max_seq, num_seq) + sel_seq, not_sel_seq = torch.split(index_order, + [num_sel, num_seq - num_sel]) + + for k in MSA_FEATURE_NAMES: + if k in protein: + if keep_extra: + protein['extra_' + k] = torch.index_select( + protein[k], 0, not_sel_seq) + protein[k] = torch.index_select(protein[k], 0, sel_seq) + + return protein + + +@curry1 +def sample_msa_distillation(protein, max_seq): + if 'is_distillation' in protein and protein['is_distillation'] == 1: + protein = sample_msa(max_seq, keep_extra=False)(protein) + return protein + + +@curry1 +def random_delete_msa(protein, config): + # to reduce the cost of msa features + num_seq = protein['msa'].shape[0] + seq_len = protein['msa'].shape[1] + max_seq = config.max_msa_entry // seq_len + if num_seq > max_seq: + keep_index = ( + torch.from_numpy( + np.random.choice(num_seq - 1, max_seq - 1, + replace=False)).long() + 1) + keep_index = torch.sort(keep_index)[0] + keep_index = torch.cat((torch.tensor([0]), keep_index), dim=0) + for k in MSA_FEATURE_NAMES: + if k in protein: + protein[k] = torch.index_select(protein[k], 0, keep_index) + return protein + + +@curry1 +def crop_extra_msa(protein, max_extra_msa): + num_seq = protein['extra_msa'].shape[0] + num_sel = min(max_extra_msa, num_seq) + select_indices = torch.from_numpy(np.random.permutation(num_seq)[:num_sel]) + for k in MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + protein['extra_' + k] = torch.index_select(protein['extra_' + k], + 0, select_indices) + + return protein + + +def delete_extra_msa(protein): + for k in MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + del protein['extra_' + k] + return protein + + +@curry1 +def block_delete_msa(protein, config): + if 'is_distillation' in protein and protein['is_distillation'] == 1: + return protein + num_seq = protein['msa'].shape[0] + if num_seq <= config.min_num_msa: + return protein + block_num_seq = torch.floor( + torch.tensor(num_seq, dtype=torch.float32) + * config.msa_fraction_per_block).to(torch.int32) + + if config.randomize_num_blocks: + nb = np.random.randint(0, config.num_blocks + 1) + else: + nb = config.num_blocks + + del_block_starts = torch.from_numpy(np.random.randint(0, num_seq, [nb])) + del_blocks = del_block_starts[:, None] + torch.arange(0, block_num_seq) + del_blocks = torch.clip(del_blocks, 0, num_seq - 1) + del_indices = torch.unique(del_blocks.view(-1)) + # add zeros to ensure cnt_zero > 1 + combined = torch.hstack((torch.arange(0, num_seq)[None], del_indices[None], + torch.zeros(2)[None])).long() + uniques, counts = combined.unique(return_counts=True) + difference = uniques[counts == 1] + # intersection = uniques[counts > 1] + keep_indices = difference.view(-1) + keep_indices = torch.hstack( + [torch.zeros(1).long()[None], keep_indices[None]]).view(-1) + assert int(keep_indices[0]) == 0 + for k in MSA_FEATURE_NAMES: + if k in protein: + protein[k] = torch.index_select(protein[k], 0, index=keep_indices) + return protein + + +@curry1 +def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0): + weights = torch.cat( + [torch.ones(21), gap_agreement_weight * torch.ones(1), + torch.zeros(1)], + 0, + ) + + msa_one_hot = one_hot(protein['msa'], 23) + sample_one_hot = protein['msa_mask'][:, :, None] * msa_one_hot + extra_msa_one_hot = one_hot(protein['extra_msa'], 23) + extra_one_hot = protein['extra_msa_mask'][:, :, None] * extra_msa_one_hot + + num_seq, num_res, _ = sample_one_hot.shape + extra_num_seq, _, _ = extra_one_hot.shape + + # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) + # in an optimized fashion to avoid possible memory or computation blowup. + a = extra_one_hot.view(extra_num_seq, num_res * 23) + b = (sample_one_hot * weights).view(num_seq, num_res * 23).transpose(0, 1) + agreement = a @ b + # Assign each sequence in the extra sequences to the closest MSA sample + protein['extra_cluster_assignment'] = torch.argmax(agreement, dim=1).long() + + return protein + + +def unsorted_segment_sum(data, segment_ids, num_segments): + assert len( + segment_ids.shape) == 1 and segment_ids.shape[0] == data.shape[0] + segment_ids = segment_ids.view(segment_ids.shape[0], + *((1, ) * len(data.shape[1:]))) + segment_ids = segment_ids.expand(data.shape) + shape = [num_segments] + list(data.shape[1:]) + tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float()) + tensor = tensor.type(data.dtype) + return tensor + + +def summarize_clusters(protein): + """Produce profile and deletion_matrix_mean within each cluster.""" + num_seq = protein['msa'].shape[0] + + def csum(x): + return unsorted_segment_sum(x, protein['extra_cluster_assignment'], + num_seq) + + mask = protein['extra_msa_mask'] + mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center + + # TODO: this line is very slow + msa_sum = csum(mask[:, :, None] * one_hot(protein['extra_msa'], 23)) + msa_sum += one_hot(protein['msa'], 23) # Original sequence + protein['cluster_profile'] = msa_sum / mask_counts[:, :, None] + del msa_sum + + del_sum = csum(mask * protein['extra_deletion_matrix']) + del_sum += protein['deletion_matrix'] # Original sequence + protein['cluster_deletion_mean'] = del_sum / mask_counts + del del_sum + + return protein + + +@curry1 +def nearest_neighbor_clusters_v2(batch, gap_agreement_weight=0.0): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + + # Determine how much weight we assign to each agreement. In theory, we could + # use a full blosum matrix here, but right now let's just down-weight gap + # agreement because it could be spurious. + # Never put weight on agreeing on BERT mask. + + weights = torch.tensor( + [1.0] * 21 + [gap_agreement_weight] + [0.0], dtype=torch.float32) + + msa_mask = batch['msa_mask'] + extra_mask = batch['extra_msa_mask'] + msa_one_hot = one_hot(batch['msa'], 23) + extra_one_hot = one_hot(batch['extra_msa'], 23) + + msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot + extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot + + t1 = weights * msa_one_hot_masked + t1 = t1.view(t1.shape[0], t1.shape[1] * t1.shape[2]) + t2 = extra_one_hot_masked.view( + extra_one_hot.shape[0], + extra_one_hot.shape[1] * extra_one_hot.shape[2]) + agreement = t1 @ t2.T + + cluster_assignment = torch.nn.functional.softmax(1e3 * agreement, dim=0) + cluster_assignment *= torch.einsum('mr, nr->mn', msa_mask, extra_mask) + + cluster_count = torch.sum(cluster_assignment, dim=-1) + cluster_count += 1.0 # We always include the sequence itself. + + msa_sum = torch.einsum('nm, mrc->nrc', cluster_assignment, + extra_one_hot_masked) + msa_sum += msa_one_hot_masked + + cluster_profile = msa_sum / cluster_count[:, None, None] + + deletion_matrix = batch['deletion_matrix'] + extra_deletion_matrix = batch['extra_deletion_matrix'] + + del_sum = torch.einsum('nm, mc->nc', cluster_assignment, + extra_mask * extra_deletion_matrix) + del_sum += deletion_matrix # Original sequence. + cluster_deletion_mean = del_sum / cluster_count[:, None] + batch['cluster_profile'] = cluster_profile + batch['cluster_deletion_mean'] = cluster_deletion_mean + + return batch + + +def make_msa_mask(protein): + """Mask features are all ones, but will later be zero-padded.""" + if 'msa_mask' not in protein: + protein['msa_mask'] = torch.ones( + protein['msa'].shape, dtype=torch.float32) + protein['msa_row_mask'] = torch.ones((protein['msa'].shape[0]), + dtype=torch.float32) + return protein + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): + """Create pseudo beta features.""" + if aatype.shape[0] > 0: + is_gly = torch.eq(aatype, rc.restype_order['G']) + ca_idx = rc.atom_order['CA'] + cb_idx = rc.atom_order['CB'] + pseudo_beta = torch.where( + torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + else: + pseudo_beta = all_atom_positions.new_zeros(*aatype.shape, 3) + if all_atom_mask is not None: + if aatype.shape[0] > 0: + pseudo_beta_mask = torch.where(is_gly, all_atom_mask[..., ca_idx], + all_atom_mask[..., cb_idx]) + else: + pseudo_beta_mask = torch.zeros_like(aatype).float() + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +@curry1 +def make_pseudo_beta(protein, prefix=''): + """Create pseudo-beta (alpha for glycine) position and mask.""" + assert prefix in ['', 'template_'] + ( + protein[prefix + 'pseudo_beta'], + protein[prefix + 'pseudo_beta_mask'], + ) = pseudo_beta_fn( + protein['template_aatype' if prefix else 'aatype'], + protein[prefix + 'all_atom_positions'], + protein['template_all_atom_mask' if prefix else 'all_atom_mask'], + ) + return protein + + +@curry1 +def add_constant_field(protein, key, value): + protein[key] = torch.tensor(value) + return protein + + +def shaped_categorical(probs, epsilon=1e-10): + ds = probs.shape + num_classes = ds[-1] + probs = torch.reshape(probs + epsilon, [-1, num_classes]) + gen = torch.Generator() + gen.manual_seed(np.random.randint(65535)) + counts = torch.multinomial(probs, 1, generator=gen) + return torch.reshape(counts, ds[:-1]) + + +def make_hhblits_profile(protein): + """Compute the HHblits MSA profile if not already present.""" + if 'hhblits_profile' in protein: + return protein + + # Compute the profile for every residue (over all MSA sequences). + msa_one_hot = one_hot(protein['msa'], 22) + + protein['hhblits_profile'] = torch.mean(msa_one_hot, dim=0) + return protein + + +def make_msa_profile(batch): + """Compute the MSA profile.""" + # Compute the profile for every residue (over all MSA sequences). + oh = one_hot(batch['msa'], 22) + mask = batch['msa_mask'][:, :, None] + oh *= mask + return oh.sum(dim=0) / (mask.sum(dim=0) + 1e-10) + + +def make_hhblits_profile_v2(protein): + """Compute the HHblits MSA profile if not already present.""" + if 'hhblits_profile' in protein: + return protein + protein['hhblits_profile'] = make_msa_profile(protein) + return protein + + +def share_mask_by_entity(mask_position, protein): # new in unifold + if 'num_sym' not in protein: + return mask_position + entity_id = protein['entity_id'] + sym_id = protein['sym_id'] + num_sym = protein['num_sym'] + unique_entity_ids = entity_id.unique() + first_sym_mask = sym_id == 1 + for cur_entity_id in unique_entity_ids: + cur_entity_mask = entity_id == cur_entity_id + cur_num_sym = int(num_sym[cur_entity_mask][0]) + if cur_num_sym > 1: + cur_sym_mask = first_sym_mask & cur_entity_mask + cur_sym_bert_mask = mask_position[:, cur_sym_mask] + mask_position[:, cur_entity_mask] = cur_sym_bert_mask.repeat( + 1, cur_num_sym) + return mask_position + + +@curry1 +def make_masked_msa(protein, + config, + replace_fraction, + gumbel_sample=False, + share_mask=False): + """Create data for BERT on raw MSA.""" + # Add a random amino acid uniformly. + random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32) + + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * protein['hhblits_profile'] + + config.same_prob * one_hot(protein['msa'], 22)) + + # Put all remaining probability on [MASK] which is a new column + pad_shapes = list( + reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])) + pad_shapes[1] = 1 + mask_prob = 1.0 - config.profile_prob - config.same_prob - config.uniform_prob + assert mask_prob >= 0.0 + categorical_probs = torch.nn.functional.pad( + categorical_probs, pad_shapes, value=mask_prob) + sh = protein['msa'].shape + mask_position = torch.from_numpy(np.random.rand(*sh) < replace_fraction) + mask_position &= protein['msa_mask'].bool() + + if 'bert_mask' in protein: + mask_position &= protein['bert_mask'].bool() + + if share_mask: + mask_position = share_mask_by_entity(mask_position, protein) + if gumbel_sample: + logits = torch.log(categorical_probs + 1e-6) + bert_msa = gumbel_max_sample(logits) + else: + bert_msa = shaped_categorical(categorical_probs) + bert_msa = torch.where(mask_position, bert_msa, protein['msa']) + bert_msa *= protein['msa_mask'].long() + + # Mix real and masked MSA + protein['bert_mask'] = mask_position.to(torch.float32) + protein['true_msa'] = protein['msa'] + protein['msa'] = bert_msa + + return protein + + +@curry1 +def make_fixed_size( + protein, + shape_schema, + msa_cluster_size, + extra_msa_size, + num_res=0, + num_templates=0, +): + """Guess at the MSA and sequence dimension to make fixed size.""" + + def get_pad_size(cur_size, multiplier=4): + return max(multiplier, + ((cur_size + multiplier - 1) // multiplier) * multiplier) + + if num_res is not None: + input_num_res = ( + protein['aatype'].shape[0] + if 'aatype' in protein else protein['msa_mask'].shape[1]) + if input_num_res != num_res: + num_res = get_pad_size(input_num_res, 4) + if 'extra_msa_mask' in protein: + input_extra_msa_size = protein['extra_msa_mask'].shape[0] + if input_extra_msa_size != extra_msa_size: + extra_msa_size = get_pad_size(input_extra_msa_size, 8) + pad_size_map = { + N_RES: num_res, + N_MSA: msa_cluster_size, + N_EXTRA_MSA: extra_msa_size, + N_TPL: num_templates, + } + + for k, v in protein.items(): + # Don't transfer this to the accelerator. + if k == 'extra_cluster_assignment': + continue + shape = list(v.shape) + schema = shape_schema[k] + msg = 'Rank mismatch between shape and shape schema for' + assert len(shape) == len(schema), f'{msg} {k}: {shape} vs {schema}' + pad_size = [ + pad_size_map.get(s2, None) or s1 + for (s1, s2) in zip(shape, schema) + ] + + padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] + padding.reverse() + padding = list(itertools.chain(*padding)) + if padding: + protein[k] = torch.nn.functional.pad(v, padding) + protein[k] = torch.reshape(protein[k], pad_size) + + return protein + + +def make_target_feat(protein): + """Create and concatenate MSA features.""" + protein['aatype'] = protein['aatype'].long() + + if 'between_segment_residues' in protein: + has_break = torch.clip( + protein['between_segment_residues'].to(torch.float32), 0, 1) + else: + has_break = torch.zeros_like(protein['aatype'], dtype=torch.float32) + if 'asym_len' in protein: + asym_len = protein['asym_len'] + entity_ends = torch.cumsum(asym_len, dim=-1)[:-1] + has_break[entity_ends] = 1.0 + has_break = has_break.float() + aatype_1hot = one_hot(protein['aatype'], 21) + target_feat = [ + torch.unsqueeze(has_break, dim=-1), + aatype_1hot, # Everyone gets the original sequence. + ] + protein['target_feat'] = torch.cat(target_feat, dim=-1) + return protein + + +def make_msa_feat(protein): + """Create and concatenate MSA features.""" + msa_1hot = one_hot(protein['msa'], 23) + has_deletion = torch.clip(protein['deletion_matrix'], 0.0, 1.0) + deletion_value = torch.atan( + protein['deletion_matrix'] / 3.0) * (2.0 / np.pi) + msa_feat = [ + msa_1hot, + torch.unsqueeze(has_deletion, dim=-1), + torch.unsqueeze(deletion_value, dim=-1), + ] + if 'cluster_profile' in protein: + deletion_mean_value = torch.atan( + protein['cluster_deletion_mean'] / 3.0) * (2.0 / np.pi) + msa_feat.extend([ + protein['cluster_profile'], + torch.unsqueeze(deletion_mean_value, dim=-1), + ]) + + if 'extra_deletion_matrix' in protein: + protein['extra_msa_has_deletion'] = torch.clip( + protein['extra_deletion_matrix'], 0.0, 1.0) + protein['extra_msa_deletion_value'] = torch.atan( + protein['extra_deletion_matrix'] / 3.0) * (2.0 / np.pi) + + protein['msa_feat'] = torch.cat(msa_feat, dim=-1) + return protein + + +def make_msa_feat_v2(batch): + """Create and concatenate MSA features.""" + msa_1hot = one_hot(batch['msa'], 23) + deletion_matrix = batch['deletion_matrix'] + has_deletion = torch.clip(deletion_matrix, 0.0, 1.0)[..., None] + deletion_value = (torch.atan(deletion_matrix / 3.0) * (2.0 / np.pi))[..., + None] + + deletion_mean_value = ( + torch.arctan(batch['cluster_deletion_mean'] / 3.0) * # noqa W504 + (2.0 / np.pi))[..., None] + + msa_feat = [ + msa_1hot, + has_deletion, + deletion_value, + batch['cluster_profile'], + deletion_mean_value, + ] + batch['msa_feat'] = torch.concat(msa_feat, dim=-1) + return batch + + +@curry1 +def make_extra_msa_feat(batch, num_extra_msa): + # 23 = 20 amino acids + 'X' for unknown + gap + bert mask + extra_msa = batch['extra_msa'][:num_extra_msa] + deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa] + has_deletion = torch.clip(deletion_matrix, 0.0, 1.0) + deletion_value = torch.atan(deletion_matrix / 3.0) * (2.0 / np.pi) + extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa] + batch['extra_msa'] = extra_msa + batch['extra_msa_mask'] = extra_msa_mask + batch['extra_msa_has_deletion'] = has_deletion + batch['extra_msa_deletion_value'] = deletion_value + return batch + + +@curry1 +def select_feat(protein, feature_list): + return {k: v for k, v in protein.items() if k in feature_list} + + +def make_atom14_masks(protein): + """Construct denser atom positions (14 dimensions instead of 37).""" + + if 'atom14_atom_exists' in protein: # lazy move + return protein + + restype_atom14_to_atom37 = torch.tensor( + rc.restype_atom14_to_atom37, + dtype=torch.int64, + device=protein['aatype'].device, + ) + restype_atom37_to_atom14 = torch.tensor( + rc.restype_atom37_to_atom14, + dtype=torch.int64, + device=protein['aatype'].device, + ) + restype_atom14_mask = torch.tensor( + rc.restype_atom14_mask, + dtype=torch.float32, + device=protein['aatype'].device, + ) + restype_atom37_mask = torch.tensor( + rc.restype_atom37_mask, + dtype=torch.float32, + device=protein['aatype'].device) + + protein_aatype = protein['aatype'].long() + protein['residx_atom14_to_atom37'] = restype_atom14_to_atom37[ + protein_aatype].long() + protein['residx_atom37_to_atom14'] = restype_atom37_to_atom14[ + protein_aatype].long() + protein['atom14_atom_exists'] = restype_atom14_mask[protein_aatype] + protein['atom37_atom_exists'] = restype_atom37_mask[protein_aatype] + + return protein + + +def make_atom14_masks_np(batch): + batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) + out = make_atom14_masks(batch) + out = tensor_tree_map(lambda t: np.array(t), out) + return out + + +def make_atom14_positions(protein): + """Constructs denser atom positions (14 dimensions instead of 37).""" + protein['aatype'] = protein['aatype'].long() + protein['all_atom_mask'] = protein['all_atom_mask'].float() + protein['all_atom_positions'] = protein['all_atom_positions'].float() + residx_atom14_mask = protein['atom14_atom_exists'] + residx_atom14_to_atom37 = protein['residx_atom14_to_atom37'] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * batched_gather( + protein['all_atom_mask'], + residx_atom14_to_atom37, + dim=-1, + num_batch_dims=len(protein['all_atom_mask'].shape[:-1]), + ) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( + batched_gather( + protein['all_atom_positions'], + residx_atom14_to_atom37, + dim=-2, + num_batch_dims=len(protein['all_atom_positions'].shape[:-2]), + )) + + protein['atom14_atom_exists'] = residx_atom14_mask + protein['atom14_gt_exists'] = residx_atom14_gt_mask + protein['atom14_gt_positions'] = residx_atom14_gt_positions + + renaming_matrices = torch.tensor( + rc.renaming_matrices, + dtype=protein['all_atom_mask'].dtype, + device=protein['all_atom_mask'].device, + ) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[protein['aatype']] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = torch.einsum('...rac,...rab->...rbc', + residx_atom14_gt_positions, + renaming_transform) + protein['atom14_alt_gt_positions'] = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = torch.einsum('...ra,...rab->...rb', + residx_atom14_gt_mask, + renaming_transform) + protein['atom14_alt_gt_exists'] = alternative_gt_mask + + restype_atom14_is_ambiguous = torch.tensor( + rc.restype_atom14_is_ambiguous, + dtype=protein['all_atom_mask'].dtype, + device=protein['all_atom_mask'].device, + ) + # From this create an ambiguous_mask for the given sequence. + protein['atom14_atom_is_ambiguous'] = restype_atom14_is_ambiguous[ + protein['aatype']] + + return protein + + +def atom37_to_frames(protein, eps=1e-8): + # TODO: extract common part and put them into residue constants. + aatype = protein['aatype'] + all_atom_positions = protein['all_atom_positions'] + all_atom_mask = protein['all_atom_mask'] + + batch_dims = len(aatype.shape[:-1]) + + restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) + restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + for restype, restype_letter in enumerate(rc.restypes): + resname = rc.restype_1to3[restype_letter] + for chi_idx in range(4): + if rc.chi_angles_mask[restype][chi_idx]: + names = rc.chi_angles_atoms[resname][chi_idx] + restype_rigidgroup_base_atom_names[restype, + chi_idx + 4, :] = names[1:] + + restype_rigidgroup_mask = all_atom_mask.new_zeros( + (*aatype.shape[:-1], 21, 8), ) + restype_rigidgroup_mask[..., 0] = 1 + restype_rigidgroup_mask[..., 3] = 1 + restype_rigidgroup_mask[..., :20, + 4:] = all_atom_mask.new_tensor(rc.chi_angles_mask) + + lookuptable = rc.atom_order.copy() + lookuptable[''] = 0 + lookup = np.vectorize(lambda x: lookuptable[x]) + restype_rigidgroup_base_atom37_idx = lookup( + restype_rigidgroup_base_atom_names, ) + restype_rigidgroup_base_atom37_idx = aatype.new_tensor( + restype_rigidgroup_base_atom37_idx, ) + restype_rigidgroup_base_atom37_idx = restype_rigidgroup_base_atom37_idx.view( + *((1, ) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape) + + residx_rigidgroup_base_atom37_idx = batched_gather( + restype_rigidgroup_base_atom37_idx, + aatype, + dim=-3, + num_batch_dims=batch_dims, + ) + + base_atom_pos = batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + dim=-2, + num_batch_dims=len(all_atom_positions.shape[:-2]), + ) + + gt_frames = Frame.from_3_points( + p_neg_x_axis=base_atom_pos[..., 0, :], + origin=base_atom_pos[..., 1, :], + p_xy_plane=base_atom_pos[..., 2, :], + eps=eps, + ) + + group_exists = batched_gather( + restype_rigidgroup_mask, + aatype, + dim=-2, + num_batch_dims=batch_dims, + ) + + gt_atoms_exist = batched_gather( + all_atom_mask, + residx_rigidgroup_base_atom37_idx, + dim=-1, + num_batch_dims=len(all_atom_mask.shape[:-1]), + ) + gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists + + rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device) + rots = torch.tile(rots, (*((1, ) * batch_dims), 8, 1, 1)) + rots[..., 0, 0, 0] = -1 + rots[..., 0, 2, 2] = -1 + rots = Rotation(mat=rots) + + gt_frames = gt_frames.compose(Frame(rots, None)) + + restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( + *((1, ) * batch_dims), 21, 8) + restype_rigidgroup_rots = torch.eye( + 3, dtype=all_atom_mask.dtype, device=aatype.device) + restype_rigidgroup_rots = torch.tile( + restype_rigidgroup_rots, + (*((1, ) * batch_dims), 21, 8, 1, 1), + ) + + for resname, _ in rc.residue_atom_renaming_swaps.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 + + residx_rigidgroup_is_ambiguous = batched_gather( + restype_rigidgroup_is_ambiguous, + aatype, + dim=-2, + num_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = batched_gather( + restype_rigidgroup_rots, + aatype, + dim=-4, + num_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = Rotation( + mat=residx_rigidgroup_ambiguity_rot) + alt_gt_frames = gt_frames.compose( + Frame(residx_rigidgroup_ambiguity_rot, None)) + + gt_frames_tensor = gt_frames.to_tensor_4x4() + alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() + + protein['rigidgroups_gt_frames'] = gt_frames_tensor + protein['rigidgroups_gt_exists'] = gt_exists + protein['rigidgroups_group_exists'] = group_exists + protein['rigidgroups_group_is_ambiguous'] = residx_rigidgroup_is_ambiguous + protein['rigidgroups_alt_gt_frames'] = alt_gt_frames_tensor + + return protein + + +@curry1 +def atom37_to_torsion_angles( + protein, + prefix='', +): + aatype = protein[prefix + 'aatype'] + all_atom_positions = protein[prefix + 'all_atom_positions'] + all_atom_mask = protein[prefix + 'all_atom_mask'] + if aatype.shape[-1] == 0: + base_shape = aatype.shape + protein[prefix + + 'torsion_angles_sin_cos'] = all_atom_positions.new_zeros( + *base_shape, 7, 2) + protein[prefix + + 'alt_torsion_angles_sin_cos'] = all_atom_positions.new_zeros( + *base_shape, 7, 2) + protein[prefix + 'torsion_angles_mask'] = all_atom_positions.new_zeros( + *base_shape, 7) + return protein + + aatype = torch.clamp(aatype, max=20) + + pad = all_atom_positions.new_zeros( + [*all_atom_positions.shape[:-3], 1, 37, 3]) + prev_all_atom_positions = torch.cat( + [pad, all_atom_positions[..., :-1, :, :]], dim=-3) + + pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37]) + prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) + + pre_omega_atom_pos = torch.cat( + [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]], + dim=-2, + ) + phi_atom_pos = torch.cat( + [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]], + dim=-2, + ) + psi_atom_pos = torch.cat( + [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]], + dim=-2, + ) + + pre_omega_mask = torch.prod( + prev_all_atom_mask[..., 1:3], dim=-1) * torch.prod( + all_atom_mask[..., :2], dim=-1) + phi_mask = prev_all_atom_mask[..., 2] * torch.prod( + all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) + psi_mask = ( + torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) + * all_atom_mask[..., 4]) + + chi_atom_indices = torch.as_tensor( + rc.chi_atom_indices, device=aatype.device) + + atom_indices = chi_atom_indices[..., aatype, :, :] + chis_atom_pos = batched_gather(all_atom_positions, atom_indices, -2, + len(atom_indices.shape[:-2])) + + chi_angles_mask = list(rc.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask) + + chis_mask = chi_angles_mask[aatype, :] + + chi_angle_atoms_mask = batched_gather( + all_atom_mask, + atom_indices, + dim=-1, + num_batch_dims=len(atom_indices.shape[:-2]), + ) + chi_angle_atoms_mask = torch.prod( + chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype) + chis_mask = chis_mask * chi_angle_atoms_mask + + torsions_atom_pos = torch.cat( + [ + pre_omega_atom_pos[..., None, :, :], + phi_atom_pos[..., None, :, :], + psi_atom_pos[..., None, :, :], + chis_atom_pos, + ], + dim=-3, + ) + + torsion_angles_mask = torch.cat( + [ + pre_omega_mask[..., None], + phi_mask[..., None], + psi_mask[..., None], + chis_mask, + ], + dim=-1, + ) + + torsion_frames = Frame.from_3_points( + torsions_atom_pos[..., 1, :], + torsions_atom_pos[..., 2, :], + torsions_atom_pos[..., 0, :], + eps=1e-8, + ) + + fourth_atom_rel_pos = torsion_frames.invert().apply( + torsions_atom_pos[..., 3, :]) + + torsion_angles_sin_cos = torch.stack( + [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1) + + denom = torch.sqrt( + torch.sum( + torch.square(torsion_angles_sin_cos), + dim=-1, + dtype=torsion_angles_sin_cos.dtype, + keepdims=True, + ) + 1e-8) + torsion_angles_sin_cos = torsion_angles_sin_cos / denom + + torsion_angles_sin_cos = ( + torsion_angles_sin_cos + * all_atom_mask.new_tensor([1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0], )[ + ((None, ) * len(torsion_angles_sin_cos.shape[:-2])) + + (slice(None), None)]) + + chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( + rc.chi_pi_periodic, )[aatype, ...] + + mirror_torsion_angles = torch.cat( + [ + all_atom_mask.new_ones(*aatype.shape, 3), + 1.0 - 2.0 * chi_is_ambiguous, + ], + dim=-1, + ) + + alt_torsion_angles_sin_cos = ( + torsion_angles_sin_cos * mirror_torsion_angles[..., None]) + + if prefix == '': + # consistent to uni-fold. use [1, 0] placeholder + placeholder_torsions = torch.stack( + [ + torch.ones(torsion_angles_sin_cos.shape[:-1]), + torch.zeros(torsion_angles_sin_cos.shape[:-1]), + ], + dim=-1, + ) + torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ + ..., + None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ + ..., + None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + + protein[prefix + 'torsion_angles_sin_cos'] = torsion_angles_sin_cos + protein[prefix + 'alt_torsion_angles_sin_cos'] = alt_torsion_angles_sin_cos + protein[prefix + 'torsion_angles_mask'] = torsion_angles_mask + + return protein + + +def get_backbone_frames(protein): + protein['true_frame_tensor'] = protein['rigidgroups_gt_frames'][..., + 0, :, :] + protein['frame_mask'] = protein['rigidgroups_gt_exists'][..., 0] + + return protein + + +def get_chi_angles(protein): + dtype = protein['all_atom_mask'].dtype + protein['chi_angles_sin_cos'] = ( + protein['torsion_angles_sin_cos'][..., 3:, :]).to(dtype) + protein['chi_mask'] = protein['torsion_angles_mask'][..., 3:].to(dtype) + + return protein + + +@curry1 +def crop_templates( + protein, + max_templates, + subsample_templates=False, +): + if 'template_mask' in protein: + num_templates = protein['template_mask'].shape[-1] + else: + num_templates = 0 + + # don't sample when there are no templates + if num_templates > 0: + if subsample_templates: + # af2's sampling, min(4, uniform[0, n]) + max_templates = min(max_templates, + np.random.randint(0, num_templates + 1)) + template_idx = torch.tensor( + np.random.choice(num_templates, max_templates, replace=False), + dtype=torch.int64, + ) + else: + # use top templates + template_idx = torch.arange( + min(num_templates, max_templates), dtype=torch.int64) + for k, v in protein.items(): + if k.startswith('template'): + try: + v = v[template_idx] + except Exception as ex: + print(ex.__class__, ex) + print('num_templates', num_templates) + print(k, v.shape) + print('protein:', protein) + print( + 'protein_shape:', + { + k: v.shape + for k, v in protein.items() if 'shape' in dir(v) + }, + ) + protein[k] = v + + return protein + + +@curry1 +def crop_to_size_single(protein, crop_size, shape_schema, seed): + """crop to size.""" + num_res = ( + protein['aatype'].shape[0] + if 'aatype' in protein else protein['msa_mask'].shape[1]) + crop_idx = get_single_crop_idx(num_res, crop_size, seed) + protein = apply_crop_idx(protein, shape_schema, crop_idx) + return protein + + +@curry1 +def crop_to_size_multimer(protein, crop_size, shape_schema, seed, + spatial_crop_prob, ca_ca_threshold): + """crop to size.""" + with data_utils.numpy_seed(seed, key='multimer_crop'): + use_spatial_crop = np.random.rand() < spatial_crop_prob + is_distillation = 'is_distillation' in protein and protein[ + 'is_distillation'] == 1 + if is_distillation: + return crop_to_size_single( + crop_size=crop_size, shape_schema=shape_schema, seed=seed)( + protein) + elif use_spatial_crop: + crop_idx = get_spatial_crop_idx(protein, crop_size, seed, + ca_ca_threshold) + else: + crop_idx = get_contiguous_crop_idx(protein, crop_size, seed) + return apply_crop_idx(protein, shape_schema, crop_idx) + + +def get_single_crop_idx(num_res: NumpyDict, crop_size: int, + random_seed: Optional[int]) -> torch.Tensor: + + if num_res < crop_size: + return torch.arange(num_res) + with data_utils.numpy_seed(random_seed): + crop_start = int(np.random.randint(0, num_res - crop_size + 1)) + return torch.arange(crop_start, crop_start + crop_size) + + +def get_crop_sizes_each_chain( + asym_len: torch.Tensor, + crop_size: int, + random_seed: Optional[int] = None, + use_multinomial: bool = False, +) -> torch.Tensor: + """get crop sizes for contiguous crop""" + if not use_multinomial: + with data_utils.numpy_seed( + random_seed, key='multimer_contiguous_perm'): + shuffle_idx = np.random.permutation(len(asym_len)) + num_left = asym_len.sum() + num_budget = torch.tensor(crop_size) + crop_sizes = [0 for _ in asym_len] + for j, idx in enumerate(shuffle_idx): + this_len = asym_len[idx] + num_left -= this_len + # num res at most we can keep in this ent + max_size = min(num_budget, this_len) + # num res at least we shall keep in this ent + min_size = min(this_len, max(0, num_budget - num_left)) + with data_utils.numpy_seed( + random_seed, j, key='multimer_contiguous_crop_size'): + this_crop_size = int( + np.random.randint( + low=int(min_size), high=int(max_size) + 1)) + num_budget -= this_crop_size + crop_sizes[idx] = this_crop_size + crop_sizes = torch.tensor(crop_sizes) + else: # use multinomial + # TODO: better multimer + entity_probs = asym_len / torch.sum(asym_len) + crop_sizes = torch.from_numpy( + np.random.multinomial(crop_size, pvals=entity_probs)) + crop_sizes = torch.min(crop_sizes, asym_len) + return crop_sizes + + +def get_contiguous_crop_idx( + protein: NumpyDict, + crop_size: int, + random_seed: Optional[int] = None, + use_multinomial: bool = False, +) -> torch.Tensor: + + num_res = protein['aatype'].shape[0] + if num_res <= crop_size: + return torch.arange(num_res) + + assert 'asym_len' in protein + asym_len = protein['asym_len'] + + crop_sizes = get_crop_sizes_each_chain(asym_len, crop_size, random_seed, + use_multinomial) + crop_idxs = [] + asym_offset = torch.tensor(0, dtype=torch.int64) + with data_utils.numpy_seed( + random_seed, key='multimer_contiguous_crop_start_idx'): + for ll, csz in zip(asym_len, crop_sizes): + this_start = np.random.randint(0, int(ll - csz) + 1) + crop_idxs.append( + torch.arange(asym_offset + this_start, + asym_offset + this_start + csz)) + asym_offset += ll + + return torch.concat(crop_idxs) + + +def get_spatial_crop_idx( + protein: NumpyDict, + crop_size: int, + random_seed: int, + ca_ca_threshold: float, + inf: float = 3e4, +) -> List[int]: + + ca_idx = rc.atom_order['CA'] + ca_coords = protein['all_atom_positions'][..., ca_idx, :] + ca_mask = protein['all_atom_mask'][..., ca_idx].bool() + # if there are not enough atoms to construct interface, use contiguous crop + if (ca_mask.sum(dim=-1) <= 1).all(): + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + pair_mask = ca_mask[..., None] * ca_mask[..., None, :] + ca_distances = get_pairwise_distances(ca_coords) + + interface_candidates = get_interface_candidates(ca_distances, + protein['asym_id'], + pair_mask, ca_ca_threshold) + + if torch.any(interface_candidates): + with data_utils.numpy_seed(random_seed, key='multimer_spatial_crop'): + target_res = int(np.random.choice(interface_candidates)) + else: + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + to_target_distances = ca_distances[target_res] + # set inf to non-position residues + to_target_distances[~ca_mask] = inf + break_tie = ( + torch.arange( + 0, + to_target_distances.shape[-1], + device=to_target_distances.device).float() * 1e-3) + to_target_distances += break_tie + ret = torch.argsort(to_target_distances)[:crop_size] + return ret.sort().values + + +def get_pairwise_distances(coords: torch.Tensor) -> torch.Tensor: + coord_diff = coords.unsqueeze(-2) - coords.unsqueeze(-3) + return torch.sqrt(torch.sum(coord_diff**2, dim=-1)) + + +def get_interface_candidates( + ca_distances: torch.Tensor, + asym_id: torch.Tensor, + pair_mask: torch.Tensor, + ca_ca_threshold, +) -> torch.Tensor: + + in_same_asym = asym_id[..., None] == asym_id[..., None, :] + # set distance in the same entity to zero + ca_distances = ca_distances * (1.0 - in_same_asym.float()) * pair_mask + cnt_interfaces = torch.sum( + (ca_distances > 0) & (ca_distances < ca_ca_threshold), dim=-1) + interface_candidates = cnt_interfaces.nonzero(as_tuple=True)[0] + return interface_candidates + + +def apply_crop_idx(protein, shape_schema, crop_idx): + cropped_protein = {} + for k, v in protein.items(): + if k not in shape_schema: # skip items with unknown shape schema + continue + for i, dim_size in enumerate(shape_schema[k]): + if dim_size == N_RES: + v = torch.index_select(v, i, crop_idx) + cropped_protein[k] = v + return cropped_protein diff --git a/modelscope/models/science/unifold/data/msa_pairing.py b/modelscope/models/science/unifold/data/msa_pairing.py new file mode 100644 index 00000000..cc65962c --- /dev/null +++ b/modelscope/models/science/unifold/data/msa_pairing.py @@ -0,0 +1,526 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pairing logic for multimer data """ + +import collections +from typing import Dict, Iterable, List, Sequence + +import numpy as np +import pandas as pd +import scipy.linalg + +from .data_ops import NumpyDict +from .residue_constants import restypes_with_x_and_gap + +MSA_GAP_IDX = restypes_with_x_and_gap.index('-') +SEQUENCE_GAP_CUTOFF = 0.5 +SEQUENCE_SIMILARITY_CUTOFF = 0.9 + +MSA_PAD_VALUES = { + 'msa_all_seq': MSA_GAP_IDX, + 'msa_mask_all_seq': 1, + 'deletion_matrix_all_seq': 0, + 'deletion_matrix_int_all_seq': 0, + 'msa': MSA_GAP_IDX, + 'msa_mask': 1, + 'deletion_matrix': 0, + 'deletion_matrix_int': 0, +} + +MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int') +SEQ_FEATURES = ( + 'residue_index', + 'aatype', + 'all_atom_positions', + 'all_atom_mask', + 'seq_mask', + 'between_segment_residues', + 'has_alt_locations', + 'has_hetatoms', + 'asym_id', + 'entity_id', + 'sym_id', + 'entity_mask', + 'deletion_mean', + 'prediction_atom_mask', + 'literature_positions', + 'atom_indices_to_group_indices', + 'rigid_group_default_frame', + # zy + 'num_sym', +) +TEMPLATE_FEATURES = ( + 'template_aatype', + 'template_all_atom_positions', + 'template_all_atom_mask', +) +CHAIN_FEATURES = ('num_alignments', 'seq_length') + + +def create_paired_features(chains: Iterable[NumpyDict], ) -> List[NumpyDict]: + """Returns the original chains with paired NUM_SEQ features. + + Args: + chains: A list of feature dictionaries for each chain. + + Returns: + A list of feature dictionaries with sequence features including only + rows to be paired. + """ + chains = list(chains) + chain_keys = chains[0].keys() + + if len(chains) < 2: + return chains + else: + updated_chains = [] + paired_chains_to_paired_row_indices = pair_sequences(chains) + paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices) + + for chain_num, chain in enumerate(chains): + new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k} + for feature_name in chain_keys: + if feature_name.endswith('_all_seq'): + feats_padded = pad_features(chain[feature_name], + feature_name) + new_chain[feature_name] = feats_padded[ + paired_rows[:, chain_num]] + new_chain['num_alignments_all_seq'] = np.asarray( + len(paired_rows[:, chain_num])) + updated_chains.append(new_chain) + return updated_chains + + +def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: + """Add a 'padding' row at the end of the features list. + + The padding row will be selected as a 'paired' row in the case of partial + alignment - for the chain that doesn't have paired alignment. + + Args: + feature: The feature to be padded. + feature_name: The name of the feature to be padded. + + Returns: + The feature with an additional padding row. + """ + assert feature.dtype != np.dtype(np.string_) + if feature_name in ( + 'msa_all_seq', + 'msa_mask_all_seq', + 'deletion_matrix_all_seq', + 'deletion_matrix_int_all_seq', + ): + num_res = feature.shape[1] + padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], + feature.dtype) + elif feature_name == 'msa_species_identifiers_all_seq': + padding = [b''] + else: + return feature + feats_padded = np.concatenate([feature, padding], axis=0) + return feats_padded + + +def _make_msa_df(chain_features: NumpyDict) -> pd.DataFrame: + """Makes dataframe with msa features needed for msa pairing.""" + chain_msa = chain_features['msa_all_seq'] + query_seq = chain_msa[0] + per_seq_similarity = np.sum( + query_seq[None] == chain_msa, axis=-1) / float(len(query_seq)) + per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq)) + msa_df = pd.DataFrame({ + 'msa_species_identifiers': + chain_features['msa_species_identifiers_all_seq'], + 'msa_row': + np.arange(len(chain_features['msa_species_identifiers_all_seq'])), + 'msa_similarity': + per_seq_similarity, + 'gap': + per_seq_gap, + }) + return msa_df + + +def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: + """Creates mapping from species to msa dataframe of that species.""" + species_lookup = {} + for species, species_df in msa_df.groupby('msa_species_identifiers'): + species_lookup[species] = species_df + return species_lookup + + +def _match_rows_by_sequence_similarity( + this_species_msa_dfs: List[pd.DataFrame], ) -> List[List[int]]: # noqa + """Finds MSA sequence pairings across chains based on sequence similarity. + + Each chain's MSA sequences are first sorted by their sequence similarity to + their respective target sequence. The sequences are then paired, starting + from the sequences most similar to their target sequence. + + Args: + this_species_msa_dfs: a list of dataframes containing MSA features for + sequences for a specific species. + + Returns: + A list of lists, each containing M indices corresponding to paired MSA rows, + where M is the number of chains. + """ + all_paired_msa_rows = [] + + num_seqs = [ + len(species_df) for species_df in this_species_msa_dfs + if species_df is not None + ] + take_num_seqs = np.min(num_seqs) + + # sort_by_similarity = lambda x: x.sort_values( + # 'msa_similarity', axis=0, ascending=False) + + def sort_by_similarity(x): + return x.sort_values('msa_similarity', axis=0, ascending=False) + + for species_df in this_species_msa_dfs: + if species_df is not None: + species_df_sorted = sort_by_similarity(species_df) + msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values + else: + msa_rows = [-1] * take_num_seqs # take the last 'padding' row + all_paired_msa_rows.append(msa_rows) + all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose()) + return all_paired_msa_rows + + +def pair_sequences(examples: List[NumpyDict]) -> Dict[int, np.ndarray]: + """Returns indices for paired MSA sequences across chains.""" + + num_examples = len(examples) + + all_chain_species_dict = [] + common_species = set() + for chain_features in examples: + msa_df = _make_msa_df(chain_features) + species_dict = _create_species_dict(msa_df) + all_chain_species_dict.append(species_dict) + common_species.update(set(species_dict)) + + common_species = sorted(common_species) + common_species.remove(b'') # Remove target sequence species. + + all_paired_msa_rows = [np.zeros(len(examples), int)] + all_paired_msa_rows_dict = {k: [] for k in range(num_examples)} + all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)] + + for species in common_species: + if not species: + continue + this_species_msa_dfs = [] + species_dfs_present = 0 + for species_dict in all_chain_species_dict: + if species in species_dict: + this_species_msa_dfs.append(species_dict[species]) + species_dfs_present += 1 + else: + this_species_msa_dfs.append(None) + + # Skip species that are present in only one chain. + if species_dfs_present <= 1: + continue + + if np.any( + np.array([ + len(species_df) for species_df in this_species_msa_dfs + if isinstance(species_df, pd.DataFrame) + ]) > 600): + continue + + paired_msa_rows = _match_rows_by_sequence_similarity( + this_species_msa_dfs) + all_paired_msa_rows.extend(paired_msa_rows) + all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) + all_paired_msa_rows_dict = { + num_examples: np.array(paired_msa_rows) + for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items() + } + return all_paired_msa_rows_dict + + +def reorder_paired_rows( + all_paired_msa_rows_dict: Dict[int, np.ndarray]) -> np.ndarray: + """Creates a list of indices of paired MSA rows across chains. + + Args: + all_paired_msa_rows_dict: a mapping from the number of paired chains to the + paired indices. + + Returns: + a list of lists, each containing indices of paired MSA rows across chains. + The paired-index lists are ordered by: + 1) the number of chains in the paired alignment, i.e, all-chain pairings + will come first. + 2) e-values + """ + all_paired_msa_rows = [] + + for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True): + paired_rows = all_paired_msa_rows_dict[num_pairings] + paired_rows_product = np.abs( + np.array( + [np.prod(rows.astype(np.float64)) for rows in paired_rows])) + paired_rows_sort_index = np.argsort(paired_rows_product) + all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index]) + + return np.array(all_paired_msa_rows) + + +def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: + """Like scipy.linalg.block_diag but with an optional padding value.""" + ones_arrs = [np.ones_like(x) for x in arrs] + off_diag_mask = 1 - scipy.linalg.block_diag(*ones_arrs) + diag = scipy.linalg.block_diag(*arrs) + diag += (off_diag_mask * pad_value).astype(diag.dtype) + return diag + + +def _correct_post_merged_feats(np_example: NumpyDict, + np_chains_list: Sequence[NumpyDict], + pair_msa_sequences: bool) -> NumpyDict: + """Adds features that need to be computed/recomputed post merging.""" + + np_example['seq_length'] = np.asarray( + np_example['aatype'].shape[0], dtype=np.int32) + np_example['num_alignments'] = np.asarray( + np_example['msa'].shape[0], dtype=np.int32) + + if not pair_msa_sequences: + # Generate a bias that is 1 for the first row of every block in the + # block diagonal MSA - i.e. make sure the cluster stack always includes + # the query sequences for each chain (since the first row is the query + # sequence). + cluster_bias_masks = [] + for chain in np_chains_list: + mask = np.zeros(chain['msa'].shape[0]) + mask[0] = 1 + cluster_bias_masks.append(mask) + np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [ + np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list + ] + + np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0) + else: + np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) + np_example['cluster_bias_mask'][0] = 1 + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [ + np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list + ] + msa_masks_all_seq = [ + np.ones(x['msa_all_seq'].shape, dtype=np.int8) + for x in np_chains_list + ] + + msa_mask_block_diag = block_diag(*msa_masks, pad_value=0) + msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) + np_example['bert_mask'] = np.concatenate( + [msa_mask_all_seq, msa_mask_block_diag], axis=0) + return np_example + + +def _pad_templates(chains: Sequence[NumpyDict], + max_templates: int) -> Sequence[NumpyDict]: + """For each chain pad the number of templates to a fixed size. + + Args: + chains: A list of protein chains. + max_templates: Each chain will be padded to have this many templates. + + Returns: + The list of chains, updated to have template features padded to + max_templates. + """ + for chain in chains: + for k, v in chain.items(): + if k in TEMPLATE_FEATURES: + padding = np.zeros_like(v.shape) + padding[0] = max_templates - v.shape[0] + padding = [(0, p) for p in padding] + chain[k] = np.pad(v, padding, mode='constant') + return chains + + +def _merge_features_from_multiple_chains( + chains: Sequence[NumpyDict], pair_msa_sequences: bool) -> NumpyDict: + """Merge features from multiple chains. + + Args: + chains: A list of feature dictionaries that we want to merge. + pair_msa_sequences: Whether to concatenate MSA features along the + num_res dimension (if True), or to block diagonalize them (if False). + + Returns: + A feature dictionary for the merged example. + """ + merged_example = {} + for feature_name in chains[0]: + feats = [x[feature_name] for x in chains] + feature_name_split = feature_name.split('_all_seq')[0] + if feature_name_split in MSA_FEATURES: + if pair_msa_sequences or '_all_seq' in feature_name: + merged_example[feature_name] = np.concatenate(feats, axis=1) + if feature_name_split == 'msa': + merged_example['msa_chains_all_seq'] = np.ones( + merged_example[feature_name].shape[0]).reshape(-1, 1) + else: + merged_example[feature_name] = block_diag( + *feats, pad_value=MSA_PAD_VALUES[feature_name]) + if feature_name_split == 'msa': + msa_chains = [] + for i, feat in enumerate(feats): + cur_shape = feat.shape[0] + vals = np.ones(cur_shape) * (i + 2) + msa_chains.append(vals) + merged_example['msa_chains'] = np.concatenate( + msa_chains).reshape(-1, 1) + elif feature_name_split in SEQ_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=0) + elif feature_name_split in TEMPLATE_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=1) + elif feature_name_split in CHAIN_FEATURES: + merged_example[feature_name] = np.sum(feats).astype(np.int32) + else: + merged_example[feature_name] = feats[0] + return merged_example + + +def _merge_homomers_dense_msa( + chains: Iterable[NumpyDict]) -> Sequence[NumpyDict]: + """Merge all identical chains, making the resulting MSA dense. + + Args: + chains: An iterable of features for each chain. + + Returns: + A list of feature dictionaries. All features with the same entity_id + will be merged - MSA features will be concatenated along the num_res + dimension - making them dense. + """ + entity_chains = collections.defaultdict(list) + for chain in chains: + entity_id = chain['entity_id'][0] + entity_chains[entity_id].append(chain) + + grouped_chains = [] + for entity_id in sorted(entity_chains): + chains = entity_chains[entity_id] + grouped_chains.append(chains) + chains = [ + _merge_features_from_multiple_chains(chains, pair_msa_sequences=True) + for chains in grouped_chains + ] + return chains + + +def _concatenate_paired_and_unpaired_features(example: NumpyDict) -> NumpyDict: + """Merges paired and block-diagonalised features.""" + features = MSA_FEATURES + ('msa_chains', ) + for feature_name in features: + if feature_name in example: + feat = example[feature_name] + feat_all_seq = example[feature_name + '_all_seq'] + try: + merged_feat = np.concatenate([feat_all_seq, feat], axis=0) + except Exception as ex: + raise Exception( + 'concat failed.', + feature_name, + feat_all_seq.shape, + feat.shape, + ex.__class__, + ex, + ) + example[feature_name] = merged_feat + example['num_alignments'] = np.array( + example['msa'].shape[0], dtype=np.int32) + return example + + +def merge_chain_features(np_chains_list: List[NumpyDict], + pair_msa_sequences: bool, + max_templates: int) -> NumpyDict: + """Merges features for multiple chains to single FeatureDict. + + Args: + np_chains_list: List of FeatureDicts for each chain. + pair_msa_sequences: Whether to merge paired MSAs. + max_templates: The maximum number of templates to include. + + Returns: + Single FeatureDict for entire complex. + """ + np_chains_list = _pad_templates( + np_chains_list, max_templates=max_templates) + np_chains_list = _merge_homomers_dense_msa(np_chains_list) + # Unpaired MSA features will be always block-diagonalised; paired MSA + # features will be concatenated. + np_example = _merge_features_from_multiple_chains( + np_chains_list, pair_msa_sequences=False) + if pair_msa_sequences: + np_example = _concatenate_paired_and_unpaired_features(np_example) + np_example = _correct_post_merged_feats( + np_example=np_example, + np_chains_list=np_chains_list, + pair_msa_sequences=pair_msa_sequences, + ) + + return np_example + + +def deduplicate_unpaired_sequences( + np_chains: List[NumpyDict]) -> List[NumpyDict]: + """Removes unpaired sequences which duplicate a paired sequence.""" + + feature_names = np_chains[0].keys() + msa_features = MSA_FEATURES + cache_msa_features = {} + for chain in np_chains: + entity_id = int(chain['entity_id'][0]) + if entity_id not in cache_msa_features: + sequence_set = set(s.tobytes() for s in chain['msa_all_seq']) + keep_rows = [] + # Go through unpaired MSA seqs and remove any rows that correspond to the + # sequences that are already present in the paired MSA. + for row_num, seq in enumerate(chain['msa']): + if seq.tobytes() not in sequence_set: + keep_rows.append(row_num) + new_msa_features = {} + for feature_name in feature_names: + if feature_name in msa_features: + if keep_rows: + new_msa_features[feature_name] = chain[feature_name][ + keep_rows] + else: + new_shape = list(chain[feature_name].shape) + new_shape[0] = 0 + new_msa_features[feature_name] = np.zeros( + new_shape, dtype=chain[feature_name].dtype) + cache_msa_features[entity_id] = new_msa_features + for feature_name in cache_msa_features[entity_id]: + chain[feature_name] = cache_msa_features[entity_id][feature_name] + chain['num_alignments'] = np.array( + chain['msa'].shape[0], dtype=np.int32) + return np_chains diff --git a/modelscope/models/science/unifold/data/process.py b/modelscope/models/science/unifold/data/process.py new file mode 100644 index 00000000..3987cb1c --- /dev/null +++ b/modelscope/models/science/unifold/data/process.py @@ -0,0 +1,264 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Optional + +import numpy as np +import torch + +from modelscope.models.science.unifold.data import data_ops + + +def nonensembled_fns(common_cfg, mode_cfg): + """Input pipeline data transformers that are not ensembled.""" + v2_feature = common_cfg.v2_feature + operators = [] + if mode_cfg.random_delete_msa: + operators.append( + data_ops.random_delete_msa(common_cfg.random_delete_msa)) + operators.extend([ + data_ops.cast_to_64bit_ints, + data_ops.correct_msa_restypes, + data_ops.squeeze_features, + data_ops.randomly_replace_msa_with_unknown(0.0), + data_ops.make_seq_mask, + data_ops.make_msa_mask, + ]) + operators.append(data_ops.make_hhblits_profile_v2 + if v2_feature else data_ops.make_hhblits_profile) + if common_cfg.use_templates: + operators.extend([ + data_ops.make_template_mask, + data_ops.make_pseudo_beta('template_'), + ]) + operators.append( + data_ops.crop_templates( + max_templates=mode_cfg.max_templates, + subsample_templates=mode_cfg.subsample_templates, + )) + + if common_cfg.use_template_torsion_angles: + operators.extend([ + data_ops.atom37_to_torsion_angles('template_'), + ]) + + operators.append(data_ops.make_atom14_masks) + operators.append(data_ops.make_target_feat) + + return operators + + +def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed): + operators = [] + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates + else: + pad_msa_clusters = mode_cfg.max_msa_clusters + crop_feats = dict(common_cfg.features) + if mode_cfg.fixed_size: + if mode_cfg.crop: + if common_cfg.is_multimer: + crop_fn = data_ops.crop_to_size_multimer( + crop_size=mode_cfg.crop_size, + shape_schema=crop_feats, + seed=crop_and_fix_size_seed, + spatial_crop_prob=mode_cfg.spatial_crop_prob, + ca_ca_threshold=mode_cfg.ca_ca_threshold, + ) + else: + crop_fn = data_ops.crop_to_size_single( + crop_size=mode_cfg.crop_size, + shape_schema=crop_feats, + seed=crop_and_fix_size_seed, + ) + operators.append(crop_fn) + + operators.append(data_ops.select_feat(crop_feats)) + + operators.append( + data_ops.make_fixed_size( + crop_feats, + pad_msa_clusters, + common_cfg.max_extra_msa, + mode_cfg.crop_size, + mode_cfg.max_templates, + )) + return operators + + +def ensembled_fns(common_cfg, mode_cfg): + """Input pipeline data transformers that can be ensembled and averaged.""" + operators = [] + multimer_mode = common_cfg.is_multimer + v2_feature = common_cfg.v2_feature + # multimer don't use block delete msa + if mode_cfg.block_delete_msa and not multimer_mode: + operators.append( + data_ops.block_delete_msa(common_cfg.block_delete_msa)) + if 'max_distillation_msa_clusters' in mode_cfg: + operators.append( + data_ops.sample_msa_distillation( + mode_cfg.max_distillation_msa_clusters)) + + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates + else: + pad_msa_clusters = mode_cfg.max_msa_clusters + + max_msa_clusters = pad_msa_clusters + max_extra_msa = common_cfg.max_extra_msa + + assert common_cfg.resample_msa_in_recycling + gumbel_sample = common_cfg.gumbel_sample + operators.append( + data_ops.sample_msa( + max_msa_clusters, + keep_extra=True, + gumbel_sample=gumbel_sample, + biased_msa_by_chain=mode_cfg.biased_msa_by_chain, + )) + + if 'masked_msa' in common_cfg: + # Masked MSA should come *before* MSA clustering so that + # the clustering and full MSA profile do not leak information about + # the masked locations and secret corrupted locations. + operators.append( + data_ops.make_masked_msa( + common_cfg.masked_msa, + mode_cfg.masked_msa_replace_fraction, + gumbel_sample=gumbel_sample, + share_mask=mode_cfg.share_mask, + )) + + if common_cfg.msa_cluster_features: + if v2_feature: + operators.append(data_ops.nearest_neighbor_clusters_v2()) + else: + operators.append(data_ops.nearest_neighbor_clusters()) + operators.append(data_ops.summarize_clusters) + + if v2_feature: + operators.append(data_ops.make_msa_feat_v2) + else: + operators.append(data_ops.make_msa_feat) + # Crop after creating the cluster profiles. + if max_extra_msa: + if v2_feature: + operators.append(data_ops.make_extra_msa_feat(max_extra_msa)) + else: + operators.append(data_ops.crop_extra_msa(max_extra_msa)) + else: + operators.append(data_ops.delete_extra_msa) + # operators.append(data_operators.select_feat(common_cfg.recycling_features)) + return operators + + +def process_features(tensors, common_cfg, mode_cfg): + """Based on the config, apply filters and transformations to the data.""" + is_distillation = bool(tensors.get('is_distillation', 0)) + multimer_mode = common_cfg.is_multimer + crop_and_fix_size_seed = int(tensors['crop_and_fix_size_seed']) + crop_fn = crop_and_fix_size_fns( + common_cfg, + mode_cfg, + crop_and_fix_size_seed, + ) + + def wrap_ensemble_fn(data, i): + """Function to be mapped over the ensemble dimension.""" + d = data.copy() + fns = ensembled_fns( + common_cfg, + mode_cfg, + ) + new_d = compose(fns)(d) + if not multimer_mode or is_distillation: + new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d) + return compose(crop_fn)(new_d) + else: # select after crop for spatial cropping + d = compose(crop_fn)(d) + d = data_ops.select_feat(common_cfg.recycling_features)(d) + return d + + nonensembled = nonensembled_fns(common_cfg, mode_cfg) + + if mode_cfg.supervised and (not multimer_mode or is_distillation): + nonensembled.extend(label_transform_fn()) + + tensors = compose(nonensembled)(tensors) + + num_recycling = int(tensors['num_recycling_iters']) + 1 + num_ensembles = mode_cfg.num_ensembles + + ensemble_tensors = map_fn( + lambda x: wrap_ensemble_fn(tensors, x), + torch.arange(num_recycling * num_ensembles), + ) + tensors = compose(crop_fn)(tensors) + # add a dummy dim to align with recycling features + tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors} + tensors.update(ensemble_tensors) + return tensors + + +@data_ops.curry1 +def compose(x, fs): + for f in fs: + x = f(x) + return x + + +def pad_then_stack(values, ): + if len(values[0].shape) >= 1: + size = max(v.shape[0] for v in values) + new_values = [] + for v in values: + if v.shape[0] < size: + res = values[0].new_zeros(size, *v.shape[1:]) + res[:v.shape[0], ...] = v + else: + res = v + new_values.append(res) + else: + new_values = values + return torch.stack(new_values, dim=0) + + +def map_fn(fun, x): + ensembles = [fun(elem) for elem in x] + features = ensembles[0].keys() + ensembled_dict = {} + for feat in features: + ensembled_dict[feat] = pad_then_stack( + [dict_i[feat] for dict_i in ensembles]) + return ensembled_dict + + +def process_single_label(label: dict, + num_ensemble: Optional[int] = None) -> dict: + assert 'aatype' in label + assert 'all_atom_positions' in label + assert 'all_atom_mask' in label + label = compose(label_transform_fn())(label) + if num_ensemble is not None: + label = { + k: torch.stack([v for _ in range(num_ensemble)]) + for k, v in label.items() + } + return label + + +def process_labels(labels_list, num_ensemble: Optional[int] = None): + return [process_single_label(ll, num_ensemble) for ll in labels_list] + + +def label_transform_fn(): + return [ + data_ops.make_atom14_masks, + data_ops.make_atom14_positions, + data_ops.atom37_to_frames, + data_ops.atom37_to_torsion_angles(''), + data_ops.make_pseudo_beta(''), + data_ops.get_backbone_frames, + data_ops.get_chi_angles, + ] diff --git a/modelscope/models/science/unifold/data/process_multimer.py b/modelscope/models/science/unifold/data/process_multimer.py new file mode 100644 index 00000000..04572d2d --- /dev/null +++ b/modelscope/models/science/unifold/data/process_multimer.py @@ -0,0 +1,417 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature processing logic for multimer data """ + +import collections +from typing import Iterable, List, MutableMapping + +import numpy as np + +from modelscope.models.science.unifold.data import (msa_pairing, + residue_constants) +from .utils import correct_template_restypes + +FeatureDict = MutableMapping[str, np.ndarray] + +REQUIRED_FEATURES = frozenset({ + 'aatype', + 'all_atom_mask', + 'all_atom_positions', + 'all_chains_entity_ids', + 'all_crops_all_chains_mask', + 'all_crops_all_chains_positions', + 'all_crops_all_chains_residue_ids', + 'assembly_num_chains', + 'asym_id', + 'bert_mask', + 'cluster_bias_mask', + 'deletion_matrix', + 'deletion_mean', + 'entity_id', + 'entity_mask', + 'mem_peak', + 'msa', + 'msa_mask', + 'num_alignments', + 'num_templates', + 'queue_size', + 'residue_index', + 'resolution', + 'seq_length', + 'seq_mask', + 'sym_id', + 'template_aatype', + 'template_all_atom_mask', + 'template_all_atom_positions', + # zy added: + 'asym_len', + 'template_sum_probs', + 'num_sym', + 'msa_chains', +}) + +MAX_TEMPLATES = 4 +MSA_CROP_SIZE = 2048 + + +def _is_homomer_or_monomer(chains: Iterable[FeatureDict]) -> bool: + """Checks if a list of chains represents a homomer/monomer example.""" + # Note that an entity_id of 0 indicates padding. + num_unique_chains = len( + np.unique( + np.concatenate([ + np.unique(chain['entity_id'][chain['entity_id'] > 0]) + for chain in chains + ]))) + return num_unique_chains == 1 + + +def pair_and_merge( + all_chain_features: MutableMapping[str, FeatureDict]) -> FeatureDict: + """Runs processing on features to augment, pair and merge. + + Args: + all_chain_features: A MutableMap of dictionaries of features for each chain. + + Returns: + A dictionary of features. + """ + + process_unmerged_features(all_chain_features) + + np_chains_list = all_chain_features + + pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list) + + if pair_msa_sequences: + np_chains_list = msa_pairing.create_paired_features( + chains=np_chains_list) + np_chains_list = msa_pairing.deduplicate_unpaired_sequences( + np_chains_list) + np_chains_list = crop_chains( + np_chains_list, + msa_crop_size=MSA_CROP_SIZE, + pair_msa_sequences=pair_msa_sequences, + max_templates=MAX_TEMPLATES, + ) + np_example = msa_pairing.merge_chain_features( + np_chains_list=np_chains_list, + pair_msa_sequences=pair_msa_sequences, + max_templates=MAX_TEMPLATES, + ) + np_example = process_final(np_example) + return np_example + + +def crop_chains( + chains_list: List[FeatureDict], + msa_crop_size: int, + pair_msa_sequences: bool, + max_templates: int, +) -> List[FeatureDict]: + """Crops the MSAs for a set of chains. + + Args: + chains_list: A list of chains to be cropped. + msa_crop_size: The total number of sequences to crop from the MSA. + pair_msa_sequences: Whether we are operating in sequence-pairing mode. + max_templates: The maximum templates to use per chain. + + Returns: + The chains cropped. + """ + + # Apply the cropping. + cropped_chains = [] + for chain in chains_list: + cropped_chain = _crop_single_chain( + chain, + msa_crop_size=msa_crop_size, + pair_msa_sequences=pair_msa_sequences, + max_templates=max_templates, + ) + cropped_chains.append(cropped_chain) + + return cropped_chains + + +def _crop_single_chain(chain: FeatureDict, msa_crop_size: int, + pair_msa_sequences: bool, + max_templates: int) -> FeatureDict: + """Crops msa sequences to `msa_crop_size`.""" + msa_size = chain['num_alignments'] + + if pair_msa_sequences: + msa_size_all_seq = chain['num_alignments_all_seq'] + msa_crop_size_all_seq = np.minimum(msa_size_all_seq, + msa_crop_size // 2) + + # We reduce the number of un-paired sequences, by the number of times a + # sequence from this chain's MSA is included in the paired MSA. This keeps + # the MSA size for each chain roughly constant. + msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :] + num_non_gapped_pairs = np.sum( + np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1)) + num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, + msa_crop_size_all_seq) + + # Restrict the unpaired crop size so that paired+unpaired sequences do not + # exceed msa_seqs_per_chain for each chain. + max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) + msa_crop_size = np.minimum(msa_size, max_msa_crop_size) + else: + msa_crop_size = np.minimum(msa_size, msa_crop_size) + + include_templates = 'template_aatype' in chain and max_templates + if include_templates: + num_templates = chain['template_aatype'].shape[0] + templates_crop_size = np.minimum(num_templates, max_templates) + + for k in chain: + k_split = k.split('_all_seq')[0] + if k_split in msa_pairing.TEMPLATE_FEATURES: + chain[k] = chain[k][:templates_crop_size, :] + elif k_split in msa_pairing.MSA_FEATURES: + if '_all_seq' in k and pair_msa_sequences: + chain[k] = chain[k][:msa_crop_size_all_seq, :] + else: + chain[k] = chain[k][:msa_crop_size, :] + + chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32) + if include_templates: + chain['num_templates'] = np.asarray( + templates_crop_size, dtype=np.int32) + if pair_msa_sequences: + chain['num_alignments_all_seq'] = np.asarray( + msa_crop_size_all_seq, dtype=np.int32) + return chain + + +def process_final(np_example: FeatureDict) -> FeatureDict: + """Final processing steps in data pipeline, after merging and pairing.""" + np_example = _make_seq_mask(np_example) + np_example = _make_msa_mask(np_example) + np_example = _filter_features(np_example) + return np_example + + +def _make_seq_mask(np_example): + np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32) + return np_example + + +def _make_msa_mask(np_example): + """Mask features are all ones, but will later be zero-padded.""" + + np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.int8) + + seq_mask = (np_example['entity_id'] > 0).astype(np.int8) + np_example['msa_mask'] *= seq_mask[None] + + return np_example + + +def _filter_features(np_example: FeatureDict) -> FeatureDict: + """Filters features of example to only those requested.""" + return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES} + + +def process_unmerged_features(all_chain_features: MutableMapping[str, + FeatureDict]): + """Postprocessing stage for per-chain features before merging.""" + num_chains = len(all_chain_features) + for chain_features in all_chain_features: + # Convert deletion matrices to float. + if 'deletion_matrix_int' in chain_features: + chain_features['deletion_matrix'] = np.asarray( + chain_features.pop('deletion_matrix_int'), dtype=np.float32) + if 'deletion_matrix_int_all_seq' in chain_features: + chain_features['deletion_matrix_all_seq'] = np.asarray( + chain_features.pop('deletion_matrix_int_all_seq'), + dtype=np.float32) + + chain_features['deletion_mean'] = np.mean( + chain_features['deletion_matrix'], axis=0) + + if 'all_atom_positions' not in chain_features: + # Add all_atom_mask and dummy all_atom_positions based on aatype. + all_atom_mask = residue_constants.STANDARD_ATOM_MASK[ + chain_features['aatype']] + chain_features['all_atom_mask'] = all_atom_mask + chain_features['all_atom_positions'] = np.zeros( + list(all_atom_mask.shape) + [3]) + + # Add assembly_num_chains. + chain_features['assembly_num_chains'] = np.asarray(num_chains) + + # Add entity_mask. + for chain_features in all_chain_features: + chain_features['entity_mask'] = ( + chain_features['entity_id'] != # noqa W504 + 0).astype(np.int32) + + +def empty_template_feats(n_res): + return { + 'template_aatype': + np.zeros((0, n_res)).astype(np.int64), + 'template_all_atom_positions': + np.zeros((0, n_res, 37, 3)).astype(np.float32), + 'template_sum_probs': + np.zeros((0, 1)).astype(np.float32), + 'template_all_atom_mask': + np.zeros((0, n_res, 37)).astype(np.float32), + } + + +def convert_monomer_features(monomer_features: FeatureDict) -> FeatureDict: + """Reshapes and modifies monomer features for multimer models.""" + if monomer_features['template_aatype'].shape[0] == 0: + monomer_features.update( + empty_template_feats(monomer_features['aatype'].shape[0])) + converted = {} + unnecessary_leading_dim_feats = { + 'sequence', + 'domain_name', + 'num_alignments', + 'seq_length', + } + for feature_name, feature in monomer_features.items(): + if feature_name in unnecessary_leading_dim_feats: + # asarray ensures it's a np.ndarray. + feature = np.asarray(feature[0], dtype=feature.dtype) + elif feature_name == 'aatype': + # The multimer model performs the one-hot operation itself. + feature = np.argmax(feature, axis=-1).astype(np.int32) + elif feature_name == 'template_aatype': + if feature.shape[0] > 0: + feature = correct_template_restypes(feature) + elif feature_name == 'template_all_atom_masks': + feature_name = 'template_all_atom_mask' + elif feature_name == 'msa': + feature = feature.astype(np.uint8) + + if feature_name.endswith('_mask'): + feature = feature.astype(np.float32) + + converted[feature_name] = feature + + if 'deletion_matrix_int' in monomer_features: + monomer_features['deletion_matrix'] = monomer_features.pop( + 'deletion_matrix_int').astype(np.float32) + + converted.pop( + 'template_sum_probs' + ) # zy: this input is checked to be dirty in shape. TODO: figure out why and make it right. + return converted + + +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +def add_assembly_features(all_chain_features, ): + """Add features to distinguish between chains. + + Args: + all_chain_features: A dictionary which maps chain_id to a dictionary of + features for each chain. + + Returns: + all_chain_features: A dictionary which maps strings of the form + `_` to the corresponding chain features. E.g. two + chains from a homodimer would have keys A_1 and A_2. Two chains from a + heterodimer would have keys A_1 and B_1. + """ + # Group the chains by sequence + seq_to_entity_id = {} + grouped_chains = collections.defaultdict(list) + for chain_features in all_chain_features: + assert 'sequence' in chain_features + seq = str(chain_features['sequence']) + if seq not in seq_to_entity_id: + seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 + grouped_chains[seq_to_entity_id[seq]].append(chain_features) + + new_all_chain_features = [] + chain_id = 1 + for entity_id, group_chain_features in grouped_chains.items(): + num_sym = len(group_chain_features) # zy + for sym_id, chain_features in enumerate(group_chain_features, start=1): + seq_length = chain_features['seq_length'] + chain_features['asym_id'] = chain_id * np.ones(seq_length) + chain_features['sym_id'] = sym_id * np.ones(seq_length) + chain_features['entity_id'] = entity_id * np.ones(seq_length) + chain_features['num_sym'] = num_sym * np.ones(seq_length) + chain_id += 1 + new_all_chain_features.append(chain_features) + + return new_all_chain_features + + +def pad_msa(np_example, min_num_seq): + np_example = dict(np_example) + num_seq = np_example['msa'].shape[0] + if num_seq < min_num_seq: + for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask', + 'msa_chains'): + np_example[feat] = np.pad(np_example[feat], + ((0, min_num_seq - num_seq), (0, 0))) + np_example['cluster_bias_mask'] = np.pad( + np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq), )) + return np_example + + +def post_process(np_example): + np_example = pad_msa(np_example, 512) + no_dim_keys = [ + 'num_alignments', + 'assembly_num_chains', + 'num_templates', + 'seq_length', + 'resolution', + ] + for k in no_dim_keys: + if k in np_example: + np_example[k] = np_example[k].reshape(-1) + return np_example + + +def merge_msas(msa, del_mat, new_msa, new_del_mat): + cur_msa_set = set([tuple(m) for m in msa]) + new_rows = [] + for i, s in enumerate(new_msa): + if tuple(s) not in cur_msa_set: + new_rows.append(i) + ret_msa = np.concatenate([msa, new_msa[new_rows]], axis=0) + ret_del_mat = np.concatenate([del_mat, new_del_mat[new_rows]], axis=0) + return ret_msa, ret_del_mat diff --git a/modelscope/models/science/unifold/data/protein.py b/modelscope/models/science/unifold/data/protein.py new file mode 100644 index 00000000..42308d04 --- /dev/null +++ b/modelscope/models/science/unifold/data/protein.py @@ -0,0 +1,322 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Mapping, Optional + +import numpy as np +from Bio.PDB import PDBParser + +from modelscope.models.science.unifold.data import residue_constants + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + +# Complete sequence of chain IDs supported by the PDB format. +PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # 0-indexed number corresponding to the chain in the protein that this residue + # belongs to. + chain_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + def __post_init__(self): + if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: + raise ValueError( + f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' + 'because these cannot be written to PDB format.') + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If chain_id is specified (e.g. A), then only that chain + is parsed. Otherwise all chains are parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.' + ) + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if chain_id is not None and chain.id != chain_id: + continue + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get( + res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num, )) + res_b_factors = np.zeros((residue_constants.atom_type_num, )) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1.0 + res_b_factors[residue_constants.atom_order[ + atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + # Chain IDs are usually characters so map these to ints. + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors), + ) + + +def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: + chain_end = 'TER' + return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' + f'{chain_name:>1}{residue_index:>4}') + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + + # res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + def res_1to3(r): + return residue_constants.restype_1to3.get(restypes[r], 'UNK') + + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + chain_index = prot.chain_index.astype(np.int32) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') + chain_ids[i] = PDB_CHAIN_IDS[i] + + pdb_lines.append('MODEL 1') + atom_index = 1 + last_chain_index = chain_index[0] + # Add all atom sites. + for i in range(aatype.shape[0]): + # Close the previous chain if in a multichain PDB. + if last_chain_index != chain_index[i]: + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(aatype[i - 1]), + chain_ids[chain_index[i - 1]], + residue_index[i - 1], + )) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip(atom_types, + atom_positions[i], + atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[ + 0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = ( + f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the final chain. + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(aatype[-1]), + chain_ids[chain_index[-1]], + residue_index[-1], + )) + pdb_lines.append('ENDMDL') + pdb_lines.append('END') + + # Pad all lines to 80 characters. + pdb_lines = [line.ljust(80) for line in pdb_lines] + return '\n'.join(pdb_lines) + '\n' # Add terminating newline. + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction(features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + fold_output: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + + Returns: + A protein instance. + """ + + if 'asym_id' in features: + chain_index = features['asym_id'] - 1 + else: + chain_index = np.zeros_like((features['aatype'])) + + if b_factors is None: + b_factors = np.zeros_like(result['final_atom_mask']) + + return Protein( + aatype=features['aatype'], + atom_positions=result['final_atom_positions'], + atom_mask=result['final_atom_mask'], + residue_index=features['residue_index'] + 1, + chain_index=chain_index, + b_factors=b_factors, + ) + + +def from_feature(features: FeatureDict, + b_factors: Optional[np.ndarray] = None) -> Protein: + """Assembles a standard pdb from input atom positions & mask. + + Args: + features: Dictionary holding model inputs. + b_factors: (Optional) B-factors to use for the protein. + + Returns: + A protein instance. + """ + + if 'asym_id' in features: + chain_index = features['asym_id'] - 1 + else: + chain_index = np.zeros_like((features['aatype'])) + + if b_factors is None: + b_factors = np.zeros_like(features['all_atom_mask']) + + return Protein( + aatype=features['aatype'], + atom_positions=features['all_atom_positions'], + atom_mask=features['all_atom_mask'], + residue_index=features['residue_index'] + 1, + chain_index=chain_index, + b_factors=b_factors, + ) diff --git a/modelscope/models/science/unifold/data/residue_constants.py b/modelscope/models/science/unifold/data/residue_constants.py new file mode 100644 index 00000000..beebfe89 --- /dev/null +++ b/modelscope/models/science/unifold/data/residue_constants.py @@ -0,0 +1,1212 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Constants used in AlphaFold.""" + +import collections +import functools +import os +from typing import List, Mapping, Tuple + +import numpy as np +from unicore.utils import tree_map + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], + ['CG', 'CD', 'NE', 'CZ'], + ], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1'], + ], + 'GLU': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1'], + ], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], + ['CG', 'CD', 'CE', 'NZ'], + ], + 'MET': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE'], + ], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + 'ALA': ['C', 'CA', 'CB', 'N', 'O'], + 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], + 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], + 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], + 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], + 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], + 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], + 'GLY': ['C', 'CA', 'N', 'O'], + 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], + 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], + 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], + 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], + 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], + 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], + 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], + 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], + 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], + 'TRP': [ + 'C', + 'CA', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'CE2', + 'CE3', + 'CZ2', + 'CZ3', + 'CH2', + 'N', + 'NE1', + 'O', + ], + 'TYR': + ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', 'OH'], + 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'], +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +residue_atom_renaming_swaps = { + 'ASP': { + 'OD1': 'OD2' + }, + 'GLU': { + 'OE1': 'OE2' + }, + 'PHE': { + 'CD1': 'CD2', + 'CE1': 'CE2' + }, + 'TYR': { + 'CD1': 'CD2', + 'CE1': 'CE2' + }, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + 'C': 1.7, + 'N': 1.55, + 'O': 1.52, + 'S': 1.8, +} + +Bond = collections.namedtuple('Bond', + ['atom1_name', 'atom2_name', 'length', 'stddev']) +BondAngle = collections.namedtuple( + 'BondAngle', + ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) + + +@functools.lru_cache(maxsize=None) +# def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], Mapping[ #noqa +# str, List[Bond]], Mapping[str, List[BondAngle]]]: +def load_stereo_chemical_props(): + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: Dict that maps resname -> list of Bond tuples. + residue_virtual_bonds: Dict that maps resname -> list of Bond tuples. + residue_bond_angles: Dict that maps resname -> list of BondAngle tuples. + """ + stereo_chemical_props_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'stereo_chemical_props.txt') + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split('-') + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds['UNK'] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split('-') + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle( + atom1, + atom2, + atom3, + float(angle_degree) / 180.0 * np.pi, + float(stddev_degree) / 180.0 * np.pi, + )) + residue_bond_angles['UNK'] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return '-'.join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 + - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length + * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length + - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length + - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + + (dl_db1 * bond1.stddev)**2 + + (dl_db2 * bond2.stddev)**2) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, residue_virtual_bonds, residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', + 'CA', + 'C', + 'CB', + 'O', + 'CG', + 'CG1', + 'CG2', + 'OG', + 'OG1', + 'SG', + 'CD', + 'CD1', + 'CD2', + 'ND1', + 'ND2', + 'OD1', + 'OD2', + 'SD', + 'CE', + 'CE1', + 'CE2', + 'CE3', + 'NE', + 'NE1', + 'NE2', + 'OE1', + 'OE2', + 'CH2', + 'NH1', + 'NH2', + 'OH', + 'CZ', + 'CZ2', + 'CZ3', + 'NZ', + 'OXT', +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD', + 'NE', + 'CZ', + 'NH1', + 'NH2', + '', + '', + '', + ], + 'ASN': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'ND1', + 'CD2', + 'CE1', + 'NE2', + '', + '', + '', + '', + ], + 'ILE': + ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'CE1', + 'CE2', + 'CZ', + '', + '', + '', + ], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': + ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'NE1', + 'CE2', + 'CE3', + 'CZ2', + 'CZ3', + 'CH2', + ], + 'TYR': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'CE1', + 'CE2', + 'CZ', + 'OH', + '', + '', + ], + 'VAL': + ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', + 'R', + 'N', + 'D', + 'C', + 'Q', + 'E', + 'G', + 'H', + 'I', + 'L', + 'K', + 'M', + 'F', + 'P', + 'S', + 'T', + 'W', + 'Y', + 'V', +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ['X'] +restype_order_with_x = { + restype: i + for i, restype in enumerate(restypes_with_x) +} + + +def sequence_to_onehot(sequence: str, + mapping: Mapping[str, int], + map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + 'The mapping must have values from 0 to num_unique_aas-1 ' + 'without any gaps. Got: %s' % sorted(mapping.values())) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping['X']) + else: + raise ValueError( + f'Invalid character in the sequence: {aa_type}') + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', +} + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + 'A': 0, + 'B': 2, + 'C': 1, + 'D': 2, + 'E': 3, + 'F': 4, + 'G': 5, + 'H': 6, + 'I': 7, + 'J': 20, + 'K': 8, + 'L': 9, + 'M': 10, + 'N': 11, + 'O': 20, + 'P': 12, + 'Q': 13, + 'R': 14, + 'S': 15, + 'T': 16, + 'U': 1, + 'V': 17, + 'W': 18, + 'X': 20, + 'Y': 19, + 'Z': 3, + '-': 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: 'A', + 1: 'C', # Also U. + 2: 'D', # Also B. + 3: 'E', # Also Z. + 4: 'F', + 5: 'G', + 6: 'H', + 7: 'I', + 8: 'K', + 9: 'L', + 10: 'M', + 11: 'N', + 12: 'P', + 13: 'Q', + 14: 'R', + 15: 'S', + 16: 'T', + 17: 'V', + 18: 'W', + 19: 'Y', + 20: 'X', # Includes J and O. + 21: '-', +} + +restypes_with_x_and_gap = restypes + ['X', '-'] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap))) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree_map( + lambda n: atom_order[n], chi_angles_atom_indices, leaf_type=str) +chi_angles_atom_indices = np.array([ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices +]) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, + translation]).transpose() + m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int_) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int_) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[ + resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, + atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, + atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = { + name: np.array(pos) + for name, _, pos in rigid_group_atom_positions[resname] + } + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['N'] - atom_positions['CA'], + ey=np.array([1.0, 0.0, 0.0]), + translation=atom_positions['N'], + ) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C'], + ) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [ + atom_positions[name] for name in base_atom_names + ] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2], + ) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1.0, 0.0, 0.0]), + translation=axis_end_atom_position, + ) + restype_rigid_group_default_frame[restype, + 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, + bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, + atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, + atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, + atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, + atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, + atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, + atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, + atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, + atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, + atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, + atom1_idx] = b.stddev + return { + 'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14) + 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14) + 'stddev': restype_atom14_bond_stddev, # shape (21,14,14) + } + + +def _make_atom14_and_atom37_constants(): + restype_atom14_to_atom37 = [] + restype_atom37_to_atom14 = [] + restype_atom14_mask = [] + + for rt in restypes: + atom_names = restype_name_to_atom14_names[restype_1to3[rt]] + restype_atom14_to_atom37.append([(atom_order[name] if name else 0) + for name in atom_names]) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in atom_types + ]) + + restype_atom14_mask.append([(1.0 if name else 0.0) + for name in atom_names]) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.0] * 14) + + restype_atom14_to_atom37 = np.array( + restype_atom14_to_atom37, dtype=np.int32) + restype_atom37_to_atom14 = np.array( + restype_atom37_to_atom14, dtype=np.int32) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + + return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask + + +( + restype_atom14_to_atom37, + restype_atom37_to_atom14, + restype_atom14_mask, +) = _make_atom14_and_atom37_constants() + + +def _make_renaming_matrices(): + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [restype_1to3[res] for res in restypes] + restype_3 += ['UNK'] + + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14) for res in restype_3} + for resname, swap in residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = restype_name_to_atom14_names[resname].index( + source_atom_swap) + target_index = restype_name_to_atom14_names[resname].index( + target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14)) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1.0 + all_matrices[resname] = renaming_matrix + renaming_matrices = np.stack( + [all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +renaming_matrices = _make_renaming_matrices() + + +def _make_atom14_is_ambiguous(): + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = np.zeros((21, 14)) + for resname, swap in residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = restype_order[restype_3to1[resname]] + atom_idx1 = restype_name_to_atom14_names[resname].index(atom_name1) + atom_idx2 = restype_name_to_atom14_names[resname].index(atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + return restype_atom14_is_ambiguous + + +restype_atom14_is_ambiguous = _make_atom14_is_ambiguous() + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in restypes: + residue_name = restype_1to3[residue_name] + residue_chi_angles = chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append([atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, + 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return chi_atom_indices + + +chi_atom_indices = get_chi_atom_indices() diff --git a/modelscope/models/science/unifold/data/stereo_chemical_props.txt b/modelscope/models/science/unifold/data/stereo_chemical_props.txt new file mode 100644 index 00000000..25262efd --- /dev/null +++ b/modelscope/models/science/unifold/data/stereo_chemical_props.txt @@ -0,0 +1,345 @@ +Bond Residue Mean StdDev +CA-CB ALA 1.520 0.021 +N-CA ALA 1.459 0.020 +CA-C ALA 1.525 0.026 +C-O ALA 1.229 0.019 +CA-CB ARG 1.535 0.022 +CB-CG ARG 1.521 0.027 +CG-CD ARG 1.515 0.025 +CD-NE ARG 1.460 0.017 +NE-CZ ARG 1.326 0.013 +CZ-NH1 ARG 1.326 0.013 +CZ-NH2 ARG 1.326 0.013 +N-CA ARG 1.459 0.020 +CA-C ARG 1.525 0.026 +C-O ARG 1.229 0.019 +CA-CB ASN 1.527 0.026 +CB-CG ASN 1.506 0.023 +CG-OD1 ASN 1.235 0.022 +CG-ND2 ASN 1.324 0.025 +N-CA ASN 1.459 0.020 +CA-C ASN 1.525 0.026 +C-O ASN 1.229 0.019 +CA-CB ASP 1.535 0.022 +CB-CG ASP 1.513 0.021 +CG-OD1 ASP 1.249 0.023 +CG-OD2 ASP 1.249 0.023 +N-CA ASP 1.459 0.020 +CA-C ASP 1.525 0.026 +C-O ASP 1.229 0.019 +CA-CB CYS 1.526 0.013 +CB-SG CYS 1.812 0.016 +N-CA CYS 1.459 0.020 +CA-C CYS 1.525 0.026 +C-O CYS 1.229 0.019 +CA-CB GLU 1.535 0.022 +CB-CG GLU 1.517 0.019 +CG-CD GLU 1.515 0.015 +CD-OE1 GLU 1.252 0.011 +CD-OE2 GLU 1.252 0.011 +N-CA GLU 1.459 0.020 +CA-C GLU 1.525 0.026 +C-O GLU 1.229 0.019 +CA-CB GLN 1.535 0.022 +CB-CG GLN 1.521 0.027 +CG-CD GLN 1.506 0.023 +CD-OE1 GLN 1.235 0.022 +CD-NE2 GLN 1.324 0.025 +N-CA GLN 1.459 0.020 +CA-C GLN 1.525 0.026 +C-O GLN 1.229 0.019 +N-CA GLY 1.456 0.015 +CA-C GLY 1.514 0.016 +C-O GLY 1.232 0.016 +CA-CB HIS 1.535 0.022 +CB-CG HIS 1.492 0.016 +CG-ND1 HIS 1.369 0.015 +CG-CD2 HIS 1.353 0.017 +ND1-CE1 HIS 1.343 0.025 +CD2-NE2 HIS 1.415 0.021 +CE1-NE2 HIS 1.322 0.023 +N-CA HIS 1.459 0.020 +CA-C HIS 1.525 0.026 +C-O HIS 1.229 0.019 +CA-CB ILE 1.544 0.023 +CB-CG1 ILE 1.536 0.028 +CB-CG2 ILE 1.524 0.031 +CG1-CD1 ILE 1.500 0.069 +N-CA ILE 1.459 0.020 +CA-C ILE 1.525 0.026 +C-O ILE 1.229 0.019 +CA-CB LEU 1.533 0.023 +CB-CG LEU 1.521 0.029 +CG-CD1 LEU 1.514 0.037 +CG-CD2 LEU 1.514 0.037 +N-CA LEU 1.459 0.020 +CA-C LEU 1.525 0.026 +C-O LEU 1.229 0.019 +CA-CB LYS 1.535 0.022 +CB-CG LYS 1.521 0.027 +CG-CD LYS 1.520 0.034 +CD-CE LYS 1.508 0.025 +CE-NZ LYS 1.486 0.025 +N-CA LYS 1.459 0.020 +CA-C LYS 1.525 0.026 +C-O LYS 1.229 0.019 +CA-CB MET 1.535 0.022 +CB-CG MET 1.509 0.032 +CG-SD MET 1.807 0.026 +SD-CE MET 1.774 0.056 +N-CA MET 1.459 0.020 +CA-C MET 1.525 0.026 +C-O MET 1.229 0.019 +CA-CB PHE 1.535 0.022 +CB-CG PHE 1.509 0.017 +CG-CD1 PHE 1.383 0.015 +CG-CD2 PHE 1.383 0.015 +CD1-CE1 PHE 1.388 0.020 +CD2-CE2 PHE 1.388 0.020 +CE1-CZ PHE 1.369 0.019 +CE2-CZ PHE 1.369 0.019 +N-CA PHE 1.459 0.020 +CA-C PHE 1.525 0.026 +C-O PHE 1.229 0.019 +CA-CB PRO 1.531 0.020 +CB-CG PRO 1.495 0.050 +CG-CD PRO 1.502 0.033 +CD-N PRO 1.474 0.014 +N-CA PRO 1.468 0.017 +CA-C PRO 1.524 0.020 +C-O PRO 1.228 0.020 +CA-CB SER 1.525 0.015 +CB-OG SER 1.418 0.013 +N-CA SER 1.459 0.020 +CA-C SER 1.525 0.026 +C-O SER 1.229 0.019 +CA-CB THR 1.529 0.026 +CB-OG1 THR 1.428 0.020 +CB-CG2 THR 1.519 0.033 +N-CA THR 1.459 0.020 +CA-C THR 1.525 0.026 +C-O THR 1.229 0.019 +CA-CB TRP 1.535 0.022 +CB-CG TRP 1.498 0.018 +CG-CD1 TRP 1.363 0.014 +CG-CD2 TRP 1.432 0.017 +CD1-NE1 TRP 1.375 0.017 +NE1-CE2 TRP 1.371 0.013 +CD2-CE2 TRP 1.409 0.012 +CD2-CE3 TRP 1.399 0.015 +CE2-CZ2 TRP 1.393 0.017 +CE3-CZ3 TRP 1.380 0.017 +CZ2-CH2 TRP 1.369 0.019 +CZ3-CH2 TRP 1.396 0.016 +N-CA TRP 1.459 0.020 +CA-C TRP 1.525 0.026 +C-O TRP 1.229 0.019 +CA-CB TYR 1.535 0.022 +CB-CG TYR 1.512 0.015 +CG-CD1 TYR 1.387 0.013 +CG-CD2 TYR 1.387 0.013 +CD1-CE1 TYR 1.389 0.015 +CD2-CE2 TYR 1.389 0.015 +CE1-CZ TYR 1.381 0.013 +CE2-CZ TYR 1.381 0.013 +CZ-OH TYR 1.374 0.017 +N-CA TYR 1.459 0.020 +CA-C TYR 1.525 0.026 +C-O TYR 1.229 0.019 +CA-CB VAL 1.543 0.021 +CB-CG1 VAL 1.524 0.021 +CB-CG2 VAL 1.524 0.021 +N-CA VAL 1.459 0.020 +CA-C VAL 1.525 0.026 +C-O VAL 1.229 0.019 +- + +Angle Residue Mean StdDev +N-CA-CB ALA 110.1 1.4 +CB-CA-C ALA 110.1 1.5 +N-CA-C ALA 111.0 2.7 +CA-C-O ALA 120.1 2.1 +N-CA-CB ARG 110.6 1.8 +CB-CA-C ARG 110.4 2.0 +CA-CB-CG ARG 113.4 2.2 +CB-CG-CD ARG 111.6 2.6 +CG-CD-NE ARG 111.8 2.1 +CD-NE-CZ ARG 123.6 1.4 +NE-CZ-NH1 ARG 120.3 0.5 +NE-CZ-NH2 ARG 120.3 0.5 +NH1-CZ-NH2 ARG 119.4 1.1 +N-CA-C ARG 111.0 2.7 +CA-C-O ARG 120.1 2.1 +N-CA-CB ASN 110.6 1.8 +CB-CA-C ASN 110.4 2.0 +CA-CB-CG ASN 113.4 2.2 +CB-CG-ND2 ASN 116.7 2.4 +CB-CG-OD1 ASN 121.6 2.0 +ND2-CG-OD1 ASN 121.9 2.3 +N-CA-C ASN 111.0 2.7 +CA-C-O ASN 120.1 2.1 +N-CA-CB ASP 110.6 1.8 +CB-CA-C ASP 110.4 2.0 +CA-CB-CG ASP 113.4 2.2 +CB-CG-OD1 ASP 118.3 0.9 +CB-CG-OD2 ASP 118.3 0.9 +OD1-CG-OD2 ASP 123.3 1.9 +N-CA-C ASP 111.0 2.7 +CA-C-O ASP 120.1 2.1 +N-CA-CB CYS 110.8 1.5 +CB-CA-C CYS 111.5 1.2 +CA-CB-SG CYS 114.2 1.1 +N-CA-C CYS 111.0 2.7 +CA-C-O CYS 120.1 2.1 +N-CA-CB GLU 110.6 1.8 +CB-CA-C GLU 110.4 2.0 +CA-CB-CG GLU 113.4 2.2 +CB-CG-CD GLU 114.2 2.7 +CG-CD-OE1 GLU 118.3 2.0 +CG-CD-OE2 GLU 118.3 2.0 +OE1-CD-OE2 GLU 123.3 1.2 +N-CA-C GLU 111.0 2.7 +CA-C-O GLU 120.1 2.1 +N-CA-CB GLN 110.6 1.8 +CB-CA-C GLN 110.4 2.0 +CA-CB-CG GLN 113.4 2.2 +CB-CG-CD GLN 111.6 2.6 +CG-CD-OE1 GLN 121.6 2.0 +CG-CD-NE2 GLN 116.7 2.4 +OE1-CD-NE2 GLN 121.9 2.3 +N-CA-C GLN 111.0 2.7 +CA-C-O GLN 120.1 2.1 +N-CA-C GLY 113.1 2.5 +CA-C-O GLY 120.6 1.8 +N-CA-CB HIS 110.6 1.8 +CB-CA-C HIS 110.4 2.0 +CA-CB-CG HIS 113.6 1.7 +CB-CG-ND1 HIS 123.2 2.5 +CB-CG-CD2 HIS 130.8 3.1 +CG-ND1-CE1 HIS 108.2 1.4 +ND1-CE1-NE2 HIS 109.9 2.2 +CE1-NE2-CD2 HIS 106.6 2.5 +NE2-CD2-CG HIS 109.2 1.9 +CD2-CG-ND1 HIS 106.0 1.4 +N-CA-C HIS 111.0 2.7 +CA-C-O HIS 120.1 2.1 +N-CA-CB ILE 110.8 2.3 +CB-CA-C ILE 111.6 2.0 +CA-CB-CG1 ILE 111.0 1.9 +CB-CG1-CD1 ILE 113.9 2.8 +CA-CB-CG2 ILE 110.9 2.0 +CG1-CB-CG2 ILE 111.4 2.2 +N-CA-C ILE 111.0 2.7 +CA-C-O ILE 120.1 2.1 +N-CA-CB LEU 110.4 2.0 +CB-CA-C LEU 110.2 1.9 +CA-CB-CG LEU 115.3 2.3 +CB-CG-CD1 LEU 111.0 1.7 +CB-CG-CD2 LEU 111.0 1.7 +CD1-CG-CD2 LEU 110.5 3.0 +N-CA-C LEU 111.0 2.7 +CA-C-O LEU 120.1 2.1 +N-CA-CB LYS 110.6 1.8 +CB-CA-C LYS 110.4 2.0 +CA-CB-CG LYS 113.4 2.2 +CB-CG-CD LYS 111.6 2.6 +CG-CD-CE LYS 111.9 3.0 +CD-CE-NZ LYS 111.7 2.3 +N-CA-C LYS 111.0 2.7 +CA-C-O LYS 120.1 2.1 +N-CA-CB MET 110.6 1.8 +CB-CA-C MET 110.4 2.0 +CA-CB-CG MET 113.3 1.7 +CB-CG-SD MET 112.4 3.0 +CG-SD-CE MET 100.2 1.6 +N-CA-C MET 111.0 2.7 +CA-C-O MET 120.1 2.1 +N-CA-CB PHE 110.6 1.8 +CB-CA-C PHE 110.4 2.0 +CA-CB-CG PHE 113.9 2.4 +CB-CG-CD1 PHE 120.8 0.7 +CB-CG-CD2 PHE 120.8 0.7 +CD1-CG-CD2 PHE 118.3 1.3 +CG-CD1-CE1 PHE 120.8 1.1 +CG-CD2-CE2 PHE 120.8 1.1 +CD1-CE1-CZ PHE 120.1 1.2 +CD2-CE2-CZ PHE 120.1 1.2 +CE1-CZ-CE2 PHE 120.0 1.8 +N-CA-C PHE 111.0 2.7 +CA-C-O PHE 120.1 2.1 +N-CA-CB PRO 103.3 1.2 +CB-CA-C PRO 111.7 2.1 +CA-CB-CG PRO 104.8 1.9 +CB-CG-CD PRO 106.5 3.9 +CG-CD-N PRO 103.2 1.5 +CA-N-CD PRO 111.7 1.4 +N-CA-C PRO 112.1 2.6 +CA-C-O PRO 120.2 2.4 +N-CA-CB SER 110.5 1.5 +CB-CA-C SER 110.1 1.9 +CA-CB-OG SER 111.2 2.7 +N-CA-C SER 111.0 2.7 +CA-C-O SER 120.1 2.1 +N-CA-CB THR 110.3 1.9 +CB-CA-C THR 111.6 2.7 +CA-CB-OG1 THR 109.0 2.1 +CA-CB-CG2 THR 112.4 1.4 +OG1-CB-CG2 THR 110.0 2.3 +N-CA-C THR 111.0 2.7 +CA-C-O THR 120.1 2.1 +N-CA-CB TRP 110.6 1.8 +CB-CA-C TRP 110.4 2.0 +CA-CB-CG TRP 113.7 1.9 +CB-CG-CD1 TRP 127.0 1.3 +CB-CG-CD2 TRP 126.6 1.3 +CD1-CG-CD2 TRP 106.3 0.8 +CG-CD1-NE1 TRP 110.1 1.0 +CD1-NE1-CE2 TRP 109.0 0.9 +NE1-CE2-CD2 TRP 107.3 1.0 +CE2-CD2-CG TRP 107.3 0.8 +CG-CD2-CE3 TRP 133.9 0.9 +NE1-CE2-CZ2 TRP 130.4 1.1 +CE3-CD2-CE2 TRP 118.7 1.2 +CD2-CE2-CZ2 TRP 122.3 1.2 +CE2-CZ2-CH2 TRP 117.4 1.0 +CZ2-CH2-CZ3 TRP 121.6 1.2 +CH2-CZ3-CE3 TRP 121.2 1.1 +CZ3-CE3-CD2 TRP 118.8 1.3 +N-CA-C TRP 111.0 2.7 +CA-C-O TRP 120.1 2.1 +N-CA-CB TYR 110.6 1.8 +CB-CA-C TYR 110.4 2.0 +CA-CB-CG TYR 113.4 1.9 +CB-CG-CD1 TYR 121.0 0.6 +CB-CG-CD2 TYR 121.0 0.6 +CD1-CG-CD2 TYR 117.9 1.1 +CG-CD1-CE1 TYR 121.3 0.8 +CG-CD2-CE2 TYR 121.3 0.8 +CD1-CE1-CZ TYR 119.8 0.9 +CD2-CE2-CZ TYR 119.8 0.9 +CE1-CZ-CE2 TYR 119.8 1.6 +CE1-CZ-OH TYR 120.1 2.7 +CE2-CZ-OH TYR 120.1 2.7 +N-CA-C TYR 111.0 2.7 +CA-C-O TYR 120.1 2.1 +N-CA-CB VAL 111.5 2.2 +CB-CA-C VAL 111.4 1.9 +CA-CB-CG1 VAL 110.9 1.5 +CA-CB-CG2 VAL 110.9 1.5 +CG1-CB-CG2 VAL 110.9 1.6 +N-CA-C VAL 111.0 2.7 +CA-C-O VAL 120.1 2.1 +- + +Non-bonded distance Minimum Dist Tolerance +C-C 3.4 1.5 +C-N 3.25 1.5 +C-S 3.5 1.5 +C-O 3.22 1.5 +N-N 3.1 1.5 +N-S 3.35 1.5 +N-O 3.07 1.5 +O-S 3.32 1.5 +O-O 3.04 1.5 +S-S 2.03 1.0 +- diff --git a/modelscope/models/science/unifold/data/utils.py b/modelscope/models/science/unifold/data/utils.py new file mode 100644 index 00000000..2be91ef0 --- /dev/null +++ b/modelscope/models/science/unifold/data/utils.py @@ -0,0 +1,161 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import copy as copy_lib +import functools +import gzip +import pickle +from typing import Any, Dict + +import json +import numpy as np +from scipy import sparse as sp + +from . import residue_constants as rc +from .data_ops import NumpyDict + +# from typing import * + + +def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False): + if deepcopy: + + def decorator(f): + cached_func = functools.lru_cache(maxsize, typed)(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return copy_lib.deepcopy(cached_func(*args, **kwargs)) + + return wrapper + + elif copy: + + def decorator(f): + cached_func = functools.lru_cache(maxsize, typed)(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return copy_lib.copy(cached_func(*args, **kwargs)) + + return wrapper + + else: + decorator = functools.lru_cache(maxsize, typed) + return decorator + + +@lru_cache(maxsize=8, deepcopy=True) +def load_pickle_safe(path: str) -> Dict[str, Any]: + + def load(path): + assert path.endswith('.pkl') or path.endswith( + '.pkl.gz'), f'bad suffix in {path} as pickle file.' + open_fn = gzip.open if path.endswith('.gz') else open + with open_fn(path, 'rb') as f: + return pickle.load(f) + + ret = load(path) + ret = uncompress_features(ret) + return ret + + +@lru_cache(maxsize=8, copy=True) +def load_pickle(path: str) -> Dict[str, Any]: + + def load(path): + assert path.endswith('.pkl') or path.endswith( + '.pkl.gz'), f'bad suffix in {path} as pickle file.' + open_fn = gzip.open if path.endswith('.gz') else open + with open_fn(path, 'rb') as f: + return pickle.load(f) + + ret = load(path) + ret = uncompress_features(ret) + return ret + + +def correct_template_restypes(feature): + """Correct template restype to have the same order as residue_constants.""" + feature = np.argmax(feature, axis=-1).astype(np.int32) + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + feature = np.take(new_order_list, feature.astype(np.int32), axis=0) + return feature + + +def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict: + feature['msa'] = feature['msa'].astype(np.uint8) + if 'num_alignments' in feature: + feature.pop('num_alignments') + # make_all_seq_key = lambda k: f'{k}_all_seq' if not k.endswith('_all_seq') else k + + def make_all_seq_key(k): + if not k.endswith('_all_seq'): + return f'{k}_all_seq' + return k + + return {make_all_seq_key(k): v for k, v in feature.items()} + + +def to_dense_matrix(spmat_dict: NumpyDict): + spmat = sp.coo_matrix( + (spmat_dict['data'], (spmat_dict['row'], spmat_dict['col'])), + shape=spmat_dict['shape'], + dtype=np.float32, + ) + return spmat.toarray() + + +FEATS_DTYPE = {'msa': np.int32} + + +def uncompress_features(feats: NumpyDict) -> NumpyDict: + if 'sparse_deletion_matrix_int' in feats: + v = feats.pop('sparse_deletion_matrix_int') + v = to_dense_matrix(v) + feats['deletion_matrix'] = v + return feats + + +def filter(feature: NumpyDict, **kwargs) -> NumpyDict: + assert len(kwargs) == 1, f'wrong usage of filter with kwargs: {kwargs}' + if 'desired_keys' in kwargs: + feature = { + k: v + for k, v in feature.items() if k in kwargs['desired_keys'] + } + elif 'required_keys' in kwargs: + for k in kwargs['required_keys']: + assert k in feature, f'cannot find required key {k}.' + elif 'ignored_keys' in kwargs: + feature = { + k: v + for k, v in feature.items() if k not in kwargs['ignored_keys'] + } + else: + raise AssertionError(f'wrong usage of filter with kwargs: {kwargs}') + return feature + + +def compress_features(features: NumpyDict): + change_dtype = { + 'msa': np.uint8, + } + sparse_keys = ['deletion_matrix_int'] + + compressed_features = {} + for k, v in features.items(): + if k in change_dtype: + v = v.astype(change_dtype[k]) + if k in sparse_keys: + v = sp.coo_matrix(v, dtype=v.dtype) + sp_v = { + 'shape': v.shape, + 'row': v.row, + 'col': v.col, + 'data': v.data + } + k = f'sparse_{k}' + v = sp_v + compressed_features[k] = v + return compressed_features diff --git a/modelscope/models/science/unifold/dataset.py b/modelscope/models/science/unifold/dataset.py new file mode 100644 index 00000000..29e1a8b0 --- /dev/null +++ b/modelscope/models/science/unifold/dataset.py @@ -0,0 +1,517 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import copy +import logging +import os +# from typing import * +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import json +import ml_collections as mlc +import numpy as np +import torch +from unicore.data import UnicoreDataset, data_utils +from unicore.distributed import utils as distributed_utils + +from .data import utils +from .data.data_ops import NumpyDict, TorchDict +from .data.process import process_features, process_labels +from .data.process_multimer import (add_assembly_features, + convert_monomer_features, merge_msas, + pair_and_merge, post_process) + +Rotation = Iterable[Iterable] +Translation = Iterable +Operation = Union[str, Tuple[Rotation, Translation]] +NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]] +TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]] + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def make_data_config( + config: mlc.ConfigDict, + mode: str, + num_res: int, +) -> Tuple[mlc.ConfigDict, List[str]]: + cfg = copy.deepcopy(config) + mode_cfg = cfg[mode] + with cfg.unlocked(): + if mode_cfg.crop_size is None: + mode_cfg.crop_size = num_res + feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features + if cfg.common.use_templates: + feature_names += cfg.common.template_features + if cfg.common.is_multimer: + feature_names += cfg.common.multimer_features + if cfg[mode].supervised: + feature_names += cfg.supervised.supervised_features + + return cfg, feature_names + + +def process_label(all_atom_positions: np.ndarray, + operation: Operation) -> np.ndarray: + if operation == 'I': + return all_atom_positions + rot, trans = operation + rot = np.array(rot).reshape(3, 3) + trans = np.array(trans).reshape(3) + return all_atom_positions @ rot.T + trans + + +@utils.lru_cache(maxsize=8, copy=True) +def load_single_feature( + sequence_id: str, + monomer_feature_dir: str, + uniprot_msa_dir: Optional[str] = None, + is_monomer: bool = False, +) -> NumpyDict: + + monomer_feature = utils.load_pickle( + os.path.join(monomer_feature_dir, f'{sequence_id}.feature.pkl.gz')) + monomer_feature = convert_monomer_features(monomer_feature) + chain_feature = {**monomer_feature} + + if uniprot_msa_dir is not None: + all_seq_feature = utils.load_pickle( + os.path.join(uniprot_msa_dir, f'{sequence_id}.uniprot.pkl.gz')) + if is_monomer: + chain_feature['msa'], chain_feature[ + 'deletion_matrix'] = merge_msas( + chain_feature['msa'], + chain_feature['deletion_matrix'], + all_seq_feature['msa'], + all_seq_feature['deletion_matrix'], + ) # noqa + else: + all_seq_feature = utils.convert_all_seq_feature(all_seq_feature) + for key in [ + 'msa_all_seq', + 'msa_species_identifiers_all_seq', + 'deletion_matrix_all_seq', + ]: + chain_feature[key] = all_seq_feature[key] + + return chain_feature + + +def load_single_label( + label_id: str, + label_dir: str, + symmetry_operation: Optional[Operation] = None, +) -> NumpyDict: + label = utils.load_pickle( + os.path.join(label_dir, f'{label_id}.label.pkl.gz')) + if symmetry_operation is not None: + label['all_atom_positions'] = process_label( + label['all_atom_positions'], symmetry_operation) + label = { + k: v + for k, v in label.items() if k in + ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] + } + return label + + +def load( + sequence_ids: List[str], + monomer_feature_dir: str, + uniprot_msa_dir: Optional[str] = None, + label_ids: Optional[List[str]] = None, + label_dir: Optional[str] = None, + symmetry_operations: Optional[List[Operation]] = None, + is_monomer: bool = False, +) -> NumpyExample: + + all_chain_features = [ + load_single_feature(s, monomer_feature_dir, uniprot_msa_dir, + is_monomer) for s in sequence_ids + ] + + if label_ids is not None: + # load labels + assert len(label_ids) == len(sequence_ids) + assert label_dir is not None + if symmetry_operations is None: + symmetry_operations = ['I' for _ in label_ids] + all_chain_labels = [ + load_single_label(ll, label_dir, o) + for ll, o in zip(label_ids, symmetry_operations) + ] + # update labels into features to calculate spatial cropping etc. + [f.update(ll) for f, ll in zip(all_chain_features, all_chain_labels)] + + all_chain_features = add_assembly_features(all_chain_features) + + # get labels back from features, as add_assembly_features may alter the order of inputs. + if label_ids is not None: + all_chain_labels = [{ + k: f[k] + for k in + ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] + } for f in all_chain_features] + else: + all_chain_labels = None + + asym_len = np.array([c['seq_length'] for c in all_chain_features], + dtype=np.int64) + if is_monomer: + all_chain_features = all_chain_features[0] + else: + all_chain_features = pair_and_merge(all_chain_features) + all_chain_features = post_process(all_chain_features) + all_chain_features['asym_len'] = asym_len + + return all_chain_features, all_chain_labels + + +def process( + config: mlc.ConfigDict, + mode: str, + features: NumpyDict, + labels: Optional[List[NumpyDict]] = None, + seed: int = 0, + batch_idx: Optional[int] = None, + data_idx: Optional[int] = None, + is_distillation: bool = False, +) -> TorchExample: + + if mode == 'train': + assert batch_idx is not None + with data_utils.numpy_seed(seed, batch_idx, key='recycling'): + num_iters = np.random.randint( + 0, config.common.max_recycling_iters + 1) + use_clamped_fape = np.random.rand( + ) < config[mode].use_clamped_fape_prob + else: + num_iters = config.common.max_recycling_iters + use_clamped_fape = 1 + + features['num_recycling_iters'] = int(num_iters) + features['use_clamped_fape'] = int(use_clamped_fape) + features['is_distillation'] = int(is_distillation) + if is_distillation and 'msa_chains' in features: + features.pop('msa_chains') + + num_res = int(features['seq_length']) + cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) + + if labels is not None: + features['resolution'] = labels[0]['resolution'].reshape(-1) + + with data_utils.numpy_seed(seed, data_idx, key='protein_feature'): + features['crop_and_fix_size_seed'] = np.random.randint(0, 63355) + features = utils.filter(features, desired_keys=feature_names) + features = {k: torch.tensor(v) for k, v in features.items()} + with torch.no_grad(): + features = process_features(features, cfg.common, cfg[mode]) + + if labels is not None: + labels = [{k: torch.tensor(v) for k, v in ll.items()} for ll in labels] + with torch.no_grad(): + labels = process_labels(labels) + + return features, labels + + +def load_and_process( + config: mlc.ConfigDict, + mode: str, + seed: int = 0, + batch_idx: Optional[int] = None, + data_idx: Optional[int] = None, + is_distillation: bool = False, + **load_kwargs, +): + is_monomer = ( + is_distillation + if 'is_monomer' not in load_kwargs else load_kwargs.pop('is_monomer')) + features, labels = load(**load_kwargs, is_monomer=is_monomer) + features, labels = process(config, mode, features, labels, seed, batch_idx, + data_idx, is_distillation) + return features, labels + + +class UnifoldDataset(UnicoreDataset): + + def __init__( + self, + args, + seed, + config, + data_path, + mode='train', + max_step=None, + disable_sd=False, + json_prefix='', + ): + self.path = data_path + + def load_json(filename): + return json.load(open(filename, 'r')) + + sample_weight = load_json( + os.path.join(self.path, + json_prefix + mode + '_sample_weight.json')) + self.multi_label = load_json( + os.path.join(self.path, json_prefix + mode + '_multi_label.json')) + self.inverse_multi_label = self._inverse_map(self.multi_label) + self.sample_weight = {} + for chain in self.inverse_multi_label: + entity = self.inverse_multi_label[chain] + self.sample_weight[chain] = sample_weight[entity] + self.seq_sample_weight = sample_weight + logger.info('load {} chains (unique {} sequences)'.format( + len(self.sample_weight), len(self.seq_sample_weight))) + self.feature_path = os.path.join(self.path, 'pdb_features') + self.label_path = os.path.join(self.path, 'pdb_labels') + sd_sample_weight_path = os.path.join( + self.path, json_prefix + 'sd_train_sample_weight.json') + if mode == 'train' and os.path.isfile( + sd_sample_weight_path) and not disable_sd: + self.sd_sample_weight = load_json(sd_sample_weight_path) + logger.info('load {} self-distillation samples.'.format( + len(self.sd_sample_weight))) + self.sd_feature_path = os.path.join(self.path, 'sd_features') + self.sd_label_path = os.path.join(self.path, 'sd_labels') + else: + self.sd_sample_weight = None + self.batch_size = ( + args.batch_size * distributed_utils.get_data_parallel_world_size() + * args.update_freq[0]) + self.data_len = ( + max_step * self.batch_size + if max_step is not None else len(self.sample_weight)) + self.mode = mode + self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight( + self.seq_sample_weight) + self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( + self.sample_weight) + if self.sd_sample_weight is not None: + ( + self.sd_num_chain, + self.sd_chain_keys, + self.sd_sample_prob, + ) = self.cal_sample_weight(self.sd_sample_weight) + self.config = config.data + self.seed = seed + self.sd_prob = args.sd_prob + + def cal_sample_weight(self, sample_weight): + prot_keys = list(sample_weight.keys()) + sum_weight = sum(sample_weight.values()) + sample_prob = [sample_weight[k] / sum_weight for k in prot_keys] + num_prot = len(prot_keys) + return num_prot, prot_keys, sample_prob + + def sample_chain(self, idx, sample_by_seq=False): + is_distillation = False + if self.mode == 'train': + with data_utils.numpy_seed(self.seed, idx, key='data_sample'): + is_distillation = ((np.random.rand(1)[0] < self.sd_prob) + if self.sd_sample_weight is not None else + False) + if is_distillation: + prot_idx = np.random.choice( + self.sd_num_chain, p=self.sd_sample_prob) + label_name = self.sd_chain_keys[prot_idx] + seq_name = label_name + else: + if not sample_by_seq: + prot_idx = np.random.choice( + self.num_chain, p=self.sample_prob) + label_name = self.chain_keys[prot_idx] + seq_name = self.inverse_multi_label[label_name] + else: + seq_idx = np.random.choice( + self.num_seq, p=self.seq_sample_prob) + seq_name = self.seq_keys[seq_idx] + label_name = np.random.choice( + self.multi_label[seq_name]) + else: + label_name = self.chain_keys[idx] + seq_name = self.inverse_multi_label[label_name] + return seq_name, label_name, is_distillation + + def __getitem__(self, idx): + sequence_id, label_id, is_distillation = self.sample_chain( + idx, sample_by_seq=True) + feature_dir, label_dir = ((self.feature_path, + self.label_path) if not is_distillation else + (self.sd_feature_path, self.sd_label_path)) + features, _ = load_and_process( + self.config, + self.mode, + self.seed, + batch_idx=(idx // self.batch_size), + data_idx=idx, + is_distillation=is_distillation, + sequence_ids=[sequence_id], + monomer_feature_dir=feature_dir, + uniprot_msa_dir=None, + label_ids=[label_id], + label_dir=label_dir, + symmetry_operations=None, + is_monomer=True, + ) + return features + + def __len__(self): + return self.data_len + + @staticmethod + def collater(samples): + # first dim is recyling. bsz is at the 2nd dim + return data_utils.collate_dict(samples, dim=1) + + @staticmethod + def _inverse_map(mapping: Dict[str, List[str]]): + inverse_mapping = {} + for ent, refs in mapping.items(): + for ref in refs: + if ref in inverse_mapping: # duplicated ent for this ref. + ent_2 = inverse_mapping[ref] + assert ( + ent == ent_2 + ), f'multiple entities ({ent_2}, {ent}) exist for reference {ref}.' + inverse_mapping[ref] = ent + return inverse_mapping + + +class UnifoldMultimerDataset(UnifoldDataset): + + def __init__( + self, + args: mlc.ConfigDict, + seed: int, + config: mlc.ConfigDict, + data_path: str, + mode: str = 'train', + max_step: Optional[int] = None, + disable_sd: bool = False, + json_prefix: str = '', + **kwargs, + ): + super().__init__(args, seed, config, data_path, mode, max_step, + disable_sd, json_prefix) + self.data_path = data_path + self.pdb_assembly = json.load( + open( + os.path.join(self.data_path, + json_prefix + 'pdb_assembly.json'))) + self.pdb_chains = self.get_chains(self.inverse_multi_label) + self.monomer_feature_path = os.path.join(self.data_path, + 'pdb_features') + self.uniprot_msa_path = os.path.join(self.data_path, 'pdb_uniprots') + self.label_path = os.path.join(self.data_path, 'pdb_labels') + self.max_chains = args.max_chains + if self.mode == 'train': + self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains( + self.pdb_chains, self.pdb_assembly, self.sample_weight, + self.max_chains) + self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( + self.sample_weight) + + def __getitem__(self, idx): + seq_id, label_id, is_distillation = self.sample_chain(idx) + if is_distillation: + label_ids = [label_id] + sequence_ids = [seq_id] + monomer_feature_path, uniprot_msa_path, label_path = ( + self.sd_feature_path, + None, + self.sd_label_path, + ) + symmetry_operations = None + else: + pdb_id = self.get_pdb_name(label_id) + if pdb_id in self.pdb_assembly and self.mode == 'train': + label_ids = [ + pdb_id + '_' + id + for id in self.pdb_assembly[pdb_id]['chains'] + ] + symmetry_operations = [ + t for t in self.pdb_assembly[pdb_id]['opers'] + ] + else: + label_ids = self.pdb_chains[pdb_id] + symmetry_operations = None + sequence_ids = [ + self.inverse_multi_label[chain_id] for chain_id in label_ids + ] + monomer_feature_path, uniprot_msa_path, label_path = ( + self.monomer_feature_path, + self.uniprot_msa_path, + self.label_path, + ) + + return load_and_process( + self.config, + self.mode, + self.seed, + batch_idx=(idx // self.batch_size), + data_idx=idx, + is_distillation=is_distillation, + sequence_ids=sequence_ids, + monomer_feature_dir=monomer_feature_path, + uniprot_msa_dir=uniprot_msa_path, + label_ids=label_ids, + label_dir=label_path, + symmetry_operations=symmetry_operations, + is_monomer=False, + ) + + @staticmethod + def collater(samples): + # first dim is recyling. bsz is at the 2nd dim + if len(samples) <= 0: # tackle empty batch + return None + feats = [s[0] for s in samples] + labs = [s[1] for s in samples if s[1] is not None] + try: + feats = data_utils.collate_dict(feats, dim=1) + except BaseException: + raise ValueError('cannot collate features', feats) + if not labs: + labs = None + return feats, labs + + @staticmethod + def get_pdb_name(chain): + return chain.split('_')[0] + + @staticmethod + def get_chains(canon_chain_map): + pdb_chains = {} + for chain in canon_chain_map: + pdb = UnifoldMultimerDataset.get_pdb_name(chain) + if pdb not in pdb_chains: + pdb_chains[pdb] = [] + pdb_chains[pdb].append(chain) + return pdb_chains + + @staticmethod + def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight, + max_chains): + new_pdb_chains = {} + for chain in pdb_chains: + if chain in pdb_assembly: + size = len(pdb_assembly[chain]['chains']) + if size <= max_chains: + new_pdb_chains[chain] = pdb_chains[chain] + else: + size = len(pdb_chains[chain]) + if size == 1: + new_pdb_chains[chain] = pdb_chains[chain] + new_sample_weight = { + k: sample_weight[k] + for k in sample_weight + if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains + } + logger.info( + f'filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs ' + f'({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) ' + f'by max_chains {max_chains}') + return new_pdb_chains, new_sample_weight diff --git a/modelscope/models/science/unifold/model.py b/modelscope/models/science/unifold/model.py new file mode 100644 index 00000000..7f28f18d --- /dev/null +++ b/modelscope/models/science/unifold/model.py @@ -0,0 +1,78 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import argparse +import os +from typing import Any + +import torch + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .config import model_config +from .modules.alphafold import AlphaFold + +__all__ = ['UnifoldForProteinStructrue'] + + +@MODELS.register_module(Tasks.protein_structure, module_name=Models.unifold) +class UnifoldForProteinStructrue(TorchModel): + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + '--model-name', + help='choose the model config', + ) + + def __init__(self, **kwargs): + super().__init__() + parser = argparse.ArgumentParser() + parse_comm = [] + for key in kwargs: + parser.add_argument(f'--{key}') + parse_comm.append(f'--{key}') + parse_comm.append(kwargs[key]) + args = parser.parse_args(parse_comm) + base_architecture(args) + self.args = args + config = model_config( + self.args.model_name, + train=True, + ) + self.model = AlphaFold(config) + self.config = config + + # load model state dict + param_path = os.path.join(kwargs['model_dir'], + ModelFile.TORCH_MODEL_BIN_FILE) + state_dict = torch.load(param_path)['ema']['params'] + state_dict = { + '.'.join(k.split('.')[1:]): v + for k, v in state_dict.items() + } + self.model.load_state_dict(state_dict) + + def half(self): + self.model = self.model.half() + return self + + def bfloat16(self): + self.model = self.model.bfloat16() + return self + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + return cls(args) + + def forward(self, batch, **kwargs): + outputs = self.model.forward(batch) + return outputs, self.config.loss + + +def base_architecture(args): + args.model_name = getattr(args, 'model_name', 'model_2') diff --git a/modelscope/models/science/unifold/modules/__init__.py b/modelscope/models/science/unifold/modules/__init__.py new file mode 100644 index 00000000..63aa84ed --- /dev/null +++ b/modelscope/models/science/unifold/modules/__init__.py @@ -0,0 +1,3 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. +"""Unifold Modules.""" diff --git a/modelscope/models/science/unifold/modules/alphafold.py b/modelscope/models/science/unifold/modules/alphafold.py new file mode 100644 index 00000000..71a1b310 --- /dev/null +++ b/modelscope/models/science/unifold/modules/alphafold.py @@ -0,0 +1,450 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import torch +import torch.nn as nn +from unicore.utils import tensor_tree_map + +from ..data import residue_constants +from .attentions import gen_msa_attn_mask, gen_tri_attn_mask +from .auxillary_heads import AuxiliaryHeads +from .common import residual +from .embedders import (ExtraMSAEmbedder, InputEmbedder, RecyclingEmbedder, + TemplateAngleEmbedder, TemplatePairEmbedder) +from .evoformer import EvoformerStack, ExtraMSAStack +from .featurization import (atom14_to_atom37, build_extra_msa_feat, + build_template_angle_feat, + build_template_pair_feat, + build_template_pair_feat_v2, pseudo_beta_fn) +from .structure_module import StructureModule +from .template import (TemplatePairStack, TemplatePointwiseAttention, + TemplateProjection) + + +class AlphaFold(nn.Module): + + def __init__(self, config): + super(AlphaFold, self).__init__() + + self.globals = config.globals + config = config.model + template_config = config.template + extra_msa_config = config.extra_msa + + self.input_embedder = InputEmbedder( + **config['input_embedder'], + use_chain_relative=config.is_multimer, + ) + self.recycling_embedder = RecyclingEmbedder( + **config['recycling_embedder'], ) + if config.template.enabled: + self.template_angle_embedder = TemplateAngleEmbedder( + **template_config['template_angle_embedder'], ) + self.template_pair_embedder = TemplatePairEmbedder( + **template_config['template_pair_embedder'], ) + self.template_pair_stack = TemplatePairStack( + **template_config['template_pair_stack'], ) + else: + self.template_pair_stack = None + self.enable_template_pointwise_attention = template_config[ + 'template_pointwise_attention'].enabled + if self.enable_template_pointwise_attention: + self.template_pointwise_att = TemplatePointwiseAttention( + **template_config['template_pointwise_attention'], ) + else: + self.template_proj = TemplateProjection( + **template_config['template_pointwise_attention'], ) + self.extra_msa_embedder = ExtraMSAEmbedder( + **extra_msa_config['extra_msa_embedder'], ) + self.extra_msa_stack = ExtraMSAStack( + **extra_msa_config['extra_msa_stack'], ) + self.evoformer = EvoformerStack(**config['evoformer_stack'], ) + self.structure_module = StructureModule(**config['structure_module'], ) + + self.aux_heads = AuxiliaryHeads(config['heads'], ) + + self.config = config + self.dtype = torch.float + self.inf = self.globals.inf + if self.globals.alphafold_original_mode: + self.alphafold_original_mode() + + def __make_input_float__(self): + self.input_embedder = self.input_embedder.float() + self.recycling_embedder = self.recycling_embedder.float() + + def half(self): + super().half() + if (not getattr(self, 'inference', False)): + self.__make_input_float__() + self.dtype = torch.half + return self + + def bfloat16(self): + super().bfloat16() + if (not getattr(self, 'inference', False)): + self.__make_input_float__() + self.dtype = torch.bfloat16 + return self + + def alphafold_original_mode(self): + + def set_alphafold_original_mode(module): + if hasattr(module, 'apply_alphafold_original_mode'): + module.apply_alphafold_original_mode() + if hasattr(module, 'act'): + module.act = nn.ReLU() + + self.apply(set_alphafold_original_mode) + + def inference_mode(self): + + def set_inference_mode(module): + setattr(module, 'inference', True) + + self.apply(set_inference_mode) + + def __convert_input_dtype__(self, batch): + for key in batch: + # only convert features with mask + if batch[key].dtype != self.dtype and 'mask' in key: + batch[key] = batch[key].type(self.dtype) + return batch + + def embed_templates_pair_core(self, batch, z, pair_mask, + tri_start_attn_mask, tri_end_attn_mask, + templ_dim, multichain_mask_2d): + if self.config.template.template_pair_embedder.v2_feature: + t = build_template_pair_feat_v2( + batch, + inf=self.config.template.inf, + eps=self.config.template.eps, + multichain_mask_2d=multichain_mask_2d, + **self.config.template.distogram, + ) + num_template = t[0].shape[-4] + single_templates = [ + self.template_pair_embedder([x[..., ti, :, :, :] + for x in t], z) + for ti in range(num_template) + ] + else: + t = build_template_pair_feat( + batch, + inf=self.config.template.inf, + eps=self.config.template.eps, + **self.config.template.distogram, + ) + single_templates = [ + self.template_pair_embedder(x, z) + for x in torch.unbind(t, dim=templ_dim) + ] + + t = self.template_pair_stack( + single_templates, + pair_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + templ_dim=templ_dim, + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + return_mean=not self.enable_template_pointwise_attention, + ) + return t + + def embed_templates_pair(self, batch, z, pair_mask, tri_start_attn_mask, + tri_end_attn_mask, templ_dim): + if self.config.template.template_pair_embedder.v2_feature and 'asym_id' in batch: + multichain_mask_2d = ( + batch['asym_id'][..., :, None] == batch['asym_id'][..., + None, :]) + multichain_mask_2d = multichain_mask_2d.unsqueeze(0) + else: + multichain_mask_2d = None + + if self.training or self.enable_template_pointwise_attention: + t = self.embed_templates_pair_core(batch, z, pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, templ_dim, + multichain_mask_2d) + if self.enable_template_pointwise_attention: + t = self.template_pointwise_att( + t, + z, + template_mask=batch['template_mask'], + chunk_size=self.globals.chunk_size, + ) + t_mask = torch.sum( + batch['template_mask'], dim=-1, keepdims=True) > 0 + t_mask = t_mask[..., None, None].type(t.dtype) + t *= t_mask + else: + t = self.template_proj(t, z) + else: + template_aatype_shape = batch['template_aatype'].shape + # template_aatype is either [n_template, n_res] or [1, n_template_, n_res] + batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0 + n_templ = batch['template_aatype'].shape[batch_templ_dim] + + if n_templ <= 0: + t = None + else: + template_batch = { + k: v + for k, v in batch.items() if k.startswith('template_') + } + + def embed_one_template(i): + + def slice_template_tensor(t): + s = [slice(None) for _ in t.shape] + s[batch_templ_dim] = slice(i, i + 1) + return t[s] + + template_feats = tensor_tree_map( + slice_template_tensor, + template_batch, + ) + t = self.embed_templates_pair_core( + template_feats, z, pair_mask, tri_start_attn_mask, + tri_end_attn_mask, templ_dim, multichain_mask_2d) + return t + + t = embed_one_template(0) + # iterate templates one by one + for i in range(1, n_templ): + t += embed_one_template(i) + t /= n_templ + t = self.template_proj(t, z) + return t + + def embed_templates_angle(self, batch): + template_angle_feat, template_angle_mask = build_template_angle_feat( + batch, + v2_feature=self.config.template.template_pair_embedder.v2_feature) + t = self.template_angle_embedder(template_angle_feat) + return t, template_angle_mask + + def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev): + batch_dims = feats['target_feat'].shape[:-2] + n = feats['target_feat'].shape[-2] + seq_mask = feats['seq_mask'] + pair_mask = seq_mask[..., None] * seq_mask[..., None, :] + msa_mask = feats['msa_mask'] + + m, z = self.input_embedder( + feats['target_feat'], + feats['msa_feat'], + ) + + if m_1_prev is None: + m_1_prev = m.new_zeros( + (*batch_dims, n, self.config.input_embedder.d_msa), + requires_grad=False, + ) + if z_prev is None: + z_prev = z.new_zeros( + (*batch_dims, n, n, self.config.input_embedder.d_pair), + requires_grad=False, + ) + if x_prev is None: + x_prev = z.new_zeros( + (*batch_dims, n, residue_constants.atom_type_num, 3), + requires_grad=False, + ) + x_prev = pseudo_beta_fn(feats['aatype'], x_prev, None) + + z += self.recycling_embedder.recyle_pos(x_prev) + + m_1_prev_emb, z_prev_emb = self.recycling_embedder( + m_1_prev, + z_prev, + ) + + m[..., 0, :, :] += m_1_prev_emb + + z += z_prev_emb + + z += self.input_embedder.relpos_emb( + feats['residue_index'].long(), + feats.get('sym_id', None), + feats.get('asym_id', None), + feats.get('entity_id', None), + feats.get('num_sym', None), + ) + + m = m.type(self.dtype) + z = z.type(self.dtype) + tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask( + pair_mask, self.inf) + + if self.config.template.enabled: + template_mask = feats['template_mask'] + if torch.any(template_mask): + z = residual( + z, + self.embed_templates_pair( + feats, + z, + pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, + templ_dim=-4, + ), + self.training, + ) + + if self.config.extra_msa.enabled: + a = self.extra_msa_embedder(build_extra_msa_feat(feats)) + extra_msa_row_mask = gen_msa_attn_mask( + feats['extra_msa_mask'], + inf=self.inf, + gen_col_mask=False, + ) + z = self.extra_msa_stack( + a, + z, + msa_mask=feats['extra_msa_mask'], + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + pair_mask=pair_mask, + msa_row_attn_mask=extra_msa_row_mask, + msa_col_attn_mask=None, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + ) + + if self.config.template.embed_angles: + template_1d_feat, template_1d_mask = self.embed_templates_angle( + feats) + m = torch.cat([m, template_1d_feat], dim=-3) + msa_mask = torch.cat([feats['msa_mask'], template_1d_mask], dim=-2) + + msa_row_mask, msa_col_mask = gen_msa_attn_mask( + msa_mask, + inf=self.inf, + ) + + m, z, s = self.evoformer( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_mask, + msa_col_attn_mask=msa_col_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + ) + return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb + + def iteration_evoformer_structure_module(self, + batch, + m_1_prev, + z_prev, + x_prev, + cycle_no, + num_recycling, + num_ensembles=1): + z, s = 0, 0 + n_seq = batch['msa_feat'].shape[-3] + assert num_ensembles >= 1 + for ensemble_no in range(num_ensembles): + idx = cycle_no * num_ensembles + ensemble_no + + # fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...] + def fetch_cur_batch(t): + return t[min(t.shape[0] - 1, idx), ...] + + feats = tensor_tree_map(fetch_cur_batch, batch) + m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer( + feats, m_1_prev, z_prev, x_prev) + z += z0 + s += s0 + del z0, s0 + if num_ensembles > 1: + z /= float(num_ensembles) + s /= float(num_ensembles) + + outputs = {} + + outputs['msa'] = m[..., :n_seq, :, :] + outputs['pair'] = z + outputs['single'] = s + + # norm loss + if (not getattr(self, 'inference', + False)) and num_recycling == (cycle_no + 1): + delta_msa = m + delta_msa[..., + 0, :, :] = delta_msa[..., + 0, :, :] - m_1_prev_emb.detach() + delta_pair = z - z_prev_emb.detach() + outputs['delta_msa'] = delta_msa + outputs['delta_pair'] = delta_pair + outputs['msa_norm_mask'] = msa_mask + + outputs['sm'] = self.structure_module( + s, + z, + feats['aatype'], + mask=feats['seq_mask'], + ) + outputs['final_atom_positions'] = atom14_to_atom37( + outputs['sm']['positions'], feats) + outputs['final_atom_mask'] = feats['atom37_atom_exists'] + outputs['pred_frame_tensor'] = outputs['sm']['frames'][-1] + + # use float32 for numerical stability + if (not getattr(self, 'inference', False)): + m_1_prev = m[..., 0, :, :].float() + z_prev = z.float() + x_prev = outputs['final_atom_positions'].float() + else: + m_1_prev = m[..., 0, :, :] + z_prev = z + x_prev = outputs['final_atom_positions'] + + return outputs, m_1_prev, z_prev, x_prev + + def forward(self, batch): + + m_1_prev = batch.get('m_1_prev', None) + z_prev = batch.get('z_prev', None) + x_prev = batch.get('x_prev', None) + + is_grad_enabled = torch.is_grad_enabled() + + num_iters = int(batch['num_recycling_iters']) + 1 + num_ensembles = int(batch['msa_mask'].shape[0]) // num_iters + if self.training: + # don't use ensemble during training + assert num_ensembles == 1 + + # convert dtypes in batch + batch = self.__convert_input_dtype__(batch) + for cycle_no in range(num_iters): + is_final_iter = cycle_no == (num_iters - 1) + with torch.set_grad_enabled(is_grad_enabled and is_final_iter): + ( + outputs, + m_1_prev, + z_prev, + x_prev, + ) = self.iteration_evoformer_structure_module( + batch, + m_1_prev, + z_prev, + x_prev, + cycle_no=cycle_no, + num_recycling=num_iters, + num_ensembles=num_ensembles, + ) + if not is_final_iter: + del outputs + + if 'asym_id' in batch: + outputs['asym_id'] = batch['asym_id'][0, ...] + outputs.update(self.aux_heads(outputs)) + return outputs diff --git a/modelscope/models/science/unifold/modules/attentions.py b/modelscope/models/science/unifold/modules/attentions.py new file mode 100644 index 00000000..d2319079 --- /dev/null +++ b/modelscope/models/science/unifold/modules/attentions.py @@ -0,0 +1,430 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partialmethod +from typing import List, Optional + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm, softmax_dropout +from unicore.utils import permute_final_dims + +from .common import Linear, chunk_layer + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +class Attention(nn.Module): + + def __init__( + self, + q_dim: int, + k_dim: int, + v_dim: int, + head_dim: int, + num_heads: int, + gating: bool = True, + ): + super(Attention, self).__init__() + + self.num_heads = num_heads + total_dim = head_dim * self.num_heads + self.gating = gating + self.linear_q = Linear(q_dim, total_dim, bias=False, init='glorot') + self.linear_k = Linear(k_dim, total_dim, bias=False, init='glorot') + self.linear_v = Linear(v_dim, total_dim, bias=False, init='glorot') + self.linear_o = Linear(total_dim, q_dim, init='final') + self.linear_g = None + if self.gating: + self.linear_g = Linear(q_dim, total_dim, init='gating') + # precompute the 1/sqrt(head_dim) + self.norm = head_dim**-0.5 + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.Tensor = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + g = None + if self.linear_g is not None: + # gating, use raw query input + g = self.linear_g(q) + + q = self.linear_q(q) + q *= self.norm + k = self.linear_k(k) + v = self.linear_v(v) + + q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose( + -2, -3).contiguous() + k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose( + -2, -3).contiguous() + v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) + + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k + + attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias) + o = torch.matmul(attn, v) + del attn, v + + o = o.transpose(-2, -3).contiguous() + o = o.view(*o.shape[:-2], -1) + + if g is not None: + o = torch.sigmoid(g) * o + + # merge heads + o = nn.functional.linear(o, self.linear_o.weight) + return o + + def get_output_bias(self): + return self.linear_o.bias + + +class GlobalAttention(nn.Module): + + def __init__(self, input_dim, head_dim, num_heads, inf, eps): + super(GlobalAttention, self).__init__() + + self.num_heads = num_heads + self.inf = inf + self.eps = eps + self.linear_q = Linear( + input_dim, head_dim * num_heads, bias=False, init='glorot') + self.linear_k = Linear(input_dim, head_dim, bias=False, init='glorot') + self.linear_v = Linear(input_dim, head_dim, bias=False, init='glorot') + self.linear_g = Linear(input_dim, head_dim * num_heads, init='gating') + self.linear_o = Linear(head_dim * num_heads, input_dim, init='final') + self.sigmoid = nn.Sigmoid() + # precompute the 1/sqrt(head_dim) + self.norm = head_dim**-0.5 + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + + # gating + g = self.sigmoid(self.linear_g(x)) + + k = self.linear_k(x) + v = self.linear_v(x) + + q = torch.sum( + x * mask.unsqueeze(-1), dim=-2) / ( + torch.sum(mask, dim=-1, keepdims=True) + self.eps) + q = self.linear_q(q) + q *= self.norm + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k + + attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :] + attn = softmax_dropout(attn, 0, self.training, mask=attn_mask) + + o = torch.matmul( + attn, + v, + ) + del attn, v + + g = g.view(g.shape[:-1] + (self.num_heads, -1)) + o = o.unsqueeze(-3) * g + del g + + # merge heads + o = o.reshape(o.shape[:-2] + (-1, )) + return self.linear_o(o) + + +def gen_msa_attn_mask(mask, inf, gen_col_mask=True): + row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] + if gen_col_mask: + col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, + None, :] + return row_mask, col_mask + else: + return row_mask + + +class MSAAttention(nn.Module): + + def __init__( + self, + d_in, + d_hid, + num_heads, + pair_bias=False, + d_pair=None, + ): + super(MSAAttention, self).__init__() + + self.pair_bias = pair_bias + self.layer_norm_m = LayerNorm(d_in) + self.layer_norm_z = None + self.linear_z = None + if self.pair_bias: + self.layer_norm_z = LayerNorm(d_pair) + self.linear_z = Linear( + d_pair, num_heads, bias=False, init='normal') + + self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) + + @torch.jit.ignore + def _chunk( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + chunk_size: int = None, + ) -> torch.Tensor: + + return chunk_layer( + self._attn_forward, + { + 'm': m, + 'mask': mask, + 'bias': bias + }, + chunk_size=chunk_size, + num_batch_dims=len(m.shape[:-2]), + ) + + @torch.jit.ignore + def _attn_chunk_forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = 2560, + ) -> torch.Tensor: + m = self.layer_norm_m(m) + num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size + outputs = [] + for i in range(num_chunk): + chunk_start = i * chunk_size + chunk_end = min(m.shape[-3], chunk_start + chunk_size) + cur_m = m[..., chunk_start:chunk_end, :, :] + cur_mask = ( + mask[..., chunk_start:chunk_end, :, :, :] + if mask is not None else None) + outputs.append( + self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) + return torch.concat(outputs, dim=-3) + + def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): + m = self.layer_norm_m(m) + return self.mha(q=m, k=m, v=m, mask=mask, bias=bias) + + def forward( + self, + m: torch.Tensor, + z: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + bias = None + if self.pair_bias: + z = self.layer_norm_z(z) + bias = ( + permute_final_dims(self.linear_z(z), + (2, 0, 1)).unsqueeze(-4).contiguous()) + + if chunk_size is not None: + m = self._chunk(m, attn_mask, bias, chunk_size) + else: + attn_chunk_size = 2560 + if m.shape[-3] <= attn_chunk_size: + m = self._attn_forward(m, attn_mask, bias) + else: + # reduce the peak memory cost in extra_msa_stack + return self._attn_chunk_forward( + m, attn_mask, bias, chunk_size=attn_chunk_size) + + return m + + def get_output_bias(self): + return self.mha.get_output_bias() + + +class MSARowAttentionWithPairBias(MSAAttention): + + def __init__(self, d_msa, d_pair, d_hid, num_heads): + super(MSARowAttentionWithPairBias, self).__init__( + d_msa, + d_hid, + num_heads, + pair_bias=True, + d_pair=d_pair, + ) + + +class MSAColumnAttention(MSAAttention): + + def __init__(self, d_msa, d_hid, num_heads): + super(MSAColumnAttention, self).__init__( + d_in=d_msa, + d_hid=d_hid, + num_heads=num_heads, + pair_bias=False, + d_pair=None, + ) + + def forward( + self, + m: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + m = m.transpose(-2, -3) + m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size) + m = m.transpose(-2, -3) + + return m + + +class MSAColumnGlobalAttention(nn.Module): + + def __init__( + self, + d_in, + d_hid, + num_heads, + inf=1e9, + eps=1e-10, + ): + super(MSAColumnGlobalAttention, self).__init__() + + self.layer_norm_m = LayerNorm(d_in) + self.global_attention = GlobalAttention( + d_in, + d_hid, + num_heads, + inf=inf, + eps=eps, + ) + + @torch.jit.ignore + def _chunk( + self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._attn_forward, + { + 'm': m, + 'mask': mask + }, + chunk_size=chunk_size, + num_batch_dims=len(m.shape[:-2]), + ) + + def _attn_forward(self, m, mask): + m = self.layer_norm_m(m) + return self.global_attention(m, mask=mask) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + m = m.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self._attn_forward(m, mask=mask) + + m = m.transpose(-2, -3) + return m + + +def gen_tri_attn_mask(mask, inf): + start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] + end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, + None, :] + return start_mask, end_mask + + +class TriangleAttention(nn.Module): + + def __init__( + self, + d_in, + d_hid, + num_heads, + starting, + ): + super(TriangleAttention, self).__init__() + self.starting = starting + self.layer_norm = LayerNorm(d_in) + self.linear = Linear(d_in, num_heads, bias=False, init='normal') + self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + chunk_size: int = None, + ) -> torch.Tensor: + return chunk_layer( + self.mha, + { + 'q': x, + 'k': x, + 'v': x, + 'mask': mask, + 'bias': bias + }, + chunk_size=chunk_size, + num_batch_dims=len(x.shape[:-2]), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + if not self.starting: + x = x.transpose(-2, -3) + + x = self.layer_norm(x) + triangle_bias = ( + permute_final_dims(self.linear(x), + (2, 0, 1)).unsqueeze(-4).contiguous()) + + if chunk_size is not None: + x = self._chunk(x, attn_mask, triangle_bias, chunk_size) + else: + x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias) + + if not self.starting: + x = x.transpose(-2, -3) + return x + + def get_output_bias(self): + return self.mha.get_output_bias() + + +class TriangleAttentionStarting(TriangleAttention): + __init__ = partialmethod(TriangleAttention.__init__, starting=True) + + +class TriangleAttentionEnding(TriangleAttention): + __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/modelscope/models/science/unifold/modules/auxillary_heads.py b/modelscope/models/science/unifold/modules/auxillary_heads.py new file mode 100644 index 00000000..2daf5d55 --- /dev/null +++ b/modelscope/models/science/unifold/modules/auxillary_heads.py @@ -0,0 +1,171 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Dict + +import torch.nn as nn +from unicore.modules import LayerNorm + +from .common import Linear +from .confidence import (predicted_aligned_error, predicted_lddt, + predicted_tm_score) + + +class AuxiliaryHeads(nn.Module): + + def __init__(self, config): + super(AuxiliaryHeads, self).__init__() + + self.plddt = PredictedLDDTHead(**config['plddt'], ) + + self.distogram = DistogramHead(**config['distogram'], ) + + self.masked_msa = MaskedMSAHead(**config['masked_msa'], ) + + if config.experimentally_resolved.enabled: + self.experimentally_resolved = ExperimentallyResolvedHead( + **config['experimentally_resolved'], ) + + if config.pae.enabled: + self.pae = PredictedAlignedErrorHead(**config.pae, ) + + self.config = config + + def forward(self, outputs): + aux_out = {} + plddt_logits = self.plddt(outputs['sm']['single']) + aux_out['plddt_logits'] = plddt_logits + + aux_out['plddt'] = predicted_lddt(plddt_logits.detach()) + + distogram_logits = self.distogram(outputs['pair']) + aux_out['distogram_logits'] = distogram_logits + + masked_msa_logits = self.masked_msa(outputs['msa']) + aux_out['masked_msa_logits'] = masked_msa_logits + + if self.config.experimentally_resolved.enabled: + exp_res_logits = self.experimentally_resolved(outputs['single']) + aux_out['experimentally_resolved_logits'] = exp_res_logits + + if self.config.pae.enabled: + pae_logits = self.pae(outputs['pair']) + aux_out['pae_logits'] = pae_logits + pae_logits = pae_logits.detach() + aux_out.update( + predicted_aligned_error( + pae_logits, + **self.config.pae, + )) + aux_out['ptm'] = predicted_tm_score( + pae_logits, interface=False, **self.config.pae) + + iptm_weight = self.config.pae.get('iptm_weight', 0.0) + if iptm_weight > 0.0: + aux_out['iptm'] = predicted_tm_score( + pae_logits, + interface=True, + asym_id=outputs['asym_id'], + **self.config.pae, + ) + aux_out['iptm+ptm'] = ( + iptm_weight * aux_out['iptm'] + # noqa W504 + (1.0 - iptm_weight) * aux_out['ptm']) + + return aux_out + + +class PredictedLDDTHead(nn.Module): + + def __init__(self, num_bins, d_in, d_hid): + super(PredictedLDDTHead, self).__init__() + + self.num_bins = num_bins + self.d_in = d_in + self.d_hid = d_hid + + self.layer_norm = LayerNorm(self.d_in) + + self.linear_1 = Linear(self.d_in, self.d_hid, init='relu') + self.linear_2 = Linear(self.d_hid, self.d_hid, init='relu') + self.act = nn.GELU() + self.linear_3 = Linear(self.d_hid, self.num_bins, init='final') + + def forward(self, s): + s = self.layer_norm(s) + s = self.linear_1(s) + s = self.act(s) + s = self.linear_2(s) + s = self.act(s) + s = self.linear_3(s) + return s + + +class EnhancedHeadBase(nn.Module): + + def __init__(self, d_in, d_out, disable_enhance_head): + super(EnhancedHeadBase, self).__init__() + if disable_enhance_head: + self.layer_norm = None + self.linear_in = None + else: + self.layer_norm = LayerNorm(d_in) + self.linear_in = Linear(d_in, d_in, init='relu') + self.act = nn.GELU() + self.linear = Linear(d_in, d_out, init='final') + + def apply_alphafold_original_mode(self): + self.layer_norm = None + self.linear_in = None + + def forward(self, x): + if self.layer_norm is not None: + x = self.layer_norm(x) + x = self.act(self.linear_in(x)) + logits = self.linear(x) + return logits + + +class DistogramHead(EnhancedHeadBase): + + def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): + super(DistogramHead, self).__init__( + d_in=d_pair, + d_out=num_bins, + disable_enhance_head=disable_enhance_head, + ) + + def forward(self, x): + logits = super().forward(x) + logits = logits + logits.transpose(-2, -3) + return logits + + +class PredictedAlignedErrorHead(EnhancedHeadBase): + + def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): + super(PredictedAlignedErrorHead, self).__init__( + d_in=d_pair, + d_out=num_bins, + disable_enhance_head=disable_enhance_head, + ) + + +class MaskedMSAHead(EnhancedHeadBase): + + def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs): + super(MaskedMSAHead, self).__init__( + d_in=d_msa, + d_out=d_out, + disable_enhance_head=disable_enhance_head, + ) + + +class ExperimentallyResolvedHead(EnhancedHeadBase): + + def __init__(self, d_single, d_out, disable_enhance_head, **kwargs): + super(ExperimentallyResolvedHead, self).__init__( + d_in=d_single, + d_out=d_out, + disable_enhance_head=disable_enhance_head, + ) diff --git a/modelscope/models/science/unifold/modules/common.py b/modelscope/models/science/unifold/modules/common.py new file mode 100644 index 00000000..186f2567 --- /dev/null +++ b/modelscope/models/science/unifold/modules/common.py @@ -0,0 +1,387 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from unicore.modules import LayerNorm +from unicore.utils import tensor_tree_map + + +class Linear(nn.Linear): + + def __init__( + self, + d_in: int, + d_out: int, + bias: bool = True, + init: str = 'default', + ): + super(Linear, self).__init__(d_in, d_out, bias=bias) + + self.use_bias = bias + + if self.use_bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init == 'default': + self._trunc_normal_init(1.0) + elif init == 'relu': + self._trunc_normal_init(2.0) + elif init == 'glorot': + self._glorot_uniform_init() + elif init == 'gating': + self._zero_init(self.use_bias) + elif init == 'normal': + self._normal_init() + elif init == 'final': + self._zero_init(False) + else: + raise ValueError('Invalid init method.') + + def _trunc_normal_init(self, scale=1.0): + # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 + _, fan_in = self.weight.shape + scale = scale / max(1, fan_in) + std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR + nn.init.trunc_normal_(self.weight, mean=0.0, std=std) + + def _glorot_uniform_init(self): + nn.init.xavier_uniform_(self.weight, gain=1) + + def _zero_init(self, use_bias=True): + with torch.no_grad(): + self.weight.fill_(0.0) + if use_bias: + with torch.no_grad(): + self.bias.fill_(1.0) + + def _normal_init(self): + torch.nn.init.kaiming_normal_(self.weight, nonlinearity='linear') + + +class Transition(nn.Module): + + def __init__(self, d_in, n): + + super(Transition, self).__init__() + + self.d_in = d_in + self.n = n + + self.layer_norm = LayerNorm(self.d_in) + self.linear_1 = Linear(self.d_in, self.n * self.d_in, init='relu') + self.act = nn.GELU() + self.linear_2 = Linear(self.n * self.d_in, d_in, init='final') + + def _transition(self, x): + x = self.layer_norm(x) + x = self.linear_1(x) + x = self.act(x) + x = self.linear_2(x) + return x + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {'x': x}, + chunk_size=chunk_size, + num_batch_dims=len(x.shape[:-2]), + ) + + def forward( + self, + x: torch.Tensor, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + if chunk_size is not None: + x = self._chunk(x, chunk_size) + else: + x = self._transition(x=x) + + return x + + +class OuterProductMean(nn.Module): + + def __init__(self, d_msa, d_pair, d_hid, eps=1e-3): + super(OuterProductMean, self).__init__() + + self.d_msa = d_msa + self.d_pair = d_pair + self.d_hid = d_hid + self.eps = eps + + self.layer_norm = LayerNorm(d_msa) + self.linear_1 = Linear(d_msa, d_hid) + self.linear_2 = Linear(d_msa, d_hid) + self.linear_out = Linear(d_hid**2, d_pair, init='relu') + self.act = nn.GELU() + self.linear_z = Linear(self.d_pair, self.d_pair, init='final') + self.layer_norm_out = LayerNorm(self.d_pair) + + def _opm(self, a, b): + outer = torch.einsum('...bac,...dae->...bdce', a, b) + outer = outer.reshape(outer.shape[:-2] + (-1, )) + outer = self.linear_out(outer) + return outer + + @torch.jit.ignore + def _chunk(self, a: torch.Tensor, b: torch.Tensor, + chunk_size: int) -> torch.Tensor: + a = a.reshape((-1, ) + a.shape[-3:]) + b = b.reshape((-1, ) + b.shape[-3:]) + out = [] + # TODO: optimize this + for a_prime, b_prime in zip(a, b): + outer = chunk_layer( + partial(self._opm, b=b_prime), + {'a': a_prime}, + chunk_size=chunk_size, + num_batch_dims=1, + ) + out.append(outer) + if len(out) == 1: + outer = out[0].unsqueeze(0) + else: + outer = torch.stack(out, dim=0) + outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) + + return outer + + def apply_alphafold_original_mode(self): + self.linear_z = None + self.layer_norm_out = None + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + m = self.layer_norm(m) + mask = mask.unsqueeze(-1) + if self.layer_norm_out is not None: + # for numerical stability + mask = mask * (mask.size(-2)**-0.5) + a = self.linear_1(m) + b = self.linear_2(m) + if self.training: + a = a * mask + b = b * mask + else: + a *= mask + b *= mask + + a = a.transpose(-2, -3) + b = b.transpose(-2, -3) + + if chunk_size is not None: + z = self._chunk(a, b, chunk_size) + else: + z = self._opm(a, b) + + norm = torch.einsum('...abc,...adc->...bdc', mask, mask) + z /= self.eps + norm + if self.layer_norm_out is not None: + z = self.act(z) + z = self.layer_norm_out(z) + z = self.linear_z(z) + return z + + +def residual(residual, x, training): + if training: + return x + residual + else: + residual += x + return residual + + +@torch.jit.script +def fused_bias_dropout_add( + x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + dropmask: torch.Tensor, + prob: float, +) -> torch.Tensor: + return (x + bias) * F.dropout(dropmask, p=prob, training=True) + residual + + +@torch.jit.script +def fused_bias_dropout_add_inference( + x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, +) -> torch.Tensor: + residual += bias + x + return residual + + +def bias_dropout_residual(module, residual, x, dropout_shared_dim, prob, + training): + bias = module.get_output_bias() + if training: + shape = list(x.shape) + shape[dropout_shared_dim] = 1 + with torch.no_grad(): + mask = x.new_ones(shape) + return fused_bias_dropout_add(x, bias, residual, mask, prob) + else: + return fused_bias_dropout_add_inference(x, bias, residual) + + +@torch.jit.script +def fused_bias_gated_dropout_add( + x: torch.Tensor, + bias: torch.Tensor, + g: torch.Tensor, + g_bias: torch.Tensor, + residual: torch.Tensor, + dropout_mask: torch.Tensor, + prob: float, +) -> torch.Tensor: + return (torch.sigmoid(g + g_bias) * (x + bias)) * F.dropout( + dropout_mask, + p=prob, + training=True, + ) + residual + + +def tri_mul_residual( + module, + residual, + outputs, + dropout_shared_dim, + prob, + training, + block_size, +): + if training: + x, g = outputs + bias, g_bias = module.get_output_bias() + shape = list(x.shape) + shape[dropout_shared_dim] = 1 + with torch.no_grad(): + mask = x.new_ones(shape) + return fused_bias_gated_dropout_add( + x, + bias, + g, + g_bias, + residual, + mask, + prob, + ) + elif block_size is None: + x, g = outputs + bias, g_bias = module.get_output_bias() + residual += (torch.sigmoid(g + g_bias) * (x + bias)) + return residual + else: + # gated is not used here + residual += outputs + return residual + + +class SimpleModuleList(nn.ModuleList): + + def __repr__(self): + return str(len(self)) + ' X ...\n' + self[0].__repr__() + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + num_batch_dims: int, +) -> Any: + # TODO: support inplace add to output + if not (len(inputs) > 0): + raise ValueError('Must provide at least one input') + + def _dict_get_shapes(input): + shapes = [] + if type(input) is torch.Tensor: + shapes.append(input.shape) + elif type(input) is dict: + for v in input.values(): + shapes.extend(_dict_get_shapes(v)) + elif isinstance(input, Iterable): + for v in input: + shapes.extend(_dict_get_shapes(v)) + else: + raise ValueError('Not supported') + + return shapes + + inputs = {k: v for k, v in inputs.items() if v is not None} + initial_dims = [ + shape[:num_batch_dims] for shape in _dict_get_shapes(inputs) + ] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size + + def _flat_inputs(t): + t = t.view(-1, *t.shape[num_batch_dims:]) + assert ( + t.shape[0] == flat_batch_dim or t.shape[0] == 1 + ), 'batch dimension must be 1 or equal to the flat batch dimension' + return t + + flat_inputs = tensor_tree_map(_flat_inputs, inputs) + + out = None + for i in range(num_chunks): + chunk_start = i * chunk_size + chunk_end = min((i + 1) * chunk_size, flat_batch_dim) + + def select_chunk(t): + if t.shape[0] == 1: + return t[0:1] + else: + return t[chunk_start:chunk_end] + + chunkes = tensor_tree_map(select_chunk, flat_inputs) + + output_chunk = layer(**chunkes) + + if out is None: + out = tensor_tree_map( + lambda t: t.new_zeros((flat_batch_dim, ) + t.shape[1:]), + output_chunk) + + out_type = type(output_chunk) + if out_type is tuple: + for x, y in zip(out, output_chunk): + x[chunk_start:chunk_end] = y + elif out_type is torch.Tensor: + out[chunk_start:chunk_end] = output_chunk + else: + raise ValueError('Not supported') + + # reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) + def reshape(t): + return t.view(orig_batch_dims + t.shape[1:]) + + out = tensor_tree_map(reshape, out) + + return out diff --git a/modelscope/models/science/unifold/modules/confidence.py b/modelscope/models/science/unifold/modules/confidence.py new file mode 100644 index 00000000..7574689c --- /dev/null +++ b/modelscope/models/science/unifold/modules/confidence.py @@ -0,0 +1,159 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Dict, Optional, Tuple + +import torch + + +def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor: + """Computes per-residue pLDDT from logits. + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + Returns: + plddt: [num_res] per-residue pLDDT. + """ + num_bins = plddt_logits.shape[-1] + bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1) + bin_width = 1.0 / num_bins + bounds = torch.arange( + start=0.5 * bin_width, + end=1.0, + step=bin_width, + device=plddt_logits.device) + plddt = torch.sum( + bin_probs + * bounds.view(*((1, ) * len(bin_probs.shape[:-1])), *bounds.shape), + dim=-1, + ) + return plddt + + +def compute_bin_values(breaks: torch.Tensor): + """Gets the bin centers from the bin edges. + Args: + breaks: [num_bins - 1] the error bin edges. + Returns: + bin_centers: [num_bins] the error bin centers. + """ + step = breaks[1] - breaks[0] + bin_values = breaks + step / 2 + bin_values = torch.cat([bin_values, (bin_values[-1] + step).unsqueeze(-1)], + dim=0) + return bin_values + + +def compute_predicted_aligned_error( + bin_edges: torch.Tensor, + bin_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculates expected aligned distance errors for every pair of residues. + Args: + alignment_confidence_breaks: [num_bins - 1] the error bin edges. + aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted + probs for each error bin, for each pair of residues. + Returns: + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_values = compute_bin_values(bin_edges) + return torch.sum(bin_probs * bin_values, dim=-1) + + +def predicted_aligned_error( + pae_logits: torch.Tensor, + max_bin: int = 31, + num_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins - 1] the error bin edges. + Returns: + aligned_confidence_probs: [num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1) + bin_edges = torch.linspace( + 0, max_bin, steps=(num_bins - 1), device=pae_logits.device) + + predicted_aligned_error = compute_predicted_aligned_error( + bin_edges=bin_edges, + bin_probs=bin_probs, + ) + + return { + 'aligned_error_probs_per_bin': bin_probs, + 'predicted_aligned_error': predicted_aligned_error, + } + + +def predicted_tm_score( + pae_logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + max_bin: int = 31, + num_bins: int = 64, + eps: float = 1e-8, + asym_id: Optional[torch.Tensor] = None, + interface: bool = False, + **kwargs, +) -> torch.Tensor: + """Computes predicted TM alignment or predicted interface TM alignment score. + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins] the error bins. + residue_weights: [num_res] the per residue weights to use for the + expectation. + asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for + ipTM calculation, i.e. when interface=True. + interface: If True, interface predicted TM score is computed. + Returns: + ptm_score: The predicted TM alignment or the predicted iTM score. + """ + pae_logits = pae_logits.float() + if residue_weights is None: + residue_weights = pae_logits.new_ones(pae_logits.shape[:-2]) + + breaks = torch.linspace( + 0, max_bin, steps=(num_bins - 1), device=pae_logits.device) + + def tm_kernal(nres): + clipped_n = max(nres, 19) + d0 = 1.24 * (clipped_n - 15)**(1.0 / 3.0) - 1.8 + return lambda x: 1.0 / (1.0 + (x / d0)**2) + + def rmsd_kernal(eps): # leave for compute pRMS + return lambda x: 1. / (x + eps) + + bin_centers = compute_bin_values(breaks) + probs = torch.nn.functional.softmax(pae_logits, dim=-1) + tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers) + # tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) + # rmsd_per_bin = rmsd_kernal()(bin_centers) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape) + if interface: + assert asym_id is not None, 'must provide asym_id for iptm calculation.' + pair_mask *= asym_id[..., :, None] != asym_id[..., None, :] + + predicted_tm_term *= pair_mask + + pair_residue_weights = pair_mask * ( + residue_weights[None, :] * residue_weights[:, None]) + normed_residue_mask = pair_residue_weights / ( + eps + pair_residue_weights.sum(dim=-1, keepdim=True)) + + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + weighted = per_alignment * residue_weights + ret = per_alignment.gather( + dim=-1, index=weighted.max(dim=-1, + keepdim=True).indices).squeeze(dim=-1) + return ret diff --git a/modelscope/models/science/unifold/modules/embedders.py b/modelscope/models/science/unifold/modules/embedders.py new file mode 100644 index 00000000..84e87e2d --- /dev/null +++ b/modelscope/models/science/unifold/modules/embedders.py @@ -0,0 +1,290 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm +from unicore.utils import one_hot + +from .common import Linear, SimpleModuleList, residual + + +class InputEmbedder(nn.Module): + + def __init__( + self, + tf_dim: int, + msa_dim: int, + d_pair: int, + d_msa: int, + relpos_k: int, + use_chain_relative: bool = False, + max_relative_chain: Optional[int] = None, + **kwargs, + ): + super(InputEmbedder, self).__init__() + + self.tf_dim = tf_dim + self.msa_dim = msa_dim + + self.d_pair = d_pair + self.d_msa = d_msa + + self.linear_tf_z_i = Linear(tf_dim, d_pair) + self.linear_tf_z_j = Linear(tf_dim, d_pair) + self.linear_tf_m = Linear(tf_dim, d_msa) + self.linear_msa_m = Linear(msa_dim, d_msa) + + # RPE stuff + self.relpos_k = relpos_k + self.use_chain_relative = use_chain_relative + self.max_relative_chain = max_relative_chain + if not self.use_chain_relative: + self.num_bins = 2 * self.relpos_k + 1 + else: + self.num_bins = 2 * self.relpos_k + 2 + self.num_bins += 1 # entity id + self.num_bins += 2 * max_relative_chain + 2 + + self.linear_relpos = Linear(self.num_bins, d_pair) + + def _relpos_indices( + self, + res_id: torch.Tensor, + sym_id: Optional[torch.Tensor] = None, + asym_id: Optional[torch.Tensor] = None, + entity_id: Optional[torch.Tensor] = None, + ): + + max_rel_res = self.relpos_k + rp = res_id[..., None] - res_id[..., None, :] + rp = rp.clip(-max_rel_res, max_rel_res) + max_rel_res + if not self.use_chain_relative: + return rp + else: + asym_id_same = asym_id[..., :, None] == asym_id[..., None, :] + rp[~asym_id_same] = 2 * max_rel_res + 1 + entity_id_same = entity_id[..., :, None] == entity_id[..., None, :] + rp_entity_id = entity_id_same.type(rp.dtype)[..., None] + + rel_sym_id = sym_id[..., :, None] - sym_id[..., None, :] + + max_rel_chain = self.max_relative_chain + + clipped_rel_chain = torch.clamp( + rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain) + + clipped_rel_chain[~entity_id_same] = 2 * max_rel_chain + 1 + return rp, rp_entity_id, clipped_rel_chain + + def relpos_emb( + self, + res_id: torch.Tensor, + sym_id: Optional[torch.Tensor] = None, + asym_id: Optional[torch.Tensor] = None, + entity_id: Optional[torch.Tensor] = None, + num_sym: Optional[torch.Tensor] = None, + ): + + dtype = self.linear_relpos.weight.dtype + if not self.use_chain_relative: + rp = self._relpos_indices(res_id=res_id) + return self.linear_relpos( + one_hot(rp, num_classes=self.num_bins, dtype=dtype)) + else: + rp, rp_entity_id, rp_rel_chain = self._relpos_indices( + res_id=res_id, + sym_id=sym_id, + asym_id=asym_id, + entity_id=entity_id) + rp = one_hot(rp, num_classes=(2 * self.relpos_k + 2), dtype=dtype) + rp_entity_id = rp_entity_id.type(dtype) + rp_rel_chain = one_hot( + rp_rel_chain, + num_classes=(2 * self.max_relative_chain + 2), + dtype=dtype) + return self.linear_relpos( + torch.cat([rp, rp_entity_id, rp_rel_chain], dim=-1)) + + def forward( + self, + tf: torch.Tensor, + msa: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # [*, N_res, d_pair] + if self.tf_dim == 21: + # multimer use 21 target dim + tf = tf[..., 1:] + # convert type if necessary + tf = tf.type(self.linear_tf_z_i.weight.dtype) + msa = msa.type(self.linear_tf_z_i.weight.dtype) + n_clust = msa.shape[-3] + + msa_emb = self.linear_msa_m(msa) + # target_feat (aatype) into msa representation + tf_m = ( + self.linear_tf_m(tf).unsqueeze(-3).expand( + ((-1, ) * len(tf.shape[:-2]) + # noqa W504 + (n_clust, -1, -1)))) + msa_emb += tf_m + + tf_emb_i = self.linear_tf_z_i(tf) + tf_emb_j = self.linear_tf_z_j(tf) + pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] + + return msa_emb, pair_emb + + +class RecyclingEmbedder(nn.Module): + + def __init__( + self, + d_msa: int, + d_pair: int, + min_bin: float, + max_bin: float, + num_bins: int, + inf: float = 1e8, + **kwargs, + ): + super(RecyclingEmbedder, self).__init__() + + self.d_msa = d_msa + self.d_pair = d_pair + self.min_bin = min_bin + self.max_bin = max_bin + self.num_bins = num_bins + self.inf = inf + + self.squared_bins = None + + self.linear = Linear(self.num_bins, self.d_pair) + self.layer_norm_m = LayerNorm(self.d_msa) + self.layer_norm_z = LayerNorm(self.d_pair) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + m_update = self.layer_norm_m(m) + z_update = self.layer_norm_z(z) + + return m_update, z_update + + def recyle_pos( + self, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.squared_bins is None: + bins = torch.linspace( + self.min_bin, + self.max_bin, + self.num_bins, + dtype=torch.float if self.training else x.dtype, + device=x.device, + requires_grad=False, + ) + self.squared_bins = bins**2 + upper = torch.cat( + [self.squared_bins[1:], + self.squared_bins.new_tensor([self.inf])], + dim=-1) + if self.training: + x = x.float() + d = torch.sum( + (x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True) + d = ((d > self.squared_bins) * # noqa W504 + (d < upper)).type(self.linear.weight.dtype) + d = self.linear(d) + return d + + +class TemplateAngleEmbedder(nn.Module): + + def __init__( + self, + d_in: int, + d_out: int, + **kwargs, + ): + super(TemplateAngleEmbedder, self).__init__() + + self.d_out = d_out + self.d_in = d_in + + self.linear_1 = Linear(self.d_in, self.d_out, init='relu') + self.act = nn.GELU() + self.linear_2 = Linear(self.d_out, self.d_out, init='relu') + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x.type(self.linear_1.weight.dtype)) + x = self.act(x) + x = self.linear_2(x) + return x + + +class TemplatePairEmbedder(nn.Module): + + def __init__( + self, + d_in: int, + v2_d_in: list, + d_out: int, + d_pair: int, + v2_feature: bool = False, + **kwargs, + ): + super(TemplatePairEmbedder, self).__init__() + + self.d_out = d_out + self.v2_feature = v2_feature + if self.v2_feature: + self.d_in = v2_d_in + self.linear = SimpleModuleList() + for d_in in self.d_in: + self.linear.append(Linear(d_in, self.d_out, init='relu')) + self.z_layer_norm = LayerNorm(d_pair) + self.z_linear = Linear(d_pair, self.d_out, init='relu') + else: + self.d_in = d_in + self.linear = Linear(self.d_in, self.d_out, init='relu') + + def forward( + self, + x, + z, + ) -> torch.Tensor: + if not self.v2_feature: + x = self.linear(x.type(self.linear.weight.dtype)) + return x + else: + dtype = self.z_linear.weight.dtype + t = self.linear[0](x[0].type(dtype)) + for i, s in enumerate(x[1:]): + t = residual(t, self.linear[i + 1](s.type(dtype)), + self.training) + t = residual(t, self.z_linear(self.z_layer_norm(z)), self.training) + return t + + +class ExtraMSAEmbedder(nn.Module): + + def __init__( + self, + d_in: int, + d_out: int, + **kwargs, + ): + super(ExtraMSAEmbedder, self).__init__() + + self.d_in = d_in + self.d_out = d_out + self.linear = Linear(self.d_in, self.d_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x.type(self.linear.weight.dtype)) diff --git a/modelscope/models/science/unifold/modules/evoformer.py b/modelscope/models/science/unifold/modules/evoformer.py new file mode 100644 index 00000000..b0834986 --- /dev/null +++ b/modelscope/models/science/unifold/modules/evoformer.py @@ -0,0 +1,362 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from unicore.utils import checkpoint_sequential + +from .attentions import (MSAColumnAttention, MSAColumnGlobalAttention, + MSARowAttentionWithPairBias, TriangleAttentionEnding, + TriangleAttentionStarting) +from .common import (Linear, OuterProductMean, SimpleModuleList, Transition, + bias_dropout_residual, residual, tri_mul_residual) +from .triangle_multiplication import (TriangleMultiplicationIncoming, + TriangleMultiplicationOutgoing) + + +class EvoformerIteration(nn.Module): + + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + num_heads_msa: int, + num_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + ): + super(EvoformerIteration, self).__init__() + + self._is_extra_msa_stack = _is_extra_msa_stack + self.outer_product_mean_first = outer_product_mean_first + + self.msa_att_row = MSARowAttentionWithPairBias( + d_msa=d_msa, + d_pair=d_pair, + d_hid=d_hid_msa_att, + num_heads=num_heads_msa, + ) + + if _is_extra_msa_stack: + self.msa_att_col = MSAColumnGlobalAttention( + d_in=d_msa, + d_hid=d_hid_msa_att, + num_heads=num_heads_msa, + inf=inf, + eps=eps, + ) + else: + self.msa_att_col = MSAColumnAttention( + d_msa, + d_hid_msa_att, + num_heads_msa, + ) + + self.msa_transition = Transition( + d_in=d_msa, + n=transition_n, + ) + + self.outer_product_mean = OuterProductMean( + d_msa, + d_pair, + d_hid_opm, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + d_pair, + d_hid_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + d_pair, + d_hid_mul, + ) + + self.tri_att_start = TriangleAttentionStarting( + d_pair, + d_hid_pair_att, + num_heads_pair, + ) + self.tri_att_end = TriangleAttentionEnding( + d_pair, + d_hid_pair_att, + num_heads_pair, + ) + + self.pair_transition = Transition( + d_in=d_pair, + n=transition_n, + ) + + self.row_dropout_share_dim = -3 + self.col_dropout_share_dim = -2 + self.msa_dropout = msa_dropout + self.pair_dropout = pair_dropout + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + msa_row_attn_mask: torch.Tensor, + msa_col_attn_mask: Optional[torch.Tensor], + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: Optional[int] = None, + block_size: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.outer_product_mean_first: + z = residual( + z, + self.outer_product_mean( + m, mask=msa_mask, chunk_size=chunk_size), self.training) + + m = bias_dropout_residual( + self.msa_att_row, + m, + self.msa_att_row( + m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.msa_dropout, + self.training, + ) + if self._is_extra_msa_stack: + m = residual( + m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size), + self.training) + else: + m = bias_dropout_residual( + self.msa_att_col, + m, + self.msa_att_col( + m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.msa_dropout, + self.training, + ) + m = residual(m, self.msa_transition(m, chunk_size=chunk_size), + self.training) + if not self.outer_product_mean_first: + z = residual( + z, + self.outer_product_mean( + m, mask=msa_mask, chunk_size=chunk_size), self.training) + + z = tri_mul_residual( + self.tri_mul_out, + z, + self.tri_mul_out(z, mask=pair_mask, block_size=block_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + block_size=block_size, + ) + + z = tri_mul_residual( + self.tri_mul_in, + z, + self.tri_mul_in(z, mask=pair_mask, block_size=block_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + block_size=block_size, + ) + + z = bias_dropout_residual( + self.tri_att_start, + z, + self.tri_att_start( + z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + ) + + z = bias_dropout_residual( + self.tri_att_end, + z, + self.tri_att_end( + z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.pair_dropout, + self.training, + ) + z = residual(z, self.pair_transition(z, chunk_size=chunk_size), + self.training) + return m, z + + +class EvoformerStack(nn.Module): + + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + d_single: int, + num_heads_msa: int, + num_heads_pair: int, + num_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + **kwargs, + ): + super(EvoformerStack, self).__init__() + + self._is_extra_msa_stack = _is_extra_msa_stack + + self.blocks = SimpleModuleList() + + for _ in range(num_blocks): + self.blocks.append( + EvoformerIteration( + d_msa=d_msa, + d_pair=d_pair, + d_hid_msa_att=d_hid_msa_att, + d_hid_opm=d_hid_opm, + d_hid_mul=d_hid_mul, + d_hid_pair_att=d_hid_pair_att, + num_heads_msa=num_heads_msa, + num_heads_pair=num_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + outer_product_mean_first=outer_product_mean_first, + inf=inf, + eps=eps, + _is_extra_msa_stack=_is_extra_msa_stack, + )) + if not self._is_extra_msa_stack: + self.linear = Linear(d_msa, d_single) + else: + self.linear = None + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + msa_row_attn_mask: torch.Tensor, + msa_col_attn_mask: torch.Tensor, + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: int, + block_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + blocks = [ + partial( + b, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_attn_mask, + msa_col_attn_mask=msa_col_attn_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=chunk_size, + block_size=block_size) for b in self.blocks + ] + + m, z = checkpoint_sequential( + blocks, + input=(m, z), + ) + + s = None + if not self._is_extra_msa_stack: + seq_dim = -3 + index = torch.tensor([0], device=m.device) + s = self.linear(torch.index_select(m, dim=seq_dim, index=index)) + s = s.squeeze(seq_dim) + + return m, z, s + + +class ExtraMSAStack(EvoformerStack): + + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + num_heads_msa: int, + num_heads_pair: int, + num_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + **kwargs, + ): + super(ExtraMSAStack, self).__init__( + d_msa=d_msa, + d_pair=d_pair, + d_hid_msa_att=d_hid_msa_att, + d_hid_opm=d_hid_opm, + d_hid_mul=d_hid_mul, + d_hid_pair_att=d_hid_pair_att, + d_single=None, + num_heads_msa=num_heads_msa, + num_heads_pair=num_heads_pair, + num_blocks=num_blocks, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + outer_product_mean_first=outer_product_mean_first, + inf=inf, + eps=eps, + _is_extra_msa_stack=True, + ) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: Optional[torch.Tensor] = None, + pair_mask: Optional[torch.Tensor] = None, + msa_row_attn_mask: torch.Tensor = None, + msa_col_attn_mask: torch.Tensor = None, + tri_start_attn_mask: torch.Tensor = None, + tri_end_attn_mask: torch.Tensor = None, + chunk_size: int = None, + block_size: int = None, + ) -> torch.Tensor: + _, z, _ = super().forward( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_attn_mask, + msa_col_attn_mask=msa_col_attn_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=chunk_size, + block_size=block_size) + return z diff --git a/modelscope/models/science/unifold/modules/featurization.py b/modelscope/models/science/unifold/modules/featurization.py new file mode 100644 index 00000000..b62adc9d --- /dev/null +++ b/modelscope/models/science/unifold/modules/featurization.py @@ -0,0 +1,195 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Dict + +import torch +import torch.nn as nn +from unicore.utils import batched_gather, one_hot + +from modelscope.models.science.unifold.data import residue_constants as rc +from .frame import Frame + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + is_gly = aatype == rc.restype_order['G'] + ca_idx = rc.atom_order['CA'] + cb_idx = rc.atom_order['CB'] + pseudo_beta = torch.where( + is_gly[..., None].expand(*((-1, ) * len(is_gly.shape)), 3), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_masks is not None: + pseudo_beta_mask = torch.where( + is_gly, + all_atom_masks[..., ca_idx], + all_atom_masks[..., cb_idx], + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +def atom14_to_atom37(atom14, batch): + atom37_data = batched_gather( + atom14, + batch['residx_atom37_to_atom14'], + dim=-2, + num_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch['atom37_atom_exists'][..., None] + + return atom37_data + + +def build_template_angle_feat(template_feats, v2_feature=False): + template_aatype = template_feats['template_aatype'] + torsion_angles_sin_cos = template_feats['template_torsion_angles_sin_cos'] + torsion_angles_mask = template_feats['template_torsion_angles_mask'] + if not v2_feature: + alt_torsion_angles_sin_cos = template_feats[ + 'template_alt_torsion_angles_sin_cos'] + template_angle_feat = torch.cat( + [ + one_hot(template_aatype, 22), + torsion_angles_sin_cos.reshape( + *torsion_angles_sin_cos.shape[:-2], 14), + alt_torsion_angles_sin_cos.reshape( + *alt_torsion_angles_sin_cos.shape[:-2], 14), + torsion_angles_mask, + ], + dim=-1, + ) + template_angle_mask = torsion_angles_mask[..., 2] + else: + chi_mask = torsion_angles_mask[..., 3:] + chi_angles_sin = torsion_angles_sin_cos[..., 3:, 0] * chi_mask + chi_angles_cos = torsion_angles_sin_cos[..., 3:, 1] * chi_mask + template_angle_feat = torch.cat( + [ + one_hot(template_aatype, 22), + chi_angles_sin, + chi_angles_cos, + chi_mask, + ], + dim=-1, + ) + template_angle_mask = chi_mask[..., 0] + return template_angle_feat, template_angle_mask + + +def build_template_pair_feat( + batch, + min_bin, + max_bin, + num_bins, + eps=1e-20, + inf=1e8, +): + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + tpb = batch['template_pseudo_beta'] + dgram = torch.sum( + (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) + lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = nn.functional.one_hot( + batch['template_aatype'], + rc.restype_num + 2, + ) + + n_res = batch['template_aatype'].shape[-1] + to_concat.append(aatype_one_hot[..., None, :, :].expand( + *aatype_one_hot.shape[:-2], n_res, -1, -1)) + to_concat.append(aatype_one_hot[..., + None, :].expand(*aatype_one_hot.shape[:-2], + -1, n_res, -1)) + + to_concat.append(template_mask_2d.new_zeros(*template_mask_2d.shape, 3)) + to_concat.append(template_mask_2d[..., None]) + + act = torch.cat(to_concat, dim=-1) + act = act * template_mask_2d[..., None] + + return act + + +def build_template_pair_feat_v2( + batch, + min_bin, + max_bin, + num_bins, + multichain_mask_2d=None, + eps=1e-20, + inf=1e8, +): + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + if multichain_mask_2d is not None: + template_mask_2d *= multichain_mask_2d + + tpb = batch['template_pseudo_beta'] + dgram = torch.sum( + (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) + lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + dgram *= template_mask_2d[..., None] + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = one_hot( + batch['template_aatype'], + rc.restype_num + 2, + ) + + n_res = batch['template_aatype'].shape[-1] + to_concat.append(aatype_one_hot[..., None, :, :].expand( + *aatype_one_hot.shape[:-2], n_res, -1, -1)) + to_concat.append(aatype_one_hot[..., + None, :].expand(*aatype_one_hot.shape[:-2], + -1, n_res, -1)) + + n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']] + rigids = Frame.make_transform_from_reference( + n_xyz=batch['template_all_atom_positions'][..., n, :], + ca_xyz=batch['template_all_atom_positions'][..., ca, :], + c_xyz=batch['template_all_atom_positions'][..., c, :], + eps=eps, + ) + points = rigids.get_trans()[..., None, :, :] + rigid_vec = rigids[..., None].invert_apply(points) + + inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) + + t_aa_masks = batch['template_all_atom_mask'] + backbone_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., + c] + backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[..., + None, :] + if multichain_mask_2d is not None: + backbone_mask_2d *= multichain_mask_2d + + inv_distance_scalar = inv_distance_scalar * backbone_mask_2d + unit_vector_data = rigid_vec * inv_distance_scalar[..., None] + to_concat.extend(torch.unbind(unit_vector_data[..., None, :], dim=-1)) + to_concat.append(backbone_mask_2d[..., None]) + + return to_concat + + +def build_extra_msa_feat(batch): + msa_1hot = one_hot(batch['extra_msa'], 23) + msa_feat = [ + msa_1hot, + batch['extra_msa_has_deletion'].unsqueeze(-1), + batch['extra_msa_deletion_value'].unsqueeze(-1), + ] + return torch.cat(msa_feat, dim=-1) diff --git a/modelscope/models/science/unifold/modules/frame.py b/modelscope/models/science/unifold/modules/frame.py new file mode 100644 index 00000000..5a0e4d6a --- /dev/null +++ b/modelscope/models/science/unifold/modules/frame.py @@ -0,0 +1,562 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from __future__ import annotations # noqa +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple + +import numpy as np +import torch + + +def zero_translation( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, +) -> torch.Tensor: + trans = torch.zeros((*batch_dims, 3), + dtype=dtype, + device=device, + requires_grad=requires_grad) + return trans + + +# pylint: disable=bad-whitespace +_QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) + +_QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr +_QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii +_QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj +_QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk + +_QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij +_QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik +_QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk + +_QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir +_QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr +_QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr + +_QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9) +_QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT) + +_QUAT_MULTIPLY = np.zeros((4, 4, 4)) +_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], + [0, 0, 0, -1]] + +_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], + [0, 0, -1, 0]] + +_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], + [0, 1, 0, 0]] + +_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], + [1, 0, 0, 0]] + +_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] +_QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC) + + +class Rotation: + + def __init__( + self, + mat: torch.Tensor, + ): + if mat.shape[-2:] != (3, 3): + raise ValueError(f'incorrect rotation shape: {mat.shape}') + self._mat = mat + + @staticmethod + def identity( + shape, + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, + ) -> Rotation: + mat = torch.eye( + 3, dtype=dtype, device=device, requires_grad=requires_grad) + mat = mat.view(*((1, ) * len(shape)), 3, 3) + mat = mat.expand(*shape, -1, -1) + return Rotation(mat) + + @staticmethod + def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return (a.float() @ b.float()).type(a.dtype) + + @staticmethod + def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype) + + def __getitem__(self, index: Any) -> Rotation: + if not isinstance(index, tuple): + index = (index, ) + return Rotation(mat=self._mat[index + (slice(None), slice(None))]) + + def __mul__(self, right: Any) -> Rotation: + if isinstance(right, (int, float)): + return Rotation(mat=self._mat * right) + elif isinstance(right, torch.Tensor): + return Rotation(mat=self._mat * right[..., None, None]) + else: + raise TypeError( + f'multiplicand must be a tensor or a number, got {type(right)}.' + ) + + def __rmul__(self, left: Any) -> Rotation: + return self.__mul__(left) + + def __matmul__(self, other: Rotation) -> Rotation: + new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat) + return Rotation(mat=new_mat) + + @property + def _inv_mat(self): + return self._mat.transpose(-1, -2) + + @property + def rot_mat(self) -> torch.Tensor: + return self._mat + + def invert(self) -> Rotation: + return Rotation(mat=self._inv_mat) + + def apply(self, pts: torch.Tensor) -> torch.Tensor: + return Rotation.mat_mul_vec(self._mat, pts) + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + return Rotation.mat_mul_vec(self._inv_mat, pts) + + # inherit tensor behaviors + @property + def shape(self) -> torch.Size: + s = self._mat.shape[:-2] + return s + + @property + def dtype(self) -> torch.dtype: + return self._mat.dtype + + @property + def device(self) -> torch.device: + return self._mat.device + + @property + def requires_grad(self) -> bool: + return self._mat.requires_grad + + def unsqueeze(self, dim: int) -> Rotation: + if dim >= len(self.shape): + raise ValueError('Invalid dimension') + + rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2) + return Rotation(mat=rot_mats) + + def map_tensor_fn(self, fn: Callable[[torch.Tensor], + torch.Tensor]) -> Rotation: + mat = self._mat.view(self._mat.shape[:-2] + (9, )) + mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1) + mat = mat.view(mat.shape[:-1] + (3, 3)) + return Rotation(mat=mat) + + @staticmethod + def cat(rs: Sequence[Rotation], dim: int) -> Rotation: + rot_mats = [r.rot_mat for r in rs] + rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + + return Rotation(mat=rot_mats) + + def cuda(self) -> Rotation: + return Rotation(mat=self._mat.cuda()) + + def to(self, device: Optional[torch.device], + dtype: Optional[torch.dtype]) -> Rotation: + return Rotation(mat=self._mat.to(device=device, dtype=dtype)) + + def type(self, dtype: Optional[torch.dtype]) -> Rotation: + return Rotation(mat=self._mat.type(dtype)) + + def detach(self) -> Rotation: + return Rotation(mat=self._mat.detach()) + + +class Frame: + + def __init__( + self, + rotation: Optional[Rotation], + translation: Optional[torch.Tensor], + ): + if rotation is None and translation is None: + rotation = Rotation.identity((0, )) + translation = zero_translation((0, )) + elif translation is None: + translation = zero_translation(rotation.shape, rotation.dtype, + rotation.device, + rotation.requires_grad) + + elif rotation is None: + rotation = Rotation.identity( + translation.shape[:-1], + translation.dtype, + translation.device, + translation.requires_grad, + ) + + if (rotation.shape != translation.shape[:-1]) or (rotation.device + != # noqa W504 + translation.device): + raise ValueError('RotationMatrix and translation incompatible') + + self._r = rotation + self._t = translation + + @staticmethod + def identity( + shape: Iterable[int], + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, + ) -> Frame: + return Frame( + Rotation.identity(shape, dtype, device, requires_grad), + zero_translation(shape, dtype, device, requires_grad), + ) + + def __getitem__( + self, + index: Any, + ) -> Frame: + if type(index) != tuple: + index = (index, ) + + return Frame( + self._r[index], + self._t[index + (slice(None), )], + ) + + def __mul__( + self, + right: torch.Tensor, + ) -> Frame: + if not (isinstance(right, torch.Tensor)): + raise TypeError('The other multiplicand must be a Tensor') + + new_rots = self._r * right + new_trans = self._t * right[..., None] + + return Frame(new_rots, new_trans) + + def __rmul__( + self, + left: torch.Tensor, + ) -> Frame: + return self.__mul__(left) + + @property + def shape(self) -> torch.Size: + s = self._t.shape[:-1] + return s + + @property + def device(self) -> torch.device: + return self._t.device + + def get_rots(self) -> Rotation: + return self._r + + def get_trans(self) -> torch.Tensor: + return self._t + + def compose( + self, + other: Frame, + ) -> Frame: + new_rot = self._r @ other._r + new_trans = self._r.apply(other._t) + self._t + return Frame(new_rot, new_trans) + + def apply( + self, + pts: torch.Tensor, + ) -> torch.Tensor: + rotated = self._r.apply(pts) + return rotated + self._t + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + pts = pts - self._t + return self._r.invert_apply(pts) + + def invert(self) -> Frame: + rot_inv = self._r.invert() + trn_inv = rot_inv.apply(self._t) + + return Frame(rot_inv, -1 * trn_inv) + + def map_tensor_fn(self, fn: Callable[[torch.Tensor], + torch.Tensor]) -> Frame: + new_rots = self._r.map_tensor_fn(fn) + new_trans = torch.stack( + list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1) + + return Frame(new_rots, new_trans) + + def to_tensor_4x4(self) -> torch.Tensor: + tensor = self._t.new_zeros((*self.shape, 4, 4)) + tensor[..., :3, :3] = self._r.rot_mat + tensor[..., :3, 3] = self._t + tensor[..., 3, 3] = 1 + return tensor + + @staticmethod + def from_tensor_4x4(t: torch.Tensor) -> Frame: + if t.shape[-2:] != (4, 4): + raise ValueError('Incorrectly shaped input tensor') + + rots = Rotation(mat=t[..., :3, :3]) + trans = t[..., :3, 3] + + return Frame(rots, trans) + + @staticmethod + def from_3_points( + p_neg_x_axis: torch.Tensor, + origin: torch.Tensor, + p_xy_plane: torch.Tensor, + eps: float = 1e-8, + ) -> Frame: + p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) + origin = torch.unbind(origin, dim=-1) + p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + + denom = torch.sqrt(sum((c * c for c in e0)) + eps) + e0 = [c / denom for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + denom = torch.sqrt(sum((c * c for c in e1)) + eps) + e1 = [c / denom for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) + rots = rots.reshape(rots.shape[:-1] + (3, 3)) + + rot_obj = Rotation(mat=rots) + + return Frame(rot_obj, torch.stack(origin, dim=-1)) + + def unsqueeze( + self, + dim: int, + ) -> Frame: + if dim >= len(self.shape): + raise ValueError('Invalid dimension') + rots = self._r.unsqueeze(dim) + trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1) + + return Frame(rots, trans) + + @staticmethod + def cat( + Ts: Sequence[Frame], + dim: int, + ) -> Frame: + rots = Rotation.cat([T._r for T in Ts], dim) + trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1) + + return Frame(rots, trans) + + def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame: + return Frame(fn(self._r), self._t) + + def apply_trans_fn(self, fn: Callable[[torch.Tensor], + torch.Tensor]) -> Frame: + return Frame(self._r, fn(self._t)) + + def scale_translation(self, trans_scale_factor: float) -> Frame: + # fn = lambda t: t * trans_scale_factor + def fn(t): + return t * trans_scale_factor + + return self.apply_trans_fn(fn) + + def stop_rot_gradient(self) -> Frame: + # fn = lambda r: r.detach() + def fn(r): + return r.detach() + + return self.apply_rot_fn(fn) + + @staticmethod + def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + input_dtype = ca_xyz.dtype + n_xyz = n_xyz.float() + ca_xyz = ca_xyz.float() + c_xyz = c_xyz.float() + n_xyz = n_xyz - ca_xyz + c_xyz = c_xyz - ca_xyz + + c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + c_x**2 + c_y**2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + + c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) + c1_rots[..., 0, 0] = cos_c1 + c1_rots[..., 0, 1] = -1 * sin_c1 + c1_rots[..., 1, 0] = sin_c1 + c1_rots[..., 1, 1] = cos_c1 + c1_rots[..., 2, 2] = 1 + + norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2) + sin_c2 = d_pair / norm + cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm + + c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + c2_rots[..., 0, 0] = cos_c2 + c2_rots[..., 0, 2] = sin_c2 + c2_rots[..., 1, 1] = 1 + c2_rots[..., 2, 0] = -1 * sin_c2 + c2_rots[..., 2, 2] = cos_c2 + + c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots) + n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz) + + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + n_y**2 + n_z**2) + sin_n = -n_z / norm + cos_n = n_y / norm + + n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + n_rots[..., 0, 0] = 1 + n_rots[..., 1, 1] = cos_n + n_rots[..., 1, 2] = -1 * sin_n + n_rots[..., 2, 1] = sin_n + n_rots[..., 2, 2] = cos_n + + rots = Rotation.mat_mul_mat(n_rots, c_rots) + + rots = rots.transpose(-1, -2) + rot_obj = Rotation(mat=rots.type(input_dtype)) + + return Frame(rot_obj, ca_xyz.type(input_dtype)) + + def cuda(self) -> Frame: + return Frame(self._r.cuda(), self._t.cuda()) + + @property + def dtype(self) -> torch.dtype: + assert self._r.dtype == self._t.dtype + return self._r.dtype + + def type(self, dtype) -> Frame: + return Frame(self._r.type(dtype), self._t.type(dtype)) + + +class Quaternion: + + def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor): + if quaternion.shape[-1] != 4: + raise ValueError(f'incorrect quaternion shape: {quaternion.shape}') + self._q = quaternion + self._t = translation + + @staticmethod + def identity( + shape: Iterable[int], + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, + ) -> Quaternion: + trans = zero_translation(shape, dtype, device, requires_grad) + quats = torch.zeros((*shape, 4), + dtype=dtype, + device=device, + requires_grad=requires_grad) + with torch.no_grad(): + quats[..., 0] = 1 + return Quaternion(quats, trans) + + def get_quats(self): + return self._q + + def get_trans(self): + return self._t + + def get_rot_mats(self): + quats = self.get_quats() + rot_mats = Quaternion.quat_to_rot(quats) + return rot_mats + + @staticmethod + def quat_to_rot(normalized_quat): + global _QUAT_TO_ROT_tensor + dtype = normalized_quat.dtype + normalized_quat = normalized_quat.float() + if _QUAT_TO_ROT_tensor.device != normalized_quat.device: + _QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to( + normalized_quat.device) + rot_tensor = torch.sum( + _QUAT_TO_ROT_tensor * normalized_quat[..., :, None, None] + * normalized_quat[..., None, :, None], + dim=(-3, -2), + ) + rot_tensor = rot_tensor.type(dtype) + rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3) + return rot_tensor + + @staticmethod + def normalize_quat(quats): + dtype = quats.dtype + quats = quats.float() + quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) + quats = quats.type(dtype) + return quats + + @staticmethod + def quat_multiply_by_vec(quat, vec): + dtype = quat.dtype + quat = quat.float() + vec = vec.float() + global _QUAT_MULTIPLY_BY_VEC_tensor + if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device: + _QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to( + quat.device) + mat = _QUAT_MULTIPLY_BY_VEC_tensor + reshaped_mat = mat.view((1, ) * len(quat.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], + dim=(-3, -2), + ).type(dtype) + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + normalize_quats: bool = True) -> torch.Tensor: + quats = self.get_quats() + new_quats = quats + Quaternion.quat_multiply_by_vec( + quats, q_update_vec) + if normalize_quats: + new_quats = Quaternion.normalize_quat(new_quats) + return new_quats + + def compose_update_vec( + self, + update_vec: torch.Tensor, + pre_rot_mat: Rotation, + ) -> Quaternion: + q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:] + new_quats = self.compose_q_update_vec(q_vec) + + trans_update = pre_rot_mat.apply(t_vec) + new_trans = self._t + trans_update + + return Quaternion(new_quats, new_trans) + + def stop_rot_gradient(self) -> Quaternion: + return Quaternion(self._q.detach(), self._t) diff --git a/modelscope/models/science/unifold/modules/structure_module.py b/modelscope/models/science/unifold/modules/structure_module.py new file mode 100644 index 00000000..4872d5c6 --- /dev/null +++ b/modelscope/models/science/unifold/modules/structure_module.py @@ -0,0 +1,592 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm, softmax_dropout +from unicore.utils import dict_multimap, one_hot, permute_final_dims + +from modelscope.models.science.unifold.data.residue_constants import ( + restype_atom14_mask, restype_atom14_rigid_group_positions, + restype_atom14_to_rigid_group, restype_rigid_group_default_frame) +from .attentions import gen_attn_mask +from .common import Linear, SimpleModuleList, residual +from .frame import Frame, Quaternion, Rotation + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +def torsion_angles_to_frames( + frame: Frame, + alpha: torch.Tensor, + aatype: torch.Tensor, + default_frames: torch.Tensor, +): + default_frame = Frame.from_tensor_4x4(default_frames[aatype, ...]) + + bb_rot = alpha.new_zeros((*((1, ) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], + dim=-2) + + all_rots = alpha.new_zeros(default_frame.get_rots().rot_mat.shape) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + all_rots = Frame(Rotation(mat=all_rots), None) + + all_frames = default_frame.compose(all_rots) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = Frame.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = frame[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + frame: Frame, + aatype: torch.Tensor, + default_frames, + group_idx, + atom_mask, + lit_positions, +): + group_mask = group_idx[aatype, ...] + group_mask = one_hot( + group_mask, + num_classes=default_frames.shape[-3], + ) + + t_atoms_to_global = frame[..., None, :] * group_mask + t_atoms_to_global = t_atoms_to_global.map_tensor_fn( + lambda x: torch.sum(x, dim=-1)) + + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions + + +class SideChainAngleResnetIteration(nn.Module): + + def __init__(self, d_hid): + super(SideChainAngleResnetIteration, self).__init__() + + self.d_hid = d_hid + + self.linear_1 = Linear(self.d_hid, self.d_hid, init='relu') + self.act = nn.GELU() + self.linear_2 = Linear(self.d_hid, self.d_hid, init='final') + + def forward(self, s: torch.Tensor) -> torch.Tensor: + + x = self.act(s) + x = self.linear_1(x) + x = self.act(x) + x = self.linear_2(x) + + return residual(s, x, self.training) + + +class SidechainAngleResnet(nn.Module): + + def __init__(self, d_in, d_hid, num_blocks, num_angles): + super(SidechainAngleResnet, self).__init__() + + self.linear_in = Linear(d_in, d_hid) + self.act = nn.GELU() + self.linear_initial = Linear(d_in, d_hid) + + self.layers = SimpleModuleList() + for _ in range(num_blocks): + self.layers.append(SideChainAngleResnetIteration(d_hid=d_hid)) + + self.linear_out = Linear(d_hid, num_angles * 2) + + def forward(self, s: torch.Tensor, + initial_s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + initial_s = self.linear_initial(self.act(initial_s)) + s = self.linear_in(self.act(s)) + + s = s + initial_s + + for layer in self.layers: + s = layer(s) + + s = self.linear_out(self.act(s)) + + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s.float()**2, dim=-1, keepdim=True), + min=1e-12, + )) + s = s.float() / norm_denom + + return unnormalized_s, s.type(unnormalized_s.dtype) + + +class InvariantPointAttention(nn.Module): + + def __init__( + self, + d_single: int, + d_pair: int, + d_hid: int, + num_heads: int, + num_qk_points: int, + num_v_points: int, + separate_kv: bool = False, + bias: bool = True, + eps: float = 1e-8, + ): + super(InvariantPointAttention, self).__init__() + + self.d_hid = d_hid + self.num_heads = num_heads + self.num_qk_points = num_qk_points + self.num_v_points = num_v_points + self.eps = eps + + hc = self.d_hid * self.num_heads + self.linear_q = Linear(d_single, hc, bias=bias) + self.separate_kv = separate_kv + if self.separate_kv: + self.linear_k = Linear(d_single, hc, bias=bias) + self.linear_v = Linear(d_single, hc, bias=bias) + else: + self.linear_kv = Linear(d_single, 2 * hc, bias=bias) + + hpq = self.num_heads * self.num_qk_points * 3 + self.linear_q_points = Linear(d_single, hpq) + hpk = self.num_heads * self.num_qk_points * 3 + hpv = self.num_heads * self.num_v_points * 3 + if self.separate_kv: + self.linear_k_points = Linear(d_single, hpk) + self.linear_v_points = Linear(d_single, hpv) + else: + hpkv = hpk + hpv + self.linear_kv_points = Linear(d_single, hpkv) + + self.linear_b = Linear(d_pair, self.num_heads) + + self.head_weights = nn.Parameter(torch.zeros((num_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = self.num_heads * ( + d_pair + self.d_hid + self.num_v_points * 4) + self.linear_out = Linear(concat_out_dim, d_single, init='final') + + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: torch.Tensor, + f: Frame, + square_mask: torch.Tensor, + ) -> torch.Tensor: + q = self.linear_q(s) + + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + if self.separate_kv: + k = self.linear_k(s) + v = self.linear_v(s) + k = k.view(k.shape[:-1] + (self.num_heads, -1)) + v = v.view(v.shape[:-1] + (self.num_heads, -1)) + else: + kv = self.linear_kv(s) + kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) + k, v = torch.split(kv, self.d_hid, dim=-1) + + q_pts = self.linear_q_points(s) + + def process_points(pts, no_points): + shape = pts.shape[:-1] + (pts.shape[-1] // 3, 3) + if self.separate_kv: + # alphafold-multimer uses different layout + pts = pts.view(pts.shape[:-1] + + (self.num_heads, no_points * 3)) + pts = torch.split(pts, pts.shape[-1] // 3, dim=-1) + pts = torch.stack(pts, dim=-1).view(*shape) + pts = f[..., None].apply(pts) + + pts = pts.view(pts.shape[:-2] + (self.num_heads, no_points, 3)) + return pts + + q_pts = process_points(q_pts, self.num_qk_points) + + if self.separate_kv: + k_pts = self.linear_k_points(s) + v_pts = self.linear_v_points(s) + k_pts = process_points(k_pts, self.num_qk_points) + v_pts = process_points(v_pts, self.num_v_points) + else: + kv_pts = self.linear_kv_points(s) + + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = f[..., None].apply(kv_pts) + + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) + + k_pts, v_pts = torch.split( + kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) + + bias = self.linear_b(z) + + attn = torch.matmul( + permute_final_dims(q, (1, 0, 2)), + permute_final_dims(k, (1, 2, 0)), + ) + + if self.training: + attn = attn * math.sqrt(1.0 / (3 * self.d_hid)) + attn = attn + ( + math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_att.float()**2 + else: + attn *= math.sqrt(1.0 / (3 * self.d_hid)) + attn += (math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att *= pt_att + + pt_att = pt_att.sum(dim=-1) + head_weights = self.softplus(self.head_weights).view( # noqa + *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) # noqa + head_weights = head_weights * math.sqrt( + 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) + pt_att *= head_weights * (-0.5) + + pt_att = torch.sum(pt_att, dim=-1) + + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + attn += square_mask + attn = softmax_dropout( + attn, 0, self.training, bias=pt_att.type(attn.dtype)) + del pt_att, q_pts, k_pts, bias + o = torch.matmul(attn, v.transpose(-2, -3)).transpose(-2, -3) + o = o.contiguous().view(*o.shape[:-2], -1) + del q, k, v + o_pts = torch.sum( + (attn[..., None, :, :, None] + * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), + dim=-2, + ) + + o_pts = permute_final_dims(o_pts, (2, 0, 3, 1)) + o_pts = f[..., None, None].invert_apply(o_pts) + if self.training: + o_pts_norm = torch.sqrt( + torch.sum(o_pts.float()**2, dim=-1) + self.eps).type( + o_pts.dtype) + else: + o_pts_norm = torch.sqrt(torch.sum(o_pts**2, dim=-1) + + self.eps).type(o_pts.dtype) + + o_pts_norm = o_pts_norm.view(*o_pts_norm.shape[:-2], -1) + + o_pts = o_pts.view(*o_pts.shape[:-3], -1, 3) + + o_pair = torch.matmul(attn.transpose(-2, -3), z) + + o_pair = o_pair.view(*o_pair.shape[:-2], -1) + + s = self.linear_out( + torch.cat((o, *torch.unbind(o_pts, dim=-1), o_pts_norm, o_pair), + dim=-1)) + + return s + + +class BackboneUpdate(nn.Module): + + def __init__(self, d_single): + super(BackboneUpdate, self).__init__() + self.linear = Linear(d_single, 6, init='final') + + def forward(self, s: torch.Tensor): + return self.linear(s) + + +class StructureModuleTransitionLayer(nn.Module): + + def __init__(self, c): + super(StructureModuleTransitionLayer, self).__init__() + + self.linear_1 = Linear(c, c, init='relu') + self.linear_2 = Linear(c, c, init='relu') + self.act = nn.GELU() + self.linear_3 = Linear(c, c, init='final') + + def forward(self, s): + s_old = s + s = self.linear_1(s) + s = self.act(s) + s = self.linear_2(s) + s = self.act(s) + s = self.linear_3(s) + + s = residual(s_old, s, self.training) + + return s + + +class StructureModuleTransition(nn.Module): + + def __init__(self, c, num_layers, dropout_rate): + super(StructureModuleTransition, self).__init__() + + self.num_layers = num_layers + self.dropout_rate = dropout_rate + + self.layers = SimpleModuleList() + for _ in range(self.num_layers): + self.layers.append(StructureModuleTransitionLayer(c)) + + self.dropout = nn.Dropout(self.dropout_rate) + self.layer_norm = LayerNorm(c) + + def forward(self, s): + for layer in self.layers: + s = layer(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class StructureModule(nn.Module): + + def __init__( + self, + d_single, + d_pair, + d_ipa, + d_angle, + num_heads_ipa, + num_qk_points, + num_v_points, + dropout_rate, + num_blocks, + no_transition_layers, + num_resnet_blocks, + num_angles, + trans_scale_factor, + separate_kv, + ipa_bias, + epsilon, + inf, + **kwargs, + ): + super(StructureModule, self).__init__() + + self.num_blocks = num_blocks + self.trans_scale_factor = trans_scale_factor + self.default_frames = None + self.group_idx = None + self.atom_mask = None + self.lit_positions = None + self.inf = inf + + self.layer_norm_s = LayerNorm(d_single) + self.layer_norm_z = LayerNorm(d_pair) + + self.linear_in = Linear(d_single, d_single) + + self.ipa = InvariantPointAttention( + d_single, + d_pair, + d_ipa, + num_heads_ipa, + num_qk_points, + num_v_points, + separate_kv=separate_kv, + bias=ipa_bias, + eps=epsilon, + ) + + self.ipa_dropout = nn.Dropout(dropout_rate) + self.layer_norm_ipa = LayerNorm(d_single) + + self.transition = StructureModuleTransition( + d_single, + no_transition_layers, + dropout_rate, + ) + + self.bb_update = BackboneUpdate(d_single) + + self.angle_resnet = SidechainAngleResnet( + d_single, + d_angle, + num_resnet_blocks, + num_angles, + ) + + def forward( + self, + s, + z, + aatype, + mask=None, + ): + if mask is None: + mask = s.new_ones(s.shape[:-1]) + + # generate square mask + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = gen_attn_mask(square_mask, -self.inf).unsqueeze(-3) + s = self.layer_norm_s(s) + z = self.layer_norm_z(z) + initial_s = s + s = self.linear_in(s) + + quat_encoder = Quaternion.identity( + s.shape[:-1], + s.dtype, + s.device, + requires_grad=False, + ) + backb_to_global = Frame( + Rotation(mat=quat_encoder.get_rot_mats(), ), + quat_encoder.get_trans(), + ) + outputs = [] + for i in range(self.num_blocks): + s = residual(s, self.ipa(s, z, backb_to_global, square_mask), + self.training) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # update quaternion encoder + # use backb_to_global to avoid quat-to-rot conversion + quat_encoder = quat_encoder.compose_update_vec( + self.bb_update(s), pre_rot_mat=backb_to_global.get_rots()) + + # initial_s is always used to update the backbone + unnormalized_angles, angles = self.angle_resnet(s, initial_s) + + # convert quaternion to rotation matrix + backb_to_global = Frame( + Rotation(mat=quat_encoder.get_rot_mats(), ), + quat_encoder.get_trans(), + ) + if i == self.num_blocks - 1: + all_frames_to_global = self.torsion_angles_to_frames( + backb_to_global.scale_translation(self.trans_scale_factor), + angles, + aatype, + ) + + pred_positions = self.frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + ) + + preds = { + 'frames': + backb_to_global.scale_translation( + self.trans_scale_factor).to_tensor_4x4(), + 'unnormalized_angles': + unnormalized_angles, + 'angles': + angles, + } + + outputs.append(preds) + if i < (self.num_blocks - 1): + # stop gradient in iteration + quat_encoder = quat_encoder.stop_rot_gradient() + backb_to_global = backb_to_global.stop_rot_gradient() + + outputs = dict_multimap(torch.stack, outputs) + outputs['sidechain_frames'] = all_frames_to_global.to_tensor_4x4() + outputs['positions'] = pred_positions + outputs['single'] = s + + return outputs + + def _init_residue_constants(self, float_dtype, device): + if self.default_frames is None: + self.default_frames = torch.tensor( + restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.group_idx is None: + self.group_idx = torch.tensor( + restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ) + if self.atom_mask is None: + self.atom_mask = torch.tensor( + restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.lit_positions is None: + self.lit_positions = torch.tensor( + restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + + def torsion_angles_to_frames(self, frame, alpha, aatype): + self._init_residue_constants(alpha.dtype, alpha.device) + return torsion_angles_to_frames(frame, alpha, aatype, + self.default_frames) + + def frames_and_literature_positions_to_atom14_pos(self, frame, aatype): + self._init_residue_constants(frame.get_rots().dtype, + frame.get_rots().device) + return frames_and_literature_positions_to_atom14_pos( + frame, + aatype, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) diff --git a/modelscope/models/science/unifold/modules/template.py b/modelscope/models/science/unifold/modules/template.py new file mode 100644 index 00000000..49e5bec0 --- /dev/null +++ b/modelscope/models/science/unifold/modules/template.py @@ -0,0 +1,330 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import math +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm +from unicore.utils import (checkpoint_sequential, permute_final_dims, + tensor_tree_map) + +from .attentions import (Attention, TriangleAttentionEnding, + TriangleAttentionStarting, gen_attn_mask) +from .common import (Linear, SimpleModuleList, Transition, + bias_dropout_residual, chunk_layer, residual, + tri_mul_residual) +from .featurization import build_template_pair_feat_v2 +from .triangle_multiplication import (TriangleMultiplicationIncoming, + TriangleMultiplicationOutgoing) + + +class TemplatePointwiseAttention(nn.Module): + + def __init__(self, d_template, d_pair, d_hid, num_heads, inf, **kwargs): + super(TemplatePointwiseAttention, self).__init__() + + self.inf = inf + + self.mha = Attention( + d_pair, + d_template, + d_template, + d_hid, + num_heads, + gating=False, + ) + + def _chunk( + self, + z: torch.Tensor, + t: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + mha_inputs = { + 'q': z, + 'k': t, + 'v': t, + 'mask': mask, + } + return chunk_layer( + self.mha, + mha_inputs, + chunk_size=chunk_size, + num_batch_dims=len(z.shape[:-2]), + ) + + def forward( + self, + t: torch.Tensor, + z: torch.Tensor, + template_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + if template_mask is None: + template_mask = t.new_ones(t.shape[:-3]) + + mask = gen_attn_mask(template_mask, -self.inf)[..., None, None, None, + None, :] + z = z.unsqueeze(-2) + + t = permute_final_dims(t, (1, 2, 0, 3)) + + if chunk_size is not None: + z = self._chunk(z, t, mask, chunk_size) + else: + z = self.mha(z, t, t, mask=mask) + + z = z.squeeze(-2) + + return z + + +class TemplateProjection(nn.Module): + + def __init__(self, d_template, d_pair, **kwargs): + super(TemplateProjection, self).__init__() + + self.d_pair = d_pair + self.act = nn.ReLU() + self.output_linear = Linear(d_template, d_pair, init='relu') + + def forward(self, t, z) -> torch.Tensor: + if t is None: + # handle for non-template case + shape = z.shape + shape[-1] = self.d_pair + t = torch.zeros(shape, dtype=z.dtype, device=z.device) + t = self.act(t) + z_t = self.output_linear(t) + return z_t + + +class TemplatePairStackBlock(nn.Module): + + def __init__( + self, + d_template: int, + d_hid_tri_att: int, + d_hid_tri_mul: int, + num_heads: int, + pair_transition_n: int, + dropout_rate: float, + tri_attn_first: bool, + inf: float, + **kwargs, + ): + super(TemplatePairStackBlock, self).__init__() + + self.tri_att_start = TriangleAttentionStarting( + d_template, + d_hid_tri_att, + num_heads, + ) + self.tri_att_end = TriangleAttentionEnding( + d_template, + d_hid_tri_att, + num_heads, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + d_template, + d_hid_tri_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + d_template, + d_hid_tri_mul, + ) + + self.pair_transition = Transition( + d_template, + pair_transition_n, + ) + self.tri_attn_first = tri_attn_first + self.dropout = dropout_rate + self.row_dropout_share_dim = -3 + self.col_dropout_share_dim = -2 + + def forward( + self, + s: torch.Tensor, + mask: torch.Tensor, + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: Optional[int] = None, + block_size: Optional[int] = None, + ): + if self.tri_attn_first: + s = bias_dropout_residual( + self.tri_att_start, + s, + self.tri_att_start( + s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + ) + + s = bias_dropout_residual( + self.tri_att_end, + s, + self.tri_att_end( + s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.dropout, + self.training, + ) + s = tri_mul_residual( + self.tri_mul_out, + s, + self.tri_mul_out(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + + s = tri_mul_residual( + self.tri_mul_in, + s, + self.tri_mul_in(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + else: + s = tri_mul_residual( + self.tri_mul_out, + s, + self.tri_mul_out(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + + s = tri_mul_residual( + self.tri_mul_in, + s, + self.tri_mul_in(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + + s = bias_dropout_residual( + self.tri_att_start, + s, + self.tri_att_start( + s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + ) + + s = bias_dropout_residual( + self.tri_att_end, + s, + self.tri_att_end( + s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.dropout, + self.training, + ) + s = residual(s, self.pair_transition( + s, + chunk_size=chunk_size, + ), self.training) + return s + + +class TemplatePairStack(nn.Module): + + def __init__( + self, + d_template, + d_hid_tri_att, + d_hid_tri_mul, + num_blocks, + num_heads, + pair_transition_n, + dropout_rate, + tri_attn_first, + inf=1e9, + **kwargs, + ): + super(TemplatePairStack, self).__init__() + + self.blocks = SimpleModuleList() + for _ in range(num_blocks): + self.blocks.append( + TemplatePairStackBlock( + d_template=d_template, + d_hid_tri_att=d_hid_tri_att, + d_hid_tri_mul=d_hid_tri_mul, + num_heads=num_heads, + pair_transition_n=pair_transition_n, + dropout_rate=dropout_rate, + inf=inf, + tri_attn_first=tri_attn_first, + )) + + self.layer_norm = LayerNorm(d_template) + + def forward( + self, + single_templates: Tuple[torch.Tensor], + mask: torch.tensor, + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + templ_dim: int, + chunk_size: int, + block_size: int, + return_mean: bool, + ): + + def one_template(i): + (s, ) = checkpoint_sequential( + functions=[ + partial( + b, + mask=mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=chunk_size, + block_size=block_size, + ) for b in self.blocks + ], + input=(single_templates[i], ), + ) + return s + + n_templ = len(single_templates) + if n_templ > 0: + new_single_templates = [one_template(0)] + if return_mean: + t = self.layer_norm(new_single_templates[0]) + for i in range(1, n_templ): + s = one_template(i) + if return_mean: + t = residual(t, self.layer_norm(s), self.training) + else: + new_single_templates.append(s) + + if return_mean: + if n_templ > 0: + t /= n_templ + else: + t = None + else: + t = torch.cat( + [s.unsqueeze(templ_dim) for s in new_single_templates], + dim=templ_dim) + t = self.layer_norm(t) + + return t diff --git a/modelscope/models/science/unifold/modules/triangle_multiplication.py b/modelscope/models/science/unifold/modules/triangle_multiplication.py new file mode 100644 index 00000000..c4094cd2 --- /dev/null +++ b/modelscope/models/science/unifold/modules/triangle_multiplication.py @@ -0,0 +1,158 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partialmethod +from typing import List, Optional + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm +from unicore.utils import permute_final_dims + +from .common import Linear + + +class TriangleMultiplication(nn.Module): + + def __init__(self, d_pair, d_hid, outgoing=True): + super(TriangleMultiplication, self).__init__() + self.outgoing = outgoing + + self.linear_ab_p = Linear(d_pair, d_hid * 2) + self.linear_ab_g = Linear(d_pair, d_hid * 2, init='gating') + + self.linear_g = Linear(d_pair, d_pair, init='gating') + self.linear_z = Linear(d_hid, d_pair, init='final') + + self.layer_norm_in = LayerNorm(d_pair) + self.layer_norm_out = LayerNorm(d_hid) + + self._alphafold_original_mode = False + + def _chunk_2d( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + block_size: int = None, + ) -> torch.Tensor: + + # avoid too small chunk size + # block_size = max(block_size, 256) + new_z = z.new_zeros(z.shape) + dim1 = z.shape[-3] + + def _slice_linear(z, linear: Linear, a=True): + d_hid = linear.bias.shape[0] // 2 + index = 0 if a else d_hid + p = ( + nn.functional.linear(z, linear.weight[index:index + d_hid]) + + linear.bias[index:index + d_hid]) + return p + + def _chunk_projection(z, mask, a=True): + p = _slice_linear(z, self.linear_ab_p, a) * mask + p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a)) + return p + + num_chunk = (dim1 + block_size - 1) // block_size + for i in range(num_chunk): + chunk_start = i * block_size + chunk_end = min(chunk_start + block_size, dim1) + if self.outgoing: + a_chunk = _chunk_projection( + z[..., chunk_start:chunk_end, :, :], + mask[..., chunk_start:chunk_end, :, :], + a=True, + ) + a_chunk = permute_final_dims(a_chunk, (2, 0, 1)) + else: + a_chunk = _chunk_projection( + z[..., :, chunk_start:chunk_end, :], + mask[..., :, chunk_start:chunk_end, :], + a=True, + ) + a_chunk = a_chunk.transpose(-1, -3) + + for j in range(num_chunk): + j_chunk_start = j * block_size + j_chunk_end = min(j_chunk_start + block_size, dim1) + if self.outgoing: + b_chunk = _chunk_projection( + z[..., j_chunk_start:j_chunk_end, :, :], + mask[..., j_chunk_start:j_chunk_end, :, :], + a=False, + ) + b_chunk = b_chunk.transpose(-1, -3) + else: + b_chunk = _chunk_projection( + z[..., :, j_chunk_start:j_chunk_end, :], + mask[..., :, j_chunk_start:j_chunk_end, :], + a=False, + ) + b_chunk = permute_final_dims(b_chunk, (2, 0, 1)) + x_chunk = torch.matmul(a_chunk, b_chunk) + del b_chunk + x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) + x_chunk = self.layer_norm_out(x_chunk) + x_chunk = self.linear_z(x_chunk) + x_chunk *= torch.sigmoid( + self.linear_g(z[..., chunk_start:chunk_end, + j_chunk_start:j_chunk_end, :])) + new_z[..., chunk_start:chunk_end, + j_chunk_start:j_chunk_end, :] = x_chunk + del x_chunk + del a_chunk + return new_z + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + block_size=None, + ) -> torch.Tensor: + + mask = mask.unsqueeze(-1) + if not self._alphafold_original_mode: + # divided by 1/sqrt(dim) for numerical stability + mask = mask * (mask.shape[-2]**-0.5) + + z = self.layer_norm_in(z) + if not self.training and block_size is not None: + return self._chunk_2d(z, mask, block_size=block_size) + + g = nn.functional.linear(z, self.linear_g.weight) + if self.training: + ab = self.linear_ab_p(z) * mask * torch.sigmoid( + self.linear_ab_g(z)) + else: + ab = self.linear_ab_p(z) + ab *= mask + ab *= torch.sigmoid(self.linear_ab_g(z)) + a, b = torch.chunk(ab, 2, dim=-1) + del z, ab + + if self.outgoing: + a = permute_final_dims(a, (2, 0, 1)) + b = b.transpose(-1, -3) + else: + b = permute_final_dims(b, (2, 0, 1)) + a = a.transpose(-1, -3) + x = torch.matmul(a, b) + del a, b + + x = permute_final_dims(x, (1, 2, 0)) + + x = self.layer_norm_out(x) + x = nn.functional.linear(x, self.linear_z.weight) + return x, g + + def get_output_bias(self): + return self.linear_z.bias, self.linear_g.bias + + +class TriangleMultiplicationOutgoing(TriangleMultiplication): + __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True) + + +class TriangleMultiplicationIncoming(TriangleMultiplication): + __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False) diff --git a/modelscope/models/science/unifold/msa/__init__.py b/modelscope/models/science/unifold/msa/__init__.py new file mode 100644 index 00000000..2121062c --- /dev/null +++ b/modelscope/models/science/unifold/msa/__init__.py @@ -0,0 +1 @@ +""" Scripts for MSA & template searching. """ diff --git a/modelscope/models/science/unifold/msa/mmcif.py b/modelscope/models/science/unifold/msa/mmcif.py new file mode 100644 index 00000000..cf67239f --- /dev/null +++ b/modelscope/models/science/unifold/msa/mmcif.py @@ -0,0 +1,483 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parses the mmCIF file format.""" +import collections +import dataclasses +import functools +import io +from typing import Any, Mapping, Optional, Sequence, Tuple + +from absl import logging +from Bio import PDB +from Bio.Data import SCOPData +from Bio.PDB.MMCIFParser import MMCIFParser + +# Type aliases: +ChainId = str +PdbHeader = Mapping[str, Any] +PdbStructure = PDB.Structure.Structure +SeqRes = str +MmCIFDict = Mapping[str, Sequence[str]] + + +@dataclasses.dataclass(frozen=True) +class Monomer: + id: str + num: int + + +# Note - mmCIF format provides no guarantees on the type of author-assigned +# sequence numbers. They need not be integers. +@dataclasses.dataclass(frozen=True) +class AtomSite: + residue_name: str + author_chain_id: str + mmcif_chain_id: str + author_seq_num: str + mmcif_seq_num: int + insertion_code: str + hetatm_atom: str + model_num: int + + +# Used to map SEQRES index to a residue in the structure. +@dataclasses.dataclass(frozen=True) +class ResiduePosition: + chain_id: str + residue_number: int + insertion_code: str + + +@dataclasses.dataclass(frozen=True) +class ResidueAtPosition: + position: Optional[ResiduePosition] + name: str + is_missing: bool + hetflag: str + + +@dataclasses.dataclass(frozen=True) +class MmcifObject: + """Representation of a parsed mmCIF file. + + Contains: + file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all + files being processed. + header: Biopython header. + structure: Biopython structure. + chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. + {'A': 'ABCDEFG'} + seqres_to_structure: Dict; for each chain_id contains a mapping between + SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, 1: ResidueAtPosition, ...}} + raw_string: The raw string used to construct the MmcifObject. + """ + + file_id: str + header: PdbHeader + structure: PdbStructure + chain_to_seqres: Mapping[ChainId, SeqRes] + seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] + raw_string: Any + mmcif_to_author_chain_id: Mapping[ChainId, ChainId] + valid_chains: Mapping[ChainId, str] + + +@dataclasses.dataclass(frozen=True) +class ParsingResult: + """Returned by the parse function. + + Contains: + mmcif_object: A MmcifObject, may be None if no chain could be successfully + parsed. + errors: A dict mapping (file_id, chain_id) to any exception generated. + """ + + mmcif_object: Optional[MmcifObject] + errors: Mapping[Tuple[str, str], Any] + + +class ParseError(Exception): + """An error indicating that an mmCIF file could not be parsed.""" + + +def mmcif_loop_to_list(prefix: str, + parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a list. + + Reference for loop_ in mmCIF: + http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. + """ + cols = [] + data = [] + for key, value in parsed_info.items(): + if key.startswith(prefix): + cols.append(key) + data.append(value) + + assert all([ + len(xs) == len(data[0]) for xs in data + ]), ('mmCIF error: Not all loops are the same length: %s' % cols) + + return [dict(zip(cols, xs)) for xs in zip(*data)] + + +def mmcif_loop_to_dict( + prefix: str, + index: str, + parsed_info: MmCIFDict, +) -> Mapping[str, Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a dictionary. + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + index: Which item of loop data should serve as the key. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, + indexed by the index column. + """ + entries = mmcif_loop_to_list(prefix, parsed_info) + return {entry[index]: entry for entry in entries} + + +@functools.lru_cache(16, typed=False) +def fast_parse(*, + file_id: str, + mmcif_string: str, + catch_all_errors: bool = True) -> ParsingResult: + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = MMCIFParser(QUIET=True) + # handle = io.StringIO(mmcif_string) + # full_structure = parser.get_structure('', handle) + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult( + None, {(file_id, ''): 'No protein chains found in this file.'}) + + mmcif_to_author_chain_id = {} + # seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != '1': + # We only process the first model at the moment. + continue + mmcif_to_author_chain_id[ + atom.mmcif_chain_id] = atom.author_chain_id + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=None, + chain_to_seqres=None, + seqres_to_structure=None, + raw_string=parsed_info, + mmcif_to_author_chain_id=mmcif_to_author_chain_id, + valid_chains=valid_chains, + ) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, '')] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +@functools.lru_cache(16, typed=False) +def parse(*, + file_id: str, + mmcif_string: str, + catch_all_errors: bool = True) -> ParsingResult: + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = PDB.MMCIFParser(QUIET=True) + handle = io.StringIO(mmcif_string) + full_structure = parser.get_structure('', handle) + first_model_structure = _get_first_model(full_structure) + # Extract the _mmcif_dict from the parser, which contains useful fields not + # reflected in the Biopython structure. + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult( + None, {(file_id, ''): 'No protein chains found in this file.'}) + seq_start_num = { + chain_id: min([monomer.num for monomer in seq]) + for chain_id, seq in valid_chains.items() + } + + # Loop over the atoms for which we have coordinates. Populate two mappings: + # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used + # the authors / Biopython). + # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). + mmcif_to_author_chain_id = {} + seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != '1': + # We only process the first model at the moment. + continue + + mmcif_to_author_chain_id[ + atom.mmcif_chain_id] = atom.author_chain_id + + if atom.mmcif_chain_id in valid_chains: + hetflag = ' ' + if atom.hetatm_atom == 'HETATM': + # Water atoms are assigned a special hetflag of W in Biopython. We + # need to do the same, so that this hetflag can be used to fetch + # a residue from the Biopython structure by id. + if atom.residue_name in ('HOH', 'WAT'): + hetflag = 'W' + else: + hetflag = 'H_' + atom.residue_name + insertion_code = atom.insertion_code + if not _is_set(atom.insertion_code): + insertion_code = ' ' + position = ResiduePosition( + chain_id=atom.author_chain_id, + residue_number=int(atom.author_seq_num), + insertion_code=insertion_code, + ) + seq_idx = int( + atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] + current = seq_to_structure_mappings.get( + atom.author_chain_id, {}) + current[seq_idx] = ResidueAtPosition( + position=position, + name=atom.residue_name, + is_missing=False, + hetflag=hetflag, + ) + seq_to_structure_mappings[atom.author_chain_id] = current + + # Add missing residue information to seq_to_structure_mappings. + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + current_mapping = seq_to_structure_mappings[author_chain] + for idx, monomer in enumerate(seq_info): + if idx not in current_mapping: + current_mapping[idx] = ResidueAtPosition( + position=None, + name=monomer.id, + is_missing=True, + hetflag=' ') + + author_chain_to_sequence = {} + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + seq = [] + for monomer in seq_info: + code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') + seq.append(code if len(code) == 1 else 'X') + seq = ''.join(seq) + author_chain_to_sequence[author_chain] = seq + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=first_model_structure, + chain_to_seqres=author_chain_to_sequence, + seqres_to_structure=seq_to_structure_mappings, + raw_string=parsed_info, + mmcif_to_author_chain_id=mmcif_to_author_chain_id, + valid_chains=valid_chains, + ) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, '')] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +def _get_first_model(structure: PdbStructure) -> PdbStructure: + """Returns the first model in a Biopython structure.""" + return next(structure.get_models()) + + +_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 + + +def get_release_date(parsed_info: MmCIFDict) -> str: + """Returns the oldest revision date.""" + revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] + return min(revision_dates) + + +def _get_header(parsed_info: MmCIFDict) -> PdbHeader: + """Returns a basic header containing method, release date and resolution.""" + header = {} + + experiments = mmcif_loop_to_list('_exptl.', parsed_info) + header['structure_method'] = ','.join( + [experiment['_exptl.method'].lower() for experiment in experiments]) + + # Note: The release_date here corresponds to the oldest revision. We prefer to + # use this for dataset filtering over the deposition_date. + if '_pdbx_audit_revision_history.revision_date' in parsed_info: + header['release_date'] = get_release_date(parsed_info) + else: + logging.warning('Could not determine release_date: %s', + parsed_info['_entry.id']) + + header['resolution'] = 0.00 + for res_key in ( + '_refine.ls_d_res_high', + '_em_3d_reconstruction.resolution', + '_reflns.d_resolution_high', + ): + if res_key in parsed_info: + try: + raw_resolution = parsed_info[res_key][0] + header['resolution'] = float(raw_resolution) + except ValueError: + logging.debug('Invalid resolution format: %s', + parsed_info[res_key]) + + return header + + +def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: + """Returns list of atom sites; contains data not present in the structure.""" + return [ + AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension + parsed_info['_atom_site.label_comp_id'], + parsed_info['_atom_site.auth_asym_id'], + parsed_info['_atom_site.label_asym_id'], + parsed_info['_atom_site.auth_seq_id'], + parsed_info['_atom_site.label_seq_id'], + parsed_info['_atom_site.pdbx_PDB_ins_code'], + parsed_info['_atom_site.group_PDB'], + parsed_info['_atom_site.pdbx_PDB_model_num'], + ) + ] + + +def _get_protein_chains( + *, parsed_info: Mapping[str, + Any]) -> Mapping[ChainId, Sequence[Monomer]]: + """Extracts polymer information for protein chains only. + + Args: + parsed_info: _mmcif_dict produced by the Biopython parser. + + Returns: + A dict mapping mmcif chain id to a list of Monomers. + """ + # Get polymer information for each entity in the structure. + entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) + + polymers = collections.defaultdict(list) + for entity_poly_seq in entity_poly_seqs: + polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( + Monomer( + id=entity_poly_seq['_entity_poly_seq.mon_id'], + num=int(entity_poly_seq['_entity_poly_seq.num']), + )) + + # Get chemical compositions. Will allow us to identify which of these polymers + # are proteins. + chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', + parsed_info) + + # Get chains information for each entity. Necessary so that we can return a + # dict keyed on chain id rather than entity. + struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) + + entity_to_mmcif_chains = collections.defaultdict(list) + for struct_asym in struct_asyms: + chain_id = struct_asym['_struct_asym.id'] + entity_id = struct_asym['_struct_asym.entity_id'] + entity_to_mmcif_chains[entity_id].append(chain_id) + + # Identify and return the valid protein chains. + valid_chains = {} + for entity_id, seq_info in polymers.items(): + chain_ids = entity_to_mmcif_chains[entity_id] + + # Reject polymers without any peptide-like components, such as DNA/RNA. + if any([ + 'peptide' in chem_comps[monomer.id]['_chem_comp.type'] + for monomer in seq_info + ]): + for chain_id in chain_ids: + valid_chains[chain_id] = seq_info + return valid_chains + + +def _is_set(data: str) -> bool: + """Returns False if data is a special mmCIF character indicating 'unset'.""" + return data not in ('.', '?') diff --git a/modelscope/models/science/unifold/msa/msa_identifiers.py b/modelscope/models/science/unifold/msa/msa_identifiers.py new file mode 100644 index 00000000..366239db --- /dev/null +++ b/modelscope/models/science/unifold/msa/msa_identifiers.py @@ -0,0 +1,88 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for extracting identifiers from MSA sequence descriptions.""" + +import dataclasses +import re +from typing import Optional + +# Sequences coming from UniProtKB database come in the +# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` +# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). +_UNIPROT_PATTERN = re.compile( + r""" + ^ + # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot + (?:tr|sp) + \| + # A primary accession number of the UniProtKB entry. + (?P[A-Za-z0-9]{6,10}) + # Occasionally there is a _0 or _1 isoform suffix, which we ignore. + (?:_\d)? + \| + # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic + # protein ID code. + (?:[A-Za-z0-9]+) + _ + # A mnemonic species identification code. + (?P([A-Za-z0-9]){1,5}) + # Small BFD uses a final value after an underscore, which we ignore. + (?:_\d+)? + $ + """, + re.VERBOSE, +) + + +@dataclasses.dataclass(frozen=True) +class Identifiers: + species_id: str = '' + + +def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: + """Gets accession id and species from an msa sequence identifier. + + The sequence identifier has the format specified by + _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. + An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` + + Args: + msa_sequence_identifier: a sequence identifier. + + Returns: + An `Identifiers` instance with a species_id. These + can be empty in the case where no identifier was found. + """ + matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) + if matches: + return Identifiers(species_id=matches.group('SpeciesIdentifier')) + return Identifiers() + + +def _extract_sequence_identifier(description: str) -> Optional[str]: + """Extracts sequence identifier from description. Returns None if no match.""" + split_description = description.split() + if split_description: + return split_description[0].partition('/')[0] + else: + return None + + +def get_identifiers(description: str) -> Identifiers: + """Computes extra MSA features from the description.""" + sequence_identifier = _extract_sequence_identifier(description) + if sequence_identifier is None: + return Identifiers() + else: + return _parse_sequence_identifier(sequence_identifier) diff --git a/modelscope/models/science/unifold/msa/parsers.py b/modelscope/models/science/unifold/msa/parsers.py new file mode 100644 index 00000000..bf36c816 --- /dev/null +++ b/modelscope/models/science/unifold/msa/parsers.py @@ -0,0 +1,627 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for parsing various file formats.""" +import collections +import dataclasses +import itertools +import re +import string +from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple + +DeletionMatrix = Sequence[Sequence[int]] + + +@dataclasses.dataclass(frozen=True) +class Msa: + """Class representing a parsed MSA file.""" + + sequences: Sequence[str] + deletion_matrix: DeletionMatrix + descriptions: Sequence[str] + + def __post_init__(self): + if not (len(self.sequences) == len(self.deletion_matrix) == len( + self.descriptions)): + raise ValueError( + 'All fields for an MSA must have the same length. ' + f'Got {len(self.sequences)} sequences, ' + f'{len(self.deletion_matrix)} rows in the deletion matrix and ' + f'{len(self.descriptions)} descriptions.') + + def __len__(self): + return len(self.sequences) + + def truncate(self, max_seqs: int): + return Msa( + sequences=self.sequences[:max_seqs], + deletion_matrix=self.deletion_matrix[:max_seqs], + descriptions=self.descriptions[:max_seqs], + ) + + +@dataclasses.dataclass(frozen=True) +class TemplateHit: + """Class representing a template hit.""" + + index: int + name: str + aligned_cols: int + sum_probs: Optional[float] + query: str + hit_sequence: str + indices_query: List[int] + indices_hit: List[int] + + +def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith('>'): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append('') + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions + + +def parse_stockholm(stockholm_string: str) -> Msa: + """Parses sequences and deletion matrix from stockholm format alignment. + + Args: + stockholm_string: The string contents of a stockholm file. The first + sequence in the file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + * The names of the targets matched, including the jackhmmer subsequence + suffix. + """ + name_to_sequence = collections.OrderedDict() + for line in stockholm_string.splitlines(): + line = line.strip() + if not line or line.startswith(('#', '//')): + continue + name, sequence = line.split() + if name not in name_to_sequence: + name_to_sequence[name] = '' + name_to_sequence[name] += sequence + + msa = [] + deletion_matrix = [] + + query = '' + keep_columns = [] + for seq_index, sequence in enumerate(name_to_sequence.values()): + if seq_index == 0: + # Gather the columns with gaps from the query + query = sequence + keep_columns = [i for i, res in enumerate(query) if res != '-'] + + # Remove the columns with gaps in the query from all sequences. + aligned_sequence = ''.join([sequence[c] for c in keep_columns]) + + msa.append(aligned_sequence) + + # Count the number of deletions w.r.t. query. + deletion_vec = [] + deletion_count = 0 + for seq_res, query_res in zip(sequence, query): + if seq_res != '-' or query_res != '-': + if query_res == '-': + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + return Msa( + sequences=msa, + deletion_matrix=deletion_matrix, + descriptions=list(name_to_sequence.keys()), + ) + + +def parse_a3m(a3m_string: str) -> Msa: + """Parses sequences and deletion matrix from a3m format alignment. + + Args: + a3m_string: The string contents of a a3m file. The first sequence in the + file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + * A list of descriptions, one per sequence, from the a3m file. + """ + sequences, descriptions = parse_fasta(a3m_string) + deletion_matrix = [] + for msa_sequence in sequences: + deletion_vec = [] + deletion_count = 0 + for j in msa_sequence: + if j.islower(): + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + # Make the MSA matrix out of aligned (deletion-free) sequences. + deletion_table = str.maketrans('', '', string.ascii_lowercase) + aligned_sequences = [s.translate(deletion_table) for s in sequences] + return Msa( + sequences=aligned_sequences, + deletion_matrix=deletion_matrix, + descriptions=descriptions, + ) + + +def _convert_sto_seq_to_a3m(query_non_gaps: Sequence[bool], + sto_seq: str) -> Iterable[str]: + for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): + if is_query_res_non_gap: + yield sequence_res + elif sequence_res != '-': + yield sequence_res.lower() + + +def convert_stockholm_to_a3m( + stockholm_format: str, + max_sequences: Optional[int] = None, + remove_first_row_gaps: bool = True, +) -> str: + """Converts MSA in Stockholm format to the A3M format.""" + descriptions = {} + sequences = {} + reached_max_sequences = False + + for line in stockholm_format.splitlines(): + reached_max_sequences = max_sequences and len( + sequences) >= max_sequences + if line.strip() and not line.startswith(('#', '//')): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname, aligned_seq = line.split(maxsplit=1) + if seqname not in sequences: + if reached_max_sequences: + continue + sequences[seqname] = '' + sequences[seqname] += aligned_seq + + for line in stockholm_format.splitlines(): + if line[:4] == '#=GS': + # Description row - example format is: + # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... + columns = line.split(maxsplit=3) + seqname, feature = columns[1:3] + value = columns[3] if len(columns) == 4 else '' + if feature != 'DE': + continue + if reached_max_sequences and seqname not in sequences: + continue + descriptions[seqname] = value + if len(descriptions) == len(sequences): + break + + # Convert sto format to a3m line by line + a3m_sequences = {} + if remove_first_row_gaps: + # query_sequence is assumed to be the first sequence + query_sequence = next(iter(sequences.values())) + query_non_gaps = [res != '-' for res in query_sequence] + for seqname, sto_sequence in sequences.items(): + # Dots are optional in a3m format and are commonly removed. + out_sequence = sto_sequence.replace('.', '') + if remove_first_row_gaps: + out_sequence = ''.join( + _convert_sto_seq_to_a3m(query_non_gaps, out_sequence)) + a3m_sequences[seqname] = out_sequence + + fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" + for k in a3m_sequences) + return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. + + +def _keep_line(line: str, seqnames: Set[str]) -> bool: + """Function to decide which lines to keep.""" + if not line.strip(): + return True + if line.strip() == '//': # End tag + return True + if line.startswith('# STOCKHOLM'): # Start tag + return True + if line.startswith('#=GC RF'): # Reference Annotation Line + return True + if line[:4] == '#=GS': # Description lines - keep if sequence in list. + _, seqname, _ = line.split(maxsplit=2) + return seqname in seqnames + elif line.startswith('#'): # Other markup - filter out + return False + else: # Alignment data - keep if sequence in list. + seqname = line.partition(' ')[0] + return seqname in seqnames + + +def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str: + """Truncates a stockholm file to a maximum number of sequences.""" + seqnames = set() + filtered_lines = [] + for line in stockholm_msa.splitlines(): + if line.strip() and not line.startswith(('#', '//')): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname = line.partition(' ')[0] + seqnames.add(seqname) + if len(seqnames) >= max_sequences: + break + + for line in stockholm_msa.splitlines(): + if _keep_line(line, seqnames): + filtered_lines.append(line) + + return '\n'.join(filtered_lines) + '\n' + + +def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str: + """Removes empty columns (dashes-only) from a Stockholm MSA.""" + processed_lines = {} + unprocessed_lines = {} + for i, line in enumerate(stockholm_msa.splitlines()): + if line.startswith('#=GC RF'): + reference_annotation_i = i + reference_annotation_line = line + # Reached the end of this chunk of the alignment. Process chunk. + _, _, first_alignment = line.rpartition(' ') + mask = [] + for j in range(len(first_alignment)): + for _, unprocessed_line in unprocessed_lines.items(): + prefix, _, alignment = unprocessed_line.rpartition(' ') + if alignment[j] != '-': + mask.append(True) + break + else: # Every row contained a hyphen - empty column. + mask.append(False) + # Add reference annotation for processing with mask. + unprocessed_lines[ + reference_annotation_i] = reference_annotation_line + + if not any( + mask + ): # All columns were empty. Output empty lines for chunk. + for line_index in unprocessed_lines: + processed_lines[line_index] = '' + else: + for line_index, unprocessed_line in unprocessed_lines.items(): + prefix, _, alignment = unprocessed_line.rpartition(' ') + masked_alignment = ''.join( + itertools.compress(alignment, mask)) + processed_lines[ + line_index] = f'{prefix} {masked_alignment}' + + # Clear raw_alignments. + unprocessed_lines = {} + elif line.strip() and not line.startswith(('#', '//')): + unprocessed_lines[i] = line + else: + processed_lines[i] = line + return '\n'.join((processed_lines[i] for i in range(len(processed_lines)))) + + +def deduplicate_stockholm_msa(stockholm_msa: str) -> str: + """Remove duplicate sequences (ignoring insertions wrt query).""" + sequence_dict = collections.defaultdict(str) + + # First we must extract all sequences from the MSA. + for line in stockholm_msa.splitlines(): + # Only consider the alignments - ignore reference annotation, empty lines, + # descriptions or markup. + if line.strip() and not line.startswith(('#', '//')): + line = line.strip() + seqname, alignment = line.split() + sequence_dict[seqname] += alignment + + seen_sequences = set() + seqnames = set() + # First alignment is the query. + query_align = next(iter(sequence_dict.values())) + mask = [c != '-' for c in query_align] # Mask is False for insertions. + for seqname, alignment in sequence_dict.items(): + # Apply mask to remove all insertions from the string. + masked_alignment = ''.join(itertools.compress(alignment, mask)) + if masked_alignment in seen_sequences: + continue + else: + seen_sequences.add(masked_alignment) + seqnames.add(seqname) + + filtered_lines = [] + for line in stockholm_msa.splitlines(): + if _keep_line(line, seqnames): + filtered_lines.append(line) + + return '\n'.join(filtered_lines) + '\n' + + +def _get_hhr_line_regex_groups(regex_pattern: str, + line: str) -> Sequence[Optional[str]]: + match = re.match(regex_pattern, line) + if match is None: + raise RuntimeError(f'Could not parse query line {line}') + return match.groups() + + +def _update_hhr_residue_indices_list(sequence: str, start_index: int, + indices_list: List[int]): + """Computes the relative indices for each residue with respect to the original sequence.""" + counter = start_index + for symbol in sequence: + if symbol == '-': + indices_list.append(-1) + else: + indices_list.append(counter) + counter += 1 + + +def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: + """Parses the detailed HMM HMM comparison section for a single Hit. + + This works on .hhr files generated from both HHBlits and HHSearch. + + Args: + detailed_lines: A list of lines from a single comparison section between 2 + sequences (which each have their own HMM's) + + Returns: + A dictionary with the information from that detailed comparison section + + Raises: + RuntimeError: If a certain line cannot be processed + """ + # Parse first 2 lines. + number_of_hit = int(detailed_lines[0].split()[-1]) + name_hit = detailed_lines[1][1:] + + # Parse the summary line. + pattern = ( + 'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' + ' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' + ']*Template_Neff=(.*)') + match = re.match(pattern, detailed_lines[2]) + if match is None: + raise RuntimeError( + 'Could not parse section: %s. Expected this: \n%s to contain summary.' + % (detailed_lines, detailed_lines[2])) + (_, _, _, aligned_cols, _, _, sum_probs, + _) = [float(x) for x in match.groups()] + + # The next section reads the detailed comparisons. These are in a 'human + # readable' format which has a fixed length. The strategy employed is to + # assume that each block starts with the query sequence line, and to parse + # that with a regexp in order to deduce the fixed length used for that block. + query = '' + hit_sequence = '' + indices_query = [] + indices_hit = [] + length_block = None + + for line in detailed_lines[3:]: + # Parse the query sequence line + if (line.startswith('Q ') and not line.startswith('Q ss_dssp') + and not line.startswith('Q ss_pred') + and not line.startswith('Q Consensus')): + # Thus the first 17 characters must be 'Q ', and we can parse + # everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + + # Get the length of the parsed block using the start and finish indices, + # and ensure it is the same as the actual block length. + start = int(groups[0]) - 1 # Make index zero based. + delta_query = groups[1] + end = int(groups[2]) + num_insertions = len([x for x in delta_query if x == '-']) + length_block = end - start + num_insertions + assert length_block == len(delta_query) + + # Update the query sequence and indices list. + query += delta_query + _update_hhr_residue_indices_list(delta_query, start, indices_query) + + elif line.startswith('T '): + # Parse the hit sequence. + if (not line.startswith('T ss_dssp') + and not line.startswith('T ss_pred') + and not line.startswith('T Consensus')): + # Thus the first 17 characters must be 'T ', and we can + # parse everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + start = int(groups[0]) - 1 # Make index zero based. + delta_hit_sequence = groups[1] + assert length_block == len(delta_hit_sequence) + + # Update the hit sequence and indices list. + hit_sequence += delta_hit_sequence + _update_hhr_residue_indices_list(delta_hit_sequence, start, + indices_hit) + + return TemplateHit( + index=number_of_hit, + name=name_hit, + aligned_cols=int(aligned_cols), + sum_probs=sum_probs, + query=query, + hit_sequence=hit_sequence, + indices_query=indices_query, + indices_hit=indices_hit, + ) + + +def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: + """Parses the content of an entire HHR file.""" + lines = hhr_string.splitlines() + + # Each .hhr file starts with a results table, then has a sequence of hit + # "paragraphs", each paragraph starting with a line 'No '. We + # iterate through each paragraph to parse each hit. + + block_starts = [ + i for i, line in enumerate(lines) if line.startswith('No ') + ] + + hits = [] + if block_starts: + block_starts.append(len(lines)) # Add the end of the final block. + for i in range(len(block_starts) - 1): + hits.append( + _parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) + return hits + + +def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: + """Parse target to e-value mapping parsed from Jackhmmer tblout string.""" + e_values = {'query': 0} + lines = [line for line in tblout.splitlines() if line[0] != '#'] + # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are + # space-delimited. Relevant fields are (1) target name: and + # (5) E-value (full sequence) (numbering from 1). + for line in lines: + fields = line.split() + e_value = fields[4] + target_name = fields[0] + e_values[target_name] = float(e_value) + return e_values + + +def _get_indices(sequence: str, start: int) -> List[int]: + """Returns indices for non-gap/insert residues starting at the given index.""" + indices = [] + counter = start + for symbol in sequence: + # Skip gaps but add a placeholder so that the alignment is preserved. + if symbol == '-': + indices.append(-1) + # Skip deleted residues, but increase the counter. + elif symbol.islower(): + counter += 1 + # Normal aligned residue. Increase the counter and append to indices. + else: + indices.append(counter) + counter += 1 + return indices + + +@dataclasses.dataclass(frozen=True) +class HitMetadata: + pdb_id: str + chain: str + start: int + end: int + length: int + text: str + + +def _parse_hmmsearch_description(description: str) -> HitMetadata: + """Parses the hmmsearch A3M sequence description line.""" + # Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text + # Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352 + match = re.match( + r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$', + description.strip(), + ) + + if not match: + raise ValueError(f'Could not parse description: "{description}".') + + return HitMetadata( + pdb_id=match[1], + chain=match[2], + start=int(match[3]), + end=int(match[4]), + length=int(match[5]), + text=match[6], + ) + + +def parse_hmmsearch_a3m(query_sequence: str, + a3m_string: str, + skip_first: bool = True) -> Sequence[TemplateHit]: + """Parses an a3m string produced by hmmsearch. + + Args: + query_sequence: The query sequence. + a3m_string: The a3m string produced by hmmsearch. + skip_first: Whether to skip the first sequence in the a3m string. + + Returns: + A sequence of `TemplateHit` results. + """ + # Zip the descriptions and MSAs together, skip the first query sequence. + parsed_a3m = list(zip(*parse_fasta(a3m_string))) + if skip_first: + parsed_a3m = parsed_a3m[1:] + + indices_query = _get_indices(query_sequence, start=0) + + hits = [] + for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1): + if 'mol:protein' not in hit_description: + continue # Skip non-protein chains. + metadata = _parse_hmmsearch_description(hit_description) + # Aligned columns are only the match states. + aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence]) + indices_hit = _get_indices(hit_sequence, start=metadata.start - 1) + + hit = TemplateHit( + index=i, + name=f'{metadata.pdb_id}_{metadata.chain}', + aligned_cols=aligned_cols, + sum_probs=None, + query=query_sequence, + hit_sequence=hit_sequence.upper(), + indices_query=indices_query, + indices_hit=indices_hit, + ) + hits.append(hit) + + return hits diff --git a/modelscope/models/science/unifold/msa/pipeline.py b/modelscope/models/science/unifold/msa/pipeline.py new file mode 100644 index 00000000..b7889bff --- /dev/null +++ b/modelscope/models/science/unifold/msa/pipeline.py @@ -0,0 +1,282 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for building the input features for the unifold model.""" + +import os +from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union + +import numpy as np +from absl import logging + +from modelscope.models.science.unifold.data import residue_constants +from modelscope.models.science.unifold.msa import (msa_identifiers, parsers, + templates) +from modelscope.models.science.unifold.msa.tools import (hhblits, hhsearch, + hmmsearch, jackhmmer) + +FeatureDict = MutableMapping[str, np.ndarray] +TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] + + +def make_sequence_features(sequence: str, description: str, + num_res: int) -> FeatureDict: + """Constructs a feature dict of sequence features.""" + features = {} + features['aatype'] = residue_constants.sequence_to_onehot( + sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True, + ) + features['between_segment_residues'] = np.zeros((num_res, ), + dtype=np.int32) + features['domain_name'] = np.array([description.encode('utf-8')], + dtype=np.object_) + features['residue_index'] = np.array(range(num_res), dtype=np.int32) + features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) + features['sequence'] = np.array([sequence.encode('utf-8')], + dtype=np.object_) + return features + + +def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: + """Constructs a feature dict of MSA features.""" + if not msas: + raise ValueError('At least one MSA must be provided.') + + int_msa = [] + deletion_matrix = [] + species_ids = [] + seen_sequences = set() + for msa_index, msa in enumerate(msas): + if not msa: + raise ValueError( + f'MSA {msa_index} must contain at least one sequence.') + for sequence_index, sequence in enumerate(msa.sequences): + if sequence in seen_sequences: + continue + seen_sequences.add(sequence) + int_msa.append( + [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) + deletion_matrix.append(msa.deletion_matrix[sequence_index]) + identifiers = msa_identifiers.get_identifiers( + msa.descriptions[sequence_index]) + species_ids.append(identifiers.species_id.encode('utf-8')) + + num_res = len(msas[0].sequences[0]) + num_alignments = len(int_msa) + features = {} + features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) + features['msa'] = np.array(int_msa, dtype=np.int32) + features['num_alignments'] = np.array( + [num_alignments] * num_res, dtype=np.int32) + features['msa_species_identifiers'] = np.array( + species_ids, dtype=np.object_) + return features + + +def run_msa_tool( + msa_runner, + input_fasta_path: str, + msa_out_path: str, + msa_format: str, + use_precomputed_msas: bool, +) -> Mapping[str, Any]: + """Runs an MSA tool, checking if output already exists first.""" + if not use_precomputed_msas or not os.path.exists(msa_out_path): + result = msa_runner.query(input_fasta_path)[0] + with open(msa_out_path, 'w') as f: + f.write(result[msa_format]) + else: + logging.warning('Reading MSA from file %s', msa_out_path) + with open(msa_out_path, 'r') as f: + result = {msa_format: f.read()} + return result + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__( + self, + jackhmmer_binary_path: str, + hhblits_binary_path: str, + uniref90_database_path: str, + mgnify_database_path: str, + bfd_database_path: Optional[str], + uniclust30_database_path: Optional[str], + small_bfd_database_path: Optional[str], + uniprot_database_path: Optional[str], + template_searcher: TemplateSearcher, + template_featurizer: templates.TemplateHitFeaturizer, + use_small_bfd: bool, + mgnify_max_hits: int = 501, + uniref_max_hits: int = 10000, + use_precomputed_msas: bool = False, + ): + """Initializes the data pipeline.""" + self._use_small_bfd = use_small_bfd + self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniref90_database_path) + if use_small_bfd: + self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=small_bfd_database_path) + else: + self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( + binary_path=hhblits_binary_path, + databases=[bfd_database_path, uniclust30_database_path], + ) + self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=mgnify_database_path) + self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniprot_database_path) + self.template_searcher = template_searcher + self.template_featurizer = template_featurizer + self.mgnify_max_hits = mgnify_max_hits + self.uniref_max_hits = uniref_max_hits + self.use_precomputed_msas = use_precomputed_msas + + def process(self, input_fasta_path: str, + msa_output_dir: str) -> FeatureDict: + """Runs alignment tools on the input sequence and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError( + f'More than one input sequence found in {input_fasta_path}.') + input_sequence = input_seqs[0] + input_description = input_descs[0] + num_res = len(input_sequence) + + uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') + jackhmmer_uniref90_result = run_msa_tool( + self.jackhmmer_uniref90_runner, + input_fasta_path, + uniref90_out_path, + 'sto', + self.use_precomputed_msas, + ) + mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') + jackhmmer_mgnify_result = run_msa_tool( + self.jackhmmer_mgnify_runner, + input_fasta_path, + mgnify_out_path, + 'sto', + self.use_precomputed_msas, + ) + + msa_for_templates = jackhmmer_uniref90_result['sto'] + msa_for_templates = parsers.truncate_stockholm_msa( + msa_for_templates, max_sequences=self.uniref_max_hits) + msa_for_templates = parsers.deduplicate_stockholm_msa( + msa_for_templates) + msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( + msa_for_templates) + + if self.template_searcher.input_format == 'sto': + pdb_templates_result = self.template_searcher.query( + msa_for_templates) + elif self.template_searcher.input_format == 'a3m': + uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( + msa_for_templates) + pdb_templates_result = self.template_searcher.query( + uniref90_msa_as_a3m) + else: + raise ValueError('Unrecognized template input format: ' + f'{self.template_searcher.input_format}') + + pdb_hits_out_path = os.path.join( + msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}') + with open(pdb_hits_out_path, 'w') as f: + f.write(pdb_templates_result) + + uniref90_msa = parsers.parse_stockholm( + jackhmmer_uniref90_result['sto']) + uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) + mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) + mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) + + pdb_template_hits = self.template_searcher.get_template_hits( + output_string=pdb_templates_result, input_sequence=input_sequence) + + if self._use_small_bfd: + bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') + jackhmmer_small_bfd_result = run_msa_tool( + self.jackhmmer_small_bfd_runner, + input_fasta_path, + bfd_out_path, + 'sto', + self.use_precomputed_msas, + ) + bfd_msa = parsers.parse_stockholm( + jackhmmer_small_bfd_result['sto']) + else: + bfd_out_path = os.path.join(msa_output_dir, + 'bfd_uniclust_hits.a3m') + hhblits_bfd_uniclust_result = run_msa_tool( + self.hhblits_bfd_uniclust_runner, + input_fasta_path, + bfd_out_path, + 'a3m', + self.use_precomputed_msas, + ) + bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) + + templates_result = self.template_featurizer.get_templates( + query_sequence=input_sequence, hits=pdb_template_hits) + + sequence_features = make_sequence_features( + sequence=input_sequence, + description=input_description, + num_res=num_res) + + msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa)) + + logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) + logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) + logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) + logging.info( + 'Final (deduplicated) MSA size: %d sequences.', + msa_features['num_alignments'][0], + ) + logging.info( + 'Total number of templates (NB: this can include bad ' + 'templates and is later filtered to top 4): %d.', + templates_result.features['template_domain_names'].shape[0], + ) + + return { + **sequence_features, + **msa_features, + **templates_result.features + } + + def process_uniprot(self, input_fasta_path: str, + msa_output_dir: str) -> FeatureDict: + uniprot_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') + uniprot_result = run_msa_tool( + self.jackhmmer_uniprot_runner, + input_fasta_path, + uniprot_path, + 'sto', + self.use_precomputed_msas, + ) + msa = parsers.parse_stockholm(uniprot_result['sto']) + msa = msa.truncate(max_seqs=50000) + all_seq_dict = make_msa_features([msa]) + return all_seq_dict diff --git a/modelscope/models/science/unifold/msa/templates.py b/modelscope/models/science/unifold/msa/templates.py new file mode 100644 index 00000000..fe3bcef9 --- /dev/null +++ b/modelscope/models/science/unifold/msa/templates.py @@ -0,0 +1,1110 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for getting templates and calculating template features.""" +import abc +import dataclasses +import datetime +import functools +import glob +import os +import re +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple + +import numpy as np +from absl import logging + +from modelscope.models.science.unifold.data import residue_constants +from modelscope.models.science.unifold.msa import mmcif, parsers +from modelscope.models.science.unifold.msa.tools import kalign + + +class Error(Exception): + """Base class for exceptions.""" + + +class NoChainsError(Error): + """An error indicating that template mmCIF didn't have any chains.""" + + +class SequenceNotInTemplateError(Error): + """An error indicating that template mmCIF didn't contain the sequence.""" + + +class NoAtomDataInTemplateError(Error): + """An error indicating that template mmCIF didn't contain atom positions.""" + + +class TemplateAtomMaskAllZerosError(Error): + """An error indicating that template mmCIF had all atom positions masked.""" + + +class QueryToTemplateAlignError(Error): + """An error indicating that the query can't be aligned to the template.""" + + +class CaDistanceError(Error): + """An error indicating that a CA atom distance exceeds a threshold.""" + + +class MultipleChainsError(Error): + """An error indicating that multiple chains were found for a given ID.""" + + +# Prefilter exceptions. +class PrefilterError(Exception): + """A base class for template prefilter exceptions.""" + + +class DateError(PrefilterError): + """An error indicating that the hit date was after the max allowed date.""" + + +class AlignRatioError(PrefilterError): + """An error indicating that the hit align ratio to the query was too small.""" + + +class DuplicateError(PrefilterError): + """An error indicating that the hit was an exact subsequence of the query.""" + + +class LengthError(PrefilterError): + """An error indicating that the hit was too short.""" + + +TEMPLATE_FEATURES = { + 'template_aatype': np.float32, + 'template_all_atom_mask': np.float32, + 'template_all_atom_positions': np.float32, + 'template_domain_names': np.object_, + 'template_sequence': np.object_, + 'template_sum_probs': np.float32, +} + + +def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]: + """Returns PDB id and chain id for an HHSearch Hit.""" + # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. + id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name) + if not id_match: + raise ValueError( + f'hit.name did not start with PDBID_chain: {hit.name}') + pdb_id, chain_id = id_match.group(0).split('_') + return pdb_id.lower(), chain_id + + +def _is_after_cutoff( + pdb_id: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: Optional[datetime.datetime], +) -> bool: + """Checks if the template date is after the release date cutoff. + + Args: + pdb_id: 4 letter pdb code. + release_dates: Dictionary mapping PDB ids to their structure release dates. + release_date_cutoff: Max release date that is valid for this query. + + Returns: + True if the template release date is after the cutoff, False otherwise. + """ + if release_date_cutoff is None: + raise ValueError('The release_date_cutoff must not be None.') + if pdb_id in release_dates: + return release_dates[pdb_id] > release_date_cutoff + else: + # Since this is just a quick prefilter to reduce the number of mmCIF files + # we need to parse, we don't have to worry about returning True here. + return False + + +def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, Optional[str]]: + """Parses the data file from PDB that lists which pdb_ids are obsolete.""" + with open(obsolete_file_path) as f: + result = {} + for line in f: + line = line.strip() + # Format: Date From To + # 'OBSLTE 06-NOV-19 6G9Y' - Removed, rare + # 'OBSLTE 31-JUL-94 116L 216L' - Replaced, common + # 'OBSLTE 26-SEP-06 2H33 2JM5 2OWI' - Replaced by multiple, rare + if line.startswith('OBSLTE'): + if len(line) > 30: + # Replaced by at least one structure. + from_id = line[20:24].lower() + to_id = line[29:33].lower() + result[from_id] = to_id + elif len(line) == 24: + # Removed. + from_id = line[20:24].lower() + result[from_id] = None + return result + + +def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: + """Parses release dates file, returns a mapping from PDBs to release dates.""" + if path.endswith('txt'): + release_dates = {} + with open(path, 'r') as f: + for line in f: + pdb_id, date = line.split(':') + date = date.strip() + # Python 3.6 doesn't have datetime.date.fromisoformat() which is about + # 90x faster than strptime. However, splitting the string manually is + # about 10x faster than strptime. + release_dates[pdb_id.strip()] = datetime.datetime( + year=int(date[:4]), + month=int(date[5:7]), + day=int(date[8:10])) + return release_dates + else: + raise ValueError('Invalid format of the release date file %s.' % path) + + +def _assess_hhsearch_hit( + hit: parsers.TemplateHit, + hit_pdb_code: str, + query_sequence: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: datetime.datetime, + max_subsequence_ratio: float = 0.95, + min_align_ratio: float = 0.1, +) -> bool: + """Determines if template is valid (without parsing the template mmcif file). + + Args: + hit: HhrHit for the template. + hit_pdb_code: The 4 letter pdb code of the template hit. This might be + different from the value in the actual hit since the original pdb might + have become obsolete. + query_sequence: Amino acid sequence of the query. + release_dates: Dictionary mapping pdb codes to their structure release + dates. + release_date_cutoff: Max release date that is valid for this query. + max_subsequence_ratio: Exclude any exact matches with this much overlap. + min_align_ratio: Minimum overlap between the template and query. + + Returns: + True if the hit passed the prefilter. Raises an exception otherwise. + + Raises: + DateError: If the hit date was after the max allowed date. + AlignRatioError: If the hit align ratio to the query was too small. + DuplicateError: If the hit was an exact subsequence of the query. + LengthError: If the hit was too short. + """ + aligned_cols = hit.aligned_cols + align_ratio = aligned_cols / len(query_sequence) + + template_sequence = hit.hit_sequence.replace('-', '') + length_ratio = float(len(template_sequence)) / len(query_sequence) + + # Check whether the template is a large subsequence or duplicate of original + # query. This can happen due to duplicate entries in the PDB database. + duplicate = ( + template_sequence in query_sequence + and length_ratio > max_subsequence_ratio) + + if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): + raise DateError( + f'Date ({release_dates[hit_pdb_code]}) > max template date ' + f'({release_date_cutoff}).') + + if align_ratio <= min_align_ratio: + raise AlignRatioError( + 'Proportion of residues aligned to query too small. ' + f'Align ratio: {align_ratio}.') + + if duplicate: + raise DuplicateError( + 'Template is an exact subsequence of query with large ' + f'coverage. Length ratio: {length_ratio}.') + + if len(template_sequence) < 10: + raise LengthError( + f'Template too short. Length: {len(template_sequence)}.') + + return True + + +def _find_template_in_pdb( + template_chain_id: str, template_sequence: str, + mmcif_object: mmcif.MmcifObject) -> Tuple[str, str, int]: + """Tries to find the template chain in the given pdb file. + + This method tries the three following things in order: + 1. Tries if there is an exact match in both the chain ID and the sequence. + If yes, the chain sequence is returned. Otherwise: + 2. Tries if there is an exact match only in the sequence. + If yes, the chain sequence is returned. Otherwise: + 3. Tries if there is a fuzzy match (X = wildcard) in the sequence. + If yes, the chain sequence is returned. + If none of these succeed, a SequenceNotInTemplateError is thrown. + + Args: + template_chain_id: The template chain ID. + template_sequence: The template chain sequence. + mmcif_object: The PDB object to search for the template in. + + Returns: + A tuple with: + * The chain sequence that was found to match the template in the PDB object. + * The ID of the chain that is being returned. + * The offset where the template sequence starts in the chain sequence. + + Raises: + SequenceNotInTemplateError: If no match is found after the steps described + above. + """ + # Try if there is an exact match in both the chain ID and the (sub)sequence. + pdb_id = mmcif_object.file_id + chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id) + if chain_sequence and (template_sequence in chain_sequence): + logging.info('Found an exact template match %s_%s.', pdb_id, + template_chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, template_chain_id, mapping_offset + + # Try if there is an exact match in the (sub)sequence only. + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + if chain_sequence and (template_sequence in chain_sequence): + logging.info('Found a sequence-only match %s_%s.', pdb_id, + chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, chain_id, mapping_offset + + # Return a chain sequence that fuzzy matches (X = wildcard) the template. + # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit. + regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence] + regex = re.compile(''.join(regex)) + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + match = re.search(regex, chain_sequence) + if match: + logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id, + chain_id) + mapping_offset = match.start() + return chain_sequence, chain_id, mapping_offset + + # No hits, raise an error. + raise SequenceNotInTemplateError( + 'Could not find the template sequence in %s_%s. Template sequence: %s, ' + 'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence, + mmcif_object.chain_to_seqres)) + + +def _realign_pdb_template_to_query( + old_template_sequence: str, + template_chain_id: str, + mmcif_object: mmcif.MmcifObject, + old_mapping: Mapping[int, int], + kalign_binary_path: str, +) -> Tuple[str, Mapping[int, int]]: + """Aligns template from the mmcif_object to the query. + + In case PDB70 contains a different version of the template sequence, we need + to perform a realignment to the actual sequence that is in the mmCIF file. + This method performs such realignment, but returns the new sequence and + mapping only if the sequence in the mmCIF file is 90% identical to the old + sequence. + + Note that the old_template_sequence comes from the hit, and contains only that + part of the chain that matches with the query while the new_template_sequence + is the full chain. + + Args: + old_template_sequence: The template sequence that was returned by the PDB + template search (typically done using HHSearch). + template_chain_id: The template chain id was returned by the PDB template + search (typically done using HHSearch). This is used to find the right + chain in the mmcif_object chain_to_seqres mapping. + mmcif_object: A mmcif_object which holds the actual template data. + old_mapping: A mapping from the query sequence to the template sequence. + This mapping will be used to compute the new mapping from the query + sequence to the actual mmcif_object template sequence by aligning the + old_template_sequence and the actual template sequence. + kalign_binary_path: The path to a kalign executable. + + Returns: + A tuple (new_template_sequence, new_query_to_template_mapping) where: + * new_template_sequence is the actual template sequence that was found in + the mmcif_object. + * new_query_to_template_mapping is the new mapping from the query to the + actual template found in the mmcif_object. + + Raises: + QueryToTemplateAlignError: + * If there was an error thrown by the alignment tool. + * Or if the actual template sequence differs by more than 10% from the + old_template_sequence. + """ + aligner = kalign.Kalign(binary_path=kalign_binary_path) + new_template_sequence = mmcif_object.chain_to_seqres.get( + template_chain_id, '') + + # Sometimes the template chain id is unknown. But if there is only a single + # sequence within the mmcif_object, it is safe to assume it is that one. + if not new_template_sequence: + if len(mmcif_object.chain_to_seqres) == 1: + logging.info( + 'Could not find %s in %s, but there is only 1 sequence, so ' + 'using that one.', + template_chain_id, + mmcif_object.file_id, + ) + new_template_sequence = list( + mmcif_object.chain_to_seqres.values())[0] + else: + raise QueryToTemplateAlignError( + f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. ' + 'If there are no mmCIF parsing errors, it is possible it was not a ' + 'protein chain.') + + try: + parsed_a3m = parsers.parse_a3m( + aligner.align([old_template_sequence, new_template_sequence])) + old_aligned_template, new_aligned_template = parsed_a3m.sequences + except Exception as e: + raise QueryToTemplateAlignError( + 'Could not align old template %s to template %s (%s_%s). Error: %s' + % ( + old_template_sequence, + new_template_sequence, + mmcif_object.file_id, + template_chain_id, + str(e), + )) + + logging.info( + 'Old aligned template: %s\nNew aligned template: %s', + old_aligned_template, + new_aligned_template, + ) + + old_to_new_template_mapping = {} + old_template_index = -1 + new_template_index = -1 + num_same = 0 + for old_template_aa, new_template_aa in zip(old_aligned_template, + new_aligned_template): + if old_template_aa != '-': + old_template_index += 1 + if new_template_aa != '-': + new_template_index += 1 + if old_template_aa != '-' and new_template_aa != '-': + old_to_new_template_mapping[ + old_template_index] = new_template_index + if old_template_aa == new_template_aa: + num_same += 1 + + # Require at least 90 % sequence identity wrt to the shorter of the sequences. + if (float(num_same) + / min(len(old_template_sequence), len(new_template_sequence)) + < # noqa W504 + 0.9): + raise QueryToTemplateAlignError( + 'Insufficient similarity of the sequence in the database: %s to the ' + 'actual sequence in the mmCIF file %s_%s: %s. We require at least ' + '90 %% similarity wrt to the shorter of the sequences. This is not a ' + 'problem unless you think this is a template that should be included.' + % ( + old_template_sequence, + mmcif_object.file_id, + template_chain_id, + new_template_sequence, + )) + + new_query_to_template_mapping = {} + for query_index, old_template_index in old_mapping.items(): + new_query_to_template_mapping[ + query_index] = old_to_new_template_mapping.get( + old_template_index, -1) + + new_template_sequence = new_template_sequence.replace('-', '') + + return new_template_sequence, new_query_to_template_mapping + + +def _check_residue_distances(all_positions: np.ndarray, + all_positions_mask: np.ndarray, + max_ca_ca_distance: float): + """Checks if the distance between unmasked neighbor residues is ok.""" + ca_position = residue_constants.atom_order['CA'] + prev_is_unmasked = False + prev_calpha = None + for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)): + this_is_unmasked = bool(mask[ca_position]) + if this_is_unmasked: + this_calpha = coords[ca_position] + if prev_is_unmasked: + distance = np.linalg.norm(this_calpha - prev_calpha) + if distance > max_ca_ca_distance: + raise CaDistanceError( + 'The distance between residues %d and %d is %f > limit %f.' + % (i, i + 1, distance, max_ca_ca_distance)) + prev_calpha = this_calpha + prev_is_unmasked = this_is_unmasked + + +def _get_atom_positions( + mmcif_object: mmcif.MmcifObject, auth_chain_id: str, + max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]: + """Gets atom positions and mask from a list of Biopython Residues.""" + num_res = len(mmcif_object.chain_to_seqres[auth_chain_id]) + + relevant_chains = [ + c for c in mmcif_object.structure.get_chains() if c.id == auth_chain_id + ] + if len(relevant_chains) != 1: + raise MultipleChainsError( + f'Expected exactly one chain in structure with id {auth_chain_id}.' + ) + chain = relevant_chains[0] + + all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3]) + all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num], + dtype=np.int64) + for res_index in range(num_res): + pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32) + mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) + res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][ + res_index] + if not res_at_position.is_missing: + res = chain[( + res_at_position.hetflag, + res_at_position.position.residue_number, + res_at_position.position.insertion_code, + )] + for atom in res.get_atoms(): + atom_name = atom.get_name() + x, y, z = atom.get_coord() + if atom_name in residue_constants.atom_order.keys(): + pos[residue_constants.atom_order[atom_name]] = [x, y, z] + mask[residue_constants.atom_order[atom_name]] = 1.0 + elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE': + # Put the coordinates of the selenium atom in the sulphur column. + pos[residue_constants.atom_order['SD']] = [x, y, z] + mask[residue_constants.atom_order['SD']] = 1.0 + + # Fix naming errors in arginine residues where NH2 is incorrectly + # assigned to be closer to CD than NH1. + cd = residue_constants.atom_order['CD'] + nh1 = residue_constants.atom_order['NH1'] + nh2 = residue_constants.atom_order['NH2'] + if (res.get_resname() == 'ARG' + and all(mask[atom_index] for atom_index in (cd, nh1, nh2)) + and (np.linalg.norm(pos[nh1] - pos[cd]) > # noqa W504 + np.linalg.norm(pos[nh2] - pos[cd]))): + pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy() + mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy() + + all_positions[res_index] = pos + all_positions_mask[res_index] = mask + _check_residue_distances(all_positions, all_positions_mask, + max_ca_ca_distance) + return all_positions, all_positions_mask + + +def _extract_template_features( + mmcif_object: mmcif.MmcifObject, + pdb_id: str, + mapping: Mapping[int, int], + template_sequence: str, + query_sequence: str, + template_chain_id: str, + kalign_binary_path: str, +) -> Tuple[Dict[str, Any], Optional[str]]: + """Parses atom positions in the target structure and aligns with the query. + + Atoms for each residue in the template structure are indexed to coincide + with their corresponding residue in the query sequence, according to the + alignment mapping provided. + + Args: + mmcif_object: mmcif_parsing.MmcifObject representing the template. + pdb_id: PDB code for the template. + mapping: Dictionary mapping indices in the query sequence to indices in + the template sequence. + template_sequence: String describing the amino acid sequence for the + template protein. + query_sequence: String describing the amino acid sequence for the query + protein. + template_chain_id: String ID describing which chain in the structure proto + should be used. + kalign_binary_path: The path to a kalign executable used for template + realignment. + + Returns: + A tuple with: + * A dictionary containing the extra features derived from the template + protein structure. + * A warning message if the hit was realigned to the actual mmCIF sequence. + Otherwise None. + + Raises: + NoChainsError: If the mmcif object doesn't contain any chains. + SequenceNotInTemplateError: If the given chain id / sequence can't + be found in the mmcif object. + QueryToTemplateAlignError: If the actual template in the mmCIF file + can't be aligned to the query. + NoAtomDataInTemplateError: If the mmcif object doesn't contain + atom positions. + TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any + unmasked residues. + """ + if mmcif_object is None or not mmcif_object.chain_to_seqres: + raise NoChainsError('No chains in PDB: %s_%s' % + (pdb_id, template_chain_id)) + + warning = None + try: + seqres, chain_id, mapping_offset = _find_template_in_pdb( + template_chain_id=template_chain_id, + template_sequence=template_sequence, + mmcif_object=mmcif_object, + ) + except SequenceNotInTemplateError: + # If PDB70 contains a different version of the template, we use the sequence + # from the mmcif_object. + chain_id = template_chain_id + warning = ( + f'The exact sequence {template_sequence} was not found in ' + f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.' + ) + logging.warning(warning) + # This throws an exception if it fails to realign the hit. + seqres, mapping = _realign_pdb_template_to_query( + old_template_sequence=template_sequence, + template_chain_id=template_chain_id, + mmcif_object=mmcif_object, + old_mapping=mapping, + kalign_binary_path=kalign_binary_path, + ) + logging.info( + 'Sequence in %s_%s: %s successfully realigned to %s', + pdb_id, + chain_id, + template_sequence, + seqres, + ) + # The template sequence changed. + template_sequence = seqres + # No mapping offset, the query is aligned to the actual sequence. + mapping_offset = 0 + + try: + # Essentially set to infinity - we don't want to reject templates unless + # they're really really bad. + all_atom_positions, all_atom_mask = _get_atom_positions( + mmcif_object, chain_id, max_ca_ca_distance=150.0) + except (CaDistanceError, KeyError) as ex: + raise NoAtomDataInTemplateError('Could not get atom data (%s_%s): %s' % + (pdb_id, chain_id, str(ex))) from ex + + all_atom_positions = np.split(all_atom_positions, + all_atom_positions.shape[0]) + all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0]) + + output_templates_sequence = [] + templates_all_atom_positions = [] + templates_all_atom_masks = [] + + for _ in query_sequence: + # Residues in the query_sequence that are not in the template_sequence: + templates_all_atom_positions.append( + np.zeros((residue_constants.atom_type_num, 3))) + templates_all_atom_masks.append( + np.zeros(residue_constants.atom_type_num)) + output_templates_sequence.append('-') + + for k, v in mapping.items(): + template_index = v + mapping_offset + templates_all_atom_positions[k] = all_atom_positions[template_index][0] + templates_all_atom_masks[k] = all_atom_masks[template_index][0] + output_templates_sequence[k] = template_sequence[v] + + # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O). + if np.sum(templates_all_atom_masks) < 5: + raise TemplateAtomMaskAllZerosError( + 'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' + % ( + pdb_id, + chain_id, + min(mapping.values()) + mapping_offset, + max(mapping.values()) + mapping_offset, + )) + + output_templates_sequence = ''.join(output_templates_sequence) + + templates_aatype = residue_constants.sequence_to_onehot( + output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID) + + return ( + { + 'template_all_atom_positions': + np.array(templates_all_atom_positions), + 'template_all_atom_mask': np.array(templates_all_atom_masks), + 'template_sequence': output_templates_sequence.encode(), + 'template_aatype': np.array(templates_aatype), + 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(), + }, + warning, + ) + + +def _build_query_to_hit_index_mapping( + hit_query_sequence: str, + hit_sequence: str, + indices_hit: Sequence[int], + indices_query: Sequence[int], + original_query_sequence: str, +) -> Mapping[int, int]: + """Gets mapping from indices in original query sequence to indices in the hit. + + hit_query_sequence and hit_sequence are two aligned sequences containing gap + characters. hit_query_sequence contains only the part of the original query + sequence that matched the hit. When interpreting the indices from the .hhr, we + need to correct for this to recover a mapping from original query sequence to + the hit sequence. + + Args: + hit_query_sequence: The portion of the query sequence that is in the .hhr + hit + hit_sequence: The portion of the hit sequence that is in the .hhr + indices_hit: The indices for each aminoacid relative to the hit sequence + indices_query: The indices for each aminoacid relative to the original query + sequence + original_query_sequence: String describing the original query sequence. + + Returns: + Dictionary with indices in the original query sequence as keys and indices + in the hit sequence as values. + """ + # If the hit is empty (no aligned residues), return empty mapping + if not hit_query_sequence: + return {} + + # Remove gaps and find the offset of hit.query relative to original query. + hhsearch_query_sequence = hit_query_sequence.replace('-', '') + hit_sequence = hit_sequence.replace('-', '') + hhsearch_query_offset = original_query_sequence.find( + hhsearch_query_sequence) + + # Index of -1 used for gap characters. Subtract the min index ignoring gaps. + min_idx = min(x for x in indices_hit if x > -1) + fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit] + + min_idx = min(x for x in indices_query if x > -1) + fixed_indices_query = [ + x - min_idx if x > -1 else -1 for x in indices_query + ] + + # Zip the corrected indices, ignore case where both seqs have gap characters. + mapping = {} + for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit): + if q_t != -1 and q_i != -1: + if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len( + original_query_sequence): + continue + mapping[q_i + hhsearch_query_offset] = q_t + + return mapping + + +@dataclasses.dataclass(frozen=True) +class SingleHitResult: + features: Optional[Mapping[str, Any]] + error: Optional[str] + warning: Optional[str] + + +@functools.lru_cache(16, typed=False) +def _read_file(path): + with open(path, 'r') as f: + file_data = f.read() + return file_data + + +def _process_single_hit( + query_sequence: str, + hit: parsers.TemplateHit, + mmcif_dir: str, + max_template_date: datetime.datetime, + release_dates: Mapping[str, datetime.datetime], + obsolete_pdbs: Mapping[str, Optional[str]], + kalign_binary_path: str, + strict_error_check: bool = False, +) -> SingleHitResult: + """Tries to extract template features from a single HHSearch hit.""" + # Fail hard if we can't get the PDB ID and chain name from the hit. + hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) + + # This hit has been removed (obsoleted) from PDB, skip it. + if hit_pdb_code in obsolete_pdbs and obsolete_pdbs[hit_pdb_code] is None: + return SingleHitResult( + features=None, + error=None, + warning=f'Hit {hit_pdb_code} is obsolete.') + + if hit_pdb_code not in release_dates: + if hit_pdb_code in obsolete_pdbs: + hit_pdb_code = obsolete_pdbs[hit_pdb_code] + + # Pass hit_pdb_code since it might have changed due to the pdb being obsolete. + try: + _assess_hhsearch_hit( + hit=hit, + hit_pdb_code=hit_pdb_code, + query_sequence=query_sequence, + release_dates=release_dates, + release_date_cutoff=max_template_date, + ) + except PrefilterError as e: + msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}' + logging.info(msg) + if strict_error_check and isinstance(e, (DateError, DuplicateError)): + # In strict mode we treat some prefilter cases as errors. + return SingleHitResult(features=None, error=msg, warning=None) + + return SingleHitResult(features=None, error=None, warning=None) + + mapping = _build_query_to_hit_index_mapping(hit.query, hit.hit_sequence, + hit.indices_hit, + hit.indices_query, + query_sequence) + + # The mapping is from the query to the actual hit sequence, so we need to + # remove gaps (which regardless have a missing confidence score). + template_sequence = hit.hit_sequence.replace('-', '') + + cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') + logging.debug( + 'Reading PDB entry from %s. Query: %s, template: %s', + cif_path, + query_sequence, + template_sequence, + ) + # Fail if we can't find the mmCIF file. + cif_string = _read_file(cif_path) + + parsing_result = mmcif.parse(file_id=hit_pdb_code, mmcif_string=cif_string) + + if parsing_result.mmcif_object is not None: + hit_release_date = datetime.datetime.strptime( + parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d') + if hit_release_date > max_template_date: + error = 'Template %s date (%s) > max template date (%s).' % ( + hit_pdb_code, + hit_release_date, + max_template_date, + ) + if strict_error_check: + return SingleHitResult( + features=None, error=error, warning=None) + else: + logging.debug(error) + return SingleHitResult(features=None, error=None, warning=None) + + try: + features, realign_warning = _extract_template_features( + mmcif_object=parsing_result.mmcif_object, + pdb_id=hit_pdb_code, + mapping=mapping, + template_sequence=template_sequence, + query_sequence=query_sequence, + template_chain_id=hit_chain_id, + kalign_binary_path=kalign_binary_path, + ) + if hit.sum_probs is None: + features['template_sum_probs'] = [0] + else: + features['template_sum_probs'] = [hit.sum_probs] + + # It is possible there were some errors when parsing the other chains in the + # mmCIF file, but the template features for the chain we want were still + # computed. In such case the mmCIF parsing errors are not relevant. + return SingleHitResult( + features=features, error=None, warning=realign_warning) + except ( + NoChainsError, + NoAtomDataInTemplateError, + TemplateAtomMaskAllZerosError, + ) as e: + # These 3 errors indicate missing mmCIF experimental data rather than a + # problem with the template search, so turn them into warnings. + warning = ( + '%s_%s (sum_probs: %s, rank: %s): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % ( + hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors, + )) + if strict_error_check: + return SingleHitResult(features=None, error=warning, warning=None) + else: + return SingleHitResult(features=None, error=None, warning=warning) + except Error as e: + error = ( + '%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % ( + hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors, + )) + return SingleHitResult(features=None, error=error, warning=None) + + +@dataclasses.dataclass(frozen=True) +class TemplateSearchResult: + features: Mapping[str, Any] + errors: Sequence[str] + warnings: Sequence[str] + + +class TemplateHitFeaturizer(abc.ABC): + """An abstract base class for turning template hits to template features.""" + + def __init__( + self, + mmcif_dir: str, + max_template_date: str, + max_hits: int, + kalign_binary_path: str, + release_dates_path: Optional[str], + obsolete_pdbs_path: Optional[str], + strict_error_check: bool = False, + ): + """Initializes the Template Search. + + Args: + mmcif_dir: Path to a directory with mmCIF structures. Once a template ID + is found by HHSearch, this directory is used to retrieve the template + data. + max_template_date: The maximum date permitted for template structures. No + template with date higher than this date will be returned. In ISO8601 + date format, YYYY-MM-DD. + max_hits: The maximum number of templates that will be returned. + kalign_binary_path: The path to a kalign executable used for template + realignment. + release_dates_path: An optional path to a file with a mapping from PDB IDs + to their release dates. Thanks to this we don't have to redundantly + parse mmCIF files to get that information. + obsolete_pdbs_path: An optional path to a file containing a mapping from + obsolete PDB IDs to the PDB IDs of their replacements. + strict_error_check: If True, then the following will be treated as errors: + * If any template date is after the max_template_date. + * If any template has identical PDB ID to the query. + * If any template is a duplicate of the query. + * Any feature computation errors. + """ + self._mmcif_dir = mmcif_dir + if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')): + logging.error('Could not find CIFs in %s', self._mmcif_dir) + raise ValueError(f'Could not find CIFs in {self._mmcif_dir}') + + try: + self._max_template_date = datetime.datetime.strptime( + max_template_date, '%Y-%m-%d') + except ValueError: + raise ValueError( + 'max_template_date must be set and have format YYYY-MM-DD.') + self._max_hits = max_hits + self._kalign_binary_path = kalign_binary_path + self._strict_error_check = strict_error_check + + if release_dates_path: + logging.info('Using precomputed release dates %s.', + release_dates_path) + self._release_dates = _parse_release_dates(release_dates_path) + else: + self._release_dates = {} + + if obsolete_pdbs_path: + logging.info('Using precomputed obsolete pdbs %s.', + obsolete_pdbs_path) + self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path) + else: + self._obsolete_pdbs = {} + + @abc.abstractmethod + def get_templates( + self, query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence.""" + + +class HhsearchHitFeaturizer(TemplateHitFeaturizer): + """A class for turning a3m hits from hhsearch to template features.""" + + def get_templates( + self, query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_sequence) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + num_hits = 0 + errors = [] + warnings = [] + + for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True): + # We got all the templates we wanted, stop processing hits. + if num_hits >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=self._max_template_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path, + ) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.info( + 'Skipped invalid hit %s, error: %s, warning: %s', + hit.name, + result.error, + result.warning, + ) + else: + # Increment the hit counter, since we got features out of this hit. + num_hits += 1 + for k in template_features: + template_features[k].append(result.features[k]) + + for name in template_features: + if num_hits > 0: + template_features[name] = np.stack( + template_features[name], + axis=0).astype(TEMPLATE_FEATURES[name]) + else: + # Make sure the feature has correct dtype even if empty. + template_features[name] = np.array( + [], dtype=TEMPLATE_FEATURES[name]) + + return TemplateSearchResult( + features=template_features, errors=errors, warnings=warnings) + + +class HmmsearchHitFeaturizer(TemplateHitFeaturizer): + """A class for turning a3m hits from hmmsearch to template features.""" + + def get_templates( + self, query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_sequence) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + already_seen = set() + errors = [] + warnings = [] + + if not hits or hits[0].sum_probs is None: + sorted_hits = hits + else: + sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True) + + for hit in sorted_hits: + # We got all the templates we wanted, stop processing hits. + if len(already_seen) >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=self._max_template_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path, + ) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.debug( + 'Skipped invalid hit %s, error: %s, warning: %s', + hit.name, + result.error, + result.warning, + ) + else: + already_seen_key = result.features['template_sequence'] + if already_seen_key in already_seen: + continue + # Increment the hit counter, since we got features out of this hit. + already_seen.add(already_seen_key) + for k in template_features: + template_features[k].append(result.features[k]) + + if already_seen: + for name in template_features: + template_features[name] = np.stack( + template_features[name], + axis=0).astype(TEMPLATE_FEATURES[name]) + else: + num_res = len(query_sequence) + # Construct a default template with all zeros. + template_features = { + 'template_aatype': + np.zeros( + (1, num_res, len( + residue_constants.restypes_with_x_and_gap)), + np.float32, + ), + 'template_all_atom_mask': + np.zeros((1, num_res, residue_constants.atom_type_num), + np.float32), + 'template_all_atom_positions': + np.zeros((1, num_res, residue_constants.atom_type_num, 3), + np.float32), + 'template_domain_names': + np.array([''.encode()], dtype=np.object), + 'template_sequence': + np.array([''.encode()], dtype=np.object), + 'template_sum_probs': + np.array([0], dtype=np.float32), + } + return TemplateSearchResult( + features=template_features, errors=errors, warnings=warnings) diff --git a/modelscope/models/science/unifold/msa/tools/__init__.py b/modelscope/models/science/unifold/msa/tools/__init__.py new file mode 100644 index 00000000..903d0979 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python wrappers for third party tools.""" diff --git a/modelscope/models/science/unifold/msa/tools/hhblits.py b/modelscope/models/science/unifold/msa/tools/hhblits.py new file mode 100644 index 00000000..ee442e39 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hhblits.py @@ -0,0 +1,170 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to run HHblits from Python.""" + +import glob +import os +import subprocess +from typing import Any, List, Mapping, Optional, Sequence + +from absl import logging + +from . import utils + +_HHBLITS_DEFAULT_P = 20 +_HHBLITS_DEFAULT_Z = 500 + + +class HHBlits: + """Python wrapper of the HHblits binary.""" + + def __init__( + self, + *, + binary_path: str, + databases: Sequence[str], + n_cpu: int = 4, + n_iter: int = 3, + e_value: float = 0.001, + maxseq: int = 1_000_000, + realign_max: int = 100_000, + maxfilt: int = 100_000, + min_prefilter_hits: int = 1000, + all_seqs: bool = False, + alt: Optional[int] = None, + p: int = _HHBLITS_DEFAULT_P, + z: int = _HHBLITS_DEFAULT_Z, + ): + """Initializes the Python HHblits wrapper. + + Args: + binary_path: The path to the HHblits executable. + databases: A sequence of HHblits database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + n_cpu: The number of CPUs to give HHblits. + n_iter: The number of HHblits iterations. + e_value: The E-value, see HHblits docs for more details. + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. + maxfilt: Max number of hits allowed to pass the 2nd prefilter. + HHblits default: 20000. + min_prefilter_hits: Min number of hits to pass prefilter. + HHblits default: 100. + all_seqs: Return all sequences in the MSA / Do not filter the result MSA. + HHblits default: False. + alt: Show up to this many alternative alignments. + p: Minimum Prob for a hit to be included in the output hhr file. + HHblits default: 20. + z: Hard cap on number of hits reported in the hhr file. + HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. + + Raises: + RuntimeError: If HHblits binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHBlits database %s', + database_path) + raise ValueError( + f'Could not find HHBlits database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.maxseq = maxseq + self.realign_max = realign_max + self.maxfilt = maxfilt + self.min_prefilter_hits = min_prefilter_hits + self.all_seqs = all_seqs + self.alt = alt + self.p = p + self.z = z + + def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]: + """Queries the database using HHblits.""" + with utils.tmpdir_manager() as query_tmp_dir: + a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [ + self.binary_path, + '-i', + input_fasta_path, + '-cpu', + str(self.n_cpu), + '-oa3m', + a3m_path, + '-o', + '/dev/null', + '-n', + str(self.n_iter), + '-e', + str(self.e_value), + '-maxseq', + str(self.maxseq), + '-realign_max', + str(self.realign_max), + '-maxfilt', + str(self.maxfilt), + '-min_prefilter_hits', + str(self.min_prefilter_hits), + ] + if self.all_seqs: + cmd += ['-all'] + if self.alt: + cmd += ['-alt', str(self.alt)] + if self.p != _HHBLITS_DEFAULT_P: + cmd += ['-p', str(self.p)] + if self.z != _HHBLITS_DEFAULT_Z: + cmd += ['-Z', str(self.z)] + cmd += db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('HHblits query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Logs have a 15k character limit, so log HHblits error line by line. + logging.error('HHblits failed. HHblits stderr begin:') + for error_line in stderr.decode('utf-8').splitlines(): + if error_line.strip(): + logging.error(error_line.strip()) + logging.error('HHblits stderr end') + raise RuntimeError( + 'HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) + + with open(a3m_path) as f: + a3m = f.read() + + raw_output = dict( + a3m=a3m, + output=stdout, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value, + ) + return [raw_output] diff --git a/modelscope/models/science/unifold/msa/tools/hhsearch.py b/modelscope/models/science/unifold/msa/tools/hhsearch.py new file mode 100644 index 00000000..ac7f3b55 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hhsearch.py @@ -0,0 +1,111 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to run HHsearch from Python.""" + +import glob +import os +import subprocess +from typing import Sequence + +from absl import logging + +from modelscope.models.science.unifold.msa import parsers +from . import utils + + +class HHSearch: + """Python wrapper of the HHsearch binary.""" + + def __init__(self, + *, + binary_path: str, + databases: Sequence[str], + maxseq: int = 1_000_000): + """Initializes the Python HHsearch wrapper. + + Args: + binary_path: The path to the HHsearch executable. + databases: A sequence of HHsearch database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + + Raises: + RuntimeError: If HHsearch binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + self.maxseq = maxseq + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHsearch database %s', + database_path) + raise ValueError( + f'Could not find HHsearch database {database_path}') + + @property + def output_format(self) -> str: + return 'hhr' + + @property + def input_format(self) -> str: + return 'a3m' + + def query(self, a3m: str) -> str: + """Queries the database using HHsearch using a given a3m.""" + with utils.tmpdir_manager() as query_tmp_dir: + input_path = os.path.join(query_tmp_dir, 'query.a3m') + hhr_path = os.path.join(query_tmp_dir, 'output.hhr') + with open(input_path, 'w') as f: + f.write(a3m) + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [ + self.binary_path, + '-i', + input_path, + '-o', + hhr_path, + '-maxseq', + str(self.maxseq), + ] + db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing('HHsearch query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Stderr is truncated to prevent proto size errors in Beam. + raise RuntimeError( + 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) + + with open(hhr_path) as f: + hhr = f.read() + return hhr + + def get_template_hits( + self, output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + del input_sequence # Used by hmmseach but not needed for hhsearch. + return parsers.parse_hhr(output_string) diff --git a/modelscope/models/science/unifold/msa/tools/hmmbuild.py b/modelscope/models/science/unifold/msa/tools/hmmbuild.py new file mode 100644 index 00000000..84f205d6 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hmmbuild.py @@ -0,0 +1,143 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" + +import os +import re +import subprocess + +from absl import logging + +from . import utils + + +class Hmmbuild(object): + """Python wrapper of the hmmbuild binary.""" + + def __init__(self, *, binary_path: str, singlemx: bool = False): + """Initializes the Python hmmbuild wrapper. + + Args: + binary_path: The path to the hmmbuild executable. + singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to + just use a common substitution score matrix. + + Raises: + RuntimeError: If hmmbuild binary not found within the path. + """ + self.binary_path = binary_path + self.singlemx = singlemx + + def build_profile_from_sto(self, + sto: str, + model_construction='fast') -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + sto: A string with the aligned sequences in the Stockholm format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + return self._build_profile(sto, model_construction=model_construction) + + def build_profile_from_a3m(self, a3m: str) -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + a3m: A string with the aligned sequences in the A3M format. + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + lines = [] + for line in a3m.splitlines(): + if not line.startswith('>'): + line = re.sub('[a-z]+', '', line) # Remove inserted residues. + lines.append(line + '\n') + msa = ''.join(lines) + return self._build_profile(msa, model_construction='fast') + + def _build_profile(self, + msa: str, + model_construction: str = 'fast') -> str: + """Builds a HMM for the aligned sequences given as an MSA string. + + Args: + msa: A string with the aligned sequences, in A3M or STO format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + ValueError: If unspecified arguments are provided. + """ + if model_construction not in {'hand', 'fast'}: + raise ValueError( + f'Invalid model_construction {model_construction} - only' + 'hand and fast supported.') + + with utils.tmpdir_manager() as query_tmp_dir: + input_query = os.path.join(query_tmp_dir, 'query.msa') + output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') + + with open(input_query, 'w') as f: + f.write(msa) + + cmd = [self.binary_path] + # If adding flags, we have to do so before the output and input: + + if model_construction == 'hand': + cmd.append(f'--{model_construction}') + if self.singlemx: + cmd.append('--singlemx') + cmd.extend([ + '--amino', + output_hmm_path, + input_query, + ]) + + logging.info('Launching subprocess %s', cmd) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('hmmbuild query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info( + 'hmmbuild stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), + stderr.decode('utf-8'), + ) + + if retcode: + raise RuntimeError( + 'hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_hmm_path, encoding='utf-8') as f: + hmm = f.read() + + return hmm diff --git a/modelscope/models/science/unifold/msa/tools/hmmsearch.py b/modelscope/models/science/unifold/msa/tools/hmmsearch.py new file mode 100644 index 00000000..445970ca --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hmmsearch.py @@ -0,0 +1,146 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Python wrapper for hmmsearch - search profile against a sequence db.""" + +import os +import subprocess +from typing import Optional, Sequence + +from absl import logging + +from modelscope.models.science.unifold.msa import parsers +from . import hmmbuild, utils + + +class Hmmsearch(object): + """Python wrapper of the hmmsearch binary.""" + + def __init__( + self, + *, + binary_path: str, + hmmbuild_binary_path: str, + database_path: str, + flags: Optional[Sequence[str]] = None, + ): + """Initializes the Python hmmsearch wrapper. + + Args: + binary_path: The path to the hmmsearch executable. + hmmbuild_binary_path: The path to the hmmbuild executable. Used to build + an hmm from an input a3m. + database_path: The path to the hmmsearch database (FASTA format). + flags: List of flags to be used by hmmsearch. + + Raises: + RuntimeError: If hmmsearch binary not found within the path. + """ + self.binary_path = binary_path + self.hmmbuild_runner = hmmbuild.Hmmbuild( + binary_path=hmmbuild_binary_path) + self.database_path = database_path + if flags is None: + # Default hmmsearch run settings. + flags = [ + '--F1', + '0.1', + '--F2', + '0.1', + '--F3', + '0.1', + '--incE', + '100', + '-E', + '100', + '--domE', + '100', + '--incdomE', + '100', + ] + self.flags = flags + + if not os.path.exists(self.database_path): + logging.error('Could not find hmmsearch database %s', + database_path) + raise ValueError( + f'Could not find hmmsearch database {database_path}') + + @property + def output_format(self) -> str: + return 'sto' + + @property + def input_format(self) -> str: + return 'sto' + + def query(self, msa_sto: str) -> str: + """Queries the database using hmmsearch using a given stockholm msa.""" + hmm = self.hmmbuild_runner.build_profile_from_sto( + msa_sto, model_construction='hand') + return self.query_with_hmm(hmm) + + def query_with_hmm(self, hmm: str) -> str: + """Queries the database using hmmsearch using a given hmm.""" + with utils.tmpdir_manager() as query_tmp_dir: + hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') + out_path = os.path.join(query_tmp_dir, 'output.sto') + with open(hmm_input_path, 'w') as f: + f.write(hmm) + + cmd = [ + self.binary_path, + '--noali', # Don't include the alignment in stdout. + '--cpu', + '8', + ] + # If adding flags, we have to do so before the output and input: + if self.flags: + cmd.extend(self.flags) + cmd.extend([ + '-A', + out_path, + hmm_input_path, + self.database_path, + ]) + + logging.info('Launching sub-process %s', cmd) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'hmmsearch ({os.path.basename(self.database_path)}) query' + ): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError( + 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(out_path) as f: + out_msa = f.read() + + return out_msa + + def get_template_hits( + self, output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + a3m_string = parsers.convert_stockholm_to_a3m( + output_string, remove_first_row_gaps=False) + template_hits = parsers.parse_hmmsearch_a3m( + query_sequence=input_sequence, + a3m_string=a3m_string, + skip_first=False) + return template_hits diff --git a/modelscope/models/science/unifold/msa/tools/jackhmmer.py b/modelscope/models/science/unifold/msa/tools/jackhmmer.py new file mode 100644 index 00000000..3e29eec9 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/jackhmmer.py @@ -0,0 +1,224 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to run Jackhmmer from Python.""" + +import glob +import os +import subprocess +from concurrent import futures +from typing import Any, Callable, Mapping, Optional, Sequence +from urllib import request + +from absl import logging + +from . import utils + + +class Jackhmmer: + """Python wrapper of the Jackhmmer binary.""" + + def __init__( + self, + *, + binary_path: str, + database_path: str, + n_cpu: int = 8, + n_iter: int = 1, + e_value: float = 0.0001, + z_value: Optional[int] = None, + get_tblout: bool = False, + filter_f1: float = 0.0005, + filter_f2: float = 0.00005, + filter_f3: float = 0.0000005, + incdom_e: Optional[float] = None, + dom_e: Optional[float] = None, + num_streamed_chunks: Optional[int] = None, + streaming_callback: Optional[Callable[[int], None]] = None, + ): + """Initializes the Python Jackhmmer wrapper. + + Args: + binary_path: The path to the jackhmmer executable. + database_path: The path to the jackhmmer database (FASTA format). + n_cpu: The number of CPUs to give Jackhmmer. + n_iter: The number of Jackhmmer iterations. + e_value: The E-value, see Jackhmmer docs for more details. + z_value: The Z-value, see Jackhmmer docs for more details. + get_tblout: Whether to save tblout string. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + incdom_e: Domain e-value criteria for inclusion of domains in MSA/next + round. + dom_e: Domain e-value criteria for inclusion in tblout. + num_streamed_chunks: Number of database chunks to stream over. + streaming_callback: Callback function run after each chunk iteration with + the iteration number as argument. + """ + self.binary_path = binary_path + self.database_path = database_path + self.num_streamed_chunks = num_streamed_chunks + + if not os.path.exists( + self.database_path) and num_streamed_chunks is None: + logging.error('Could not find Jackhmmer database %s', + database_path) + raise ValueError( + f'Could not find Jackhmmer database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.z_value = z_value + self.filter_f1 = filter_f1 + self.filter_f2 = filter_f2 + self.filter_f3 = filter_f3 + self.incdom_e = incdom_e + self.dom_e = dom_e + self.get_tblout = get_tblout + self.streaming_callback = streaming_callback + + def _query_chunk(self, input_fasta_path: str, + database_path: str) -> Mapping[str, Any]: + """Queries the database chunk using Jackhmmer.""" + with utils.tmpdir_manager() as query_tmp_dir: + sto_path = os.path.join(query_tmp_dir, 'output.sto') + + # The F1/F2/F3 are the expected proportion to pass each of the filtering + # stages (which get progressively more expensive), reducing these + # speeds up the pipeline at the expensive of sensitivity. They are + # currently set very low to make querying Mgnify run in a reasonable + # amount of time. + cmd_flags = [ + # Don't pollute stdout with Jackhmmer output. + '-o', + '/dev/null', + '-A', + sto_path, + '--noali', + '--F1', + str(self.filter_f1), + '--F2', + str(self.filter_f2), + '--F3', + str(self.filter_f3), + '--incE', + str(self.e_value), + # Report only sequences with E-values <= x in per-sequence output. + '-E', + str(self.e_value), + '--cpu', + str(self.n_cpu), + '-N', + str(self.n_iter), + ] + if self.get_tblout: + tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') + cmd_flags.extend(['--tblout', tblout_path]) + + if self.z_value: + cmd_flags.extend(['-Z', str(self.z_value)]) + + if self.dom_e is not None: + cmd_flags.extend(['--domE', str(self.dom_e)]) + + if self.incdom_e is not None: + cmd_flags.extend(['--incdomE', str(self.incdom_e)]) + + cmd = [self.binary_path + ] + cmd_flags + [input_fasta_path, database_path] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'Jackhmmer ({os.path.basename(database_path)}) query'): + _, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' + % stderr.decode('utf-8')) + + # Get e-values for each target name + tbl = '' + if self.get_tblout: + with open(tblout_path) as f: + tbl = f.read() + + with open(sto_path) as f: + sto = f.read() + + raw_output = dict( + sto=sto, + tbl=tbl, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value) + + return raw_output + + def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: + """Queries the database using Jackhmmer.""" + if self.num_streamed_chunks is None: + return [self._query_chunk(input_fasta_path, self.database_path)] + + db_basename = os.path.basename(self.database_path) + + def db_remote_chunk(db_idx): + return f'{self.database_path}.{db_idx}' + + def db_local_chunk(db_idx): + return f'/tmp/ramdisk/{db_basename}.{db_idx}' + + # db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' + # db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' + + # Remove existing files to prevent OOM + for f in glob.glob(db_local_chunk('[0-9]*')): + try: + os.remove(f) + except OSError: + print(f'OSError while deleting {f}') + + # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk + with futures.ThreadPoolExecutor(max_workers=2) as executor: + chunked_output = [] + for i in range(1, self.num_streamed_chunks + 1): + # Copy the chunk locally + if i == 1: + future = executor.submit(request.urlretrieve, + db_remote_chunk(i), + db_local_chunk(i)) + if i < self.num_streamed_chunks: + next_future = executor.submit( + request.urlretrieve, + db_remote_chunk(i + 1), + db_local_chunk(i + 1), + ) + + # Run Jackhmmer with the chunk + future.result() + chunked_output.append( + self._query_chunk(input_fasta_path, db_local_chunk(i))) + + # Remove the local copy of the chunk + os.remove(db_local_chunk(i)) + # Do not set next_future for the last chunk so that this works even for + # databases with only 1 chunk. + if i < self.num_streamed_chunks: + future = next_future + if self.streaming_callback: + self.streaming_callback(i) + return chunked_output diff --git a/modelscope/models/science/unifold/msa/tools/kalign.py b/modelscope/models/science/unifold/msa/tools/kalign.py new file mode 100644 index 00000000..1ea997fa --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/kalign.py @@ -0,0 +1,110 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Python wrapper for Kalign.""" +import os +import subprocess +from typing import Sequence + +from absl import logging + +from . import utils + + +def _to_a3m(sequences: Sequence[str]) -> str: + """Converts sequences to an a3m file.""" + names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] + a3m = [] + for sequence, name in zip(sequences, names): + a3m.append('>' + name + '\n') + a3m.append(sequence + '\n') + return ''.join(a3m) + + +class Kalign: + """Python wrapper of the Kalign binary.""" + + def __init__(self, *, binary_path: str): + """Initializes the Python Kalign wrapper. + + Args: + binary_path: The path to the Kalign binary. + + Raises: + RuntimeError: If Kalign binary not found within the path. + """ + self.binary_path = binary_path + + def align(self, sequences: Sequence[str]) -> str: + """Aligns the sequences and returns the alignment in A3M string. + + Args: + sequences: A list of query sequence strings. The sequences have to be at + least 6 residues long (Kalign requires this). Note that the order in + which you give the sequences might alter the output slightly as + different alignment tree might get constructed. + + Returns: + A string with the alignment in a3m format. + + Raises: + RuntimeError: If Kalign fails. + ValueError: If any of the sequences is less than 6 residues long. + """ + logging.info('Aligning %d sequences', len(sequences)) + + for s in sequences: + if len(s) < 6: + raise ValueError( + 'Kalign requires all sequences to be at least 6 ' + 'residues long. Got %s (%d residues).' % (s, len(s))) + + with utils.tmpdir_manager() as query_tmp_dir: + input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') + output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + with open(input_fasta_path, 'w') as f: + f.write(_to_a3m(sequences)) + + cmd = [ + self.binary_path, + '-i', + input_fasta_path, + '-o', + output_a3m_path, + '-format', + 'fasta', + ] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('Kalign query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info( + 'Kalign stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), + stderr.decode('utf-8'), + ) + + if retcode: + raise RuntimeError( + 'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_a3m_path) as f: + a3m = f.read() + + return a3m diff --git a/modelscope/models/science/unifold/msa/tools/utils.py b/modelscope/models/science/unifold/msa/tools/utils.py new file mode 100644 index 00000000..1c2af936 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/utils.py @@ -0,0 +1,40 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common utilities for data pipeline tools.""" +import contextlib +import shutil +import tempfile +import time +from typing import Optional + +from absl import logging + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: Optional[str] = None): + """Context manager that deletes a temporary directory on exit.""" + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@contextlib.contextmanager +def timing(msg: str): + logging.info('Started %s', msg) + tic = time.time() + yield + toc = time.time() + logging.info('Finished %s in %.3f seconds', msg, toc - tic) diff --git a/modelscope/models/science/unifold/msa/utils.py b/modelscope/models/science/unifold/msa/utils.py new file mode 100644 index 00000000..50e380d4 --- /dev/null +++ b/modelscope/models/science/unifold/msa/utils.py @@ -0,0 +1,89 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import os +from typing import Mapping, Sequence + +import json +from absl import logging + +from modelscope.models.science.unifold.data import protein + + +def get_chain_id_map( + sequences: Sequence[str], + descriptions: Sequence[str], +): + """ + Makes a mapping from PDB-format chain ID to sequence and description, + and parses the order of multi-chains + """ + unique_seqs = [] + for seq in sequences: + if seq not in unique_seqs: + unique_seqs.append(seq) + + chain_id_map = { + chain_id: { + 'descriptions': [], + 'sequence': seq + } + for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs) + } + chain_order = [] + + for seq, des in zip(sequences, descriptions): + chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)] + chain_id_map[chain_id]['descriptions'].append(des) + chain_order.append(chain_id) + + return chain_id_map, chain_order + + +def divide_multi_chains( + fasta_name: str, + output_dir_base: str, + sequences: Sequence[str], + descriptions: Sequence[str], +): + """ + Divides the multi-chains fasta into several single fasta files and + records multi-chains mapping information. + """ + if len(sequences) != len(descriptions): + raise ValueError('sequences and descriptions must have equal length. ' + f'Got {len(sequences)} != {len(descriptions)}.') + if len(sequences) > protein.PDB_MAX_CHAINS: + raise ValueError( + 'Cannot process more chains than the PDB format supports. ' + f'Got {len(sequences)} chains.') + + chain_id_map, chain_order = get_chain_id_map(sequences, descriptions) + + output_dir = os.path.join(output_dir_base, fasta_name) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + chain_id_map_path = os.path.join(output_dir, 'chain_id_map.json') + with open(chain_id_map_path, 'w') as f: + json.dump(chain_id_map, f, indent=4, sort_keys=True) + + chain_order_path = os.path.join(output_dir, 'chains.txt') + with open(chain_order_path, 'w') as f: + f.write(' '.join(chain_order)) + + logging.info('Mapping multi-chains fasta with chain order: %s', + ' '.join(chain_order)) + + temp_names = [] + temp_paths = [] + for chain_id in chain_id_map.keys(): + temp_name = fasta_name + '_{}'.format(chain_id) + temp_path = os.path.join(output_dir, temp_name + '.fasta') + des = 'chain_{}'.format(chain_id) + seq = chain_id_map[chain_id]['sequence'] + with open(temp_path, 'w') as f: + f.write('>' + des + '\n' + seq) + temp_names.append(temp_name) + temp_paths.append(temp_path) + return temp_names, temp_paths diff --git a/modelscope/msdatasets/__init__.py b/modelscope/msdatasets/__init__.py new file mode 100644 index 00000000..073f9396 --- /dev/null +++ b/modelscope/msdatasets/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from . import cv +from .ms_dataset import MsDataset diff --git a/modelscope/msdatasets/cv/__init__.py b/modelscope/msdatasets/cv/__init__.py new file mode 100644 index 00000000..fad91bcf --- /dev/null +++ b/modelscope/msdatasets/cv/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from . import (image_classification, image_semantic_segmentation, + object_detection) diff --git a/modelscope/msdatasets/cv/easycv_base.py b/modelscope/msdatasets/cv/easycv_base.py new file mode 100644 index 00000000..7b6df6e0 --- /dev/null +++ b/modelscope/msdatasets/cv/easycv_base.py @@ -0,0 +1,41 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + + +class EasyCVBaseDataset(object): + """Adapt to MSDataset. + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + """ + DATA_ROOT_PATTERN = '${data_root}' + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + args=(), + kwargs={}) -> None: + self.split_config = split_config + self.preprocessor = preprocessor + self.mode = mode + if self.split_config is not None: + self._update_data_source(kwargs['data_source']) + + def _update_data_root(self, input_dict, data_root): + for k, v in input_dict.items(): + if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: + input_dict.update( + {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) + elif isinstance(v, dict): + self._update_data_root(v, data_root) + + def _update_data_source(self, data_source): + data_root = next(iter(self.split_config.values())) + data_root = data_root.rstrip(osp.sep) + + self._update_data_root(data_source, data_root) diff --git a/modelscope/msdatasets/cv/face_2d_keypoins/__init__.py b/modelscope/msdatasets/cv/face_2d_keypoins/__init__.py new file mode 100644 index 00000000..e9d76b7e --- /dev/null +++ b/modelscope/msdatasets/cv/face_2d_keypoins/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .face_2d_keypoints_dataset import FaceKeypointDataset + +else: + _import_structure = {'face_2d_keypoints_dataset': ['FaceKeypointDataset']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/face_2d_keypoins/face_2d_keypoints_dataset.py b/modelscope/msdatasets/cv/face_2d_keypoins/face_2d_keypoints_dataset.py new file mode 100644 index 00000000..2f2e03ef --- /dev/null +++ b/modelscope/msdatasets/cv/face_2d_keypoins/face_2d_keypoints_dataset.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.datasets.face import FaceKeypointDataset as _FaceKeypointDataset + +from modelscope.metainfo import Datasets +from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + group_key=Tasks.face_2d_keypoints, + module_name=Datasets.Face2dKeypointsDataset) +class FaceKeypointDataset(EasyCVBaseDataset, _FaceKeypointDataset): + """EasyCV dataset for face 2d keypoints. + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _FaceKeypointDataset.__init__(self, *args, **kwargs) diff --git a/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py b/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py new file mode 100644 index 00000000..5c1c72c1 --- /dev/null +++ b/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .hand_2d_keypoints_dataset import Hand2DKeypointDataset + +else: + _import_structure = { + 'hand_2d_keypoints_dataset': ['Hand2DKeypointDataset'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py b/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py new file mode 100644 index 00000000..89ee0bb8 --- /dev/null +++ b/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.datasets.pose import \ + HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset + +from modelscope.metainfo import Datasets +from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + group_key=Tasks.hand_2d_keypoints, + module_name=Datasets.HandCocoWholeBodyDataset) +class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset): + """EasyCV dataset for human hand 2d keypoints. + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _HandCocoWholeBodyDataset.__init__(self, *args, **kwargs) diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py b/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py new file mode 100644 index 00000000..472ed2d8 --- /dev/null +++ b/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .human_wholebody_keypoint_dataset import WholeBodyCocoTopDownDataset + +else: + _import_structure = { + 'human_wholebody_keypoint_dataset': ['WholeBodyCocoTopDownDataset'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py b/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py new file mode 100644 index 00000000..fc9469f2 --- /dev/null +++ b/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.datasets.pose import \ + WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset + +from modelscope.metainfo import Datasets +from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + group_key=Tasks.human_wholebody_keypoint, + module_name=Datasets.HumanWholeBodyKeypointDataset) +class WholeBodyCocoTopDownDataset(EasyCVBaseDataset, + _WholeBodyCocoTopDownDataset): + """EasyCV dataset for human whole body 2d keypoints. + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _WholeBodyCocoTopDownDataset.__init__(self, *args, **kwargs) diff --git a/modelscope/msdatasets/cv/image_classification/__init__.py b/modelscope/msdatasets/cv/image_classification/__init__.py new file mode 100644 index 00000000..95e8d7a1 --- /dev/null +++ b/modelscope/msdatasets/cv/image_classification/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .classification_dataset import ClsDataset + +else: + _import_structure = {'classification_dataset': ['ClsDataset']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/image_classification/classification_dataset.py b/modelscope/msdatasets/cv/image_classification/classification_dataset.py new file mode 100644 index 00000000..ba73e472 --- /dev/null +++ b/modelscope/msdatasets/cv/image_classification/classification_dataset.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.datasets.classification import ClsDataset as _ClsDataset + +from modelscope.metainfo import Datasets +from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + group_key=Tasks.image_classification, module_name=Datasets.ClsDataset) +class ClsDataset(_ClsDataset): + """EasyCV dataset for classification. + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _ClsDataset.__init__(self, *args, **kwargs) diff --git a/modelscope/msdatasets/cv/image_semantic_segmentation/__init__.py b/modelscope/msdatasets/cv/image_semantic_segmentation/__init__.py new file mode 100644 index 00000000..26121bdb --- /dev/null +++ b/modelscope/msdatasets/cv/image_semantic_segmentation/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .segmentation_dataset import SegDataset + +else: + _import_structure = {'easycv_segmentation': ['SegDataset']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/image_semantic_segmentation/segmentation_dataset.py b/modelscope/msdatasets/cv/image_semantic_segmentation/segmentation_dataset.py new file mode 100644 index 00000000..b1316e2e --- /dev/null +++ b/modelscope/msdatasets/cv/image_semantic_segmentation/segmentation_dataset.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.datasets.segmentation import SegDataset as _SegDataset + +from modelscope.metainfo import Datasets +from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset) +class SegDataset(EasyCVBaseDataset, _SegDataset): + """EasyCV dataset for Sementic segmentation. + For more details, please refer to : + https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/segmentation/raw.py . + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + data_source: Data source config to parse input data. + pipeline: Sequence of transform object or config dict to be composed. + ignore_index (int): Label index to be ignored. + profiling: If set True, will print transform time. + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _SegDataset.__init__(self, *args, **kwargs) diff --git a/modelscope/msdatasets/cv/object_detection/__init__.py b/modelscope/msdatasets/cv/object_detection/__init__.py new file mode 100644 index 00000000..30af2d9b --- /dev/null +++ b/modelscope/msdatasets/cv/object_detection/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .easycv_detection import DetDataset, DetImagesMixDataset + +else: + _import_structure = { + 'easycv_detection': ['DetDataset', 'DetImagesMixDataset'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/object_detection/detection_dataset.py b/modelscope/msdatasets/cv/object_detection/detection_dataset.py new file mode 100644 index 00000000..2f6ad7d3 --- /dev/null +++ b/modelscope/msdatasets/cv/object_detection/detection_dataset.py @@ -0,0 +1,92 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +from easycv.datasets.detection import DetDataset as _DetDataset +from easycv.datasets.detection import \ + DetImagesMixDataset as _DetImagesMixDataset + +from modelscope.metainfo import Datasets +from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset +from modelscope.msdatasets.task_datasets import TASK_DATASETS +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset) +class DetDataset(EasyCVBaseDataset, _DetDataset): + """EasyCV dataset for object detection. + For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py . + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + data_source: Data source config to parse input data. + pipeline: Transform config list + profiling: If set True, will print pipeline time + classes: A list of class names, used in evaluation for result and groundtruth visualization + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _DetDataset.__init__(self, *args, **kwargs) + + +@TASK_DATASETS.register_module( + group_key=Tasks.image_object_detection, + module_name=Datasets.DetImagesMixDataset) +class DetImagesMixDataset(EasyCVBaseDataset, _DetImagesMixDataset): + """EasyCV dataset for object detection, a wrapper of multiple images mixed dataset. + Suitable for training on multiple images mixed data augmentation like + mosaic and mixup. For the augmentation pipeline of mixed image data, + the `get_indexes` method needs to be provided to obtain the image + indexes, and you can set `skip_flags` to change the pipeline running + process. At the same time, we provide the `dynamic_scale` parameter + to dynamically change the output image size. + output boxes format: cx, cy, w, h + + For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/mix.py . + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + data_source (:obj:`DetSourceCoco`): Data source config to parse input data. + pipeline (Sequence[dict]): Sequence of transform object or + config dict to be composed. + dynamic_scale (tuple[int], optional): The image scale can be changed + dynamically. Default to None. + skip_type_keys (list[str], optional): Sequence of type string to + be skip pipeline. Default to None. + label_padding: out labeling padding [N, 120, 5] + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _DetImagesMixDataset.__init__(self, *args, **kwargs) diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py new file mode 100644 index 00000000..5c8ea59f --- /dev/null +++ b/modelscope/msdatasets/ms_dataset.py @@ -0,0 +1,731 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional, + Sequence, Union) + +import json +import numpy as np +import torch +from datasets import Dataset, DatasetDict +from datasets import load_dataset as hf_load_dataset +from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE +from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES +from datasets.utils.download_manager import DownloadConfig +from datasets.utils.file_utils import (is_relative_path, + relative_to_absolute_path) + +from modelscope.hub.repository import DatasetRepository +from modelscope.msdatasets.task_datasets.builder import build_task_dataset +from modelscope.msdatasets.utils.dataset_builder import ExternalDataset +from modelscope.msdatasets.utils.dataset_utils import ( + get_dataset_files, get_target_dataset_structure, load_dataset_builder) +from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager +from modelscope.msdatasets.utils.download_utils import DatasetDownloadManager +from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager +from modelscope.utils.config import ConfigDict +from modelscope.utils.config_ds import MS_DATASETS_CACHE +from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, + DEFAULT_DATASET_REVISION, + DatasetFormations, DownloadMode, Hubs, + UploadMode) +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def format_list(para) -> List: + if para is None: + para = [] + elif isinstance(para, str): + para = [para] + elif len(set(para)) < len(para): + raise ValueError(f'List columns contains duplicates: {para}') + return para + + +class MsMapDataset(torch.utils.data.Dataset): + + def __init__(self, dataset: Iterable, preprocessor_list, retained_columns, + columns, to_tensor): + super(MsDataset).__init__() + self.dataset = dataset + self.preprocessor_list = preprocessor_list + self.to_tensor = to_tensor + self.retained_columns = retained_columns + self.columns = columns + + def __len__(self): + return len(self.dataset) + + def type_converter(self, x): + if self.to_tensor: + return torch.tensor(x) + else: + return x + + def __getitem__(self, index): + item_dict = self.dataset[index] + res = { + k: self.type_converter(item_dict[k]) + for k in self.columns + if (not self.to_tensor) or k in self.retained_columns + } + for preprocessor in self.preprocessor_list: + res.update({ + k: self.type_converter(v) + for k, v in preprocessor(item_dict).items() + if (not self.to_tensor) or k in self.retained_columns + }) + return res + + +class MsDataset: + """ + ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to + provide efficient data access and local storage managements. On top of + that, MsDataset supports the data integration and interactions with multiple + remote hubs, particularly, ModelScope's own Dataset-hub. MsDataset also + abstracts away data-access details with other remote storage, including both + general external web-hosted data and cloud storage such as OSS. + """ + # the underlying huggingface Dataset + _hf_ds = None + + def __init__(self, hf_ds: Dataset, target: Optional[str] = None): + self._hf_ds = hf_ds + if target is not None and target not in self._hf_ds.features: + raise TypeError( + f'"target" must be a column of the dataset({list(self._hf_ds.features.keys())}, but got {target}' + ) + self.target = target + + def __iter__(self): + for item in self._hf_ds: + if self.target is not None: + yield item[self.target] + else: + yield item + + def __getitem__(self, key): + return self._hf_ds[key] + + def __len__(self): + return len(self._hf_ds) + + @property + def config_kwargs(self): + if isinstance(self._hf_ds, ExternalDataset): + return self._hf_ds.config_kwargs + else: + return None + + @classmethod + def from_hf_dataset(cls, + hf_ds: Union[Dataset, DatasetDict, ExternalDataset], + target: str = None) -> Union[dict, 'MsDataset']: + if isinstance(hf_ds, Dataset): + return cls(hf_ds, target) + elif isinstance(hf_ds, DatasetDict): + if len(hf_ds.keys()) == 1: + return cls(next(iter(hf_ds.values())), target) + return {k: cls(v, target) for k, v in hf_ds.items()} + elif isinstance(hf_ds, ExternalDataset): + return cls(hf_ds) + else: + raise TypeError( + f'"hf_ds" must be a Dataset or DatasetDict, but got {type(hf_ds)}' + ) + + @staticmethod + def load( + dataset_name: Union[str, list], + namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, + target: Optional[str] = None, + version: Optional[str] = DEFAULT_DATASET_REVISION, + hub: Optional[Hubs] = Hubs.modelscope, + subset_name: Optional[str] = None, + split: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[str, Sequence[str], + Mapping[str, Union[str, + Sequence[str]]]]] = None, + download_mode: Optional[DownloadMode] = DownloadMode. + REUSE_DATASET_IF_EXISTS, + **config_kwargs, + ) -> Union[dict, 'MsDataset']: + """Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset. + Args: + + dataset_name (str): Path or name of the dataset. + namespace(str, optional): Namespace of the dataset. It should not be None if you load a remote dataset + from Hubs.modelscope, + target (str, optional): Name of the column to output. + version (str, optional): Version of the dataset script to load: + subset_name (str, optional): Defining the subset_name of the dataset. + data_dir (str, optional): Defining the data_dir of the dataset configuration. I + data_files (str or Sequence or Mapping, optional): Path(s) to source data file(s). + split (str, optional): Which split of the data to load. + hub (Hubs or str, optional): When loading from a remote hub, where it is from. default Hubs.modelscope + download_mode (DownloadMode or str, optional): How to treat existing datasets. default + DownloadMode.REUSE_DATASET_IF_EXISTS + **config_kwargs (additional keyword arguments): Keyword arguments to be passed + + Returns: + MsDataset (obj:`MsDataset`): MsDataset object for a certain dataset. + """ + download_mode = DownloadMode(download_mode + or DownloadMode.REUSE_DATASET_IF_EXISTS) + hub = Hubs(hub or Hubs.modelscope) + if hub == Hubs.huggingface: + dataset = hf_load_dataset( + dataset_name, + name=subset_name, + revision=version, + split=split, + data_dir=data_dir, + data_files=data_files, + download_mode=download_mode.value, + **config_kwargs) + return MsDataset.from_hf_dataset(dataset, target=target) + elif hub == Hubs.modelscope: + return MsDataset._load_ms_dataset( + dataset_name, + namespace=namespace, + target=target, + subset_name=subset_name, + version=version, + split=split, + data_dir=data_dir, + data_files=data_files, + download_mode=download_mode, + **config_kwargs) + + @staticmethod + def _load_ms_dataset(dataset_name: Union[str, list], + namespace: Optional[str] = None, + target: Optional[str] = None, + version: Optional[str] = DEFAULT_DATASET_REVISION, + subset_name: Optional[str] = None, + split: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[ + str, Sequence[str], + Mapping[str, Union[str, Sequence[str]]]]] = None, + download_mode: Optional[DownloadMode] = None, + **config_kwargs) -> Union[dict, 'MsDataset']: + from modelscope.hub.api import HubApi + api = HubApi() + download_dataset = '' + if isinstance(dataset_name, str): + dataset_formation = DatasetFormations.native + if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir( + dataset_name): + dataset_formation = DatasetFormations.hf_compatible + elif os.path.isfile(dataset_name) and dataset_name.endswith('.py'): + dataset_formation = DatasetFormations.hf_compatible + file_name = os.path.basename(dataset_name) + download_dataset = os.path.splitext(file_name)[0] + elif is_relative_path(dataset_name) and dataset_name.count( + '/') == 0: + download_dataset = dataset_name + dataset_scripts, dataset_formation, download_dir = api.fetch_dataset_scripts( + dataset_name, namespace, download_mode, version) + # dataset organized to be compatible with hf format + if dataset_formation == DatasetFormations.hf_compatible: + dataset_name = dataset_scripts['.py'][0] + else: + raise FileNotFoundError( + f"Couldn't find a dataset script at {relative_to_absolute_path(dataset_name)} " + f'or any data file in the same directory.') + + if dataset_formation == DatasetFormations.hf_compatible: + dataset = hf_load_dataset( + dataset_name, + name=subset_name, + revision=version, + split=split, + data_dir=data_dir, + data_files=data_files, + cache_dir=MS_DATASETS_CACHE, + download_mode=download_mode.value, + **config_kwargs) + else: + dataset = MsDataset._load_from_ms( + dataset_name, + dataset_scripts, + download_dir, + namespace=namespace, + version=version, + subset_name=subset_name, + split=split, + download_mode=download_mode, + **config_kwargs) + elif isinstance(dataset_name, list): + if target is None: + target = 'target' + dataset = Dataset.from_dict({target: dataset_name}) + else: + raise TypeError('path must be a str or a list, but got' + f' {type(dataset_name)}') + + is_ci_test = os.getenv('CI_TEST') == 'True' + if download_dataset and not is_ci_test: + try: + api.on_dataset_download( + dataset_name=download_dataset, namespace=namespace) + api.dataset_download_uv( + dataset_name=download_dataset, namespace=namespace) + except Exception as e: + logger.error(e) + + return MsDataset.from_hf_dataset(dataset, target=target) + + @staticmethod + def _load_from_ms(dataset_name: str, + dataset_files: dict, + download_dir: str, + namespace: Optional[str] = None, + version: Optional[str] = DEFAULT_DATASET_REVISION, + subset_name: Optional[str] = None, + split: Optional[str] = None, + download_mode: Optional[DownloadMode] = None, + **config_kwargs) -> Union[Dataset, DatasetDict]: + for json_path in dataset_files['.json']: + if json_path.endswith(f'{dataset_name}.json'): + with open(json_path, encoding='utf-8') as dataset_json_file: + dataset_json = json.load(dataset_json_file) + break + target_subset_name, target_dataset_structure = get_target_dataset_structure( + dataset_json, subset_name, split) + meta_map, file_map, args_map = get_dataset_files( + target_dataset_structure, dataset_name, namespace, version) + builder = load_dataset_builder( + dataset_name, + subset_name, + namespace, + meta_data_files=meta_map, + zip_data_files=file_map, + args_map=args_map, + cache_dir=MS_DATASETS_CACHE, + version=version, + split=list(target_dataset_structure.keys()), + **config_kwargs) + + download_config = DownloadConfig( + cache_dir=download_dir, + force_download=bool( + download_mode == DownloadMode.FORCE_REDOWNLOAD), + force_extract=bool(download_mode == DownloadMode.FORCE_REDOWNLOAD), + use_etag=False, + ) + + dl_manager = DatasetDownloadManager( + dataset_name=dataset_name, + namespace=namespace, + version=version, + download_config=download_config, + data_dir=download_dir, + ) + builder.download_and_prepare( + dl_manager=dl_manager, + download_mode=download_mode.value, + try_from_hf_gcs=False) + + ds = builder.as_dataset() + return ds + + def to_torch_dataset_with_processors( + self, + preprocessors: Union[Callable, List[Callable]], + columns: Union[str, List[str]] = None, + to_tensor: bool = True, + ): + preprocessor_list = preprocessors if isinstance( + preprocessors, list) else [preprocessors] + + columns = format_list(columns) + + columns = [ + key for key in self._hf_ds.features.keys() if key in columns + ] + retained_columns = [] + if to_tensor: + sample = next(iter(self._hf_ds)) + + sample_res = {k: np.array(sample[k]) for k in columns} + for processor in preprocessor_list: + sample_res.update( + {k: np.array(v) + for k, v in processor(sample).items()}) + + def is_numpy_number(value): + return np.issubdtype(value.dtype, np.integer) or np.issubdtype( + value.dtype, np.floating) + + for k in sample_res.keys(): + if not is_numpy_number(sample_res[k]): + logger.warning( + f'Data of column {k} is non-numeric, will be removed') + continue + retained_columns.append(k) + + return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns, + columns, to_tensor) + + def to_torch_dataset( + self, + columns: Union[str, List[str]] = None, + preprocessors: Union[Callable, List[Callable]] = None, + task_name: str = None, + task_data_config: ConfigDict = None, + to_tensor: bool = True, + **format_kwargs, + ): + """Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to + torch.utils.data.DataLoader. + + Args: + preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process + every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict + will be used as a field of torch.utils.data.Dataset. + columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if + `to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column. + If the `preprocessors` is not None, the output fields of processors will also be added. + task_name (str, default None): task name, refer to :obj:`Tasks` for more details + task_data_config (ConfigDict, default None): config dict for model object. + to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not. + format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`. + + Returns: + :class:`tf.data.Dataset` + + """ + if not TORCH_AVAILABLE: + raise ImportError( + 'The function to_torch_dataset requires pytorch to be installed' + ) + if isinstance(self._hf_ds, ExternalDataset): + task_data_config.update({'preprocessor': preprocessors}) + task_data_config.update(self._hf_ds.config_kwargs) + return build_task_dataset(task_data_config, task_name) + if preprocessors is not None: + return self.to_torch_dataset_with_processors( + preprocessors, columns=columns, to_tensor=to_tensor) + else: + self._hf_ds.reset_format() + self._hf_ds.set_format( + type='torch', columns=columns, format_kwargs=format_kwargs) + return self._hf_ds + + def to_tf_dataset_with_processors( + self, + batch_size: int, + shuffle: bool, + preprocessors: Union[Callable, List[Callable]], + drop_remainder: bool = None, + prefetch: bool = True, + label_cols: Union[str, List[str]] = None, + columns: Union[str, List[str]] = None, + ): + preprocessor_list = preprocessors if isinstance( + preprocessors, list) else [preprocessors] + + label_cols = format_list(label_cols) + columns = format_list(columns) + cols_to_retain = list(set(label_cols + columns)) + retained_columns = [ + key for key in self._hf_ds.features.keys() if key in cols_to_retain + ] + import tensorflow as tf + tf_dataset = tf.data.Dataset.from_tensor_slices( + np.arange(len(self._hf_ds), dtype=np.int64)) + if shuffle: + tf_dataset = tf_dataset.shuffle(buffer_size=len(self._hf_ds)) + + def func(i, return_dict=False): + i = int(i) + res = {k: np.array(self._hf_ds[i][k]) for k in retained_columns} + for preprocessor in preprocessor_list: + # TODO preprocessor output may have the same key + res.update({ + k: np.array(v) + for k, v in preprocessor(self._hf_ds[i]).items() + }) + if return_dict: + return res + return tuple(list(res.values())) + + sample_res = func(0, True) + + @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)]) + def fetch_function(i): + output = tf.numpy_function( + func, + inp=[i], + Tout=[ + tf.dtypes.as_dtype(val.dtype) + for val in sample_res.values() + ], + ) + return {key: output[i] for i, key in enumerate(sample_res)} + + tf_dataset = tf_dataset.map( + fetch_function, num_parallel_calls=tf.data.AUTOTUNE) + if label_cols: + + def split_features_and_labels(input_batch): + labels = { + key: tensor + for key, tensor in input_batch.items() if key in label_cols + } + if len(input_batch) == 1: + input_batch = next(iter(input_batch.values())) + if len(labels) == 1: + labels = next(iter(labels.values())) + return input_batch, labels + + tf_dataset = tf_dataset.map(split_features_and_labels) + + elif len(columns) == 1: + tf_dataset = tf_dataset.map(lambda x: next(iter(x.values()))) + if batch_size > 1: + tf_dataset = tf_dataset.batch( + batch_size, drop_remainder=drop_remainder) + + if prefetch: + tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE) + return tf_dataset + + def to_tf_dataset( + self, + batch_size: int, + shuffle: bool, + preprocessors: Union[Callable, List[Callable]] = None, + columns: Union[str, List[str]] = None, + collate_fn: Callable = None, + drop_remainder: bool = None, + collate_fn_args: Dict[str, Any] = None, + label_cols: Union[str, List[str]] = None, + prefetch: bool = True, + ): + """Create a tf.data.Dataset from the MS Dataset. This tf.data.Dataset can be passed to tf methods like + model.fit() or model.predict(). + + Args: + batch_size (int): Number of samples in a single batch. + shuffle(bool): Shuffle the dataset order. + preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process + every sample of the dataset. The output type of processors is dict, and each field of the dict will be + used as a field of the tf.data. Dataset. If the `preprocessors` is None, the `collate_fn` + shouldn't be None. + columns (str or List[str], default None): Dataset column(s) to be loaded. If the preprocessor is None, + the arg columns must have at least one column. If the `preprocessors` is not None, the output fields of + processors will also be added. + collate_fn(Callable, default None): A callable object used to collect lists of samples into a batch. If + the `preprocessors` is None, the `collate_fn` shouldn't be None. + drop_remainder(bool, default None): Drop the last incomplete batch when loading. + collate_fn_args (Dict, optional): A `dict` of arguments to be passed to the`collate_fn`. + label_cols (str or List[str], defalut None): Dataset column(s) to load as labels. + prefetch (bool, default True): Prefetch data. + + Returns: + :class:`tf.data.Dataset` + + """ + if not TF_AVAILABLE: + raise ImportError( + 'The function to_tf_dataset requires Tensorflow to be installed.' + ) + if preprocessors is not None: + return self.to_tf_dataset_with_processors( + batch_size, + shuffle, + preprocessors, + drop_remainder=drop_remainder, + prefetch=prefetch, + label_cols=label_cols, + columns=columns) + + if collate_fn is None: + logger.error( + 'The `preprocessors` and the `collate_fn` should`t be both None.' + ) + return None + self._hf_ds.reset_format() + return self._hf_ds.to_tf_dataset( + columns, + batch_size, + shuffle, + collate_fn, + drop_remainder=drop_remainder, + collate_fn_args=collate_fn_args, + label_cols=label_cols, + prefetch=prefetch) + + def to_hf_dataset(self) -> Dataset: + self._hf_ds.reset_format() + return self._hf_ds + + def remap_columns(self, column_mapping: Dict[str, str]) -> Dataset: + """ + Rename columns and return the underlying hf dataset directly + TODO: support native MsDataset column rename. + Args: + column_mapping: the mapping of the original and new column names + Returns: + underlying hf dataset + """ + self._hf_ds.reset_format() + return self._hf_ds.rename_columns(column_mapping) + + @staticmethod + def upload( + object_name: str, + local_file_path: str, + dataset_name: str, + namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, + version: Optional[str] = DEFAULT_DATASET_REVISION, + num_processes: Optional[int] = None, + chunksize: Optional[int] = 1, + filter_hidden_files: Optional[bool] = True, + upload_mode: Optional[UploadMode] = UploadMode.OVERWRITE) -> None: + """Upload dataset file or directory to the ModelScope Hub. Please log in to the ModelScope Hub first. + + Args: + object_name (str): The object name on ModelScope, in the form of your-dataset-name.zip or your-dataset-name + local_file_path (str): Local file or directory to upload + dataset_name (str): Name of the dataset + namespace(str, optional): Namespace of the dataset + version: Optional[str]: Version of the dataset + num_processes: Optional[int]: The number of processes used for multiprocess uploading. + This is only applicable when local_file_path is a directory, and we are uploading mutliple-files + insided the directory. When None provided, the number returned by os.cpu_count() is used as default. + chunksize: Optional[int]: The chunksize of objects to upload. + For very long iterables using a large value for chunksize can make the job complete much faster than + using the default value of 1. Available if local_file_path is a directory. + filter_hidden_files: Optional[bool]: Whether to filter hidden files. + Available if local_file_path is a directory. + upload_mode: Optional[UploadMode]: How to upload objects from local. Default: UploadMode.OVERWRITE, upload + all objects from local, existing remote objects may be overwritten. + + Returns: + None + + """ + if not object_name: + raise ValueError('object_name cannot be empty!') + + _upload_manager = DatasetUploadManager( + dataset_name=dataset_name, namespace=namespace, version=version) + + upload_mode = UploadMode(upload_mode or UploadMode.OVERWRITE) + + if os.path.isfile(local_file_path): + _upload_manager.upload( + object_name=object_name, + local_file_path=local_file_path, + upload_mode=upload_mode) + elif os.path.isdir(local_file_path): + _upload_manager.upload_dir( + object_dir_name=object_name, + local_dir_path=local_file_path, + num_processes=num_processes, + chunksize=chunksize, + filter_hidden_files=filter_hidden_files, + upload_mode=upload_mode) + else: + raise ValueError( + f'{local_file_path} is not a valid file path or directory') + + @staticmethod + def clone_meta(dataset_work_dir: str, + dataset_id: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION, + auth_token: Optional[str] = None, + git_path: Optional[str] = None) -> None: + """Clone meta-file of dataset from the ModelScope Hub. + Args: + dataset_work_dir (str): Current git working directory. + dataset_id (str): Dataset id, in the form of your-namespace/your-dataset-name . + revision(`Optional[str]`): + revision of the model you want to clone from. Can be any of a branch, tag or commit hash + auth_token(`Optional[str]`): + token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter + as the token is already saved when you login the first time, if None, we will use saved token. + git_path:(`Optional[str]`): + The git command line path, if None, we use 'git' + Returns: + None + """ + + _repo = DatasetRepository( + repo_work_dir=dataset_work_dir, + dataset_id=dataset_id, + revision=revision, + auth_token=auth_token, + git_path=git_path) + clone_work_dir = _repo.clone() + if clone_work_dir: + logger.info('Already cloned repo to: {}'.format(clone_work_dir)) + else: + logger.warning( + 'Repo dir already exists: {}'.format(clone_work_dir)) + + @staticmethod + def upload_meta(dataset_work_dir: str, + commit_message: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION, + auth_token: Optional[str] = None, + git_path: Optional[str] = None, + force: bool = False) -> None: + """Upload meta-file of dataset to the ModelScope Hub. Please clone the meta-data from the ModelScope Hub first. + + Args: + dataset_work_dir (str): Current working directory. + commit_message (str): Commit message. + revision(`Optional[str]`): + revision of the model you want to clone from. Can be any of a branch, tag or commit hash + auth_token(`Optional[str]`): + token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter + as the token is already saved when you log in the first time, if None, we will use saved token. + git_path:(`Optional[str]`): + The git command line path, if None, we use 'git' + force (Optional[bool]): whether to use forced-push. + + Returns: + None + + """ + _repo = DatasetRepository( + repo_work_dir=dataset_work_dir, + dataset_id='', + revision=revision, + auth_token=auth_token, + git_path=git_path) + _repo.push(commit_message=commit_message, branch=revision, force=force) + + @staticmethod + def delete(object_name: str, + dataset_name: str, + namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE, + version: Optional[str] = DEFAULT_DATASET_REVISION) -> str: + """ Delete object of dataset. Please log in first and make sure you have permission to manage the dataset. + + Args: + object_name (str): The object name of dataset to be deleted. Could be a name of file or directory. If it's + directory, then ends with `/`. + For example: your-data-name.zip, train/001/img_001.png, train/, ... + dataset_name (str): Path or name of the dataset. + namespace(str, optional): Namespace of the dataset. + version (str, optional): Version of the dataset. + + Returns: + res_msg (str): Response message. + + """ + _delete_manager = DatasetDeleteManager( + dataset_name=dataset_name, namespace=namespace, version=version) + resp_msg = _delete_manager.delete(object_name=object_name) + logger.info(f'Object {object_name} successfully removed!') + return resp_msg diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py new file mode 100644 index 00000000..043010bf --- /dev/null +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule, is_torch_available + +if TYPE_CHECKING: + from .base import TaskDataset + from .builder import TASK_DATASETS, build_task_dataset + from .torch_base_dataset import TorchTaskDataset + from .veco_dataset import VecoDataset + from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset + from .movie_scene_segmentation import MovieSceneSegmentationDataset + from .video_summarization_dataset import VideoSummarizationDataset + from .image_inpainting import ImageInpaintingDataset + from .text_ranking_dataset import TextRankingDataset + from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset + +else: + _import_structure = { + 'base': ['TaskDataset'], + 'builder': ['TASK_DATASETS', 'build_task_dataset'], + 'torch_base_dataset': ['TorchTaskDataset'], + 'text_ranking_dataset': ['TextRankingDataset'], + 'veco_dataset': ['VecoDataset'], + 'image_instance_segmentation_coco_dataset': + ['ImageInstanceSegmentationCocoDataset'], + 'video_summarization_dataset': ['VideoSummarizationDataset'], + 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], + 'image_inpainting': ['ImageInpaintingDataset'], + 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], + 'image_portrait_enhancement_dataset': + ['ImagePortraitEnhancementDataset'], + 'referring_video_object_segmentation': + ['ReferringVideoObjectSegmentationDataset'], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/audio/__init__.py b/modelscope/msdatasets/task_datasets/audio/__init__.py new file mode 100644 index 00000000..c62a8d9c --- /dev/null +++ b/modelscope/msdatasets/task_datasets/audio/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .kws_farfield_dataset import KWSDataset, KWSDataLoader + +else: + _import_structure = { + 'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py b/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py new file mode 100644 index 00000000..8c518ec9 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py @@ -0,0 +1,280 @@ +""" +Used to prepare simulated data. +""" +import math +import os.path +import queue +import threading +import time + +import numpy as np +import torch + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +BLOCK_DEC = 2 +BLOCK_CAT = 3 +FBANK_SIZE = 40 +LABEL_SIZE = 1 +LABEL_GAIN = 100.0 + + +class KWSDataset: + """ + dataset for keyword spotting and vad + conf_basetrain: basetrain configure file path + conf_finetune: finetune configure file path, null allowed + numworkers: no. of workers + basetrainratio: basetrain workers ratio + numclasses: no. of nn output classes, 2 classes to generate vad label + blockdec: block decimation + blockcat: block concatenation + """ + + def __init__(self, + conf_basetrain, + conf_finetune, + numworkers, + basetrainratio, + numclasses, + blockdec=BLOCK_CAT, + blockcat=BLOCK_CAT): + super().__init__() + self.numclasses = numclasses + self.blockdec = blockdec + self.blockcat = blockcat + self.sims_base = [] + self.sims_senior = [] + self.setup_sims(conf_basetrain, conf_finetune, numworkers, + basetrainratio) + + def release(self): + for sim in self.sims_base: + del sim + for sim in self.sims_senior: + del sim + del self.base_conf + del self.senior_conf + logger.info('KWSDataset: Released.') + + def setup_sims(self, conf_basetrain, conf_finetune, numworkers, + basetrainratio): + if not os.path.exists(conf_basetrain): + raise ValueError(f'{conf_basetrain} does not exist!') + if not os.path.exists(conf_finetune): + raise ValueError(f'{conf_finetune} does not exist!') + import py_sound_connect + logger.info('KWSDataset init SoundConnect...') + num_base = math.ceil(numworkers * basetrainratio) + num_senior = numworkers - num_base + # hold by fields to avoid python releasing conf object + self.base_conf = py_sound_connect.ConfigFile(conf_basetrain) + self.senior_conf = py_sound_connect.ConfigFile(conf_finetune) + for i in range(num_base): + fs = py_sound_connect.FeatSimuKWS(self.base_conf.params) + self.sims_base.append(fs) + for i in range(num_senior): + self.sims_senior.append( + py_sound_connect.FeatSimuKWS(self.senior_conf.params)) + logger.info('KWSDataset init SoundConnect finished.') + + def getBatch(self, id): + """ + Generate a data batch + + Args: + id: worker id + + Return: time x channel x feature, label + """ + fs = self.get_sim(id) + fs.processBatch() + # get multi-channel feature vector size + featsize = fs.featSize() + # get label vector size + labelsize = fs.labelSize() + # get minibatch size (time dimension) + # batchsize = fs.featBatchSize() + # no. of fe output channels + numchs = featsize // FBANK_SIZE + # get raw data + fs_feat = fs.feat() + data = np.frombuffer(fs_feat, dtype='float32') + data = data.reshape((-1, featsize + labelsize)) + + # convert float label to int + label = data[:, FBANK_SIZE * numchs:] + + if self.numclasses == 2: + # generate vad label + label[label > 0.0] = 1.0 + else: + # generate kws label + label = np.round(label * LABEL_GAIN) + label[label > self.numclasses - 1] = 0.0 + + # decimated size + size1 = int(np.ceil( + label.shape[0] / self.blockdec)) - self.blockcat + 1 + + # label decimation + label1 = np.zeros((size1, LABEL_SIZE), dtype='float32') + for tau in range(size1): + label1[tau, :] = label[(tau + self.blockcat // 2) + * self.blockdec, :] + + # feature decimation and concatenation + # time x channel x feature + featall = np.zeros((size1, numchs, FBANK_SIZE * self.blockcat), + dtype='float32') + for n in range(numchs): + feat = data[:, FBANK_SIZE * n:FBANK_SIZE * (n + 1)] + + for tau in range(size1): + for i in range(self.blockcat): + featall[tau, n, FBANK_SIZE * i:FBANK_SIZE * (i + 1)] = \ + feat[(tau + i) * self.blockdec, :] + + return torch.from_numpy(featall), torch.from_numpy(label1).long() + + def get_sim(self, id): + num_base = len(self.sims_base) + if id < num_base: + fs = self.sims_base[id] + else: + fs = self.sims_senior[id - num_base] + return fs + + +class Worker(threading.Thread): + """ + id: worker id + dataset: the dataset + pool: queue as the global data buffer + """ + + def __init__(self, id, dataset, pool): + threading.Thread.__init__(self) + + self.id = id + self.dataset = dataset + self.pool = pool + self.isrun = True + self.nn = 0 + + def run(self): + while self.isrun: + self.nn += 1 + logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:1') + # get simulated minibatch + if self.isrun: + data = self.dataset.getBatch(self.id) + logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:2') + + # put data into buffer + if self.isrun: + self.pool.put(data) + logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:3') + + logger.info('KWSDataLoader: Worker {:02d} stopped.'.format(self.id)) + + def stopWorker(self): + """ + stop the worker thread + """ + self.isrun = False + + +class KWSDataLoader: + """ + dataset: the dataset reference + batchsize: data batch size + numworkers: no. of workers + prefetch: prefetch factor + """ + + def __init__(self, dataset, batchsize, numworkers, prefetch=2): + self.dataset = dataset + self.batchsize = batchsize + self.datamap = {} + self.isrun = True + + # data queue + self.pool = queue.Queue(batchsize * prefetch) + + # initialize workers + self.workerlist = [] + for id in range(numworkers): + w = Worker(id, dataset, self.pool) + self.workerlist.append(w) + + def __iter__(self): + return self + + def __next__(self): + while self.isrun: + # get data from common data pool + data = self.pool.get() + self.pool.task_done() + + # group minibatches with the same shape + key = str(data[0].shape) + + batchl = self.datamap.get(key) + if batchl is None: + batchl = [] + self.datamap.update({key: batchl}) + + batchl.append(data) + + # a full data batch collected + if len(batchl) >= self.batchsize: + featbatch = [] + labelbatch = [] + + for feat, label in batchl: + featbatch.append(feat) + labelbatch.append(label) + + batchl.clear() + + feattensor = torch.stack(featbatch, dim=0) + labeltensor = torch.stack(labelbatch, dim=0) + + if feattensor.shape[-2] == 1: + logger.debug('KWSDataLoader: Basetrain batch.') + else: + logger.debug('KWSDataLoader: Finetune batch.') + + return feattensor, labeltensor + + return None, None + + def start(self): + """ + start multi-thread data loader + """ + for w in self.workerlist: + w.start() + + def stop(self): + """ + stop data loader + """ + logger.info('KWSDataLoader: Stopping...') + self.isrun = False + + for w in self.workerlist: + w.stopWorker() + + while not self.pool.empty(): + self.pool.get(block=True, timeout=0.001) + + # wait workers terminated + for w in self.workerlist: + while not self.pool.empty(): + self.pool.get(block=True, timeout=0.001) + w.join() + logger.info('KWSDataLoader: All worker stopped.') diff --git a/modelscope/msdatasets/task_datasets/base.py b/modelscope/msdatasets/task_datasets/base.py new file mode 100644 index 00000000..39b791b1 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/base.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from abc import ABC, abstractmethod +from typing import Any, List, Tuple, Union + + +class TaskDataset(ABC): + """The task dataset base class for all the task specific dataset processors. + """ + + def __init__(self, + datasets: Union[Any, List[Any]], + mode, + preprocessor=None, + **kwargs): + super().__init__() + self.mode = mode + self.preprocessor = preprocessor + self._inner_dataset = self.prepare_dataset(datasets) + + @abstractmethod + def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: + """Prepare a dataset. + + User can process the input datasets in a whole dataset perspective. + This method also helps to merge several datasets to one. + + Args: + datasets: The original dataset(s) + + Returns: A single dataset, which may be created after merging. + + """ + pass + + @abstractmethod + def prepare_sample(self, data): + """Preprocess the data fetched from the inner_dataset. + + If the preprocessor is None, the original data will be returned, else the preprocessor will be called. + User can override this method to implement custom logics. + + Args: + data: The data fetched from the dataset. + + Returns: The processed data. + + """ + pass diff --git a/modelscope/msdatasets/task_datasets/builder.py b/modelscope/msdatasets/task_datasets/builder.py new file mode 100644 index 00000000..683bec8f --- /dev/null +++ b/modelscope/msdatasets/task_datasets/builder.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.config import ConfigDict +from modelscope.utils.registry import Registry, build_from_cfg + +TASK_DATASETS = Registry('task_datasets') + + +def build_task_dataset(cfg: ConfigDict, + task_name: str = None, + default_args: dict = None): + """ Build task specific dataset processor given model config dict and the task name. + + Args: + cfg (:obj:`ConfigDict`): config dict for model object. + task_name (str, optional): task name, refer to + :obj:`Tasks` for more details + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, TASK_DATASETS, group_key=task_name, default_args=default_args) diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py b/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py new file mode 100644 index 00000000..732a1bd7 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .image_inpainting_dataset import ImageInpaintingDataset diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/aug.py b/modelscope/msdatasets/task_datasets/image_inpainting/aug.py new file mode 100644 index 00000000..445bb9b4 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_inpainting/aug.py @@ -0,0 +1,100 @@ +""" +The implementation is borrowed from LaMa, +publicly available at https://github.com/saic-mdal/lama +""" +import imgaug.augmenters as iaa +from albumentations import DualIAATransform, to_tuple + + +class IAAAffine2(DualIAATransform): + """Place a regular grid of points on the input and randomly move the neighbourhood of these point around + via affine transformations. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__( + self, + scale=(0.7, 1.3), + translate_percent=None, + translate_px=None, + rotate=0.0, + shear=(-0.1, 0.1), + order=1, + cval=0, + mode='reflect', + always_apply=False, + p=0.5, + ): + super(IAAAffine2, self).__init__(always_apply, p) + self.scale = dict(x=scale, y=scale) + self.translate_percent = to_tuple(translate_percent, 0) + self.translate_px = to_tuple(translate_px, 0) + self.rotate = to_tuple(rotate) + self.shear = dict(x=shear, y=shear) + self.order = order + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.Affine( + self.scale, + self.translate_percent, + self.translate_px, + self.rotate, + self.shear, + self.order, + self.cval, + self.mode, + ) + + def get_transform_init_args_names(self): + return ('scale', 'translate_percent', 'translate_px', 'rotate', + 'shear', 'order', 'cval', 'mode') + + +class IAAPerspective2(DualIAATransform): + """Perform a random four point perspective transform of the input. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + scale ((float, float): standard deviation of the normal distributions. These are used to sample + the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__(self, + scale=(0.05, 0.1), + keep_size=True, + always_apply=False, + p=0.5, + order=1, + cval=0, + mode='replicate'): + super(IAAPerspective2, self).__init__(always_apply, p) + self.scale = to_tuple(scale, 1.0) + self.keep_size = keep_size + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.PerspectiveTransform( + self.scale, + keep_size=self.keep_size, + mode=self.mode, + cval=self.cval) + + def get_transform_init_args_names(self): + return ('scale', 'keep_size') diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py b/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py new file mode 100644 index 00000000..057b8f88 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py @@ -0,0 +1,337 @@ +""" +Part of the implementation is borrowed and modified from LaMa, +publicly available at https://github.com/saic-mdal/lama +""" +import glob +import os +import os.path as osp +from enum import Enum + +import albumentations as A +import cv2 +import json +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from .aug import IAAAffine2, IAAPerspective2 + +LOGGER = get_logger() + + +class LinearRamp: + + def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): + self.start_value = start_value + self.end_value = end_value + self.start_iter = start_iter + self.end_iter = end_iter + + def __call__(self, i): + if i < self.start_iter: + return self.start_value + if i >= self.end_iter: + return self.end_value + part = (i - self.start_iter) / (self.end_iter - self.start_iter) + return self.start_value * (1 - part) + self.end_value * part + + +class DrawMethod(Enum): + LINE = 'line' + CIRCLE = 'circle' + SQUARE = 'square' + + +def make_random_superres_mask(shape, + min_step=2, + max_step=4, + min_width=1, + max_width=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + step_x = np.random.randint(min_step, max_step + 1) + width_x = np.random.randint(min_width, min(step_x, max_width + 1)) + offset_x = np.random.randint(0, step_x) + + step_y = np.random.randint(min_step, max_step + 1) + width_y = np.random.randint(min_width, min(step_y, max_width + 1)) + offset_y = np.random.randint(0, step_y) + + for dy in range(width_y): + mask[offset_y + dy::step_y] = 1 + for dx in range(width_x): + mask[:, offset_x + dx::step_x] = 1 + return mask[None, ...] + + +class RandomSuperresMaskGenerator: + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __call__(self, img, iter_i=None): + return make_random_superres_mask(img.shape[1:], **self.kwargs) + + +def make_random_rectangle_mask(shape, + margin=10, + bbox_min_size=30, + bbox_max_size=100, + min_times=0, + max_times=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + box_width = np.random.randint(bbox_min_size, bbox_max_size) + box_height = np.random.randint(bbox_min_size, bbox_max_size) + start_x = np.random.randint(margin, width - margin - box_width + 1) + start_y = np.random.randint(margin, height - margin - box_height + 1) + mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1 + return mask[None, ...] + + +class RandomRectangleMaskGenerator: + + def __init__(self, + margin=10, + bbox_min_size=30, + bbox_max_size=100, + min_times=0, + max_times=3, + ramp_kwargs=None): + self.margin = margin + self.bbox_min_size = bbox_min_size + self.bbox_max_size = bbox_max_size + self.min_times = min_times + self.max_times = max_times + self.ramp = LinearRamp( + **ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and ( + iter_i is not None) else 1 + cur_bbox_max_size = int(self.bbox_min_size + 1 + + (self.bbox_max_size - self.bbox_min_size) + * coef) + cur_max_times = int(self.min_times + + (self.max_times - self.min_times) * coef) + return make_random_rectangle_mask( + img.shape[1:], + margin=self.margin, + bbox_min_size=self.bbox_min_size, + bbox_max_size=cur_bbox_max_size, + min_times=self.min_times, + max_times=cur_max_times) + + +def make_random_irregular_mask(shape, + max_angle=4, + max_len=60, + max_width=20, + min_times=0, + max_times=10, + draw_method=DrawMethod.LINE): + draw_method = DrawMethod(draw_method) + + height, width = shape + mask = np.zeros((height, width), np.float32) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + start_x = np.random.randint(width) + start_y = np.random.randint(height) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_len) + brush_w = 5 + np.random.randint(max_width) + end_x = np.clip( + (start_x + length * np.sin(angle)).astype(np.int32), 0, width) + end_y = np.clip( + (start_y + length * np.cos(angle)).astype(np.int32), 0, height) + if draw_method == DrawMethod.LINE: + cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, + brush_w) + elif draw_method == DrawMethod.CIRCLE: + cv2.circle( + mask, (start_x, start_y), + radius=brush_w, + color=1., + thickness=-1) + elif draw_method == DrawMethod.SQUARE: + radius = brush_w // 2 + mask[start_y - radius:start_y + radius, + start_x - radius:start_x + radius] = 1 + start_x, start_y = end_x, end_y + return mask[None, ...] + + +class RandomIrregularMaskGenerator: + + def __init__(self, + max_angle=4, + max_len=60, + max_width=20, + min_times=0, + max_times=10, + ramp_kwargs=None, + draw_method=DrawMethod.LINE): + self.max_angle = max_angle + self.max_len = max_len + self.max_width = max_width + self.min_times = min_times + self.max_times = max_times + self.draw_method = draw_method + self.ramp = LinearRamp( + **ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and ( + iter_i is not None) else 1 + cur_max_len = int(max(1, self.max_len * coef)) + cur_max_width = int(max(1, self.max_width * coef)) + cur_max_times = int(self.min_times + 1 + + (self.max_times - self.min_times) * coef) + return make_random_irregular_mask( + img.shape[1:], + max_angle=self.max_angle, + max_len=cur_max_len, + max_width=cur_max_width, + min_times=self.min_times, + max_times=cur_max_times, + draw_method=self.draw_method) + + +class MixedMaskGenerator: + + def __init__(self, + irregular_proba=1 / 3, + irregular_kwargs=None, + box_proba=1 / 3, + box_kwargs=None, + segm_proba=1 / 3, + segm_kwargs=None, + squares_proba=0, + squares_kwargs=None, + superres_proba=0, + superres_kwargs=None, + outpainting_proba=0, + outpainting_kwargs=None, + invert_proba=0): + self.probas = [] + self.gens = [] + + if irregular_proba > 0: + self.probas.append(irregular_proba) + if irregular_kwargs is None: + irregular_kwargs = {} + else: + irregular_kwargs = dict(irregular_kwargs) + irregular_kwargs['draw_method'] = DrawMethod.LINE + self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs)) + + if box_proba > 0: + self.probas.append(box_proba) + if box_kwargs is None: + box_kwargs = {} + self.gens.append(RandomRectangleMaskGenerator(**box_kwargs)) + + if squares_proba > 0: + self.probas.append(squares_proba) + if squares_kwargs is None: + squares_kwargs = {} + else: + squares_kwargs = dict(squares_kwargs) + squares_kwargs['draw_method'] = DrawMethod.SQUARE + self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs)) + + if superres_proba > 0: + self.probas.append(superres_proba) + if superres_kwargs is None: + superres_kwargs = {} + self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs)) + + self.probas = np.array(self.probas, dtype='float32') + self.probas /= self.probas.sum() + self.invert_proba = invert_proba + + def __call__(self, img, iter_i=None, raw_image=None): + kind = np.random.choice(len(self.probas), p=self.probas) + gen = self.gens[kind] + result = gen(img, iter_i=iter_i, raw_image=raw_image) + if self.invert_proba > 0 and random.random() < self.invert_proba: + result = 1 - result + return result + + +def get_transforms(test_mode, out_size): + if not test_mode: + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.7, 1.3), rotate=(-40, 40), shear=(-0.1, 0.1)), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast( + brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue( + hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + else: + transform = A.Compose([ + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.CenterCrop(height=out_size, width=out_size), + A.ToFloat() + ]) + return transform + + +@TASK_DATASETS.register_module( + Tasks.image_inpainting, module_name=Models.image_inpainting) +class ImageInpaintingDataset(TorchTaskDataset): + + def __init__(self, **kwargs): + split_config = kwargs['split_config'] + LOGGER.info(kwargs) + mode = kwargs.get('test_mode', False) + + self.data_root = next(iter(split_config.values())) + if not osp.exists(self.data_root): + self.data_root = osp.dirname(self.data_root) + assert osp.exists(self.data_root) + mask_gen_kwargs = kwargs.get('mask_gen_kwargs', {}) + out_size = kwargs.get('out_size', 256) + self.mask_generator = MixedMaskGenerator(**mask_gen_kwargs) + self.transform = get_transforms(mode, out_size) + self.in_files = sorted( + list( + glob.glob( + osp.join(self.data_root, '**', '*.jpg'), recursive=True)) + + list( + glob.glob( + osp.join(self.data_root, '**', '*.png'), recursive=True))) + self.iter_i = 0 + + def __len__(self): + return len(self.in_files) + + def __getitem__(self, index): + path = self.in_files[index] + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = self.transform(image=img)['image'] + img = np.transpose(img, (2, 0, 1)) + # TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks + mask = self.mask_generator(img, iter_i=self.iter_i) + self.iter_i += 1 + return dict(image=img, mask=mask) diff --git a/modelscope/msdatasets/task_datasets/image_instance_segmentation_coco_dataset.py b/modelscope/msdatasets/task_datasets/image_instance_segmentation_coco_dataset.py new file mode 100644 index 00000000..1c7bc249 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_instance_segmentation_coco_dataset.py @@ -0,0 +1,346 @@ +# Part of the implementation is borrowed and modified from MMDetection, publicly available at +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py +import os.path as osp + +import numpy as np +from pycocotools.coco import COCO + +from modelscope.metainfo import Models +from modelscope.utils.constant import Tasks +from .builder import TASK_DATASETS +from .torch_base_dataset import TorchTaskDataset + +DATASET_STRUCTURE = { + 'train': { + 'annotation': 'annotations/instances_train.json', + 'images': 'images/train' + }, + 'validation': { + 'annotation': 'annotations/instances_val.json', + 'images': 'images/val' + } +} + + +@TASK_DATASETS.register_module( + module_name=Models.cascade_mask_rcnn_swin, + group_key=Tasks.image_segmentation) +class ImageInstanceSegmentationCocoDataset(TorchTaskDataset): + """Coco-style dataset for image instance segmentation. + + Args: + split_config (dict): Annotation file path. {"train":"xxxxx"} + classes (Sequence[str], optional): Specify classes to load. + If is None, ``cls.CLASSES`` will be used. Default: None. + data_root (str, optional): Data root for ``ann_file``, + ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. + test_mode (bool, optional): If set True, annotation will not be loaded. + filter_empty_gt (bool, optional): If set true, images without bounding + boxes of the dataset's classes will be filtered out. This option + only works when `test_mode=False`, i.e., we never filter images + during tests. + """ + + CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') + + def __init__(self, + split_config: dict, + preprocessor=None, + classes=None, + seg_prefix=None, + folder_name=None, + ann_file=None, + img_prefix=None, + test_mode=False, + filter_empty_gt=True, + **kwargs): + data_root = next(iter(split_config.values())) + self.data_root = osp.join(data_root, + folder_name) if folder_name else data_root + self.split = next(iter(split_config.keys())) + self.preprocessor = preprocessor + + self.ann_file = osp.join(self.data_root, ann_file) + + self.img_prefix = osp.join(self.data_root, img_prefix) + self.seg_prefix = seg_prefix + self.test_mode = test_mode + self.filter_empty_gt = filter_empty_gt + self.CLASSES = self.get_classes(classes) + + # load annotations + self.data_infos = self.load_annotations(self.ann_file) + + # filter images too small and containing no annotations + if not test_mode: + valid_inds = self._filter_imgs() + self.data_infos = [self.data_infos[i] for i in valid_inds] + # set group flag for the sampler + self._set_group_flag() + + def __len__(self): + """Total number of samples of data.""" + return len(self.data_infos) + + def load_annotations(self, ann_file): + """Load annotation from COCO style annotation file. + + Args: + ann_file (str): Path of annotation file. + + Returns: + list[dict]: Annotation info from COCO api. + """ + + self.coco = COCO(ann_file) + # The order of returned `cat_ids` will not + # change with the order of the CLASSES + self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES) + + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.img_ids = self.coco.getImgIds() + data_infos = [] + total_ann_ids = [] + for i in self.img_ids: + info = self.coco.loadImgs([i])[0] + info['filename'] = info['file_name'] + info['ann_file'] = ann_file + info['classes'] = self.CLASSES + data_infos.append(info) + ann_ids = self.coco.getAnnIds(imgIds=[i]) + total_ann_ids.extend(ann_ids) + assert len(set(total_ann_ids)) == len( + total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!" + return data_infos + + def get_ann_info(self, idx): + """Get COCO annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + img_id = self.data_infos[idx]['id'] + ann_ids = self.coco.getAnnIds(imgIds=[img_id]) + ann_info = self.coco.loadAnns(ann_ids) + return self._parse_ann_info(self.data_infos[idx], ann_info) + + def get_cat_ids(self, idx): + """Get COCO category ids by index. + + Args: + idx (int): Index of data. + + Returns: + list[int]: All categories in the image of specified index. + """ + + img_id = self.data_infos[idx]['id'] + ann_ids = self.coco.getAnnIds(imgIds=[img_id]) + ann_info = self.coco.loadAnns(ann_ids) + return [ann['category_id'] for ann in ann_info] + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['img_prefix'] = self.img_prefix + results['seg_prefix'] = self.seg_prefix + results['bbox_fields'] = [] + results['mask_fields'] = [] + results['seg_fields'] = [] + + def _filter_imgs(self, min_size=32): + """Filter images too small or without ground truths.""" + valid_inds = [] + # obtain images that contain annotation + ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.coco.catToImgs[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_img_ids = [] + for i, img_info in enumerate(self.data_infos): + img_id = self.img_ids[i] + if self.filter_empty_gt and img_id not in ids_in_cat: + continue + if min(img_info['width'], img_info['height']) >= min_size: + valid_inds.append(i) + valid_img_ids.append(img_id) + self.img_ids = valid_img_ids + return valid_inds + + def _parse_ann_info(self, img_info, ann_info): + """Parse bbox and mask annotation. + + Args: + ann_info (list[dict]): Annotation info of an image. + + Returns: + dict: A dict containing the following keys: bboxes, bboxes_ignore,\ + labels, masks, seg_map. "masks" are raw annotations and not \ + decoded into binary masks. + """ + gt_bboxes = [] + gt_labels = [] + gt_bboxes_ignore = [] + gt_masks_ann = [] + for i, ann in enumerate(ann_info): + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if ann['area'] <= 0 or w < 1 or h < 1: + continue + if ann['category_id'] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + if ann.get('iscrowd', False): + gt_bboxes_ignore.append(bbox) + else: + gt_bboxes.append(bbox) + gt_labels.append(self.cat2label[ann['category_id']]) + gt_masks_ann.append(ann.get('segmentation', None)) + + if gt_bboxes: + gt_bboxes = np.array(gt_bboxes, dtype=np.float32) + gt_labels = np.array(gt_labels, dtype=np.int64) + else: + gt_bboxes = np.zeros((0, 4), dtype=np.float32) + gt_labels = np.array([], dtype=np.int64) + + if gt_bboxes_ignore: + gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) + else: + gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) + + seg_map = img_info['filename'].replace('jpg', 'png') + + ann = dict( + bboxes=gt_bboxes, + labels=gt_labels, + bboxes_ignore=gt_bboxes_ignore, + masks=gt_masks_ann, + seg_map=seg_map) + + return ann + + def _set_group_flag(self): + """Set flag according to image aspect ratio. + + Images with aspect ratio greater than 1 will be set as group 1, + otherwise group 0. + """ + self.flag = np.zeros(len(self), dtype=np.uint8) + for i in range(len(self)): + img_info = self.data_infos[i] + if img_info['width'] / img_info['height'] > 1: + self.flag[i] = 1 + + def _rand_another(self, idx): + """Get another random index from the same group as the given index.""" + pool = np.where(self.flag == self.flag[idx])[0] + return np.random.choice(pool) + + def __getitem__(self, idx): + """Get training/test data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training/test data (with annotation if `test_mode` is set \ + True). + """ + + if self.test_mode: + return self.prepare_test_img(idx) + while True: + data = self.prepare_train_img(idx) + if data is None: + idx = self._rand_another(idx) + continue + return data + + def prepare_train_img(self, idx): + """Get training data and annotations after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys \ + introduced by pipeline. + """ + + img_info = self.data_infos[idx] + ann_info = self.get_ann_info(idx) + results = dict(img_info=img_info, ann_info=ann_info) + self.pre_pipeline(results) + if self.preprocessor is None: + return results + self.preprocessor.train() + return self.preprocessor(results) + + def prepare_test_img(self, idx): + """Get testing data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys introduced by \ + pipeline. + """ + + img_info = self.data_infos[idx] + results = dict(img_info=img_info) + self.pre_pipeline(results) + if self.preprocessor is None: + return results + self.preprocessor.eval() + results = self.preprocessor(results) + return results + + @classmethod + def get_classes(cls, classes=None): + """Get class names of current dataset. + + Args: + classes (Sequence[str] | None): If classes is None, use + default CLASSES defined by builtin dataset. If classes is + a tuple or list, override the CLASSES defined by the dataset. + + Returns: + tuple[str] or list[str]: Names of categories of the dataset. + """ + if classes is None: + return cls.CLASSES + + if isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + return class_names diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py new file mode 100644 index 00000000..4df24fae --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_portrait_enhancement_dataset import ImagePortraitEnhancementDataset + +else: + _import_structure = { + 'image_portrait_enhancement_dataset': + ['ImagePortraitEnhancementDataset'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py new file mode 100644 index 00000000..1133d3c2 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py @@ -0,0 +1,32 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + +import cv2 +import torch + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py new file mode 100644 index 00000000..58d40778 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +from modelscope.metainfo import Datasets, Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from .data_utils import img2tensor + + +def default_loader(path): + return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0 + + +@TASK_DATASETS.register_module( + Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset) +class ImagePortraitEnhancementDataset(TorchTaskDataset): + """Paired image dataset for image portrait enhancement. + """ + + def __init__(self, dataset, is_train): + self.dataset = dataset + self.gt_size = 256 + self.is_train = is_train + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + item_dict = self.dataset[index] + gt_path = item_dict['hq:FILE'] + img_gt = default_loader(gt_path) + lq_path = item_dict['lq:FILE'] + img_lq = default_loader(lq_path) + + gt_size = self.gt_size + img_gt = cv2.resize(img_gt, (gt_size, gt_size)) + img_lq = cv2.resize(img_lq, (gt_size, gt_size)) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + + return {'input': (img_lq - 0.5) / 0.5, 'target': (img_gt - 0.5) / 0.5} diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py new file mode 100644 index 00000000..b1bc40f8 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py new file mode 100644 index 00000000..68cbf918 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py @@ -0,0 +1,172 @@ +# The implementation here is modified based on BaSSL, +# originally Apache 2.0 License and publicly available at https://github.com/kakaobrain/bassl +import copy +import os +import os.path as osp +import random + +import json +import torch +from torchvision.datasets.folder import pil_loader + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from . import sampler + +DATASET_STRUCTURE = { + 'train': { + 'annotation': 'anno/train.json', + 'images': 'keyf_240p', + 'feat': 'feat' + }, + 'test': { + 'annotation': 'anno/test.json', + 'images': 'keyf_240p', + 'feat': 'feat' + } +} + + +@TASK_DATASETS.register_module( + Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert) +class MovieSceneSegmentationDataset(TorchTaskDataset): + """dataset for movie scene segmentation. + + Args: + split_config (dict): Annotation file path. {"train":"xxxxx"} + data_root (str, optional): Data root for ``ann_file``, + ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. + test_mode (bool, optional): If set True, annotation will not be loaded. + """ + + def __init__(self, **kwargs): + split_config = kwargs['split_config'] + + self.data_root = next(iter(split_config.values())) + if not osp.exists(self.data_root): + self.data_root = osp.dirname(self.data_root) + assert osp.exists(self.data_root) + + self.split = next(iter(split_config.keys())) + self.preprocessor = kwargs['preprocessor'] + + self.ann_file = osp.join(self.data_root, + DATASET_STRUCTURE[self.split]['annotation']) + self.img_prefix = osp.join(self.data_root, + DATASET_STRUCTURE[self.split]['images']) + self.feat_prefix = osp.join(self.data_root, + DATASET_STRUCTURE[self.split]['feat']) + + self.test_mode = kwargs['test_mode'] + if self.test_mode: + self.preprocessor.eval() + else: + self.preprocessor.train() + + self.cfg = kwargs.pop('cfg', None) + + self.num_keyframe = self.cfg.num_keyframe if self.cfg is not None else 3 + self.use_single_keyframe = self.cfg.use_single_keyframe if self.cfg is not None else False + + self.load_data() + self.init_sampler(self.cfg) + + def __len__(self): + """Total number of samples of data.""" + return len(self.anno_data) + + def __getitem__(self, idx: int): + data = self.anno_data[ + idx] # {"video_id", "shot_id", "num_shot", "boundary_label"} + vid, sid = data['video_id'], data['shot_id'] + num_shot = data['num_shot'] + + shot_idx = self.shot_sampler(int(sid), num_shot) + + video = self.load_shot_list(vid, shot_idx) + if self.preprocessor is None: + video = torch.stack(video, dim=0) + video = video.view(-1, self.num_keyframe, 3, 224, 224) + else: + video = self.preprocessor(video) + + payload = { + 'idx': idx, + 'vid': vid, + 'sid': sid, + 'video': video, + 'label': abs(data['boundary_label']), # ignore -1 label. + } + return payload + + def load_data(self): + self.tmpl = '{}/shot_{}_img_{}.jpg' # video_id, shot_id, shot_num + + if not self.test_mode: + with open(self.ann_file) as f: + self.anno_data = json.load(f) + self.vidsid2label = { + f"{it['video_id']}_{it['shot_id']}": it['boundary_label'] + for it in self.anno_data + } + else: + with open(self.ann_file) as f: + self.anno_data = json.load(f) + + def init_sampler(self, cfg): + # shot sampler + if cfg is not None: + self.sampling_method = cfg.sampling_method.name + sampler_args = copy.deepcopy( + cfg.sampling_method.params.get(self.sampling_method, {})) + if self.sampling_method == 'instance': + self.shot_sampler = sampler.InstanceShotSampler() + elif self.sampling_method == 'temporal': + self.shot_sampler = sampler.TemporalShotSampler(**sampler_args) + elif self.sampling_method == 'shotcol': + self.shot_sampler = sampler.SequenceShotSampler(**sampler_args) + elif self.sampling_method == 'bassl': + self.shot_sampler = sampler.SequenceShotSampler(**sampler_args) + elif self.sampling_method == 'bassl+shotcol': + self.shot_sampler = sampler.SequenceShotSampler(**sampler_args) + elif self.sampling_method == 'sbd': + self.shot_sampler = sampler.NeighborShotSampler(**sampler_args) + else: + raise NotImplementedError + else: + self.shot_sampler = sampler.NeighborShotSampler() + + def load_shot_list(self, vid, shot_idx): + shot_list = [] + cache = {} + for sidx in shot_idx: + vidsid = f'{vid}_{sidx:04d}' + if vidsid in cache: + shot = cache[vidsid] + else: + shot_path = os.path.join( + self.img_prefix, self.tmpl.format(vid, f'{sidx:04d}', + '{}')) + shot = self.load_shot_keyframes(shot_path) + cache[vidsid] = shot + shot_list.extend(shot) + return shot_list + + def load_shot_keyframes(self, path): + shot = None + if not self.test_mode and self.use_single_keyframe: + # load one randomly sampled keyframe + shot = [ + pil_loader( + path.format(random.randint(0, self.num_keyframe - 1))) + ] + else: + # load all keyframes + shot = [ + pil_loader(path.format(i)) for i in range(self.num_keyframe) + ] + assert shot is not None + return shot diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py new file mode 100644 index 00000000..0fc2fe0f --- /dev/null +++ b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py @@ -0,0 +1,102 @@ +# ------------------------------------------------------------------------------------ +# BaSSL +# Copyright (c) 2021 KakaoBrain. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# Github: https://github.com/kakaobrain/bassl +# ------------------------------------------------------------------------------------ + +import random + +import numpy as np + + +class InstanceShotSampler: + """ This is for instance at pre-training stage """ + + def __call__(self, center_sid: int, *args, **kwargs): + return center_sid + + +class TemporalShotSampler: + """ This is for temporal at pre-training stage """ + + def __init__(self, neighbor_size: int): + self.N = neighbor_size + + def __call__(self, center_sid: int, total_num_shot: int): + """ we randomly sample one shot from neighbor shots within local temporal window + """ + shot_idx = center_sid + np.arange( + -self.N, self.N + 1 + ) # total number of neighbor shots = 2N+1 (query (1) + neighbors (2*N)) + shot_idx = np.clip(shot_idx, 0, + total_num_shot) # deal with out-of-boundary indices + shot_idx = random.choice( + np.unique(np.delete(shot_idx, np.where(shot_idx == center_sid)))) + return shot_idx + + +class SequenceShotSampler: + """ This is for bassl or shotcol at pre-training stage """ + + def __init__(self, neighbor_size: int, neighbor_interval: int): + self.interval = neighbor_interval + self.window_size = neighbor_size * self.interval # temporal coverage + + def __call__(self, + center_sid: int, + total_num_shot: int, + sparse_method: str = 'edge'): + """ + Args: + center_sid: index of center shot + total_num_shot: last index of shot for given video + sparse_stride: stride to sample sparse ones from dense sequence + for curriculum learning + """ + + dense_shot_idx = center_sid + np.arange( + -self.window_size, self.window_size + 1, + self.interval) # total number of shots = 2*neighbor_size+1 + + if dense_shot_idx[0] < 0: + # if center_sid is near left-side of video, we shift window rightward + # so that the leftmost index is 0 + dense_shot_idx -= dense_shot_idx[0] + elif dense_shot_idx[-1] > (total_num_shot - 1): + # if center_sid is near right-side of video, we shift window leftward + # so that the rightmost index is total_num_shot - 1 + dense_shot_idx -= dense_shot_idx[-1] - (total_num_shot - 1) + + # to deal with videos that have smaller number of shots than window size + dense_shot_idx = np.clip(dense_shot_idx, 0, total_num_shot) + + if sparse_method == 'edge': + # in this case, we use two edge shots as sparse sequence + sparse_stride = len(dense_shot_idx) - 1 + sparse_idx_to_dense = np.arange(0, len(dense_shot_idx), + sparse_stride) + elif sparse_method == 'edge+center': + # in this case, we use two edge shots + center shot as sparse sequence + sparse_idx_to_dense = np.array( + [0, len(dense_shot_idx) - 1, + len(dense_shot_idx) // 2]) + + shot_idx = [sparse_idx_to_dense, dense_shot_idx] + return shot_idx + + +class NeighborShotSampler: + """ This is for scene boundary detection (sbd), i.e., fine-tuning stage """ + + def __init__(self, neighbor_size: int = 8): + self.neighbor_size = neighbor_size + + def __call__(self, center_sid: int, total_num_shot: int): + # total number of shots = 2 * neighbor_size + 1 + shot_idx = center_sid + np.arange(-self.neighbor_size, + self.neighbor_size + 1) + shot_idx = np.clip(shot_idx, 0, + total_num_shot) # for out-of-boundary indices + + return shot_idx diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py new file mode 100644 index 00000000..7c1b724e --- /dev/null +++ b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .referring_video_object_segmentation_dataset import \ + ReferringVideoObjectSegmentationDataset diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py new file mode 100644 index 00000000..c90351e9 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py @@ -0,0 +1,361 @@ +# Part of the implementation is borrowed and modified from MTTR, +# publicly available at https://github.com/mttr2021/MTTR + +from glob import glob +from os import path as osp + +import h5py +import json +import numpy as np +import pandas +import torch +import torch.distributed as dist +import torchvision.transforms.functional as F +from pycocotools.mask import area, encode +from torchvision.io import read_video +from tqdm import tqdm + +from modelscope.metainfo import Models +from modelscope.models.cv.referring_video_object_segmentation.utils import \ + nested_tensor_from_videos_list +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from . import transformers as T + +LOGGER = get_logger() + + +def get_image_id(video_id, frame_idx, ref_instance_a2d_id): + image_id = f'v_{video_id}_f_{frame_idx}_i_{ref_instance_a2d_id}' + return image_id + + +@TASK_DATASETS.register_module( + Tasks.referring_video_object_segmentation, + module_name=Models.referring_video_object_segmentation) +class ReferringVideoObjectSegmentationDataset(TorchTaskDataset): + + def __init__(self, **kwargs): + split_config = kwargs['split_config'] + LOGGER.info(kwargs) + data_cfg = kwargs.get('cfg').data_kwargs + trans_cfg = kwargs.get('cfg').transformers_kwargs + distributed = data_cfg.get('distributed', False) + + self.data_root = next(iter(split_config.values())) + if not osp.exists(self.data_root): + self.data_root = osp.dirname(self.data_root) + assert osp.exists(self.data_root) + + self.window_size = data_cfg.get('window_size', 8) + self.mask_annotations_dir = osp.join( + self.data_root, 'text_annotations/annotation_with_instances') + self.videos_dir = osp.join(self.data_root, 'Release/CLIPS320') + self.subset_type = next(iter(split_config.keys())) + self.text_annotations = self.get_text_annotations( + self.data_root, self.subset_type, distributed) + self.transforms = A2dSentencesTransforms(self.subset_type, **trans_cfg) + self.collator = Collator() + self.ann_file = osp.join( + self.data_root, + data_cfg.get('ann_file', + 'a2d_sentences_test_annotations_in_coco_format.json')) + + # create ground-truth test annotations for the evaluation process if necessary: + if self.subset_type == 'test' and not osp.exists(self.ann_file): + if (distributed and dist.get_rank() == 0) or not distributed: + create_a2d_sentences_ground_truth_test_annotations( + self.data_root, self.subset_type, + self.mask_annotations_dir, self.ann_file) + if distributed: + dist.barrier() + + def __len__(self): + return len(self.text_annotations) + + def __getitem__(self, idx): + text_query, video_id, frame_idx, instance_id = self.text_annotations[ + idx] + + text_query = ' '.join( + text_query.lower().split()) # clean up the text query + + # read the source window frames: + video_frames, _, _ = read_video( + osp.join(self.videos_dir, f'{video_id}.mp4'), + pts_unit='sec') # (T, H, W, C) + # get a window of window_size frames with frame frame_idx in the middle. + # note that the original a2d dataset is 1 indexed, so we have to subtract 1 from frame_idx + start_idx, end_idx = frame_idx - 1 - self.window_size // 2, frame_idx - 1 + ( + self.window_size + 1) // 2 + + # extract the window source frames: + source_frames = [] + for i in range(start_idx, end_idx): + i = min(max(i, 0), + len(video_frames) + - 1) # pad out of range indices with edge frames + source_frames.append( + F.to_pil_image(video_frames[i].permute(2, 0, 1))) + + # read the instance mask: + frame_annot_path = osp.join(self.mask_annotations_dir, video_id, + f'{frame_idx:05d}.h5') + f = h5py.File(frame_annot_path, 'r') + instances = list(f['instance']) + instance_idx = instances.index( + instance_id) # existence was already validated during init + + instance_masks = np.array(f['reMask']) + if len(instances) == 1: + instance_masks = instance_masks[np.newaxis, ...] + instance_masks = torch.tensor(instance_masks).transpose(1, 2) + mask_rles = [encode(mask) for mask in instance_masks.numpy()] + mask_areas = area(mask_rles).astype(np.float) + f.close() + + # create the target dict for the center frame: + target = { + 'masks': instance_masks, + 'orig_size': instance_masks. + shape[-2:], # original frame shape without any augmentations + # size with augmentations, will be changed inside transforms if necessary + 'size': instance_masks.shape[-2:], + 'referred_instance_idx': torch.tensor( + instance_idx), # idx in 'masks' of the text referred instance + 'area': torch.tensor(mask_areas), + 'iscrowd': + torch.zeros(len(instance_masks) + ), # for compatibility with DETR COCO transforms + 'image_id': get_image_id(video_id, frame_idx, instance_id) + } + + # create dummy targets for adjacent frames: + targets = self.window_size * [None] + center_frame_idx = self.window_size // 2 + targets[center_frame_idx] = target + source_frames, targets, text_query = self.transforms( + source_frames, targets, text_query) + return source_frames, targets, text_query + + @staticmethod + def get_text_annotations(root_path, subset, distributed): + saved_annotations_file_path = osp.join( + root_path, f'sentences_single_frame_{subset}_annotations.json') + if osp.exists(saved_annotations_file_path): + with open(saved_annotations_file_path, 'r') as f: + text_annotations_by_frame = [tuple(a) for a in json.load(f)] + return text_annotations_by_frame + elif (distributed and dist.get_rank() == 0) or not distributed: + print(f'building a2d sentences {subset} text annotations...') + # without 'header == None' pandas will ignore the first sample... + a2d_data_info = pandas.read_csv( + osp.join(root_path, 'Release/videoset.csv'), header=None) + # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' + a2d_data_info.columns = [ + 'vid', '', '', '', '', '', '', '', 'subset' + ] + with open( + osp.join(root_path, 'text_annotations/missed_videos.txt'), + 'r') as f: + unused_videos = f.read().splitlines() + subsets = {'train': 0, 'test': 1} + # filter unused videos and videos which do not belong to our train/test subset: + used_videos = a2d_data_info[ + ~a2d_data_info.vid.isin(unused_videos) + & (a2d_data_info.subset == subsets[subset])] + used_videos_ids = list(used_videos['vid']) + text_annotations = pandas.read_csv( + osp.join(root_path, 'text_annotations/annotation.txt')) + # filter the text annotations based on the used videos: + used_text_annotations = text_annotations[ + text_annotations.video_id.isin(used_videos_ids)] + # remove a single dataset annotation mistake in video: T6bNPuKV-wY + used_text_annotations = used_text_annotations[ + used_text_annotations['instance_id'] != '1 (copy)'] + # convert data-frame to list of tuples: + used_text_annotations = list( + used_text_annotations.to_records(index=False)) + text_annotations_by_frame = [] + mask_annotations_dir = osp.join( + root_path, 'text_annotations/annotation_with_instances') + for video_id, instance_id, text_query in tqdm( + used_text_annotations): + frame_annot_paths = sorted( + glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) + instance_id = int(instance_id) + for p in frame_annot_paths: + f = h5py.File(p) + instances = list(f['instance']) + if instance_id in instances: + # in case this instance does not appear in this frame it has no ground-truth mask, and thus this + # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: + # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 + frame_idx = int(p.split('/')[-1].split('.')[0]) + text_query = text_query.lower( + ) # lower the text query prior to augmentation & tokenization + text_annotations_by_frame.append( + (text_query, video_id, frame_idx, instance_id)) + with open(saved_annotations_file_path, 'w') as f: + json.dump(text_annotations_by_frame, f) + if distributed: + dist.barrier() + with open(saved_annotations_file_path, 'r') as f: + text_annotations_by_frame = [tuple(a) for a in json.load(f)] + return text_annotations_by_frame + + +class A2dSentencesTransforms: + + def __init__(self, subset_type, horizontal_flip_augmentations, + resize_and_crop_augmentations, train_short_size, + train_max_size, eval_short_size, eval_max_size, **kwargs): + self.h_flip_augmentation = subset_type == 'train' and horizontal_flip_augmentations + normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + scales = [ + train_short_size + ] # no more scales for now due to GPU memory constraints. might be changed later + transforms = [] + if resize_and_crop_augmentations: + if subset_type == 'train': + transforms.append( + T.RandomResize(scales, max_size=train_max_size)) + elif subset_type == 'test': + transforms.append( + T.RandomResize([eval_short_size], max_size=eval_max_size)), + transforms.extend([T.ToTensor(), normalize]) + self.size_transforms = T.Compose(transforms) + + def __call__(self, source_frames, targets, text_query): + if self.h_flip_augmentation and torch.rand(1) > 0.5: + source_frames = [F.hflip(f) for f in source_frames] + targets[len(targets) // 2]['masks'] = F.hflip( + targets[len(targets) // 2]['masks']) + # Note - is it possible for both 'right' and 'left' to appear together in the same query. hence this fix: + text_query = text_query.replace('left', '@').replace( + 'right', 'left').replace('@', 'right') + source_frames, targets = list( + zip(*[ + self.size_transforms(f, t) + for f, t in zip(source_frames, targets) + ])) + source_frames = torch.stack(source_frames) # [T, 3, H, W] + return source_frames, targets, text_query + + +class Collator: + + def __call__(self, batch): + samples, targets, text_queries = list(zip(*batch)) + samples = nested_tensor_from_videos_list(samples) # [T, B, C, H, W] + # convert targets to a list of tuples. outer list - time steps, inner tuples - time step batch + targets = list(zip(*targets)) + batch_dict = { + 'samples': samples, + 'targets': targets, + 'text_queries': text_queries + } + return batch_dict + + +def get_text_annotations_gt(root_path, subset): + # without 'header == None' pandas will ignore the first sample... + a2d_data_info = pandas.read_csv( + osp.join(root_path, 'Release/videoset.csv'), header=None) + # 'vid', 'label', 'start_time', 'end_time', 'height', 'width', 'total_frames', 'annotated_frames', 'subset' + a2d_data_info.columns = ['vid', '', '', '', '', '', '', '', 'subset'] + with open(osp.join(root_path, 'text_annotations/missed_videos.txt'), + 'r') as f: + unused_videos = f.read().splitlines() + subsets = {'train': 0, 'test': 1} + # filter unused videos and videos which do not belong to our train/test subset: + used_videos = a2d_data_info[~a2d_data_info.vid.isin(unused_videos) + & (a2d_data_info.subset == subsets[subset])] + used_videos_ids = list(used_videos['vid']) + text_annotations = pandas.read_csv( + osp.join(root_path, 'text_annotations/annotation.txt')) + # filter the text annotations based on the used videos: + used_text_annotations = text_annotations[text_annotations.video_id.isin( + used_videos_ids)] + # convert data-frame to list of tuples: + used_text_annotations = list(used_text_annotations.to_records(index=False)) + return used_text_annotations + + +def create_a2d_sentences_ground_truth_test_annotations(dataset_path, + subset_type, + mask_annotations_dir, + output_path): + text_annotations = get_text_annotations_gt(dataset_path, subset_type) + + # Note - it is very important to start counting the instance and category ids from 1 (not 0). This is implicitly + # expected by pycocotools as it is the convention of the original coco dataset annotations. + + categories_dict = [{ + 'id': 1, + 'name': 'dummy_class' + }] # dummy class, as categories are not used/predicted in RVOS + + images_dict = [] + annotations_dict = [] + images_set = set() + instance_id_counter = 1 + for annot in tqdm(text_annotations): + video_id, instance_id, text_query = annot + annot_paths = sorted( + glob(osp.join(mask_annotations_dir, video_id, '*.h5'))) + for p in annot_paths: + f = h5py.File(p) + instances = list(f['instance']) + try: + instance_idx = instances.index(int(instance_id)) + # in case this instance does not appear in this frame it has no ground-truth mask, and thus this + # frame-instance pair is ignored in evaluation, same as SOTA method: CMPC-V. check out: + # https://github.com/spyflying/CMPC-Refseg/blob/094639b8bf00cc169ea7b49cdf9c87fdfc70d963/CMPC_video/build_A2D_batches.py#L98 + except ValueError: + continue # instance_id does not appear in current frame + mask = f['reMask'][instance_idx] if len( + instances) > 1 else np.array(f['reMask']) + mask = mask.transpose() + + frame_idx = int(p.split('/')[-1].split('.')[0]) + image_id = get_image_id(video_id, frame_idx, instance_id) + assert image_id not in images_set, f'error: image id: {image_id} appeared twice' + images_set.add(image_id) + images_dict.append({ + 'id': image_id, + 'height': mask.shape[0], + 'width': mask.shape[1] + }) + + mask_rle = encode(mask) + mask_rle['counts'] = mask_rle['counts'].decode('ascii') + mask_area = float(area(mask_rle)) + bbox = f['reBBox'][:, instance_idx] if len( + instances) > 1 else np.array( + f['reBBox']).squeeze() # x1y1x2y2 form + bbox_xywh = [ + bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] + ] + instance_annot = { + 'id': instance_id_counter, + 'image_id': image_id, + 'category_id': + 1, # dummy class, as categories are not used/predicted in ref-vos + 'segmentation': mask_rle, + 'area': mask_area, + 'bbox': bbox_xywh, + 'iscrowd': 0, + } + annotations_dict.append(instance_annot) + instance_id_counter += 1 + dataset_dict = { + 'categories': categories_dict, + 'images': images_dict, + 'annotations': annotations_dict + } + with open(output_path, 'w') as f: + json.dump(dataset_dict, f) diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py new file mode 100644 index 00000000..a5067b1b --- /dev/null +++ b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py @@ -0,0 +1,294 @@ +# The implementation is adopted from MTTR, +# made publicly available under the Apache 2.0 License at https://github.com/mttr2021/MTTR +# Modified from DETR https://github.com/facebookresearch/detr + +import random + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F + +from modelscope.models.cv.referring_video_object_segmentation.utils import \ + interpolate + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target['size'] = torch.tensor([h, w]) + + fields = ['labels', 'area', 'iscrowd'] + + if 'boxes' in target: + boxes = target['boxes'] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target['boxes'] = cropped_boxes.reshape(-1, 4) + target['area'] = area + fields.append('boxes') + + if 'masks' in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append('masks') + + # remove elements for which the boxes or masks that have zero area + if 'boxes' in target or 'masks' in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if 'boxes' in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all( + cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target['boxes'] = boxes + + if 'masks' in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int( + round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple( + float(s) / float(s_orig) + for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target['boxes'] = scaled_boxes + + if 'area' in target: + area = target['area'] + scaled_area = area * (ratio_width * ratio_height) + target['area'] = scaled_area + + h, w = size + target['size'] = torch.tensor([h, w]) + + if 'masks' in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode='nearest')[:, 0] > 0.5 + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target['size'] = torch.tensor(padded_image.size[::-1]) + if 'masks' in target: + target['masks'] = torch.nn.functional.pad( + target['masks'], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class RandomCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop(object): + + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, + (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomPad(object): + + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor(object): + + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing(object): + + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize(object): + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if 'boxes' in target: + boxes = target['boxes'] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target['boxes'] = boxes + return image, target + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py new file mode 100644 index 00000000..5376cd7c --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .sidd_image_denoising_dataset import SiddImageDenoisingDataset + +else: + _import_structure = { + 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py new file mode 100644 index 00000000..33fce4c8 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py @@ -0,0 +1,46 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + +import cv2 +import torch + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def padding(img_lq, img_gt, gt_size): + h, w, _ = img_lq.shape + + h_pad = max(0, gt_size - h) + w_pad = max(0, gt_size - w) + + if h_pad == 0 and w_pad == 0: + return img_lq, img_gt + + img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + return img_lq, img_gt diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py new file mode 100644 index 00000000..3f0cdae0 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from .data_utils import img2tensor, padding +from .transforms import augment, paired_random_crop + + +def default_loader(path): + return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 + + +@TASK_DATASETS.register_module( + Tasks.image_denoising, module_name=Models.nafnet) +class SiddImageDenoisingDataset(TorchTaskDataset): + """Paired image dataset for image restoration. + """ + + def __init__(self, dataset, opt, is_train): + self.dataset = dataset + self.opt = opt + self.is_train = is_train + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + item_dict = self.dataset[index] + gt_path = item_dict['Clean Image:FILE'] + img_gt = default_loader(gt_path) + lq_path = item_dict['Noisy Image:FILE'] + img_lq = default_loader(lq_path) + + # augmentation for training + if self.is_train: + gt_size = self.opt.gt_size + # padding + img_gt, img_lq = padding(img_gt, img_lq, gt_size) + + # random crop + img_gt, img_lq = paired_random_crop( + img_gt, img_lq, gt_size, scale=1) + + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip, + self.opt.use_rot) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + + return {'input': img_lq, 'target': img_gt} diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py new file mode 100644 index 00000000..c5ad12f6 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py @@ -0,0 +1,96 @@ +# Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/data/transforms.py + +import random + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + Args: + img (ndarray): Input image. + scale (int): Scale factor. + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale): + """Paired random crop. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray): GT images. + img_lqs (list[ndarray] | ndarray): LQ images. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + h_lq, w_lq, _ = img_lqs[0].shape + h_gt, w_gt, _ = img_gts[0].shape + lq_patch_size = gt_patch_size // scale + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + img_lqs = [ + v[top:top + lq_patch_size, left:left + lq_patch_size, ...] + for v in img_lqs + ] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + img_gts = [ + v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] + for v in img_gts + ] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, vflip=False): + """Augment: horizontal flips | rotate + + All the images in the list use the same augmentation. + """ + hflip = hflip and random.random() < 0.5 + if vflip or rotation: + vflip = random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + img = img[:, ::-1, :].copy() + if vflip: # vertical + img = img[::-1, :, :].copy() + if rot90: + img = img.transpose(1, 0, 2) + return img + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + return imgs diff --git a/modelscope/msdatasets/task_datasets/text_ranking_dataset.py b/modelscope/msdatasets/task_datasets/text_ranking_dataset.py new file mode 100644 index 00000000..54276843 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/text_ranking_dataset.py @@ -0,0 +1,150 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import random +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +import torch +from datasets import Dataset, IterableDataset, concatenate_datasets +from torch.utils.data import ConcatDataset +from transformers import DataCollatorWithPadding + +from modelscope.metainfo import Models +from modelscope.utils.constant import ModeKeys, Tasks +from .base import TaskDataset +from .builder import TASK_DATASETS +from .torch_base_dataset import TorchTaskDataset + + +@TASK_DATASETS.register_module( + group_key=Tasks.text_ranking, module_name=Models.bert) +class TextRankingDataset(TorchTaskDataset): + + def __init__(self, + datasets: Union[Any, List[Any]], + mode, + preprocessor=None, + *args, + **kwargs): + self.seed = kwargs.get('seed', 42) + self.permutation = None + self.datasets = None + self.dataset_config = kwargs + self.query_sequence = self.dataset_config.get('query_sequence', + 'query') + self.pos_sequence = self.dataset_config.get('pos_sequence', + 'positive_passages') + self.neg_sequence = self.dataset_config.get('neg_sequence', + 'negative_passages') + self.text_fileds = self.dataset_config.get('text_fileds', + ['title', 'text']) + self.qid_field = self.dataset_config.get('qid_field', 'query_id') + if mode == ModeKeys.TRAIN: + self.neg_samples = self.dataset_config.get('neg_sample', 4) + + super().__init__(datasets, mode, preprocessor, **kwargs) + + def __getitem__(self, index) -> Any: + if self.mode == ModeKeys.TRAIN: + return self.__get_train_item__(index) + else: + return self.__get_test_item__(index) + + def __get_test_item__(self, index): + group = self._inner_dataset[index] + labels = [] + + qry = group[self.query_sequence] + + pos_sequences = group[self.pos_sequence] + pos_sequences = [ + ' '.join([ele[key] for key in self.text_fileds]) + for ele in pos_sequences + ] + labels.extend([1] * len(pos_sequences)) + + neg_sequences = group[self.neg_sequence] + neg_sequences = [ + ' '.join([ele[key] for key in self.text_fileds]) + for ele in neg_sequences + ] + + labels.extend([0] * len(neg_sequences)) + qid = group[self.qid_field] + + examples = pos_sequences + neg_sequences + sample = { + 'qid': torch.LongTensor([int(qid)] * len(labels)), + self.preprocessor.first_sequence: qry, + self.preprocessor.second_sequence: examples, + 'labels': torch.LongTensor(labels) + } + return self.prepare_sample(sample) + + def __get_train_item__(self, index): + group = self._inner_dataset[index] + + qry = group[self.query_sequence] + + pos_sequences = group[self.pos_sequence] + pos_sequences = [ + ' '.join([ele[key] for key in self.text_fileds]) + for ele in pos_sequences + ] + + neg_sequences = group[self.neg_sequence] + neg_sequences = [ + ' '.join([ele[key] for key in self.text_fileds]) + for ele in neg_sequences + ] + + pos_psg = random.choice(pos_sequences) + + if len(neg_sequences) < self.neg_samples: + negs = random.choices(neg_sequences, k=self.neg_samples) + else: + negs = random.sample(neg_sequences, k=self.neg_samples) + examples = [pos_psg] + negs + sample = { + self.preprocessor.first_sequence: qry, + self.preprocessor.second_sequence: examples, + } + return self.prepare_sample(sample) + + def __len__(self): + return len(self._inner_dataset) + + def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: + """Prepare a dataset. + + User can process the input datasets in a whole dataset perspective. + This method gives a default implementation of datasets merging, user can override this + method to write custom logics. + + Args: + datasets: The original dataset(s) + + Returns: A single dataset, which may be created after merging. + + """ + if isinstance(datasets, List): + if len(datasets) == 1: + return datasets[0] + elif len(datasets) > 1: + return ConcatDataset(datasets) + else: + return datasets + + def prepare_sample(self, data): + """Preprocess the data fetched from the inner_dataset. + + If the preprocessor is None, the original data will be returned, else the preprocessor will be called. + User can override this method to implement custom logics. + + Args: + data: The data fetched from the dataset. + + Returns: The processed data. + + """ + return self.preprocessor( + data) if self.preprocessor is not None else data diff --git a/modelscope/msdatasets/task_datasets/torch_base_dataset.py b/modelscope/msdatasets/task_datasets/torch_base_dataset.py new file mode 100644 index 00000000..4d82b741 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/torch_base_dataset.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, List, Tuple, Union + +from torch.utils.data import ConcatDataset, Dataset + +from .base import TaskDataset + + +class TorchTaskDataset(TaskDataset, Dataset): + """The task dataset base class for all the torch-based task processors. + + This base class is enough for most cases, except there are procedures which can not be executed in + preprocessors and Datasets like dataset merging. + """ + + def __init__(self, + datasets: Union[Any, List[Any]], + mode, + preprocessor=None, + **kwargs): + TaskDataset.__init__(self, datasets, mode, preprocessor, **kwargs) + self.trainer = None + + def __getitem__(self, index) -> Any: + return self.prepare_sample(self._inner_dataset[index]) + + def __len__(self): + return len(self._inner_dataset) + + def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: + """Prepare a dataset. + + User can process the input datasets in a whole dataset perspective. + This method gives a default implementation of datasets merging, user can override this + method to write custom logics. + + Args: + datasets: The original dataset(s) + + Returns: A single dataset, which may be created after merging. + + """ + if isinstance(datasets, List): + if len(datasets) == 1: + return datasets[0] + elif len(datasets) > 1: + return ConcatDataset(datasets) + else: + return datasets + + def prepare_sample(self, data): + """Preprocess the data fetched from the inner_dataset. + + If the preprocessor is None, the original data will be returned, else the preprocessor will be called. + User can override this method to implement custom logics. + + Args: + data: The data fetched from the dataset. + + Returns: The processed data. + + """ + return self.preprocessor( + data) if self.preprocessor is not None else data diff --git a/modelscope/msdatasets/task_datasets/veco_dataset.py b/modelscope/msdatasets/task_datasets/veco_dataset.py new file mode 100644 index 00000000..df7c6483 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/veco_dataset.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, List, Union + +import numpy as np +from datasets import Dataset, IterableDataset, concatenate_datasets + +from modelscope.metainfo import Models +from modelscope.utils.constant import Tasks +from .builder import TASK_DATASETS +from .torch_base_dataset import TorchTaskDataset + + +@TASK_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli) +class VecoDataset(TorchTaskDataset): + + def __init__(self, + datasets: Union[Any, List[Any]], + mode, + preprocessor=None, + **kwargs): + self.seed = kwargs.get('seed', 42) + self.permutation = None + self.datasets = None + super().__init__(datasets, mode, preprocessor, **kwargs) + + def switch_dataset(self, idx): + """Switch dataset in evaluation. + + Veco evaluates dataset one by one. + + Args: + idx: The index of the dataset + """ + if self.mode == 'train': + raise ValueError( + 'Only support switch dataset in the evaluation loop') + if idx >= len(self.datasets): + raise ValueError( + 'Index is bigger than the number of the datasets.') + self._inner_dataset = self.datasets[idx] + + def __getitem__(self, item): + if self.permutation is not None: + item = self.permutation[item] + return super().__getitem__(item) + + def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: + """Compose all the datasets. + + If the mode is 'train', all datasets will be mixed together, if the mode is 'eval', + the datasets will be kept and returns the first one. + + Args: + datasets: The datasets to be composed. + + Returns: The final dataset. + """ + if not isinstance(datasets, (list, tuple)): + datasets = [datasets] + if self.mode == 'train': + if len(datasets) == 1: + return datasets[0] + elif all([ + isinstance(dataset, (Dataset, IterableDataset)) + for dataset in datasets + ]): + dataset = concatenate_datasets(list(datasets)) + return dataset.shuffle(seed=self.seed) + else: + generator = np.random.default_rng(self.seed) + _len = sum([len(dataset) for dataset in datasets]) + self.permutation = generator.permutation(_len) + return super().prepare_dataset(datasets) + else: + self.datasets = datasets + return self.datasets[0] diff --git a/modelscope/msdatasets/task_datasets/video_summarization_dataset.py b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py new file mode 100644 index 00000000..34eb0450 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py @@ -0,0 +1,72 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM + +import os + +import h5py +import json +import numpy as np +import torch + +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset + + +class VideoSummarizationDataset(TorchTaskDataset): + + def __init__(self, mode, opt, root_dir): + self.mode = mode + self.data_filename = os.path.join(root_dir, opt.dataset_file) + self.split_filename = os.path.join(root_dir, opt.split_file) + self.split_index = opt.split_index + hdf = h5py.File(self.data_filename, 'r') + self.list_frame_features, self.list_gtscores = [], [] + self.list_user_summary = [] + self.list_change_points = [] + self.list_n_frames = [] + self.list_positions = [] + + with open(self.split_filename) as f: + data = json.loads(f.read()) + for i, split in enumerate(data): + if i == self.split_index: + self.split = split + break + + for video_name in self.split[self.mode + '_keys']: + frame_features = torch.Tensor( + np.array(hdf[video_name + '/features'])) + gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore'])) + user_summary = np.array(hdf[f'{video_name}/user_summary']) + change_points = np.array(hdf[f'{video_name}/change_points']) + n_frames = np.array(hdf[f'{video_name}/n_frames']) + positions = np.array(hdf[f'{video_name}/picks']) + + self.list_frame_features.append(frame_features) + self.list_gtscores.append(gtscore) + self.list_user_summary.append(user_summary) + self.list_change_points.append(change_points) + self.list_n_frames.append(n_frames) + self.list_positions.append(positions) + + hdf.close() + + def __len__(self): + self.len = len(self.split[self.mode + '_keys']) + return self.len + + def __getitem__(self, index): + frame_features = self.list_frame_features[index] + gtscore = self.list_gtscores[index] + user_summary = self.list_user_summary[index] + change_points = self.list_change_points[index] + n_frames = self.list_n_frames[index] + positions = self.list_positions[index] + + return dict( + frame_features=frame_features, + gtscore=gtscore, + user_summary=user_summary, + change_points=change_points, + n_frames=n_frames, + positions=positions) diff --git a/modelscope/msdatasets/utils/__init__.py b/modelscope/msdatasets/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/msdatasets/utils/dataset_builder.py b/modelscope/msdatasets/utils/dataset_builder.py new file mode 100644 index 00000000..0548f7b9 --- /dev/null +++ b/modelscope/msdatasets/utils/dataset_builder.py @@ -0,0 +1,197 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Mapping, Sequence, Union + +import datasets +import pandas as pd +import pyarrow as pa +from datasets.info import DatasetInfo +from datasets.naming import camelcase_to_snakecase +from datasets.packaged_modules import csv +from datasets.utils.filelock import FileLock + +from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DownloadMode +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class MsCsvDatasetBuilder(csv.Csv): + + def __init__( + self, + dataset_name: str, + cache_dir: str, + namespace: str, + subset_name: str, + hash: str, + meta_data_files: Mapping[str, Union[str, Sequence[str]]], + zip_data_files: Mapping[str, Union[str, Sequence[str]]] = None, + **config_kwargs, + ): + super().__init__( + cache_dir=cache_dir, + name=subset_name, + hash=hash, + data_files=meta_data_files, + **config_kwargs) + + self.name = camelcase_to_snakecase(dataset_name) + self.info.builder_name = dataset_name + self._cache_dir = self._build_cache_dir(namespace=namespace) + lock_path = os.path.join( + self._cache_dir_root, + self._cache_dir.replace(os.sep, '_') + '.lock') + with FileLock(lock_path): + # check if data exist + if os.path.exists(self._cache_dir): + if len(os.listdir(self._cache_dir)) > 0: + logger.info( + f'Overwrite dataset info from restored data version, cache_dir is {self._cache_dir}' + ) + # dir exists but no data, remove the empty dir as data aren't available anymore + else: + logger.warning( + f'Old caching folder {self._cache_dir} for dataset {self.name} exists ' + f'but not data were found. Removing it. ') + os.rmdir(self._cache_dir) + self.zip_data_files = zip_data_files + + def _relative_data_dir(self, + with_version=True, + with_hash=True, + namespace=DEFAULT_DATASET_NAMESPACE) -> str: + """Relative path of this dataset in cache_dir: + Will be: + self.name/self.config.version/self.hash/ + or if a namespace has been specified: + self.namespace___self.name/self.config.version/self.hash/ + """ + builder_data_dir = self.info.builder_name if namespace is None else f'{namespace}___{self.info.builder_name}' + builder_config = self.config + hash = self.hash + if builder_config: + builder_data_dir = os.path.join(builder_data_dir, self.config_id) + if with_version: + builder_data_dir = os.path.join(builder_data_dir, + str(self.config.version)) + if with_hash and hash and isinstance(hash, str): + builder_data_dir = os.path.join(builder_data_dir, hash) + return builder_data_dir + + def _build_cache_dir(self, namespace=DEFAULT_DATASET_NAMESPACE): + builder_data_dir = os.path.join( + self._cache_dir_root, + self._relative_data_dir( + with_version=False, with_hash=True, namespace=namespace)) + + return builder_data_dir + + def _split_generators(self, dl_manager): + if not self.config.data_files: + raise ValueError( + 'At least one data file must be specified, but got none.') + data_files = dl_manager.download_and_extract(self.config.data_files) + zip_data_files = dl_manager.download_and_extract(self.zip_data_files) + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + splits.append( + datasets.SplitGenerator( + name=split_name, + gen_kwargs={ + 'files': dl_manager.iter_files(files), + 'base_dir': zip_data_files.get(split_name) + })) + return splits + + def _generate_tables(self, files, base_dir): + schema = pa.schema(self.config.features.type + ) if self.config.features is not None else None + dtype = { + name: dtype.to_pandas_dtype() + for name, dtype in zip(schema.names, schema.types) + } if schema else None + for file_idx, file in enumerate(files): + csv_file_reader = pd.read_csv( + file, + iterator=True, + dtype=dtype, + **self.config.read_csv_kwargs) + transform_fields = [] + for field_name in csv_file_reader._engine.names: + if field_name.endswith(':FILE'): + transform_fields.append(field_name) + try: + for batch_idx, df in enumerate(csv_file_reader): + for field_name in transform_fields: + if base_dir: + df[field_name] = df[field_name].apply( + lambda x: os.path.join(base_dir, x)) + pa_table = pa.Table.from_pandas(df, schema=schema) + yield (file_idx, batch_idx), pa_table + except ValueError as e: + logger.error( + f"Failed to read file '{file}' with error {type(e)}: {e}") + raise + + +class TaskSpecificDatasetBuilder(MsCsvDatasetBuilder): + + def __init__( + self, + dataset_name: str, + cache_dir: str, + namespace: str, + subset_name: str, + hash: str, + meta_data_files: Mapping[str, Union[str, Sequence[str]]], + zip_data_files: Mapping[str, Union[str, Sequence[str]]] = None, + **config_kwargs, + ): + self.name = dataset_name + self.subset_name = subset_name + self.namespace = namespace + self.hash = hash + self.data_files = meta_data_files + self.zip_data_files = zip_data_files + self.split_path_dict = None + self.config = None + self.info = DatasetInfo.from_dict({'builder_name': dataset_name}) + self._cache_dir_root = os.path.expanduser(cache_dir) + self._cache_dir = self._build_cache_dir() + self._config_kwargs = config_kwargs + + def download_and_prepare(self, download_mode, dl_manager, + **download_kwargs): + # Prevent parallel disk operations + lock_path = os.path.join( + self._cache_dir_root, + self._cache_dir.replace(os.sep, '_') + '.lock') + with FileLock(lock_path): + data_exists = os.path.exists(self._cache_dir) + if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: + logger.warning( + f'Reusing dataset {self.name} ({self._cache_dir})') + return + logger.info(f'Generating dataset {self.name} ({self._cache_dir})') + self._download_and_prepare(dl_manager=dl_manager) + + def _download_and_prepare(self, dl_manager): + self.split_path_dict = dl_manager.download_and_extract( + self.zip_data_files) + + def as_dataset(self): + return ExternalDataset(self.split_path_dict, self._config_kwargs) + + +class ExternalDataset(object): + + def __init__(self, split_path_dict, config_kwargs): + config_kwargs.update({'split_config': split_path_dict}) + self.config_kwargs = config_kwargs + + def __len__(self): + return len(self.config_kwargs['split_config']) diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py new file mode 100644 index 00000000..7a46b325 --- /dev/null +++ b/modelscope/msdatasets/utils/dataset_utils.py @@ -0,0 +1,231 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from collections import defaultdict +from typing import Any, Mapping, Optional, Sequence, Union + +from datasets.builder import DatasetBuilder + +from modelscope.hub.api import HubApi +from modelscope.utils.constant import DEFAULT_DATASET_REVISION +from modelscope.utils.logger import get_logger +from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder + +logger = get_logger() + + +def format_dataset_structure(dataset_structure): + return { + k: v + for k, v in dataset_structure.items() + if (v.get('meta') or v.get('file')) + } + + +def get_target_dataset_structure(dataset_structure: dict, + subset_name: Optional[str] = None, + split: Optional[str] = None): + """ + Args: + dataset_structure (dict): Dataset Structure, like + { + "default":{ + "train":{ + "meta":"my_train.csv", + "file":"pictures.zip" + } + }, + "subsetA":{ + "test":{ + "meta":"mytest.csv", + "file":"pictures.zip" + } + } + } + subset_name (str, optional): Defining the subset_name of the dataset. + split (str, optional): Which split of the data to load. + Returns: + target_subset_name (str): Name of the chosen subset. + target_dataset_structure (dict): Structure of the chosen split(s), like + { + "test":{ + "meta":"mytest.csv", + "file":"pictures.zip" + } + } + """ + # verify dataset subset + if (subset_name and subset_name not in dataset_structure) or ( + not subset_name and len(dataset_structure.keys()) > 1): + raise ValueError( + f'subset_name {subset_name} not found. Available: {dataset_structure.keys()}' + ) + target_subset_name = subset_name + if not subset_name: + target_subset_name = next(iter(dataset_structure.keys())) + logger.info( + f'No subset_name specified, defaulting to the {target_subset_name}' + ) + # verify dataset split + target_dataset_structure = format_dataset_structure( + dataset_structure[target_subset_name]) + if split and split not in target_dataset_structure: + raise ValueError( + f'split {split} not found. Available: {target_dataset_structure.keys()}' + ) + if split: + target_dataset_structure = {split: target_dataset_structure[split]} + return target_subset_name, target_dataset_structure + + +def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool, + dataset_name: str, namespace: str, + version: str) -> list: + """ + List all objects for specific dataset. + + Args: + hub_api (class HubApi): HubApi instance. + max_limit (int): Max number of objects. + is_recursive (bool): Whether to list objects recursively. + dataset_name (str): Dataset name. + namespace (str): Namespace. + version (str): Dataset version. + Returns: + res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...] + """ + res = [] + objects = hub_api.list_oss_dataset_objects( + dataset_name=dataset_name, + namespace=namespace, + max_limit=max_limit, + is_recursive=is_recursive, + is_filter_dir=True, + revision=version) + + for item in objects: + object_key = item.get('Key') + res.append(object_key) + + return res + + +def contains_dir(file_map) -> bool: + """ + To check whether input contains at least one directory. + + Args: + file_map (dict): Structure of data files. e.g., {'train': 'train.zip', 'validation': 'val.zip'} + Returns: + True if input contains at least one directory, False otherwise. + """ + res = False + for k, v in file_map.items(): + if isinstance(v, str) and not v.endswith('.zip'): + res = True + break + return res + + +def get_split_objects_map(file_map, objects): + """ + Get the map between dataset split and oss objects. + + Args: + file_map (dict): Structure of data files. e.g., {'train': 'train', 'validation': 'val'}, both of train and val + are dirs. + objects (list): List of oss objects. e.g., ['train/001/1_123.png', 'train/001/1_124.png', 'val/003/3_38.png'] + Returns: + A map of split-objects. e.g., {'train': ['train/001/1_123.png', 'train/001/1_124.png'], + 'validation':['val/003/3_38.png']} + """ + res = {} + for k, v in file_map.items(): + res[k] = [] + + for obj_key in objects: + for k, v in file_map.items(): + if obj_key.startswith(v): + res[k].append(obj_key) + + return res + + +def get_dataset_files(subset_split_into: dict, + dataset_name: str, + namespace: str, + revision: Optional[str] = DEFAULT_DATASET_REVISION): + """ + Return: + meta_map: Structure of meta files (.csv), the meta file name will be replaced by url, like + { + "test": "https://xxx/mytest.csv" + } + file_map: Structure of data files (.zip), like + { + "test": "pictures.zip" + } + """ + meta_map = defaultdict(dict) + file_map = defaultdict(dict) + args_map = defaultdict(dict) + modelscope_api = HubApi() + objects = list_dataset_objects( + hub_api=modelscope_api, + max_limit=-1, + is_recursive=True, + dataset_name=dataset_name, + namespace=namespace, + version=revision) + + for split, info in subset_split_into.items(): + meta_map[split] = modelscope_api.get_dataset_file_url( + info.get('meta', ''), dataset_name, namespace, revision) + if info.get('file'): + file_map[split] = info['file'] + args_map[split] = info.get('args') + + if contains_dir(file_map): + file_map = get_split_objects_map(file_map, objects) + return meta_map, file_map, args_map + + +def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, + meta_data_files: Mapping[str, Union[str, + Sequence[str]]], + zip_data_files: Mapping[str, Union[str, + Sequence[str]]], + args_map: Mapping[str, Any], cache_dir: str, + version: Optional[Union[str]], split: Sequence[str], + **config_kwargs) -> DatasetBuilder: + sub_dir = os.path.join(version, '_'.join(split)) + meta_data_file = next(iter(meta_data_files.values())) + if not meta_data_file: + args_map = next(iter(args_map.values())) + if args_map is None: + args_map = {} + args_map.update(config_kwargs) + builder_instance = TaskSpecificDatasetBuilder( + dataset_name=dataset_name, + namespace=namespace, + cache_dir=cache_dir, + subset_name=subset_name, + meta_data_files=meta_data_files, + zip_data_files=zip_data_files, + hash=sub_dir, + **args_map) + elif meta_data_file.endswith('.csv'): + builder_instance = MsCsvDatasetBuilder( + dataset_name=dataset_name, + namespace=namespace, + cache_dir=cache_dir, + subset_name=subset_name, + meta_data_files=meta_data_files, + zip_data_files=zip_data_files, + hash=sub_dir) + else: + raise NotImplementedError( + f'Dataset mete file extensions "{os.path.splitext(meta_data_file)[-1]}" is not implemented yet' + ) + + return builder_instance diff --git a/modelscope/msdatasets/utils/delete_utils.py b/modelscope/msdatasets/utils/delete_utils.py new file mode 100644 index 00000000..a5a6f53f --- /dev/null +++ b/modelscope/msdatasets/utils/delete_utils.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.hub.api import HubApi + + +class DatasetDeleteManager(object): + + def __init__(self, dataset_name: str, namespace: str, version: str): + self.api = HubApi() + self.dataset_name = dataset_name + self.namespace = namespace + self.version = version + + def delete(self, object_name: str) -> str: + + # single object + if not object_name.endswith('/'): + resp_msg = self.api.delete_oss_dataset_object( + object_name=object_name, + dataset_name=self.dataset_name, + namespace=self.namespace, + revision=self.version) + else: + # multiple objects + object_name = object_name.strip('/') + resp_msg = self.api.delete_oss_dataset_dir( + object_name=object_name, + dataset_name=self.dataset_name, + namespace=self.namespace, + revision=self.version) + + return resp_msg diff --git a/modelscope/msdatasets/utils/download_utils.py b/modelscope/msdatasets/utils/download_utils.py new file mode 100644 index 00000000..ebe9b8f5 --- /dev/null +++ b/modelscope/msdatasets/utils/download_utils.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Optional + +from datasets.utils.download_manager import DownloadConfig, DownloadManager +from datasets.utils.file_utils import cached_path, is_relative_path + +from .oss_utils import OssUtilities + + +class DatasetDownloadManager(DownloadManager): + + def __init__(self, + dataset_name: str, + namespace: str, + version: str, + data_dir: Optional[str] = None, + download_config: Optional[DownloadConfig] = None, + base_path: Optional[str] = None, + record_checksums=True): + super().__init__(dataset_name, data_dir, download_config, base_path, + record_checksums) + self._namespace = namespace + self._version = version + from modelscope.hub.api import HubApi + api = HubApi() + oss_config = api.get_dataset_access_config(self._dataset_name, + self._namespace, + self._version) + self.oss_utilities = OssUtilities( + oss_config=oss_config, + dataset_name=self._dataset_name, + namespace=self._namespace, + revision=self._version) + + def _download(self, url_or_filename: str, + download_config: DownloadConfig) -> str: + url_or_filename = str(url_or_filename) + if is_relative_path(url_or_filename): + # fetch oss files + return self.oss_utilities.download( + url_or_filename, download_config=download_config) + else: + return cached_path( + url_or_filename, download_config=download_config) diff --git a/modelscope/msdatasets/utils/oss_utils.py b/modelscope/msdatasets/utils/oss_utils.py new file mode 100644 index 00000000..e27ff8c4 --- /dev/null +++ b/modelscope/msdatasets/utils/oss_utils.py @@ -0,0 +1,122 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from __future__ import print_function +import os + +import oss2 +from datasets.utils.file_utils import hash_url_to_filename + +from modelscope.hub.api import HubApi +from modelscope.utils.constant import UploadMode +from modelscope.utils.logger import get_logger + +logger = get_logger() + +ACCESS_ID = 'AccessId' +ACCESS_SECRET = 'AccessSecret' +SECURITY_TOKEN = 'SecurityToken' +BUCKET = 'Bucket' +BACK_DIR = 'BackupDir' +DIR = 'Dir' + + +class OssUtilities: + + def __init__(self, oss_config, dataset_name, namespace, revision): + self._do_init(oss_config=oss_config) + + self.dataset_name = dataset_name + self.namespace = namespace + self.revision = revision + + self.upload_resumable_tmp_store = '/tmp/modelscope/tmp_dataset' + self.upload_multipart_threshold = 50 * 1024 * 1024 + self.upload_part_size = 1 * 1024 * 1024 + self.upload_num_threads = 4 + self.upload_max_retries = 3 + + self.api = HubApi() + + def _do_init(self, oss_config): + self.key = oss_config[ACCESS_ID] + self.secret = oss_config[ACCESS_SECRET] + self.token = oss_config[SECURITY_TOKEN] + self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" + self.bucket_name = oss_config[BUCKET] + auth = oss2.StsAuth(self.key, self.secret, self.token) + self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) + self.oss_dir = oss_config[DIR] + self.oss_backup_dir = oss_config[BACK_DIR] + + def _reload_sts(self): + cookies = self.api.check_local_cookies(use_cookies=True) + oss_config_refresh = self.api.get_dataset_access_config_session( + cookies=cookies, + dataset_name=self.dataset_name, + namespace=self.namespace, + revision=self.revision) + self._do_init(oss_config_refresh) + + @staticmethod + def _percentage(consumed_bytes, total_bytes): + if total_bytes: + rate = int(100 * (float(consumed_bytes) / float(total_bytes))) + print('\r{0}% '.format(rate), end='', flush=True) + + def download(self, oss_file_name, download_config): + cache_dir = download_config.cache_dir + candidate_key = os.path.join(self.oss_dir, oss_file_name) + candidate_key_backup = os.path.join(self.oss_backup_dir, oss_file_name) + file_oss_key = candidate_key if self.bucket.object_exists( + candidate_key) else candidate_key_backup + filename = hash_url_to_filename(file_oss_key, etag=None) + local_path = os.path.join(cache_dir, filename) + + if download_config.force_download or not os.path.exists(local_path): + oss2.resumable_download( + self.bucket, + file_oss_key, + local_path, + multiget_threshold=0, + progress_callback=self._percentage) + return local_path + + def upload(self, oss_object_name: str, local_file_path: str, + indicate_individual_progress: bool, + upload_mode: UploadMode) -> str: + retry_count = 0 + object_key = os.path.join(self.oss_dir, oss_object_name) + resumable_store = oss2.ResumableStore( + root=self.upload_resumable_tmp_store) + if indicate_individual_progress: + progress_callback = self._percentage + else: + progress_callback = None + + while True: + try: + retry_count += 1 + exist = self.bucket.object_exists(object_key) + if upload_mode == UploadMode.APPEND and exist: + logger.info( + f'Skip {oss_object_name} in case of {upload_mode.value} mode.' + ) + break + + oss2.resumable_upload( + self.bucket, + object_key, + local_file_path, + store=resumable_store, + multipart_threshold=self.upload_multipart_threshold, + part_size=self.upload_part_size, + progress_callback=progress_callback, + num_threads=self.upload_num_threads) + break + except Exception as e: + if e.__getattribute__('status') == 403: + self._reload_sts() + if retry_count >= self.upload_max_retries: + raise + + return object_key diff --git a/modelscope/msdatasets/utils/upload_utils.py b/modelscope/msdatasets/utils/upload_utils.py new file mode 100644 index 00000000..bbdcd9e9 --- /dev/null +++ b/modelscope/msdatasets/utils/upload_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from multiprocessing.dummy import Pool as ThreadPool + +from tqdm import tqdm + +from modelscope.utils.constant import UploadMode +from .oss_utils import OssUtilities + + +class DatasetUploadManager(object): + + def __init__(self, dataset_name: str, namespace: str, version: str): + from modelscope.hub.api import HubApi + _hub_api = HubApi() + _cookies = _hub_api.check_local_cookies(use_cookies=True) + _oss_config = _hub_api.get_dataset_access_config_session( + cookies=_cookies, + dataset_name=dataset_name, + namespace=namespace, + revision=version) + + self.oss_utilities = OssUtilities( + oss_config=_oss_config, + dataset_name=dataset_name, + namespace=namespace, + revision=version) + + def upload(self, object_name: str, local_file_path: str, + upload_mode: UploadMode) -> str: + object_key = self.oss_utilities.upload( + oss_object_name=object_name, + local_file_path=local_file_path, + indicate_individual_progress=True, + upload_mode=upload_mode) + return object_key + + def upload_dir(self, object_dir_name: str, local_dir_path: str, + num_processes: int, chunksize: int, + filter_hidden_files: bool, upload_mode: UploadMode) -> int: + + def run_upload(args): + self.oss_utilities.upload( + oss_object_name=args[0], + local_file_path=args[1], + indicate_individual_progress=False, + upload_mode=upload_mode) + + files_list = [] + for root, dirs, files in os.walk(local_dir_path): + for file_name in files: + if filter_hidden_files and file_name.startswith('.'): + continue + # Concatenate directory name and relative path into oss object key. e.g., train/001/1_1230.png + object_name = os.path.join( + object_dir_name, + root.replace(local_dir_path, '', 1).strip('/'), file_name) + + local_file_path = os.path.join(root, file_name) + files_list.append((object_name, local_file_path)) + + with ThreadPool(processes=num_processes) as pool: + result = list( + tqdm( + pool.imap(run_upload, files_list, chunksize=chunksize), + total=len(files_list))) + + return len(result) diff --git a/modelscope/outputs/__init__.py b/modelscope/outputs/__init__.py new file mode 100644 index 00000000..47e66714 --- /dev/null +++ b/modelscope/outputs/__init__.py @@ -0,0 +1,2 @@ +from .nlp.model_outputs import * # noqa +from .outputs import TASK_OUTPUTS, ModelOutputBase, OutputKeys diff --git a/modelscope/outputs/nlp/__init__.py b/modelscope/outputs/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/outputs/nlp/model_outputs.py b/modelscope/outputs/nlp/model_outputs.py new file mode 100644 index 00000000..46267007 --- /dev/null +++ b/modelscope/outputs/nlp/model_outputs.py @@ -0,0 +1,590 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +from modelscope.outputs.outputs import ModelOutputBase + +Tensor = Union['torch.Tensor', 'tf.Tensor'] + + +@dataclass +class TextClassificationModelOutput(ModelOutputBase): + """The output class for text classification models. + + Args: + logits (`Tensor`): The logits output of the model. loss (`Tensor`, + *optional*) The loss of the model, available when training. + hidden_states (`Tensor`, *optional*) Hidden-states of the model at the + output of each layer plus the optional initial embedding outputs. + """ + + logits: Tensor = None + loss: Tensor = None + + +@dataclass +class TokenClassificationModelOutput(ModelOutputBase): + """The output class for token classification models. + logits (`Tensor`): The logits output of the model. + loss (`Tensor`, *optional*) The loss of the model, available when training. + """ + + logits: Tensor = None + loss: Tensor = None + offset_mapping: Tensor = None + + +@dataclass +class FillMaskModelOutput(ModelOutputBase): + """The output class for text classification models. + + Args: + logits (`Tensor`): The logits output of the model. + loss (`Tensor`, *optional*) The loss of the model, available when training. + input_ids (`Tensor`, *optional*) The input id tensor fed into the model. + hidden_states (`Tensor`, *optional*) Hidden-states of the model at the + output of each layer plus the optional initial embedding outputs. + """ + + logits: Tensor = None + loss: Tensor = None + input_ids: Tensor = None + hidden_states: Tensor = None + + +@dataclass +class TokenClassifierOutput(ModelOutputBase): + """ + Base class for outputs of token classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, + config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + + """ + + loss: Tensor = None + logits: Tensor = None + hidden_states: Tensor = None + attentions: Tensor = None + offset_mapping: Tensor = None + + +@dataclass +class TokenClassifierWithPredictionsOutput(ModelOutputBase): + """ + Base class for outputs of token classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, + config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, + sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the sentence. + Selected in the range ``[0, sequence_length - 1]``. + predictions: A PyTorch tensor of the best tag sequence for each batch of shape + (nbest, batch_size, seq_length) + + """ + + loss: Tensor = None + logits: Tensor = None + hidden_states: Tensor = None + attentions: Tensor = None + offset_mapping: Tensor = None + predictions: Tensor = None + + +@dataclass +class BaseModelOutput(ModelOutputBase): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + last_hidden_state: Tensor = None + hidden_states: Optional[Tuple[Tensor]] = None + attentions: Optional[Tuple[Tensor]] = None + + +@dataclass +class BackboneModelOutput(ModelOutputBase): + """The output class for text classification models. + + Args: + last_hidden_state (`Tensor`, *optional*): Sequence of hidden-states at + the output of the last layer of the model. + pooler_output (`Tensor`, *optional*) The tensor of the pooled hidden state. + hidden_states (`Tensor`, *optional*) Hidden-states of the model at + the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Tensor = None + pooler_output: Tensor = None + hidden_states: Tensor = None + + +@dataclass +class AttentionBackboneModelOutput(BackboneModelOutput): + """The output class for backbones of attention based models. + + Args: + attentions (`tuple(Tensor)`, *optional* Attentions weights after the + attention softmax, used to compute the weighted average in the + self-attention heads. + """ + attentions: Tensor = None + past_key_values: Tensor = None + cross_attentions: Tensor = None + + +@dataclass +class AttentionTextClassificationModelOutput(TextClassificationModelOutput): + """The output class for backbones of attention based models. + + Args: + attentions (`tuple(Tensor)`, *optional* Attentions weights after the + attention softmax, used to compute the weighted average in the + self-attention heads. + """ + attentions: Tensor = None + hidden_states: Tensor = None + + +@dataclass +class AttentionTokenClassificationModelOutput(TokenClassificationModelOutput): + """The output class for backbones of attention based models. + + Args: + attentions (`tuple(Tensor)`, *optional* Attentions weights after the attention softmax, + used to compute the weighted average in the self-attention heads. + """ + attentions: Tensor = None + hidden_states: Tensor = None + + +@dataclass +class AttentionFillMaskModelOutput(FillMaskModelOutput): + """The output class for the fill mask and attention based models. + + Args: + attentions (`tuple(Tensor)`, *optional* Attentions weights after the + attention softmax, used to compute the weighted average in the + self-attention heads. + """ + attentions: Tensor = None + + +@dataclass +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutputBase): + """ + Base class for model's outputs that also contains a pooling of the last + hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, + hidden_size)`): + Last layer hidden-state of the first token of the sequence + (classification token) after further processing through the layers + used for the auxiliary pretraining task. E.g. for BERT-family of + models, this returns the classification token after processing + through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction + (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` and `config.add_cross_attention=True` is passed + or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the + attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, + embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the + self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that + can be used (see `past_key_values` input) to speed up sequential + decoding. + """ + + last_hidden_state: Tensor = None + pooler_output: Tensor = None + hidden_states: Tensor = None + past_key_values: Tensor = None + attentions: Tensor = None + cross_attentions: Tensor = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutputBase): + """ + Base class for model's outputs that may also contain a past key/values (to + speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + model. + + If `past_key_values` is used only the last hidden-state of the + sequences of shape `(batch_size, 1, hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, + embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the + self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that + can be used (see `past_key_values` input) to speed up sequential + decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the + optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` and `config.add_cross_attention=True` is passed + or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the + attention softmax, used to compute the weighted average in the + cross-attention heads. + """ + + last_hidden_state: Tensor = None + past_key_values: Tensor = None + hidden_states: Tensor = None + attentions: Tensor = None + cross_attentions: Tensor = None + + +@dataclass +class Seq2SeqModelOutput(ModelOutputBase): + """ + Base class for model encoder's outputs that also contains : pre-computed + hidden states that can speed up sequential decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the + decoder of the model. + + If `past_key_values` is used only the last hidden-state of the + sequences of shape `(batch_size, 1, hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, + embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the + self-attention blocks and in the cross-attention blocks) that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the + optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used + to compute the weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the + attention softmax, used to compute the weighted average in the + cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the + encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the + optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used + to compute the weighted average in the self-attention heads. + """ + + last_hidden_state: Tensor = None + past_key_values: Optional[Tuple[Tuple[Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tensor]] = None + decoder_attentions: Optional[Tuple[Tensor]] = None + cross_attentions: Optional[Tuple[Tensor]] = None + encoder_last_hidden_state: Optional[Tensor] = None + encoder_hidden_states: Optional[Tuple[Tensor]] = None + encoder_attentions: Optional[Tuple[Tensor]] = None + + +@dataclass +class Seq2SeqLMOutput(ModelOutputBase): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, + config.vocab_size)`): + Prediction scores of the language modeling head (scores for each + vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, + embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the + self-attention blocks and in the cross-attention blocks) that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the + initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used + to compute the weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when + `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the + attention softmax, used to compute the weighted average in the + cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, + sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the + encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, + if the model has an embedding layer, + one for the output of each + layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the + initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned + when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used + to compute the weighted average in the self-attention heads. + """ + + loss: Optional[Tensor] = None + logits: Tensor = None + past_key_values: Optional[Tuple[Tuple[Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tensor]] = None + decoder_attentions: Optional[Tuple[Tensor]] = None + cross_attentions: Optional[Tuple[Tensor]] = None + encoder_last_hidden_state: Optional[Tensor] = None + encoder_hidden_states: Optional[Tuple[Tensor]] = None + encoder_attentions: Optional[Tuple[Tensor]] = None + + +@dataclass +class TextGenerationModelOutput(ModelOutputBase): + """The output class for text generation models. + + Args: + logits (`Tensor`): The logits output of the model. loss (`Tensor`, + *optional*) The loss of the model, available when training. + hidden_states (`Tensor`, *optional*) Hidden-states of the model at the + output of each layer plus the optional initial embedding outputs. + """ + + logits: Tensor = None + loss: Tensor = None + + +@dataclass +class TokenGeneratorOutput(ModelOutputBase): + """ + The output class for generate method of text generation models. + + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` + is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` + is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, + sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` + is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. + """ + + sequences: Tensor = None + scores: Optional[Tuple[Tensor]] = None + attentions: Optional[Tuple[Tuple[Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[Tensor]]] = None diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py new file mode 100644 index 00000000..2c6dd85a --- /dev/null +++ b/modelscope/outputs/outputs.py @@ -0,0 +1,851 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections import OrderedDict, namedtuple +from dataclasses import dataclass, fields + +from modelscope.utils.constant import Tasks + + +class OutputKeys(object): + LOSS = 'loss' + LOGITS = 'logits' + SCORES = 'scores' + SCORE = 'score' + LABEL = 'label' + LABELS = 'labels' + INPUT_IDS = 'input_ids' + LABEL_POS = 'label_pos' + POSES = 'poses' + CAPTION = 'caption' + BOXES = 'boxes' + KEYPOINTS = 'keypoints' + MASKS = 'masks' + TEXT = 'text' + POLYGONS = 'polygons' + OUTPUT = 'output' + OUTPUT_IMG = 'output_img' + OUTPUT_VIDEO = 'output_video' + OUTPUT_PCM = 'output_pcm' + IMG_EMBEDDING = 'img_embedding' + SPO_LIST = 'spo_list' + TEXT_EMBEDDING = 'text_embedding' + TRANSLATION = 'translation' + RESPONSE = 'response' + PREDICTION = 'prediction' + PREDICTIONS = 'predictions' + PROBABILITIES = 'probabilities' + DIALOG_STATES = 'dialog_states' + VIDEO_EMBEDDING = 'video_embedding' + UUID = 'uuid' + WORD = 'word' + KWS_LIST = 'kws_list' + SQL_STRING = 'sql_string' + SQL_QUERY = 'sql_query' + HISTORY = 'history' + QUERT_RESULT = 'query_result' + TIMESTAMPS = 'timestamps' + SHOT_NUM = 'shot_num' + SCENE_NUM = 'scene_num' + SCENE_META_LIST = 'scene_meta_list' + SHOT_META_LIST = 'shot_meta_list' + + +TASK_OUTPUTS = { + + # ============ vision tasks =================== + + # ocr detection result for single sample + # { + # "polygons": np.array with shape [num_text, 8], each polygon is + # [x1, y1, x2, y2, x3, y3, x4, y4] + # } + Tasks.ocr_detection: [OutputKeys.POLYGONS], + + # ocr recognition result for single sample + # { + # "text": "电子元器件提供BOM配单" + # } + Tasks.ocr_recognition: [OutputKeys.TEXT], + + # face 2d keypoint result for single sample + # { + # "keypoints": [ + # [[x, y]*106], + # [[x, y]*106], + # [[x, y]*106], + # ], + # "poses": [ + # [pitch, roll, yaw], + # [pitch, roll, yaw], + # [pitch, roll, yaw], + # ], + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] + # } + Tasks.face_2d_keypoints: + [OutputKeys.KEYPOINTS, OutputKeys.POSES, OutputKeys.BOXES], + + # face detection result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "keypoints": [ + # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], + # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], + # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], + # [x1, y1, x2, y2, x3, y3, x4, y4, x5, y5], + # ], + # } + Tasks.face_detection: + [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], + + # card detection result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "keypoints": [ + # [x1, y1, x2, y2, x3, y3, x4, y4], + # [x1, y1, x2, y2, x3, y3, x4, y4], + # [x1, y1, x2, y2, x3, y3, x4, y4], + # [x1, y1, x2, y2, x3, y3, x4, y4], + # ], + # } + Tasks.card_detection: + [OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS], + + # facial expression recognition result for single sample + # { + # "scores": [0.9, 0.1, 0.02, 0.02, 0.02, 0.02, 0.02], + # "labels": ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'] + # } + Tasks.facial_expression_recognition: + [OutputKeys.SCORES, OutputKeys.LABELS], + + # face recognition result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # } + Tasks.face_recognition: [OutputKeys.IMG_EMBEDDING], + + # human detection result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "labels": ["person", "person", "person", "person"], + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # } + # + Tasks.human_detection: + [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], + + # face generation result for single sample + # { + # "output_img": np.array with shape(h, w, 3) + # } + Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG], + + # image classification result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "labels": ["dog", "horse", "cow", "cat"], + # } + Tasks.image_classification: [OutputKeys.SCORES, OutputKeys.LABELS], + + # object detection result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "labels": ["dog", "horse", "cow", "cat"], + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # } + Tasks.image_object_detection: + [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], + + # video object detection result for single sample + # { + + # "scores": [[0.8, 0.25, 0.05, 0.05], [0.9, 0.1, 0.05, 0.05]] + # "labels": [["person", "traffic light", "car", "bus"], + # ["person", "traffic light", "car", "bus"]] + # "boxes": + # [ + # [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] + # ], + + # } + Tasks.video_object_detection: + [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], + + # instance segmentation result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05], + # "labels": ["dog", "horse", "cow", "cat"], + # "masks": [ + # np.array # 2D array containing only 0, 1 + # ] + # } + Tasks.image_segmentation: + [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS], + + # semantic segmentation result for single sample + # { + # "masks": [np.array # 2D array with shape [height, width]] + # } + Tasks.semantic_segmentation: [OutputKeys.MASKS], + + # image matting result for single sample + # { + # "output_img": np.array with shape(h, w, 4) + # for matting or (h, w, 3) for general purpose + # , shape(h, w) for crowd counting + # } + Tasks.portrait_matting: [OutputKeys.OUTPUT_IMG], + + # image editing task result for a single image + # {"output_img": np.array with shape (h, w, 3)} + Tasks.skin_retouching: [OutputKeys.OUTPUT_IMG], + Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG], + Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], + Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG], + Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], + Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], + Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG], + Tasks.image_inpainting: [OutputKeys.OUTPUT_IMG], + + # image generation task result for a single image + # {"output_img": np.array with shape (h, w, 3)} + Tasks.image_to_image_generation: [OutputKeys.OUTPUT_IMG], + Tasks.image_to_image_translation: [OutputKeys.OUTPUT_IMG], + Tasks.image_style_transfer: [OutputKeys.OUTPUT_IMG], + Tasks.image_portrait_stylization: [OutputKeys.OUTPUT_IMG], + Tasks.image_body_reshaping: [OutputKeys.OUTPUT_IMG], + + # live category recognition result for single video + # { + # "scores": [0.885272, 0.014790631, 0.014558001] + # "labels": ['女装/女士精品>>棉衣/棉服', '女装/女士精品>>牛仔裤', '女装/女士精品>>裤子>>休闲裤'], + # } + Tasks.live_category: [OutputKeys.SCORES, OutputKeys.LABELS], + + # action recognition result for single video + # { + # "output_label": "abseiling" + # } + Tasks.action_recognition: [OutputKeys.LABELS], + + # human body keypoints detection result for single sample + # { + # "keypoints": [ + # [[x, y]*15], + # [[x, y]*15], + # [[x, y]*15] + # ] + # "scores": [ + # [[score]*15], + # [[score]*15], + # [[score]*15] + # ] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] + # } + Tasks.body_2d_keypoints: + [OutputKeys.KEYPOINTS, OutputKeys.SCORES, OutputKeys.BOXES], + + # 3D human body keypoints detection result for single sample + # { + # "keypoints": [ # 3d pose coordinate in camera coordinate + # [[x, y, z]*17], # joints of per image + # [[x, y, z]*17], + # ... + # ], + # "timestamps": [ # timestamps of all frames + # "00:00:0.230", + # "00:00:0.560", + # "00:00:0.690", + # ], + # "output_video": "path_to_rendered_video" , this is optional + # and is only avaialbe when the "render" option is enabled. + # } + Tasks.body_3d_keypoints: + [OutputKeys.KEYPOINTS, OutputKeys.TIMESTAMPS, OutputKeys.OUTPUT_VIDEO], + + # 2D hand keypoints result for single sample + # { + # "keypoints": [ + # [[x, y, score] * 21], + # [[x, y, score] * 21], + # [[x, y, score] * 21], + # ], + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] + # } + Tasks.hand_2d_keypoints: [OutputKeys.KEYPOINTS, OutputKeys.BOXES], + + # video single object tracking result for single video + # { + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "timestamps": ["hh:mm:ss", "hh:mm:ss", "hh:mm:ss"] + # } + Tasks.video_single_object_tracking: + [OutputKeys.BOXES, OutputKeys.TIMESTAMPS], + + # live category recognition result for single video + # { + # "scores": [0.885272, 0.014790631, 0.014558001], + # 'labels': ['修身型棉衣', '高腰牛仔裤', '休闲连体裤'] + # } + Tasks.live_category: [OutputKeys.SCORES, OutputKeys.LABELS], + + # video category recognition result for single video + # { + # "scores": [0.7716429233551025], + # "labels": ['生活>>好物推荐'] + # } + Tasks.video_category: [OutputKeys.SCORES, OutputKeys.LABELS], + + # image embedding result for a single image + # { + # "image_bedding": np.array with shape [D] + # } + Tasks.product_retrieval_embedding: [OutputKeys.IMG_EMBEDDING], + + # video embedding result for single video + # { + # "video_embedding": np.array with shape [D], + # } + Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING], + + # virtual_try_on result for a single sample + # { + # "output_img": np.ndarray with shape [height, width, 3] + # } + Tasks.virtual_try_on: [OutputKeys.OUTPUT_IMG], + # text driven segmentation result for single sample + # { + # "masks": [ + # np.array # 2D array containing only 0, 255 + # ] + # } + Tasks.text_driven_segmentation: [OutputKeys.MASKS], + # shop segmentation result for single sample + # { + # "masks": [ + # np.array # 2D array containing only 0, 255 + # ] + # } + Tasks.shop_segmentation: [OutputKeys.MASKS], + # movide scene segmentation result for a single video + # { + # "shot_num":15, + # "shot_meta_list": + # [ + # { + # "frame": [start_frame, end_frame], + # "timestamps": [start_timestamp, end_timestamp] # ['00:00:01.133', '00:00:02.245'] + # + # } + # ] + # "scene_num":3, + # "scene_meta_list": + # [ + # { + # "shot": [0,1,2], + # "frame": [start_frame, end_frame], + # "timestamps": [start_timestamp, end_timestamp] # ['00:00:01.133', '00:00:02.245'] + # } + # ] + # + # } + Tasks.movie_scene_segmentation: [ + OutputKeys.SHOT_NUM, OutputKeys.SHOT_META_LIST, OutputKeys.SCENE_NUM, + OutputKeys.SCENE_META_LIST + ], + + # human whole body keypoints detection result for single sample + # { + # "keypoints": [ + # [[x, y]*133], + # [[x, y]*133], + # [[x, y]*133] + # ] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] + # } + Tasks.human_wholebody_keypoint: [OutputKeys.KEYPOINTS, OutputKeys.BOXES], + + # video summarization result for a single video + # { + # "output": + # [ + # { + # "frame": [start_frame, end_frame] + # "timestamps": [start_time, end_time] + # }, + # { + # "frame": [start_frame, end_frame] + # "timestamps": [start_time, end_time] + # } + # ] + # } + Tasks.video_summarization: [OutputKeys.OUTPUT], + + # referring video object segmentation result for a single video + # { + # "masks": [np.array # 2D array with shape [height, width]] + # } + Tasks.referring_video_object_segmentation: [OutputKeys.MASKS], + + # ============ nlp tasks =================== + + # text classification result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "labels": ["happy", "sad", "calm", "angry"], + # } + Tasks.text_classification: [OutputKeys.SCORES, OutputKeys.LABELS], + + # sentence similarity result for single sample + # { + # "scores": 0.9 + # "labels": "1", + # } + Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS], + + # nli result for single sample + # { + # "labels": ["happy", "sad", "calm", "angry"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS], + + # sentiment classification result for single sample + # { + # 'scores': [0.07183828949928284, 0.9281617403030396], + # 'labels': ['1', '0'] + # } + Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS], + + # zero-shot classification result for single sample + # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "labels": ["happy", "sad", "calm", "angry"], + # } + Tasks.zero_shot_classification: [OutputKeys.SCORES, OutputKeys.LABELS], + + # relation extraction result for a single sample + # { + # "uuid": "人生信息-1", + # "text": "《父老乡亲》是由是由由中国人民解放军海政文工团创作的军旅歌曲,石顺义作词,王锡仁作曲,范琳琳演唱", + # "spo_list": [{"subject": "石顺义", "predicate": "国籍", "object": "中国"}] + # } + Tasks.relation_extraction: [OutputKeys.SPO_LIST], + + # translation result for a source sentence + # { + # "translation": “北京是中国的首都” + # } + Tasks.translation: [OutputKeys.TRANSLATION], + + # word segmentation result for single sample + # { + # "output": "今天 天气 不错 , 适合 出去 游玩" + # } + Tasks.word_segmentation: [OutputKeys.OUTPUT], + + # TODO @wenmeng.zwm support list of result check + # named entity recognition result for single sample + # { + # "output": [ + # {"type": "LOC", "start": 2, "end": 5, "span": "温岭市"}, + # {"type": "LOC", "start": 5, "end": 8, "span": "新河镇"} + # ] + # } + Tasks.named_entity_recognition: [OutputKeys.OUTPUT], + Tasks.part_of_speech: [OutputKeys.OUTPUT], + + # text_error_correction result for a single sample + # { + # "output": "我想吃苹果" + # } + Tasks.text_error_correction: [OutputKeys.OUTPUT], + Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], + Tasks.text_ranking: [OutputKeys.SCORES], + + # text generation result for single sample + # { + # "text": "this is the text generated by a model." + # } + Tasks.text_generation: [OutputKeys.TEXT], + + # summarization result for single sample + # { + # "text": "this is the text generated by a model." + # } + Tasks.text_summarization: [OutputKeys.TEXT], + + # text generation result for single sample + # { + # "text": "北京" + # } + Tasks.text2text_generation: [OutputKeys.TEXT], + + # fill mask result for single sample + # { + # "text": "this is the text which masks filled by model." + # } + Tasks.fill_mask: [OutputKeys.TEXT], + + # feature extraction result for single sample + # { + # "text_embedding": [[ + # [1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04], + # [6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01], + # [2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05] + # ], + # [ + # [2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05], + # [8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05], + # [3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05] + # ] + # ] + # } + Tasks.feature_extraction: [OutputKeys.TEXT_EMBEDDING], + + # (Deprecated) dialog intent prediction result for single sample + # {'output': {'prediction': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, + # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, + # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, + # 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05, + # 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05, + # 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05, + # 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05, + # 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05, + # 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03, + # 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05, + # 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05, + # 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05, + # 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05, + # 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05, + # 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05, + # 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05, + # 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05, + # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, + # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, + # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'}} + + # (Deprecated) dialog modeling prediction result for single sample + # {'output' : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!']} + + # (Deprecated) dialog state tracking result for single sample + # { + # "output":{ + # "dialog_states": { + # "taxi-leaveAt": "none", + # "taxi-destination": "none", + # "taxi-departure": "none", + # "taxi-arriveBy": "none", + # "restaurant-book_people": "none", + # "restaurant-book_day": "none", + # "restaurant-book_time": "none", + # "restaurant-food": "none", + # "restaurant-pricerange": "none", + # "restaurant-name": "none", + # "restaurant-area": "none", + # "hotel-book_people": "none", + # "hotel-book_day": "none", + # "hotel-book_stay": "none", + # "hotel-name": "none", + # "hotel-area": "none", + # "hotel-parking": "none", + # "hotel-pricerange": "cheap", + # "hotel-stars": "none", + # "hotel-internet": "none", + # "hotel-type": "true", + # "attraction-type": "none", + # "attraction-name": "none", + # "attraction-area": "none", + # "train-book_people": "none", + # "train-leaveAt": "none", + # "train-destination": "none", + # "train-day": "none", + # "train-arriveBy": "none", + # "train-departure": "none" + # } + # } + # } + Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], + + # table-question-answering result for single sample + # { + # "sql": "SELECT shop.Name FROM shop." + # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} + # } + Tasks.table_question_answering: [OutputKeys.OUTPUT], + + # ============ audio tasks =================== + # asr result for single sample + # { "text": "每一天都要快乐喔"} + Tasks.auto_speech_recognition: [OutputKeys.TEXT], + + # audio processed for single file in PCM format + # { + # "output_pcm": pcm encoded audio bytes + # } + Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM], + Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM], + Tasks.acoustic_noise_suppression: [OutputKeys.OUTPUT_PCM], + + # text_to_speech result for a single sample + # { + # "output_pcm": {"input_label" : np.ndarray with shape [D]} + # } + Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM], + + # { + # "kws_list": [ + # { + # 'keyword': '', # the keyword spotted + # 'offset': 19.4, # the keyword start time in second + # 'length': 0.68, # the keyword length in second + # 'confidence': 0.85 # the possibility if it is the keyword + # }, + # ... + # ] + # } + Tasks.keyword_spotting: [OutputKeys.KWS_LIST], + + # ============ multi-modal tasks =================== + + # image caption result for single sample + # { + # "caption": "this is an image caption text." + # } + Tasks.image_captioning: [OutputKeys.CAPTION], + Tasks.ocr_recognition: [OutputKeys.TEXT], + + # visual grounding result for single sample + # { + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.visual_grounding: [OutputKeys.BOXES, OutputKeys.SCORES], + + # text_to_image result for a single sample + # { + # "output_img": np.ndarray with shape [height, width, 3] + # } + Tasks.text_to_image_synthesis: [OutputKeys.OUTPUT_IMG], + + # text_to_speech result for a single sample + # { + # "output_pcm": {"input_label" : np.ndarray with shape [D]} + # } + Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM], + + # multi-modal embedding result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # "text_embedding": np.array with shape [1, D] + # } + Tasks.multi_modal_embedding: + [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING], + + # generative multi-modal embedding result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # "text_embedding": np.array with shape [1, D], + # "caption": "this is an image caption text." + # } + Tasks.generative_multi_modal_embedding: [ + OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.CAPTION + ], + + # multi-modal similarity result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # "text_embedding": np.array with shape [1, D], + # "similarity": float + # } + Tasks.multi_modal_similarity: [ + OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES + ], + + # VQA result for a sample + # {"text": "this is a text answser. "} + Tasks.visual_question_answering: [OutputKeys.TEXT], + + # auto_speech_recognition result for a single sample + # { + # "text": "每天都要快乐喔" + # } + Tasks.auto_speech_recognition: [OutputKeys.TEXT], + + # { + # "scores": [0.9, 0.1, 0.1], + # "labels": ["entailment", "contradiction", "neutral"] + # } + Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], + + # { + # 'labels': ['吸烟', '打电话', '吸烟'], + # 'scores': [0.7527753114700317, 0.753358006477356, 0.6880350708961487], + # 'boxes': [[547, 2, 1225, 719], [529, 8, 1255, 719], [584, 0, 1269, 719]], + # 'timestamps': [1, 3, 5] + # } + Tasks.action_detection: [ + OutputKeys.TIMESTAMPS, + OutputKeys.LABELS, + OutputKeys.SCORES, + OutputKeys.BOXES, + ], + + # { + # 'output': [ + # [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, + # {'label': '13421097', 'score': 2.2825044965202324e-08}], + # [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, + # {'label': '13421097', 'score': 2.75914817393641e-06}], + # [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, + # {'label': '13421097', 'score': 2.75914817393641e-06}]] + # } + Tasks.faq_question_answering: [OutputKeys.OUTPUT], + + # image person reid result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # } + Tasks.image_reid_person: [OutputKeys.IMG_EMBEDDING], + + # { + # 'output': ['Done' / 'Decode_Error'] + # } + Tasks.video_inpainting: [OutputKeys.OUTPUT], + + # { + # 'output': ['bixin'] + # } + Tasks.hand_static: [OutputKeys.OUTPUT], + + # { 'labels': [2, 1, 0], + # 'boxes':[[[78, 282, 240, 504], [127, 87, 332, 370], [0, 0, 367, 639]] + # 'scores':[0.8202137351036072, 0.8987470269203186, 0.9679114818572998] + # } + Tasks.face_human_hand_detection: [ + OutputKeys.LABELS, OutputKeys.BOXES, OutputKeys.SCORES + ], + + # { + # {'output': 'Happiness', 'boxes': (203, 104, 663, 564)} + # } + Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES], + + # { + # "masks": [ + # np.array # 2D array containing only 0, 255 + # ] + # } + Tasks.product_segmentation: [OutputKeys.MASKS], +} + + +class ModelOutputBase(list): + + def __post_init__(self): + self.reconstruct() + self.post_init = True + + def reconstruct(self): + # Low performance, but low frequency. + self.clear() + for idx, key in enumerate(self.keys()): + self.append(getattr(self, key)) + + def __getitem__(self, item): + if isinstance(item, str): + if hasattr(self, item): + return getattr(self, item) + elif isinstance(item, (int, slice)): + return super().__getitem__(item) + raise IndexError(f'No Index {item} found in the dataclass.') + + def __setitem__(self, key, value): + if isinstance(key, str): + if key in [f.name for f in fields(self)]: + if key not in self.keys(): + super().__setattr__(key, value) + self.reconstruct() + elif id(getattr(self, key)) != id(value): + super().__setattr__(key, value) + super().__setitem__(self.keys().index(key), value) + else: + super().__setattr__(key, value) + elif isinstance(key, int): + super().__setitem__(key, value) + key_name = self.keys()[key] + super().__setattr__(key_name, value) + + def __setattr__(self, key, value): + if getattr(self, 'post_init', False): + return self.__setitem__(key, value) + else: + return super().__setattr__(key, value) + + def keys(self): + return [ + f.name for f in fields(self) if getattr(self, f.name) is not None + ] + + def items(self): + return self.to_dict().items() + + def to_dict(self): + output = OrderedDict() + for key in self.keys(): + output[key] = getattr(self, key) + return output diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py new file mode 100644 index 00000000..13560229 --- /dev/null +++ b/modelscope/pipeline_inputs.py @@ -0,0 +1,244 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np +from PIL import Image + +from modelscope.models.base.base_head import Input +from modelscope.utils.constant import Tasks + + +class InputKeys(object): + IMAGE = 'image' + TEXT = 'text' + VIDEO = 'video' + + +class InputType(object): + IMAGE = 'image' + TEXT = 'text' + AUDIO = 'audio' + VIDEO = 'video' + BOX = 'box' + DICT = 'dict' + LIST = 'list' + INT = 'int' + + +INPUT_TYPE = { + InputType.IMAGE: (str, np.ndarray, Image.Image), + InputType.TEXT: str, + InputType.AUDIO: (str, bytes, np.ndarray), + InputType.VIDEO: (str, np.ndarray, cv2.VideoCapture), + InputType.BOX: (list, np.ndarray), + InputType.DICT: (dict, type(None)), + InputType.LIST: (list, type(None)), + InputType.INT: int, +} + + +def check_input_type(input_type, input): + expected_type = INPUT_TYPE[input_type] + assert isinstance(input, expected_type), \ + f'invalid input type for {input_type}, expected {expected_type} but got {type(input)}\n {input}' + + +TASK_INPUTS = { + # if task input is single var, value is InputType + # if task input is a tuple, value is tuple of InputType + # if task input is a dict, value is a dict of InputType, where key + # equals the one needed in pipeline input dict + # if task input is a list, value is a set of input format, in which + # each elements corresponds to one input format as described above. + # ============ vision tasks =================== + Tasks.ocr_detection: + InputType.IMAGE, + Tasks.ocr_recognition: + InputType.IMAGE, + Tasks.face_2d_keypoints: + InputType.IMAGE, + Tasks.face_detection: + InputType.IMAGE, + Tasks.facial_expression_recognition: + InputType.IMAGE, + Tasks.face_recognition: + InputType.IMAGE, + Tasks.human_detection: + InputType.IMAGE, + Tasks.face_image_generation: + InputType.INT, + Tasks.image_classification: + InputType.IMAGE, + Tasks.image_object_detection: + InputType.IMAGE, + Tasks.image_segmentation: + InputType.IMAGE, + Tasks.portrait_matting: + InputType.IMAGE, + + # image editing task result for a single image + Tasks.skin_retouching: + InputType.IMAGE, + Tasks.image_super_resolution: + InputType.IMAGE, + Tasks.image_colorization: + InputType.IMAGE, + Tasks.image_color_enhancement: + InputType.IMAGE, + Tasks.image_denoising: + InputType.IMAGE, + Tasks.image_portrait_enhancement: + InputType.IMAGE, + Tasks.crowd_counting: + InputType.IMAGE, + Tasks.image_inpainting: { + 'img': InputType.IMAGE, + 'mask': InputType.IMAGE, + }, + + # image generation task result for a single image + Tasks.image_to_image_generation: + InputType.IMAGE, + Tasks.image_to_image_translation: + InputType.IMAGE, + Tasks.image_style_transfer: { + 'content': InputType.IMAGE, + 'style': InputType.IMAGE, + }, + Tasks.image_portrait_stylization: + InputType.IMAGE, + Tasks.live_category: + InputType.VIDEO, + Tasks.action_recognition: + InputType.VIDEO, + Tasks.body_2d_keypoints: + InputType.IMAGE, + Tasks.body_3d_keypoints: + InputType.VIDEO, + Tasks.hand_2d_keypoints: + InputType.IMAGE, + Tasks.video_single_object_tracking: (InputType.VIDEO, InputType.BOX), + Tasks.video_category: + InputType.VIDEO, + Tasks.product_retrieval_embedding: + InputType.IMAGE, + Tasks.video_embedding: + InputType.VIDEO, + Tasks.virtual_try_on: (InputType.IMAGE, InputType.IMAGE, InputType.IMAGE), + Tasks.text_driven_segmentation: { + InputKeys.IMAGE: InputType.IMAGE, + InputKeys.TEXT: InputType.TEXT + }, + Tasks.shop_segmentation: + InputType.IMAGE, + Tasks.movie_scene_segmentation: + InputType.VIDEO, + + # ============ nlp tasks =================== + Tasks.text_classification: [ + InputType.TEXT, + (InputType.TEXT, InputType.TEXT), + { + 'text': InputType.TEXT, + 'text2': InputType.TEXT + }, + ], + Tasks.sentence_similarity: (InputType.TEXT, InputType.TEXT), + Tasks.nli: (InputType.TEXT, InputType.TEXT), + Tasks.sentiment_classification: + InputType.TEXT, + Tasks.zero_shot_classification: + InputType.TEXT, + Tasks.relation_extraction: + InputType.TEXT, + Tasks.translation: + InputType.TEXT, + Tasks.word_segmentation: [InputType.TEXT, { + 'text': InputType.TEXT, + }], + Tasks.part_of_speech: + InputType.TEXT, + Tasks.named_entity_recognition: + InputType.TEXT, + Tasks.text_error_correction: + InputType.TEXT, + Tasks.sentence_embedding: { + 'source_sentence': InputType.LIST, + 'sentences_to_compare': InputType.LIST, + }, + Tasks.text_ranking: (InputType.TEXT, InputType.TEXT), + Tasks.text_generation: + InputType.TEXT, + Tasks.fill_mask: + InputType.TEXT, + Tasks.task_oriented_conversation: { + 'user_input': InputType.TEXT, + 'history': InputType.DICT, + }, + Tasks.table_question_answering: { + 'question': InputType.TEXT, + 'history_sql': InputType.DICT, + }, + Tasks.faq_question_answering: { + 'query_set': InputType.LIST, + 'support_set': InputType.LIST, + }, + + # ============ audio tasks =================== + Tasks.auto_speech_recognition: + InputType.AUDIO, + Tasks.speech_signal_process: + InputType.AUDIO, + Tasks.acoustic_echo_cancellation: { + 'nearend_mic': InputType.AUDIO, + 'farend_speech': InputType.AUDIO + }, + Tasks.acoustic_noise_suppression: + InputType.AUDIO, + Tasks.text_to_speech: + InputType.TEXT, + Tasks.keyword_spotting: + InputType.AUDIO, + + # ============ multi-modal tasks =================== + Tasks.image_captioning: [InputType.IMAGE, { + 'image': InputType.IMAGE, + }], + Tasks.visual_grounding: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.text_to_image_synthesis: { + 'text': InputType.TEXT, + }, + Tasks.multi_modal_embedding: { + 'img': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.generative_multi_modal_embedding: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.multi_modal_similarity: { + 'img': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.visual_question_answering: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.visual_entailment: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT, + 'text2': InputType.TEXT, + }, + Tasks.action_detection: + InputType.VIDEO, + Tasks.image_reid_person: + InputType.IMAGE, + Tasks.video_inpainting: { + 'video_input_path': InputType.TEXT, + 'video_output_path': InputType.TEXT, + 'mask_path': InputType.TEXT, + } +} diff --git a/modelscope/pipelines/__init__.py b/modelscope/pipelines/__init__.py new file mode 100644 index 00000000..71fe307b --- /dev/null +++ b/modelscope/pipelines/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule +from . import audio, cv, multi_modal, nlp +from .base import Pipeline +from .builder import pipeline diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py new file mode 100644 index 00000000..b46ca87e --- /dev/null +++ b/modelscope/pipelines/audio/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .ans_pipeline import ANSPipeline + from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline + from .kws_farfield_pipeline import KWSFarfieldPipeline + from .kws_kwsbp_pipeline import KeyWordSpottingKwsbpPipeline + from .linear_aec_pipeline import LinearAECPipeline + from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline + +else: + _import_structure = { + 'ans_pipeline': ['ANSPipeline'], + 'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'], + 'kws_farfield_pipeline': ['KWSFarfieldPipeline'], + 'kws_kwsbp_pipeline': ['KeyWordSpottingKwsbpPipeline'], + 'linear_aec_pipeline': ['LinearAECPipeline'], + 'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py new file mode 100644 index 00000000..e55f613e --- /dev/null +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -0,0 +1,120 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +from typing import Any, Dict + +import librosa +import numpy as np +import soundfile as sf +import torch + +from modelscope.fileio import File +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.audio.audio_utils import audio_norm +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.acoustic_noise_suppression, + module_name=Pipelines.speech_frcrn_ans_cirm_16k) +class ANSPipeline(Pipeline): + r"""ANS (Acoustic Noise Suppression) Inference Pipeline . + + When invoke the class with pipeline.__call__(), it accept only one parameter: + inputs(str): the path of wav file + """ + SAMPLE_RATE = 16000 + + def __init__(self, model, **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.model.eval() + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + if isinstance(inputs, bytes): + data1, fs = sf.read(io.BytesIO(inputs)) + elif isinstance(inputs, str): + file_bytes = File.read(inputs) + data1, fs = sf.read(io.BytesIO(file_bytes)) + else: + raise TypeError(f'Unsupported type {type(inputs)}.') + if len(data1.shape) > 1: + data1 = data1[:, 0] + if fs != self.SAMPLE_RATE: + data1 = librosa.resample(data1, fs, self.SAMPLE_RATE) + data1 = audio_norm(data1) + data = data1.astype(np.float32) + inputs = np.reshape(data, [1, data.shape[0]]) + return {'ndarray': inputs, 'nsamples': data.shape[0]} + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + ndarray = inputs['ndarray'] + if isinstance(ndarray, torch.Tensor): + ndarray = ndarray.cpu().numpy() + nsamples = inputs['nsamples'] + decode_do_segement = False + window = 16000 + stride = int(window * 0.75) + print('inputs:{}'.format(ndarray.shape)) + b, t = ndarray.shape # size() + if t > window * 120: + decode_do_segement = True + + if t < window: + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], window - t))], 1) + elif t < window + stride: + padding = window + stride - t + print('padding: {}'.format(padding)) + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], padding))], 1) + else: + if (t - window) % stride != 0: + padding = t - (t - window) // stride * stride + print('padding: {}'.format(padding)) + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], padding))], 1) + print('inputs after padding:{}'.format(ndarray.shape)) + with torch.no_grad(): + ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device) + b, t = ndarray.shape + if decode_do_segement: + outputs = np.zeros(t) + give_up_length = (window - stride) // 2 + current_idx = 0 + while current_idx + window <= t: + print('current_idx: {}'.format(current_idx)) + tmp_input = dict(noisy=ndarray[:, current_idx:current_idx + + window]) + tmp_output = self.model( + tmp_input, )['wav_l2'][0].cpu().numpy() + end_index = current_idx + window - give_up_length + if current_idx == 0: + outputs[current_idx: + end_index] = tmp_output[:-give_up_length] + else: + outputs[current_idx + + give_up_length:end_index] = tmp_output[ + give_up_length:-give_up_length] + current_idx += stride + else: + outputs = self.model( + dict(noisy=ndarray))['wav_l2'][0].cpu().numpy() + outputs = (outputs[:nsamples] * 32768).astype(np.int16).tobytes() + return {OutputKeys.OUTPUT_PCM: outputs} + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + if 'output_path' in kwargs.keys(): + sf.write( + kwargs['output_path'], + np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16), + self.SAMPLE_RATE) + return inputs diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py new file mode 100644 index 00000000..6a4864bf --- /dev/null +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -0,0 +1,281 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Sequence, Tuple, Union + +import yaml + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import WavToScp +from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, + load_bytes_from_url) +from modelscope.utils.constant import Frameworks, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['AutomaticSpeechRecognitionPipeline'] + + +@PIPELINES.register_module( + Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) +class AutomaticSpeechRecognitionPipeline(Pipeline): + """ASR Inference Pipeline + """ + + def __init__(self, + model: Union[Model, str] = None, + preprocessor: WavToScp = None, + **kwargs): + """use `model` and `preprocessor` to create an asr pipeline for prediction + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model_cfg = self.model.forward() + + def __call__(self, + audio_in: Union[str, bytes], + audio_fs: int = None, + recog_type: str = None, + audio_format: str = None) -> Dict[str, Any]: + from easyasr.common import asr_utils + + self.recog_type = recog_type + self.audio_format = audio_format + self.audio_fs = audio_fs + + if isinstance(audio_in, str): + # load pcm data from url if audio_in is url str + self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in) + elif isinstance(audio_in, bytes): + # load pcm data from wav data if audio_in is wave format + self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in) + else: + self.audio_in = audio_in + + # set the sample_rate of audio_in if checking_audio_fs is valid + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs + + if recog_type is None or audio_format is None: + self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( + audio_in=self.audio_in, + recog_type=recog_type, + audio_format=audio_format) + + if hasattr(asr_utils, 'sample_rate_checking'): + checking_audio_fs = asr_utils.sample_rate_checking( + self.audio_in, self.audio_format) + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs + + if self.preprocessor is None: + self.preprocessor = WavToScp() + + output = self.preprocessor.forward(self.model_cfg, self.recog_type, + self.audio_format, self.audio_in, + self.audio_fs) + output = self.forward(output) + rst = self.postprocess(output) + return rst + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Decoding + """ + + logger.info(f"Decoding with {inputs['audio_format']} files ...") + + data_cmd: Sequence[Tuple[str, str, str]] + if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm': + data_cmd = ['speech', 'sound'] + elif inputs['audio_format'] == 'kaldi_ark': + data_cmd = ['speech', 'kaldi_ark'] + elif inputs['audio_format'] == 'tfrecord': + data_cmd = ['speech', 'tfrecord'] + + if inputs.__contains__('mvn_file'): + data_cmd.append(inputs['mvn_file']) + + # generate asr inference command + cmd = { + 'model_type': inputs['model_type'], + 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available + 'log_level': 'ERROR', + 'audio_in': inputs['audio_lists'], + 'name_and_type': data_cmd, + 'asr_model_file': inputs['am_model_path'], + 'idx_text': '', + 'sampled_ids': 'seq2seq/sampled_ids', + 'sampled_lengths': 'seq2seq/sampled_lengths', + 'lang': 'zh-cn', + 'fs': { + 'audio_fs': inputs['audio_fs'], + 'model_fs': 16000 + } + } + + if self.framework == Frameworks.torch: + config_file = open(inputs['asr_model_config']) + root = yaml.full_load(config_file) + config_file.close() + frontend_conf = None + if 'frontend_conf' in root: + frontend_conf = root['frontend_conf'] + + cmd['beam_size'] = root['beam_size'] + cmd['penalty'] = root['penalty'] + cmd['maxlenratio'] = root['maxlenratio'] + cmd['minlenratio'] = root['minlenratio'] + cmd['ctc_weight'] = root['ctc_weight'] + cmd['lm_weight'] = root['lm_weight'] + cmd['asr_train_config'] = inputs['am_model_config'] + cmd['batch_size'] = inputs['model_config']['batch_size'] + cmd['frontend_conf'] = frontend_conf + if frontend_conf is not None and 'fs' in frontend_conf: + cmd['fs']['model_fs'] = frontend_conf['fs'] + + elif self.framework == Frameworks.tf: + cmd['fs']['model_fs'] = inputs['model_config']['fs'] + cmd['hop_length'] = inputs['model_config']['hop_length'] + cmd['feature_dims'] = inputs['model_config']['feature_dims'] + cmd['predictions_file'] = 'text' + cmd['mvn_file'] = inputs['am_mvn_file'] + cmd['vocab_file'] = inputs['vocab_file'] + cmd['lang'] = inputs['model_lang'] + if 'idx_text' in inputs: + cmd['idx_text'] = inputs['idx_text'] + if 'sampled_ids' in inputs['model_config']: + cmd['sampled_ids'] = inputs['model_config']['sampled_ids'] + if 'sampled_lengths' in inputs['model_config']: + cmd['sampled_lengths'] = inputs['model_config'][ + 'sampled_lengths'] + + else: + raise ValueError('model type is mismatching') + + inputs['asr_result'] = self.run_inference(cmd) + + return inputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the asr results + """ + from easyasr.common import asr_utils + + logger.info('Computing the result of ASR ...') + + rst = {} + + # single wav or pcm task + if inputs['recog_type'] == 'wav': + if 'asr_result' in inputs and len(inputs['asr_result']) > 0: + text = inputs['asr_result'][0]['value'] + if len(text) > 0: + rst[OutputKeys.TEXT] = text + + # run with datasets, and audio format is waveform or kaldi_ark or tfrecord + elif inputs['recog_type'] != 'wav': + inputs['reference_list'] = self.ref_list_tidy(inputs) + + if hasattr(asr_utils, 'set_parameters'): + asr_utils.set_parameters(language=inputs['model_lang']) + inputs['datasets_result'] = asr_utils.compute_wer( + hyp_list=inputs['asr_result'], + ref_list=inputs['reference_list']) + + else: + raise ValueError('recog_type and audio_format are mismatching') + + if 'datasets_result' in inputs: + rst[OutputKeys.TEXT] = inputs['datasets_result'] + + return rst + + def ref_list_tidy(self, inputs: Dict[str, Any]) -> List[Any]: + ref_list = [] + + if inputs['audio_format'] == 'tfrecord': + # should assemble idx + txt + with open(inputs['reference_text'], 'r', encoding='utf-8') as r: + text_lines = r.readlines() + + with open(inputs['idx_text'], 'r', encoding='utf-8') as i: + idx_lines = i.readlines() + + j: int = 0 + while j < min(len(text_lines), len(idx_lines)): + idx_str = idx_lines[j].strip() + text_str = text_lines[j].strip().replace(' ', '') + item = {'key': idx_str, 'value': text_str} + ref_list.append(item) + j += 1 + + else: + # text contain idx + sentence + with open(inputs['reference_text'], 'r', encoding='utf-8') as f: + lines = f.readlines() + + for line in lines: + line_item = line.split(None, 1) + if len(line_item) > 1: + item = { + 'key': line_item[0], + 'value': line_item[1].strip('\n') + } + ref_list.append(item) + + return ref_list + + def run_inference(self, cmd): + asr_result = [] + if self.framework == Frameworks.torch: + from easyasr import asr_inference_paraformer_espnet + + if hasattr(asr_inference_paraformer_espnet, 'set_parameters'): + asr_inference_paraformer_espnet.set_parameters( + sample_rate=cmd['fs']) + asr_inference_paraformer_espnet.set_parameters( + language=cmd['lang']) + + asr_result = asr_inference_paraformer_espnet.asr_inference( + batch_size=cmd['batch_size'], + maxlenratio=cmd['maxlenratio'], + minlenratio=cmd['minlenratio'], + beam_size=cmd['beam_size'], + ngpu=cmd['ngpu'], + ctc_weight=cmd['ctc_weight'], + lm_weight=cmd['lm_weight'], + penalty=cmd['penalty'], + log_level=cmd['log_level'], + name_and_type=cmd['name_and_type'], + audio_lists=cmd['audio_in'], + asr_train_config=cmd['asr_train_config'], + asr_model_file=cmd['asr_model_file'], + frontend_conf=cmd['frontend_conf']) + + elif self.framework == Frameworks.tf: + from easyasr import asr_inference_paraformer_tf + if hasattr(asr_inference_paraformer_tf, 'set_parameters'): + asr_inference_paraformer_tf.set_parameters( + language=cmd['lang']) + else: + # in order to support easyasr-0.0.2 + cmd['fs'] = cmd['fs']['model_fs'] + + asr_result = asr_inference_paraformer_tf.asr_inference( + ngpu=cmd['ngpu'], + name_and_type=cmd['name_and_type'], + audio_lists=cmd['audio_in'], + idx_text_file=cmd['idx_text'], + asr_model_file=cmd['asr_model_file'], + vocab_file=cmd['vocab_file'], + am_mvn_file=cmd['mvn_file'], + predictions_file=cmd['predictions_file'], + fs=cmd['fs'], + hop_length=cmd['hop_length'], + feature_dims=cmd['feature_dims'], + sampled_ids=cmd['sampled_ids'], + sampled_lengths=cmd['sampled_lengths']) + + return asr_result diff --git a/modelscope/pipelines/audio/kws_farfield_pipeline.py b/modelscope/pipelines/audio/kws_farfield_pipeline.py new file mode 100644 index 00000000..e2f618fa --- /dev/null +++ b/modelscope/pipelines/audio/kws_farfield_pipeline.py @@ -0,0 +1,91 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +import wave +from typing import Any, Dict + +import numpy +import soundfile as sf + +from modelscope.fileio import File +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.keyword_spotting, + module_name=Pipelines.speech_dfsmn_kws_char_farfield) +class KWSFarfieldPipeline(Pipeline): + r"""A Keyword Spotting Inference Pipeline . + + When invoke the class with pipeline.__call__(), it accept only one parameter: + inputs(str): the path of wav file + """ + SAMPLE_RATE = 16000 + SAMPLE_WIDTH = 2 + INPUT_CHANNELS = 3 + OUTPUT_CHANNELS = 2 + + def __init__(self, model, **kwargs): + """ + use `model` to create a kws far field pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.model = self.model.to(self.device) + self.model.eval() + frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH + self._nframe = self.model.size_in // frame_size + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + if isinstance(inputs, bytes): + return dict(input_file=inputs) + elif isinstance(inputs, str): + return dict(input_file=inputs) + elif isinstance(inputs, Dict): + return inputs + else: + raise ValueError(f'Not supported input type: {type(inputs)}') + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + input_file = inputs['input_file'] + if isinstance(input_file, str): + input_file = File.read(input_file) + frames, samplerate = sf.read(io.BytesIO(input_file), dtype='int16') + if len(frames.shape) == 1: + frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1) + + kws_list = [] + if 'output_file' in inputs: + with wave.open(inputs['output_file'], 'wb') as fout: + fout.setframerate(self.SAMPLE_RATE) + fout.setnchannels(self.OUTPUT_CHANNELS) + fout.setsampwidth(self.SAMPLE_WIDTH) + self._process(frames, kws_list, fout) + else: + self._process(frames, kws_list) + return {OutputKeys.KWS_LIST: kws_list} + + def _process(self, + frames: numpy.ndarray, + kws_list, + fout: wave.Wave_write = None): + for start_index in range(0, frames.shape[0], self._nframe): + end_index = start_index + self._nframe + if end_index > frames.shape[0]: + end_index = frames.shape[0] + data = frames[start_index:end_index, :].tobytes() + result = self.model.forward_decode(data) + if fout: + fout.writeframes(result['pcm']) + if 'kws' in result: + result['kws']['offset'] += start_index / self.SAMPLE_RATE + kws_list.append(result['kws']) + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py new file mode 100644 index 00000000..db6fc65d --- /dev/null +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -0,0 +1,196 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, List, Union + +import json + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import WavToLists +from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, + load_bytes_from_url) +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['KeyWordSpottingKwsbpPipeline'] + + +@PIPELINES.register_module( + Tasks.keyword_spotting, module_name=Pipelines.kws_kwsbp) +class KeyWordSpottingKwsbpPipeline(Pipeline): + """KWS Pipeline - key word spotting decoding + """ + + def __init__(self, + model: Union[Model, str] = None, + preprocessor: WavToLists = None, + **kwargs): + """use `model` and `preprocessor` to create a kws pipeline for prediction + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def __call__(self, audio_in: Union[List[str], str, bytes], + **kwargs) -> Dict[str, Any]: + if 'keywords' in kwargs.keys(): + self.keywords = kwargs['keywords'] + if isinstance(self.keywords, str): + word_list = [] + word = {} + word['keyword'] = self.keywords + word_list.append(word) + self.keywords = word_list + else: + self.keywords = None + + if self.preprocessor is None: + self.preprocessor = WavToLists() + + if isinstance(audio_in, str): + # load pcm data from url if audio_in is url str + audio_in, audio_fs = load_bytes_from_url(audio_in) + elif isinstance(audio_in, bytes): + # load pcm data from wav data if audio_in is wave format + audio_in, audio_fs = extract_pcm_from_wav(audio_in) + + output = self.preprocessor.forward(self.model.forward(), audio_in) + output = self.forward(output) + rst = self.postprocess(output) + return rst + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Decoding + """ + + logger.info(f"Decoding with {inputs['kws_type']} mode ...") + + # will generate kws result + out = self.run_with_kwsbp(inputs) + + return out + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the kws results + + Args: + inputs['pos_kws_list'] or inputs['neg_kws_list']: + result_dict format example: + [{ + 'confidence': 0.9903678297996521, + 'filename': 'data/test/audios/kws_xiaoyunxiaoyun.wav', + 'keyword': '小云小云', + 'offset': 5.760000228881836, # second + 'rtf_time': 66, # millisecond + 'threshold': 0, + 'wav_time': 9.1329375 # second + }] + """ + + import kws_util.common + neg_kws_list = None + pos_kws_list = None + if 'pos_kws_list' in inputs: + pos_kws_list = inputs['pos_kws_list'] + if 'neg_kws_list' in inputs: + neg_kws_list = inputs['neg_kws_list'] + + rst_dict = kws_util.common.parsing_kws_result( + kws_type=inputs['kws_type'], + pos_list=pos_kws_list, + neg_list=neg_kws_list) + + if 'kws_list' not in rst_dict: + rst_dict['kws_list'] = [] + + return rst_dict + + def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + import kwsbp + import kws_util.common + kws_inference = kwsbp.KwsbpEngine() + + cmd = { + 'sys_dir': + inputs['model_workspace'], + 'cfg_file': + inputs['cfg_file_path'], + 'sample_rate': + inputs['sample_rate'], + 'keyword_custom': + '', + 'pcm_data': + None, + 'pcm_data_len': + 0, + 'list_flag': + True, + # setting customized keywords + 'customized_keywords': + kws_util.common.generate_customized_keywords(self.keywords) + } + + if inputs['kws_type'] == 'pcm': + cmd['pcm_data'] = inputs['pos_data'] + cmd['pcm_data_len'] = len(inputs['pos_data']) + cmd['list_flag'] = False + + if inputs['kws_type'] == 'roc': + inputs['keyword_grammar_path'] = os.path.join( + inputs['model_workspace'], 'keywords_roc.json') + + if inputs['kws_type'] in ['wav', 'pcm', 'pos_testsets', 'roc']: + cmd['wave_scp'] = inputs['pos_wav_list'] + cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] + cmd['num_thread'] = inputs['pos_num_thread'] + + if hasattr(kws_inference, 'inference_new'): + # run and get inference result + result = kws_inference.inference_new( + cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), cmd['pcm_data'], + cmd['pcm_data_len'], cmd['sample_rate'], cmd['num_thread'], + cmd['list_flag']) + else: + # in order to support kwsbp-0.0.1 + result = kws_inference.inference( + cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), cmd['sample_rate'], + cmd['num_thread']) + + pos_result = json.loads(result) + inputs['pos_kws_list'] = pos_result['kws_list'] + + if inputs['kws_type'] in ['neg_testsets', 'roc']: + cmd['wave_scp'] = inputs['neg_wav_list'] + cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] + cmd['num_thread'] = inputs['neg_num_thread'] + + if hasattr(kws_inference, 'inference_new'): + # run and get inference result + result = kws_inference.inference_new( + cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), cmd['pcm_data'], + cmd['pcm_data_len'], cmd['sample_rate'], cmd['num_thread'], + cmd['list_flag']) + else: + # in order to support kwsbp-0.0.1 + result = kws_inference.inference( + cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), cmd['sample_rate'], + cmd['num_thread']) + + neg_result = json.loads(result) + inputs['neg_kws_list'] = neg_result['kws_list'] + + return inputs diff --git a/modelscope/pipelines/audio/linear_aec_pipeline.py b/modelscope/pipelines/audio/linear_aec_pipeline.py new file mode 100644 index 00000000..e1e75ddb --- /dev/null +++ b/modelscope/pipelines/audio/linear_aec_pipeline.py @@ -0,0 +1,171 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import importlib +import os +from typing import Any, Dict + +import numpy as np +import scipy.io.wavfile as wav +import torch +import yaml + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LinearAECAndFbank +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +FEATURE_MVN = 'feature.DEY.mvn.txt' + +CONFIG_YAML = 'dey_mini.yaml' + + +def initialize_config(module_cfg): + r"""According to config items, load specific module dynamically with params. + 1. Load the module corresponding to the "module" param. + 2. Call function (or instantiate class) corresponding to the "main" param. + 3. Send the param (in "args") into the function (or class) when calling ( or instantiating). + + Args: + module_cfg (dict): config items, eg: + { + "module": "models.model", + "main": "Model", + "args": {...} + } + + Returns: + the module loaded. + """ + module = importlib.import_module(module_cfg['module']) + return getattr(module, module_cfg['main'])(**module_cfg['args']) + + +@PIPELINES.register_module( + Tasks.acoustic_echo_cancellation, + module_name=Pipelines.speech_dfsmn_aec_psm_16k) +class LinearAECPipeline(Pipeline): + r"""AEC Inference Pipeline only support 16000 sample rate. + + When invoke the class with pipeline.__call__(), you should provide two params: + Dict[str, Any] + the path of wav files, eg:{ + "nearend_mic": "/your/data/near_end_mic_audio.wav", + "farend_speech": "/your/data/far_end_speech_audio.wav"} + output_path (str, optional): "/your/output/audio_after_aec.wav" + the file path to write generate audio. + """ + + def __init__(self, model, **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + self.use_cuda = torch.cuda.is_available() + with open( + os.path.join(self.model, CONFIG_YAML), encoding='utf-8') as f: + self.config = yaml.full_load(f.read()) + self.config['io']['mvn'] = os.path.join(self.model, FEATURE_MVN) + self._init_model() + self.preprocessor = LinearAECAndFbank(self.config['io']) + + n_fft = self.config['loss']['args']['n_fft'] + hop_length = self.config['loss']['args']['hop_length'] + winlen = n_fft + window = torch.hamming_window(winlen, periodic=False) + + def stft(x): + return torch.stft( + x, + n_fft, + hop_length, + winlen, + center=False, + window=window.to(x.device), + return_complex=False) + + def istft(x, slen): + return torch.istft( + x, + n_fft, + hop_length, + winlen, + window=window.to(x.device), + center=False, + length=slen) + + self.stft = stft + self.istft = istft + + def _init_model(self): + checkpoint = torch.load( + os.path.join(self.model, ModelFile.TORCH_MODEL_BIN_FILE), + map_location='cpu') + self.model = initialize_config(self.config['nnet']) + if self.use_cuda: + self.model = self.model.cuda() + self.model.load_state_dict(checkpoint) + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + r"""The AEC process. + + Args: + inputs: dict={'feature': Tensor, 'base': Tensor} + 'feature' feature of input audio. + 'base' the base audio to mask. + + Returns: + dict: + { + 'output_pcm': generated audio array + } + """ + output_data = self._process(inputs['feature'], inputs['base']) + output_data = output_data.astype(np.int16).tobytes() + return {OutputKeys.OUTPUT_PCM: output_data} + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + r"""The post process. Will save audio to file, if the output_path is given. + + Args: + inputs: dict: + { + 'output_pcm': generated audio array + } + kwargs: accept 'output_path' which is the path to write generated audio + + Returns: + dict: + { + 'output_pcm': generated audio array + } + """ + if 'output_path' in kwargs.keys(): + wav.write( + kwargs['output_path'], self.preprocessor.SAMPLE_RATE, + np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16)) + return inputs + + def _process(self, fbanks, mixture): + if self.use_cuda: + fbanks = fbanks.cuda() + mixture = mixture.cuda() + if self.model.vad: + with torch.no_grad(): + masks, vad = self.model(fbanks.unsqueeze(0)) + masks = masks.permute([2, 1, 0]) + else: + with torch.no_grad(): + masks = self.model(fbanks.unsqueeze(0)) + masks = masks.permute([2, 1, 0]) + spectrum = self.stft(mixture) + masked_spec = spectrum * masks + masked_sig = self.istft(masked_spec, len(mixture)).cpu().numpy() + return masked_sig diff --git a/modelscope/pipelines/audio/text_to_speech_pipeline.py b/modelscope/pipelines/audio/text_to_speech_pipeline.py new file mode 100644 index 00000000..2063da68 --- /dev/null +++ b/modelscope/pipelines/audio/text_to_speech_pipeline.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, List + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.audio.tts import SambertHifigan +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, InputModel, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Fields, Tasks + +__all__ = ['TextToSpeechSambertHifiganPipeline'] + + +@PIPELINES.register_module( + Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_tts) +class TextToSpeechSambertHifiganPipeline(Pipeline): + + def __init__(self, model: InputModel, **kwargs): + """use `model` to create a text-to-speech pipeline for prediction + + Args: + model (SambertHifigan or str): a model instance or valid offical model id + """ + super().__init__(model=model, **kwargs) + + def forward(self, input: str, **forward_params) -> Dict[str, np.ndarray]: + """synthesis text from inputs with pipeline + Args: + input (str): text to synthesis + forward_params: valid param is 'voice' used to setting speaker vocie + Returns: + Dict[str, np.ndarray]: {OutputKeys.OUTPUT_PCM : np.ndarray(16bit pcm data)} + """ + output_wav = self.model.forward(input, forward_params.get('voice')) + return {OutputKeys.OUTPUT_PCM: output_wav} + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, Any]: + return inputs + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + return inputs + + def _sanitize_parameters(self, **pipeline_parameters): + return {}, pipeline_parameters, {} diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py new file mode 100644 index 00000000..7a8bfd14 --- /dev/null +++ b/modelscope/pipelines/base.py @@ -0,0 +1,460 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +from abc import ABC, abstractmethod +from functools import partial +from multiprocessing import Pool +from threading import Lock +from typing import Any, Dict, Generator, List, Mapping, Union + +import numpy as np + +from modelscope.hub.utils.utils import create_library_statistics +from modelscope.models.base import Model +from modelscope.msdatasets import MsDataset +from modelscope.outputs import TASK_OUTPUTS +from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type +from modelscope.preprocessors import Preprocessor +from modelscope.utils.config import Config +from modelscope.utils.constant import Frameworks, ModelFile +from modelscope.utils.device import (create_device, device_placement, + verify_device) +from modelscope.utils.hub import read_config, snapshot_download +from modelscope.utils.import_utils import is_tf_available, is_torch_available +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import _find_free_port, _is_free_port +from .util import is_model, is_official_hub_path + +if is_torch_available(): + import torch + +if is_tf_available(): + pass + +Tensor = Union['torch.Tensor', 'tf.Tensor'] +Input = Union[str, tuple, MsDataset, 'Image.Image', 'numpy.ndarray'] +InputModel = Union[str, Model, 'torch.nn.Module'] + +logger = get_logger() + + +class Pipeline(ABC): + + def initiate_single_model(self, model): + if isinstance(model, str): + logger.info(f'initiate model from {model}') + if isinstance(model, str) and is_official_hub_path(model): + logger.info(f'initiate model from location {model}.') + # expecting model has been prefetched to local cache beforehand + return Model.from_pretrained( + model, model_prefetched=True, + device=self.device_name) if is_model(model) else model + else: + return model + + def initiate_multiple_models(self, input_models: List[InputModel]): + models = [] + for model in input_models: + models.append(self.initiate_single_model(model)) + return models + + def __init__(self, + config_file: str = None, + model: Union[InputModel, List[InputModel]] = None, + preprocessor: Union[Preprocessor, List[Preprocessor]] = None, + device: str = 'gpu', + auto_collate=True, + **kwargs): + """ Base class for pipeline. + + If config_file is provided, model and preprocessor will be + instantiated from corresponding config. Otherwise, model + and preprocessor will be constructed separately. + + Args: + config_file(str, optional): Filepath to configuration file. + model: (list of) Model name or model object + preprocessor: (list of) Preprocessor object + device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X + auto_collate (bool): automatically to convert data to tensor or not. + """ + if config_file is not None: + self.cfg = Config.from_file(config_file) + + verify_device(device) + self.device_name = device + + if not isinstance(model, List): + self.model = self.initiate_single_model(model) + self.models = [self.model] + else: + self.model = None + self.models = self.initiate_multiple_models(model) + + self.has_multiple_models = len(self.models) > 1 + self.preprocessor = preprocessor + + if self.model or (self.has_multiple_models and self.models[0]): + self.framework = self._get_framework() + else: + self.framework = None + + if self.framework == Frameworks.torch: + self.device = create_device(self.device_name) + self._model_prepare = False + self._model_prepare_lock = Lock() + self._auto_collate = auto_collate + + def prepare_model(self): + """ Place model on certain device for pytorch models before first inference + """ + self._model_prepare_lock.acquire(timeout=600) + + def _prepare_single(model): + if isinstance(model, torch.nn.Module): + model.to(self.device) + model.eval() + elif hasattr(model, 'model') and isinstance( + model.model, torch.nn.Module): + model.model.to(self.device) + model.model.eval() + + if not self._model_prepare: + # prepare model for pytorch + if self.framework == Frameworks.torch: + if self.has_multiple_models: + for m in self.models: + _prepare_single(m) + else: + _prepare_single(self.model) + self._model_prepare = True + self._model_prepare_lock.release() + + def _get_framework(self) -> str: + frameworks = [] + for m in self.models: + if isinstance(m, str): + model_dir = m + else: + model_dir = m.model_dir + cfg_file = osp.join(model_dir, ModelFile.CONFIGURATION) + cfg = Config.from_file(cfg_file) + frameworks.append(cfg.framework) + if not all(x == frameworks[0] for x in frameworks): + raise ValueError( + f'got multiple models, but they are in different frameworks {frameworks}' + ) + + return frameworks[0] + + def __call__(self, input: Union[Input, List[Input]], *args, + **kwargs) -> Union[Dict[str, Any], Generator]: + # model provider should leave it as it is + # modelscope library developer will handle this function + for single_model in self.models: + if hasattr(single_model, 'name'): + create_library_statistics('pipeline', single_model.name, None) + # place model to cpu or gpu + if (self.model or (self.has_multiple_models and self.models[0])): + if not self._model_prepare: + self.prepare_model() + + # simple showcase, need to support iterator type for both tensorflow and pytorch + # input_dict = self._handle_input(input) + + # sanitize the parameters + preprocess_params, forward_params, postprocess_params = self._sanitize_parameters( + **kwargs) + kwargs['preprocess_params'] = preprocess_params + kwargs['forward_params'] = forward_params + kwargs['postprocess_params'] = postprocess_params + + if isinstance(input, list): + output = [] + for ele in input: + output.append(self._process_single(ele, *args, **kwargs)) + + elif isinstance(input, MsDataset): + return self._process_iterator(input, *args, **kwargs) + + else: + output = self._process_single(input, *args, **kwargs) + return output + + def _sanitize_parameters(self, **pipeline_parameters): + """ + this method should sanitize the keyword args to preprocessor params, + forward params and postprocess params on '__call__' or '_process_single' method + considered to be a normal classmethod with default implementation / output + + Default Returns: + Dict[str, str]: preprocess_params = {} + Dict[str, str]: forward_params = {} + Dict[str, str]: postprocess_params = pipeline_parameters + """ + return {}, {}, pipeline_parameters + + def _process_iterator(self, input: Input, *args, **kwargs): + for ele in input: + yield self._process_single(ele, *args, **kwargs) + + def _collate_fn(self, data): + return collate_fn(data, self.device) + + def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: + preprocess_params = kwargs.get('preprocess_params', {}) + forward_params = kwargs.get('forward_params', {}) + postprocess_params = kwargs.get('postprocess_params', {}) + self._check_input(input) + out = self.preprocess(input, **preprocess_params) + with device_placement(self.framework, self.device_name): + if self.framework == Frameworks.torch: + with torch.no_grad(): + if self._auto_collate: + out = self._collate_fn(out) + out = self.forward(out, **forward_params) + else: + out = self.forward(out, **forward_params) + + out = self.postprocess(out, **postprocess_params) + self._check_output(out) + return out + + def _check_input(self, input): + task_name = self.group_key + if task_name in TASK_INPUTS: + input_type = TASK_INPUTS[task_name] + + # if multiple input formats are defined, we first + # found the one that match input data and check + if isinstance(input_type, list): + matched_type = None + for t in input_type: + if isinstance(input, (dict, tuple)): + if type(t) == type(input): + matched_type = t + break + elif isinstance(t, str): + matched_type = t + break + if matched_type is None: + err_msg = 'input data format for current pipeline should be one of following: \n' + for t in input_type: + err_msg += f'{t}\n' + raise ValueError(err_msg) + else: + input_type = matched_type + + if isinstance(input_type, str): + check_input_type(input_type, input) + elif isinstance(input_type, tuple): + for t, input_ele in zip(input_type, input): + check_input_type(t, input_ele) + elif isinstance(input_type, dict): + for k in input_type.keys(): + # allow single input for multi-modal models + if k in input: + check_input_type(input_type[k], input[k]) + else: + raise ValueError(f'invalid input_type definition {input_type}') + else: + logger.warning(f'task {task_name} input definition is missing') + + def _check_output(self, input): + # this attribute is dynamically attached by registry + # when cls is registered in registry using task name + task_name = self.group_key + if task_name not in TASK_OUTPUTS: + logger.warning(f'task {task_name} output keys are missing') + return + output_keys = TASK_OUTPUTS[task_name] + missing_keys = [] + for k in output_keys: + if k not in input: + missing_keys.append(k) + if len(missing_keys) > 0: + raise ValueError(f'expected output keys are {output_keys}, ' + f'those {missing_keys} are missing') + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + """ Provide default implementation based on preprocess_cfg and user can reimplement it + """ + assert self.preprocessor is not None, 'preprocess method should be implemented' + assert not isinstance(self.preprocessor, List),\ + 'default implementation does not support using multiple preprocessors.' + return self.preprocessor(inputs, **preprocess_params) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + """ Provide default implementation using self.model and user can reimplement it + """ + assert self.model is not None, 'forward method should be implemented' + assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' + return self.model(inputs, **forward_params) + + @abstractmethod + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ If current pipeline support model reuse, common postprocess + code should be write here. + + Args: + inputs: input data + + Return: + dict of results: a dict containing outputs of model, each + output should have the standard output name. + """ + raise NotImplementedError('postprocess') + + +class DistributedPipeline(Pipeline): + """This pipeline is used to load multi gpu models. + + What will this class do: + 1. Read the global config from the configuration.json + 2. Set the multiprocessing method to spawn + 3. Open a multiprocessing pool of the world_size to instantiate model pieces. + 4. Set the master port and ip + 5. Call _instantiate_one to instantiate one model piece + This method should be implemented by the derived class. + 6. After the forward method is called, do preprocess in main process + and call _forward_one to collect results, and do + post process in main process. + + NOTE: _instantiate_one and _forward_one are class methods, any derived class should implement them and + store the model handler in the class field. + """ + + def __init__(self, + model: str = None, + preprocessor: Union[Preprocessor, List[Preprocessor]] = None, + auto_collate=True, + **kwargs): + self.preprocessor = preprocessor + self._model_prepare = False + self._model_prepare_lock = Lock() + self._auto_collate = auto_collate + + if os.path.exists(model): + self.model_dir = model + else: + self.model_dir = snapshot_download(model) + self.cfg = read_config(self.model_dir) + self.world_size = self.cfg.model.world_size + self.model_pool = None + self.device_name = 'cpu' + self.device = create_device(self.device_name) + self.has_multiple_models = False + self.framework = self.cfg.framework + if torch.multiprocessing.get_start_method(allow_none=True) is None: + torch.multiprocessing.set_start_method('spawn') + + ranks = list(range(self.world_size)) + self.model_pool = Pool(self.world_size) + master_ip = '127.0.0.1' if 'master_ip' not in kwargs else kwargs[ + 'master_ip'] + master_port = '29500' if 'master_port' not in kwargs else kwargs[ + 'master_port'] + if not _is_free_port(int(master_port)): + master_port = str(_find_free_port()) + self.model_pool.map( + partial( + self.__class__._instantiate_one, + model_dir=self.model_dir, + master_ip=master_ip, + master_port=master_port, + **self.cfg.model, + **kwargs), ranks) + self.models = [] + + def __del__(self): + if hasattr(self, 'model_pool') and self.model_pool is not None: + self.model_pool.terminate() + + def __getstate__(self): + self_dict = self.__dict__.copy() + del self_dict['model_pool'] + del self_dict['preprocessor'] + del self_dict['_model_prepare_lock'] + return self_dict + + @classmethod + def _instantiate_one(cls, rank, model_dir, **kwargs): + """Instantiate one model piece. + + Args: + rank: The model rank. + model_dir: The model_dir in the node. + kwargs: Any extra args. + + Returns: + None. The model handler should be kept in the class field. + """ + pass + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + inputs = { + 'inputs': inputs, + 'forward_params': forward_params, + } + res = self.model_pool.map(self.__class__._forward_one, + [inputs] * self.world_size) + return res[0] + + @classmethod + def _forward_one(cls, inputs): + """Forward the inputs to one model piece. + + Use the model handler kept in the class field to forward. + + Args: + inputs: The inputs after the preprocessing. + + Returns: + The forward results. + """ + pass + + +def collate_fn(data, device): + """Prepare the input just before the forward function. + This method will move the tensors to the right device. + Usually this method does not need to be overridden. + + Args: + data: The data out of the dataloader. + device: The device to move data to. + + Returns: The processed data. + + """ + from torch.utils.data.dataloader import default_collate + from modelscope.preprocessors.nlp import InputFeatures + if isinstance(data, dict) or isinstance(data, Mapping): + return type(data)({k: collate_fn(v, device) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + if 0 == len(data): + return torch.Tensor([]) + if isinstance(data[0], (int, float)): + return default_collate(data).to(device) + else: + return type(data)(collate_fn(v, device) for v in data) + elif isinstance(data, np.ndarray): + if data.dtype.type is np.str_: + return data + else: + return collate_fn(torch.from_numpy(data), device) + elif isinstance(data, torch.Tensor): + return data.to(device) + elif isinstance(data, (bytes, str, int, float, bool, type(None))): + return data + elif isinstance(data, InputFeatures): + return data + else: + import mmcv + if isinstance(data, mmcv.parallel.data_container.DataContainer): + return data + else: + raise ValueError(f'Unsupported data type {type(data)}') diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py new file mode 100644 index 00000000..70f8f11c --- /dev/null +++ b/modelscope/pipelines/builder.py @@ -0,0 +1,380 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import List, Optional, Union + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.utils.config import ConfigDict, check_config +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Tasks +from modelscope.utils.hub import read_config +from modelscope.utils.registry import Registry, build_from_cfg +from .base import Pipeline +from .util import is_official_hub_path + +PIPELINES = Registry('pipelines') + +DEFAULT_MODEL_FOR_PIPELINE = { + # TaskName: (pipeline_module_name, model_repo) + Tasks.sentence_embedding: + (Pipelines.sentence_embedding, + 'damo/nlp_corom_sentence-embedding_english-base'), + Tasks.text_ranking: (Pipelines.text_ranking, + 'damo/nlp_corom_passage-ranking_english-base'), + Tasks.word_segmentation: + (Pipelines.word_segmentation, + 'damo/nlp_structbert_word-segmentation_chinese-base'), + Tasks.part_of_speech: (Pipelines.part_of_speech, + 'damo/nlp_structbert_part-of-speech_chinese-base'), + Tasks.token_classification: + (Pipelines.part_of_speech, + 'damo/nlp_structbert_part-of-speech_chinese-base'), + Tasks.named_entity_recognition: + (Pipelines.named_entity_recognition, + 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), + Tasks.relation_extraction: + (Pipelines.relation_extraction, + 'damo/nlp_bert_relation-extraction_chinese-base'), + Tasks.information_extraction: + (Pipelines.relation_extraction, + 'damo/nlp_bert_relation-extraction_chinese-base'), + Tasks.sentence_similarity: + (Pipelines.sentence_similarity, + 'damo/nlp_structbert_sentence-similarity_chinese-base'), + Tasks.translation: (Pipelines.csanmt_translation, + 'damo/nlp_csanmt_translation_zh2en'), + Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), + Tasks.sentiment_classification: + (Pipelines.sentiment_classification, + 'damo/nlp_structbert_sentiment-classification_chinese-base' + ), # TODO: revise back after passing the pr + Tasks.portrait_matting: (Pipelines.portrait_matting, + 'damo/cv_unet_image-matting'), + Tasks.human_detection: (Pipelines.human_detection, + 'damo/cv_resnet18_human-detection'), + Tasks.image_object_detection: (Pipelines.object_detection, + 'damo/cv_vit_object-detection_coco'), + Tasks.image_denoising: (Pipelines.image_denoise, + 'damo/cv_nafnet_image-denoise_sidd'), + Tasks.text_classification: + (Pipelines.sentiment_classification, + 'damo/nlp_structbert_sentiment-classification_chinese-base'), + Tasks.text_generation: (Pipelines.text_generation, + 'damo/nlp_palm2.0_text-generation_chinese-base'), + Tasks.zero_shot_classification: + (Pipelines.zero_shot_classification, + 'damo/nlp_structbert_zero-shot-classification_chinese-base'), + Tasks.task_oriented_conversation: (Pipelines.dialog_modeling, + 'damo/nlp_space_dialog-modeling'), + Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, + 'damo/nlp_space_dialog-state-tracking'), + Tasks.table_question_answering: + (Pipelines.table_question_answering_pipeline, + 'damo/nlp-convai-text2sql-pretrain-cn'), + Tasks.text_error_correction: + (Pipelines.text_error_correction, + 'damo/nlp_bart_text-error-correction_chinese'), + Tasks.image_captioning: (Pipelines.image_captioning, + 'damo/ofa_image-caption_coco_large_en'), + Tasks.image_portrait_stylization: + (Pipelines.person_image_cartoon, + 'damo/cv_unet_person-image-cartoon_compound-models'), + Tasks.ocr_detection: (Pipelines.ocr_detection, + 'damo/cv_resnet18_ocr-detection-line-level_damo'), + Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), + Tasks.feature_extraction: (Pipelines.feature_extraction, + 'damo/pert_feature-extraction_base-test'), + Tasks.action_recognition: (Pipelines.action_recognition, + 'damo/cv_TAdaConv_action-recognition'), + Tasks.action_detection: (Pipelines.action_detection, + 'damo/cv_ResNetC3D_action-detection_detection2d'), + Tasks.live_category: (Pipelines.live_category, + 'damo/cv_resnet50_live-category'), + Tasks.video_category: (Pipelines.video_category, + 'damo/cv_resnet50_video-category'), + Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding, + 'damo/multi-modal_clip-vit-base-patch16_zh'), + Tasks.generative_multi_modal_embedding: + (Pipelines.generative_multi_modal_embedding, + 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' + ), + Tasks.multi_modal_similarity: + (Pipelines.multi_modal_similarity, + 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'), + Tasks.visual_question_answering: + (Pipelines.visual_question_answering, + 'damo/mplug_visual-question-answering_coco_large_en'), + Tasks.video_embedding: (Pipelines.cmdssl_video_embedding, + 'damo/cv_r2p1d_video_embedding'), + Tasks.text_to_image_synthesis: + (Pipelines.text_to_image_synthesis, + 'damo/cv_diffusion_text-to-image-synthesis_tiny'), + Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints, + 'damo/cv_hrnetv2w32_body-2d-keypoints_image'), + Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints, + 'damo/cv_canonical_body-3d-keypoints_video'), + Tasks.hand_2d_keypoints: + (Pipelines.hand_2d_keypoints, + 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'), + Tasks.face_detection: (Pipelines.face_detection, + 'damo/cv_resnet_facedetection_scrfd10gkps'), + Tasks.card_detection: (Pipelines.card_detection, + 'damo/cv_resnet_carddetection_scrfd34gkps'), + Tasks.face_detection: + (Pipelines.face_detection, + 'damo/cv_resnet101_face-detection_cvpr22papermogface'), + Tasks.face_recognition: (Pipelines.face_recognition, + 'damo/cv_ir101_facerecognition_cfglint'), + Tasks.facial_expression_recognition: + (Pipelines.facial_expression_recognition, + 'damo/cv_vgg19_facial-expression-recognition_fer'), + Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints, + 'damo/cv_mobilenet_face-2d-keypoints_alignment'), + Tasks.video_multi_modal_embedding: + (Pipelines.video_multi_modal_embedding, + 'damo/multi_modal_clip_vtretrival_msrvtt_53'), + Tasks.image_color_enhancement: + (Pipelines.image_color_enhance, + 'damo/cv_csrnet_image-color-enhance-models'), + Tasks.virtual_try_on: (Pipelines.virtual_try_on, + 'damo/cv_daflow_virtual-try-on_base'), + Tasks.image_colorization: (Pipelines.image_colorization, + 'damo/cv_unet_image-colorization'), + Tasks.image_segmentation: + (Pipelines.image_instance_segmentation, + 'damo/cv_swin-b_image-instance-segmentation_coco'), + Tasks.image_style_transfer: (Pipelines.image_style_transfer, + 'damo/cv_aams_style-transfer_damo'), + Tasks.face_image_generation: (Pipelines.face_image_generation, + 'damo/cv_gan_face-image-generation'), + Tasks.image_super_resolution: (Pipelines.image_super_resolution, + 'damo/cv_rrdb_image-super-resolution'), + Tasks.image_portrait_enhancement: + (Pipelines.image_portrait_enhancement, + 'damo/cv_gpen_image-portrait-enhancement'), + Tasks.product_retrieval_embedding: + (Pipelines.product_retrieval_embedding, + 'damo/cv_resnet50_product-bag-embedding-models'), + Tasks.image_to_image_generation: + (Pipelines.image_to_image_generation, + 'damo/cv_latent_diffusion_image2image_generate'), + Tasks.image_classification: + (Pipelines.daily_image_classification, + 'damo/cv_vit-base_image-classification_Dailylife-labels'), + Tasks.image_object_detection: + (Pipelines.image_object_detection_auto, + 'damo/cv_yolox_image-object-detection-auto'), + Tasks.ocr_recognition: + (Pipelines.ocr_recognition, + 'damo/cv_convnextTiny_ocr-recognition-general_damo'), + Tasks.skin_retouching: (Pipelines.skin_retouching, + 'damo/cv_unet_skin-retouching'), + Tasks.faq_question_answering: + (Pipelines.faq_question_answering, + 'damo/nlp_structbert_faq-question-answering_chinese-base'), + Tasks.crowd_counting: (Pipelines.crowd_counting, + 'damo/cv_hrnet_crowd-counting_dcanet'), + Tasks.video_single_object_tracking: + (Pipelines.video_single_object_tracking, + 'damo/cv_vitb_video-single-object-tracking_ostrack'), + Tasks.image_reid_person: (Pipelines.image_reid_person, + 'damo/cv_passvitb_image-reid-person_market'), + Tasks.text_driven_segmentation: + (Pipelines.text_driven_segmentation, + 'damo/cv_vitl16_segmentation_text-driven-seg'), + Tasks.movie_scene_segmentation: + (Pipelines.movie_scene_segmentation, + 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), + Tasks.shop_segmentation: (Pipelines.shop_segmentation, + 'damo/cv_vitb16_segmentation_shop-seg'), + Tasks.image_inpainting: (Pipelines.image_inpainting, + 'damo/cv_fft_inpainting_lama'), + Tasks.video_inpainting: (Pipelines.video_inpainting, + 'damo/cv_video-inpainting'), + Tasks.human_wholebody_keypoint: + (Pipelines.human_wholebody_keypoint, + 'damo/cv_hrnetw48_human-wholebody-keypoint_image'), + Tasks.hand_static: (Pipelines.hand_static, + 'damo/cv_mobileface_hand-static'), + Tasks.face_human_hand_detection: + (Pipelines.face_human_hand_detection, + 'damo/cv_nanodet_face-human-hand-detection'), + Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), + Tasks.product_segmentation: (Pipelines.product_segmentation, + 'damo/cv_F3Net_product-segmentation'), + Tasks.referring_video_object_segmentation: + (Pipelines.referring_video_object_segmentation, + 'damo/cv_swin-t_referring_video-object-segmentation'), +} + + +def normalize_model_input(model, model_revision): + """ normalize the input model, to ensure that a model str is a valid local path: in other words, + for model represented by a model id, the model shall be downloaded locally + """ + if isinstance(model, str) and is_official_hub_path(model, model_revision): + # skip revision download if model is a local directory + if not os.path.exists(model): + # note that if there is already a local copy, snapshot_download will check and skip downloading + model = snapshot_download(model, revision=model_revision) + elif isinstance(model, list) and isinstance(model[0], str): + for idx in range(len(model)): + if is_official_hub_path( + model[idx], + model_revision) and not os.path.exists(model[idx]): + model[idx] = snapshot_download( + model[idx], revision=model_revision) + return model + + +def build_pipeline(cfg: ConfigDict, + task_name: str = None, + default_args: dict = None): + """ build pipeline given model config dict. + + Args: + cfg (:obj:`ConfigDict`): config dict for model object. + task_name (str, optional): task name, refer to + :obj:`Tasks` for more details. + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, PIPELINES, group_key=task_name, default_args=default_args) + + +def pipeline(task: str = None, + model: Union[str, List[str], Model, List[Model]] = None, + preprocessor=None, + config_file: str = None, + pipeline_name: str = None, + framework: str = None, + device: str = 'gpu', + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + **kwargs) -> Pipeline: + """ Factory method to build an obj:`Pipeline`. + + + Args: + task (str): Task name defining which pipeline will be returned. + model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object. + preprocessor: preprocessor object. + config_file (str, optional): path to config file. + pipeline_name (str, optional): pipeline class name or alias name. + framework (str, optional): framework type. + model_revision: revision of model(s) if getting from model hub, for multiple models, expecting + all models to have the same revision + device (str, optional): whether to use gpu or cpu is used to do inference. + + Return: + pipeline (obj:`Pipeline`): pipeline object for certain task. + + Examples: + ```python + >>> # Using default model for a task + >>> p = pipeline('image-classification') + >>> # Using pipeline with a model name + >>> p = pipeline('text-classification', model='damo/distilbert-base-uncased') + >>> # Using pipeline with a model object + >>> resnet = Model.from_pretrained('Resnet') + >>> p = pipeline('image-classification', model=resnet) + >>> # Using pipeline with a list of model names + >>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2']) + """ + if task is None and pipeline_name is None: + raise ValueError('task or pipeline_name is required') + + model = normalize_model_input(model, model_revision) + if pipeline_name is None: + # get default pipeline for this task + if isinstance(model, str) \ + or (isinstance(model, list) and isinstance(model[0], str)): + if is_official_hub_path(model, revision=model_revision): + # read config file from hub and parse + cfg = read_config( + model, revision=model_revision) if isinstance( + model, str) else read_config( + model[0], revision=model_revision) + check_config(cfg) + pipeline_name = cfg.pipeline.type + else: + # used for test case, when model is str and is not hub path + pipeline_name = get_pipeline_by_model_name(task, model) + elif model is not None: + # get pipeline info from Model object + first_model = model[0] if isinstance(model, list) else model + if not hasattr(first_model, 'pipeline'): + # model is instantiated by user, we should parse config again + cfg = read_config(first_model.model_dir) + check_config(cfg) + first_model.pipeline = cfg.pipeline + pipeline_name = first_model.pipeline.type + else: + pipeline_name, default_model_repo = get_default_pipeline_info(task) + model = normalize_model_input(default_model_repo, model_revision) + + cfg = ConfigDict(type=pipeline_name, model=model) + cfg.device = device + + if kwargs: + cfg.update(kwargs) + + if preprocessor is not None: + cfg.preprocessor = preprocessor + + return build_pipeline(cfg, task_name=task) + + +def add_default_pipeline_info(task: str, + model_name: str, + modelhub_name: str = None, + overwrite: bool = False): + """ Add default model for a task. + + Args: + task (str): task name. + model_name (str): model_name. + modelhub_name (str): name for default modelhub. + overwrite (bool): overwrite default info. + """ + if not overwrite: + assert task not in DEFAULT_MODEL_FOR_PIPELINE, \ + f'task {task} already has default model.' + + DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name) + + +def get_default_pipeline_info(task): + """ Get default info for certain task. + + Args: + task (str): task name. + + Return: + A tuple: first element is pipeline name(model_name), second element + is modelhub name. + """ + + if task not in DEFAULT_MODEL_FOR_PIPELINE: + # support pipeline which does not register default model + pipeline_name = list(PIPELINES.modules[task].keys())[0] + default_model = None + else: + pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] + return pipeline_name, default_model + + +def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]): + """ Get pipeline name by task name and model name + + Args: + task (str): task name. + model (str| list[str]): model names + """ + if isinstance(model, str): + model_key = model + else: + model_key = '_'.join(model) + assert model_key in PIPELINES.modules[task], \ + f'pipeline for task {task} model {model_key} not found.' + return model_key diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py new file mode 100644 index 00000000..97cd8761 --- /dev/null +++ b/modelscope/pipelines/cv/__init__.py @@ -0,0 +1,145 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .action_recognition_pipeline import ActionRecognitionPipeline + from .action_detection_pipeline import ActionDetectionPipeline + from .animal_recognition_pipeline import AnimalRecognitionPipeline + from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline + from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline + from .hand_2d_keypoints_pipeline import Hand2DKeypointsPipeline + from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline + from .hicossl_video_embedding_pipeline import HICOSSLVideoEmbeddingPipeline + from .crowd_counting_pipeline import CrowdCountingPipeline + from .image_detection_pipeline import ImageDetectionPipeline + from .image_salient_detection_pipeline import ImageSalientDetectionPipeline + from .face_detection_pipeline import FaceDetectionPipeline + from .face_image_generation_pipeline import FaceImageGenerationPipeline + from .face_recognition_pipeline import FaceRecognitionPipeline + from .general_recognition_pipeline import GeneralRecognitionPipeline + from .image_cartoon_pipeline import ImageCartoonPipeline + from .image_classification_pipeline import GeneralImageClassificationPipeline + from .image_color_enhance_pipeline import ImageColorEnhancePipeline + from .image_colorization_pipeline import ImageColorizationPipeline + from .image_classification_pipeline import ImageClassificationPipeline + from .image_denoise_pipeline import ImageDenoisePipeline + from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline + from .image_matting_pipeline import ImageMattingPipeline + from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline + from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline + from .image_reid_person_pipeline import ImageReidPersonPipeline + from .image_semantic_segmentation_pipeline import ImageSemanticSegmentationPipeline + from .image_style_transfer_pipeline import ImageStyleTransferPipeline + from .image_super_resolution_pipeline import ImageSuperResolutionPipeline + from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline + from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline + from .image_inpainting_pipeline import ImageInpaintingPipeline + from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline + from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline + from .live_category_pipeline import LiveCategoryPipeline + from .ocr_detection_pipeline import OCRDetectionPipeline + from .ocr_recognition_pipeline import OCRRecognitionPipeline + from .skin_retouching_pipeline import SkinRetouchingPipeline + from .tinynas_classification_pipeline import TinynasClassificationPipeline + from .video_category_pipeline import VideoCategoryPipeline + from .virtual_try_on_pipeline import VirtualTryonPipeline + from .shop_segmentation_pipleline import ShopSegmentationPipeline + from .easycv_pipelines import (EasyCVDetectionPipeline, + EasyCVSegmentationPipeline, + Face2DKeypointsPipeline, + HumanWholebodyKeypointsPipeline) + from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline + from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline + from .mog_face_detection_pipeline import MogFaceDetectionPipeline + from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline + from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline + from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline + from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin + from .hand_static_pipeline import HandStaticPipeline + from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline + +else: + _import_structure = { + 'action_recognition_pipeline': ['ActionRecognitionPipeline'], + 'action_detection_pipeline': ['ActionDetectionPipeline'], + 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], + 'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], + 'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], + 'hand_2d_keypoints_pipeline': ['Hand2DKeypointsPipeline'], + 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], + 'hicossl_video_embedding_pipeline': ['HICOSSLVideoEmbeddingPipeline'], + 'crowd_counting_pipeline': ['CrowdCountingPipeline'], + 'image_detection_pipeline': ['ImageDetectionPipeline'], + 'image_salient_detection_pipeline': ['ImageSalientDetectionPipeline'], + 'face_detection_pipeline': ['FaceDetectionPipeline'], + 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], + 'face_recognition_pipeline': ['FaceRecognitionPipeline'], + 'general_recognition_pipeline': ['GeneralRecognitionPipeline'], + 'image_classification_pipeline': + ['GeneralImageClassificationPipeline', 'ImageClassificationPipeline'], + 'image_cartoon_pipeline': ['ImageCartoonPipeline'], + 'image_denoise_pipeline': ['ImageDenoisePipeline'], + 'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'], + 'image_colorization_pipeline': ['ImageColorizationPipeline'], + 'image_instance_segmentation_pipeline': + ['ImageInstanceSegmentationPipeline'], + 'image_matting_pipeline': ['ImageMattingPipeline'], + 'image_panoptic_segmentation_pipeline': + ['ImagePanopticSegmentationPipeline'], + 'image_portrait_enhancement_pipeline': + ['ImagePortraitEnhancementPipeline'], + 'image_reid_person_pipeline': ['ImageReidPersonPipeline'], + 'image_semantic_segmentation_pipeline': + ['ImageSemanticSegmentationPipeline'], + 'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], + 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], + 'image_to_image_translation_pipeline': + ['Image2ImageTranslationPipeline'], + 'product_retrieval_embedding_pipeline': + ['ProductRetrievalEmbeddingPipeline'], + 'realtime_object_detection_pipeline': + ['RealtimeObjectDetectionPipeline'], + 'live_category_pipeline': ['LiveCategoryPipeline'], + 'image_to_image_generation_pipeline': + ['Image2ImageGenerationPipeline'], + 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], + 'ocr_detection_pipeline': ['OCRDetectionPipeline'], + 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], + 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], + 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], + 'video_category_pipeline': ['VideoCategoryPipeline'], + 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], + 'shop_segmentation_pipleline': ['ShopSegmentationPipeline'], + 'easycv_pipeline': [ + 'EasyCVDetectionPipeline', + 'EasyCVSegmentationPipeline', + 'Face2DKeypointsPipeline', + 'HumanWholebodyKeypointsPipeline', + ], + 'text_driven_segmentation_pipeline': + ['TextDrivenSegmentationPipeline'], + 'movie_scene_segmentation_pipeline': + ['MovieSceneSegmentationPipeline'], + 'mog_face_detection_pipeline': ['MogFaceDetectionPipeline'], + 'ulfd_face_detection_pipeline': ['UlfdFaceDetectionPipeline'], + 'retina_face_detection_pipeline': ['RetinaFaceDetectionPipeline'], + 'facial_expression_recognition_pipelin': + ['FacialExpressionRecognitionPipeline'], + 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], + 'hand_static_pipeline': ['HandStaticPipeline'], + 'referring_video_object_segmentation_pipeline': [ + 'ReferringVideoObjectSegmentationPipeline' + ], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/cv/action_detection_pipeline.py b/modelscope/pipelines/cv/action_detection_pipeline.py new file mode 100644 index 00000000..74d1862e --- /dev/null +++ b/modelscope/pipelines/cv/action_detection_pipeline.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os.path as osp +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.action_detection import ActionDetONNX +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.action_detection, module_name=Pipelines.action_detection) +class ActionDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a action detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.ONNX_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.cfg.MODEL.model_file = model_path + self.model = ActionDetONNX(self.model, self.cfg.MODEL, + self.device_name) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + video_name = input + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_name': video_name} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + preds = self.model.forward(input['video_name']) + labels = sum([pred['actions']['labels'] for pred in preds], []) + scores = sum([pred['actions']['scores'] for pred in preds], []) + boxes = sum([pred['actions']['boxes'] for pred in preds], []) + timestamps = sum([[pred['timestamp']] * len(pred['actions']['labels']) + for pred in preds], []) + out = { + OutputKeys.TIMESTAMPS: timestamps, + OutputKeys.LABELS: labels, + OutputKeys.SCORES: scores, + OutputKeys.BOXES: boxes + } + return out + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py new file mode 100644 index 00000000..993a32f0 --- /dev/null +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -0,0 +1,123 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os.path as osp +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.action_recognition import (BaseVideoModel, + PatchShiftTransformer) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import ReadVideoData +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.action_recognition, module_name=Pipelines.action_recognition) +class ActionRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a action recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + + self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) + self.infer_model.eval() + self.infer_model.load_state_dict( + torch.load(model_path, map_location=self.device)['model_state']) + self.label_mapping = self.cfg.label_mapping + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + video_input_data = ReadVideoData(self.cfg, input).to(self.device) + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_data': video_input_data} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred = self.perform_inference(input['video_data']) + output_label = self.label_mapping[str(pred)] + return {OutputKeys.LABELS: output_label} + + @torch.no_grad() + def perform_inference(self, data, max_bsz=4): + iter_num = math.ceil(data.size(0) / max_bsz) + preds_list = [] + for i in range(iter_num): + preds_list.append( + self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0]) + pred = torch.cat(preds_list, dim=0) + return pred.mean(dim=0).argmax().item() + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + +@PIPELINES.register_module( + Tasks.action_recognition, module_name=Pipelines.pst_action_recognition) +class PSTActionRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a PST action recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.infer_model = PatchShiftTransformer(model).to(self.device) + self.infer_model.eval() + self.infer_model.load_state_dict( + torch.load(model_path, map_location=self.device)['state_dict']) + self.label_mapping = self.cfg.label_mapping + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + video_input_data = ReadVideoData(self.cfg, input).to(self.device) + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_data': video_input_data} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred = self.perform_inference(input['video_data']) + output_label = self.label_mapping[str(pred)] + return {OutputKeys.LABELS: output_label} + + @torch.no_grad() + def perform_inference(self, data, max_bsz=4): + iter_num = math.ceil(data.size(0) / max_bsz) + preds_list = [] + for i in range(iter_num): + preds_list.append( + self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])) + pred = torch.cat(preds_list, dim=0) + return pred.mean(dim=0).argmax().item() + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/animal_recognition_pipeline.py b/modelscope/pipelines/cv/animal_recognition_pipeline.py new file mode 100644 index 00000000..671a5b4c --- /dev/null +++ b/modelscope/pipelines/cv/animal_recognition_pipeline.py @@ -0,0 +1,120 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Pipelines +from modelscope.models.cv.animal_recognition import Bottleneck, ResNet +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.animal_recognition, module_name=Pipelines.animal_recognition) +class AnimalRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a animal recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + import torch + + def resnest101(**kwargs): + model = ResNet( + Bottleneck, [3, 4, 23, 3], + radix=2, + groups=1, + bottleneck_width=64, + deep_stem=True, + stem_width=64, + avg_down=True, + avd=True, + avd_first=False, + **kwargs) + return model + + def filter_param(src_params, own_state): + copied_keys = [] + for name, param in src_params.items(): + if 'module.' == name[0:7]: + name = name[7:] + if '.module.' not in list(own_state.keys())[0]: + name = name.replace('.module.', '.') + if (name in own_state) and (own_state[name].shape + == param.shape): + own_state[name].copy_(param) + copied_keys.append(name) + + def load_pretrained(model, src_params): + if 'state_dict' in src_params: + src_params = src_params['state_dict'] + own_state = model.state_dict() + filter_param(src_params, own_state) + model.load_state_dict(own_state) + + self.model = resnest101(num_classes=8288) + local_model_dir = model + if osp.exists(model): + local_model_dir = model + else: + local_model_dir = snapshot_download(model) + self.local_path = local_model_dir + src_params = torch.load( + osp.join(local_model_dir, 'pytorch_model.pt'), 'cpu') + load_pretrained(self.model, src_params) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + test_transforms = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), normalize + ]) + img = test_transforms(img) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + def set_phase(model, is_train): + if is_train: + model.train() + else: + model.eval() + + is_train = False + set_phase(self.model, is_train) + img = input['img'] + input_img = torch.unsqueeze(img, 0) + outputs = self.model(input_img) + return {'outputs': outputs} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + label_mapping_path = osp.join(self.local_path, 'label_mapping.txt') + with open(label_mapping_path, 'r') as f: + label_mapping = f.readlines() + score = torch.max(inputs['outputs']) + inputs = { + OutputKeys.SCORES: [score.item()], + OutputKeys.LABELS: + [label_mapping[inputs['outputs'].argmax()].split('\t')[1]] + } + return inputs diff --git a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py new file mode 100644 index 00000000..bc2e975d --- /dev/null +++ b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py @@ -0,0 +1,270 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict, List, Union + +import cv2 +import json +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.body_2d_keypoints.hrnet_v2 import \ + PoseHighResolutionNetV2 +from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15 +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Model, Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.body_2d_keypoints, module_name=Pipelines.body_2d_keypoints) +class Body2DKeypointsPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + super().__init__(model=model, **kwargs) + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + self.keypoint_model = KeypointsDetection(model, device) + + self.human_detect_model_id = 'damo/cv_resnet18_human-detection' + self.human_detector = pipeline( + Tasks.human_detection, model=self.human_detect_model_id) + + def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: + output = self.human_detector(input) + + image = LoadImage.convert_to_ndarray(input) + image = image[:, :, [2, 1, 0]] # rgb2bgr + + return {'image': image, 'output': output} + + def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]: + input_image = input['image'] + output = input['output'] + + bboxes = [] + scores = np.array(output[OutputKeys.SCORES].cpu(), dtype=np.float32) + boxes = np.array(output[OutputKeys.BOXES].cpu(), dtype=np.float32) + + for id, box in enumerate(boxes): + box_tmp = [ + box[0], box[1], box[2] - box[0], box[3] - box[1], scores[id], 0 + ] + bboxes.append(box_tmp) + if len(bboxes) == 0: + logger.error('cannot detect human in the image') + return [None, None] + human_images, metas = self.keypoint_model.preprocess( + [bboxes, input_image]) + outputs = self.keypoint_model.forward(human_images) + return [outputs, metas] + + def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], + **kwargs) -> str: + if input[0] is None or input[1] is None: + return { + OutputKeys.BOXES: [], + OutputKeys.KEYPOINTS: [], + OutputKeys.SCORES: [] + } + + poses, scores, boxes = self.keypoint_model.postprocess(input) + result_boxes = [] + for box in boxes: + result_boxes.append([box[0][0], box[0][1], box[1][0], box[1][1]]) + return { + OutputKeys.BOXES: result_boxes, + OutputKeys.KEYPOINTS: poses, + OutputKeys.SCORES: scores + } + + +class KeypointsDetection(): + + def __init__(self, model: str, device: str, **kwargs): + self.model = model + self.device = device + cfg = cfg_128x128_15 + self.key_points_model = PoseHighResolutionNetV2(cfg) + pretrained_state_dict = torch.load( + osp.join(self.model, ModelFile.TORCH_MODEL_FILE), + map_location=device) + self.key_points_model.load_state_dict( + pretrained_state_dict, strict=False) + self.key_points_model = self.key_points_model.to(device) + self.key_points_model.eval() + + self.input_size = cfg['MODEL']['IMAGE_SIZE'] + self.lst_parent_ids = cfg['DATASET']['PARENT_IDS'] + self.lst_left_ids = cfg['DATASET']['LEFT_IDS'] + self.lst_right_ids = cfg['DATASET']['RIGHT_IDS'] + self.box_enlarge_ratio = 0.05 + + def train(self): + return self.key_points_model.train() + + def eval(self): + return self.key_points_model.eval() + + def forward(self, input: Tensor) -> Tensor: + with torch.no_grad(): + return self.key_points_model.forward(input.to(self.device)) + + def get_pts(self, heatmaps): + [pts_num, height, width] = heatmaps.shape + pts = [] + scores = [] + for i in range(pts_num): + heatmap = heatmaps[i, :, :] + pt = np.where(heatmap == np.max(heatmap)) + scores.append(np.max(heatmap)) + x = pt[1][0] + y = pt[0][0] + + [h, w] = heatmap.shape + if x >= 1 and x <= w - 2 and y >= 1 and y <= h - 2: + x_diff = heatmap[y, x + 1] - heatmap[y, x - 1] + y_diff = heatmap[y + 1, x] - heatmap[y - 1, x] + x_sign = 0 + y_sign = 0 + if x_diff < 0: + x_sign = -1 + if x_diff > 0: + x_sign = 1 + if y_diff < 0: + y_sign = -1 + if y_diff > 0: + y_sign = 1 + x = x + x_sign * 0.25 + y = y + y_sign * 0.25 + + pts.append([x, y]) + return pts, scores + + def pts_transform(self, meta, pts, lt_x, lt_y): + pts_new = [] + s = meta['s'] + o = meta['o'] + size = len(pts) + for i in range(size): + ratio = 4 + x = (int(pts[i][0] * ratio) - o[0]) / s[0] + y = (int(pts[i][1] * ratio) - o[1]) / s[1] + + pt = [x, y] + pts_new.append(pt) + + return pts_new + + def postprocess(self, inputs: Dict[Tensor, Dict[str, np.ndarray]], + **kwargs): + output_poses = [] + output_scores = [] + output_boxes = [] + for i in range(inputs[0].shape[0]): + outputs, scores = self.get_pts( + (inputs[0][i]).detach().cpu().numpy()) + outputs = self.pts_transform(inputs[1][i], outputs, 0, 0) + box = np.array(inputs[1][i]['human_box'][0:4]).reshape(2, 2) + outputs = np.array(outputs) + box[0] + output_poses.append(outputs.tolist()) + output_scores.append(scores) + output_boxes.append(box.tolist()) + return output_poses, output_scores, output_boxes + + def image_crop_resize(self, input, margin=[0, 0]): + pad_img = np.zeros((self.input_size[1], self.input_size[0], 3), + dtype=np.uint8) + + h, w, ch = input.shape + + h_new = self.input_size[1] - margin[1] * 2 + w_new = self.input_size[0] - margin[0] * 2 + s0 = float(h_new) / h + s1 = float(w_new) / w + s = min(s0, s1) + w_new = int(s * w) + h_new = int(s * h) + + img_new = cv2.resize(input, (w_new, h_new), cv2.INTER_LINEAR) + + cx = self.input_size[0] // 2 + cy = self.input_size[1] // 2 + + pad_img[cy - h_new // 2:cy - h_new // 2 + h_new, + cx - w_new // 2:cx - w_new // 2 + w_new, :] = img_new + + return pad_img, np.array([cx, cy]), np.array([s, s]), np.array( + [cx - w_new // 2, cy - h_new // 2]) + + def image_transform(self, input: Input) -> Dict[Tensor, Any]: + if isinstance(input, str): + image = cv2.imread(input, -1)[:, :, 0:3] + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + else: + image = input + image = image[:, :, 0:3] + elif isinstance(input, torch.Tensor): + image = input.cpu().numpy()[:, :, 0:3] + + w, h, _ = image.shape + w_new = self.input_size[0] + h_new = self.input_size[1] + + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + img_resize, c, s, o = self.image_crop_resize(image) + + img_resize = np.float32(img_resize) / 255. + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + img_resize = (img_resize - mean) / std + + input_data = np.zeros([1, 3, h_new, w_new], dtype=np.float32) + + img_resize = img_resize.transpose((2, 0, 1)) + input_data[0, :] = img_resize + meta = {'c': c, 's': s, 'o': o} + return [torch.from_numpy(input_data), meta] + + def crop_image(self, image, box): + height, width, _ = image.shape + w, h = box[1] - box[0] + box[0, :] -= (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio) + box[1, :] += (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio) + + box[0, 0] = min(max(box[0, 0], 0.0), width) + box[0, 1] = min(max(box[0, 1], 0.0), height) + box[1, 0] = min(max(box[1, 0], 0.0), width) + box[1, 1] = min(max(box[1, 1], 0.0), height) + + cropped_image = image[int(box[0][1]):int(box[1][1]), + int(box[0][0]):int(box[1][0])] + return cropped_image + + def preprocess(self, input: Dict[Tensor, Tensor]) -> Dict[Tensor, Any]: + bboxes = input[0] + image = input[1] + + lst_human_images = [] + lst_meta = [] + for i in range(len(bboxes)): + box = np.array(bboxes[i][0:4]).reshape(2, 2) + box[1] += box[0] + human_image = self.crop_image(image.clone(), box) + human_image, meta = self.image_transform(human_image) + lst_human_images.append(human_image) + meta['human_box'] = box + lst_meta.append(meta) + + return [torch.cat(lst_human_images, dim=0), lst_meta] diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py new file mode 100644 index 00000000..d113fb3c --- /dev/null +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -0,0 +1,372 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import datetime +import os.path as osp +import tempfile +from typing import Any, Dict, List, Union + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d.axes3d as p3 +import numpy as np +import torch +from matplotlib import animation +from matplotlib.animation import writers +from matplotlib.ticker import MultipleLocator + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.body_3d_keypoints.body_3d_pose import ( + BodyKeypointsDetection3D, KeypointsTypes) +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Model, Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +matplotlib.use('Agg') + +logger = get_logger() + + +def convert_2_h36m(joints, joints_nbr=15): + lst_mappings = [[0, 8], [1, 7], [2, 12], [3, 13], [4, 14], [5, 9], [6, 10], + [7, 11], [8, 1], [9, 2], [10, 3], [11, 4], [12, 5], + [13, 6], [14, 0]] + nbr, dim = joints.shape + h36m_joints = np.zeros((nbr, dim)) + for mapping in lst_mappings: + h36m_joints[mapping[1]] = joints[mapping[0]] + + if joints_nbr == 17: + lst_mappings_17 = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], + [5, 5], [6, 6], [7, 8], [8, 10], [9, 11], + [10, 12], [11, 13], [12, 14], [13, 15], + [14, 16]]) + h36m_joints_17 = np.zeros((17, 2)) + h36m_joints_17[lst_mappings_17[:, 1]] = h36m_joints[lst_mappings_17[:, + 0]] + h36m_joints_17[7] = (h36m_joints_17[0] + h36m_joints_17[8]) * 0.5 + h36m_joints_17[9] = (h36m_joints_17[8] + h36m_joints_17[10]) * 0.5 + h36m_joints = h36m_joints_17 + + return h36m_joints + + +def smooth_pts(cur_pts, pre_pts, bbox, smooth_x=15.0, smooth_y=15.0): + if pre_pts is None: + return cur_pts + + w, h = bbox[1] - bbox[0] + if w == 0 or h == 0: + return cur_pts + + size_pre = len(pre_pts) + size_cur = len(cur_pts) + if (size_pre == 0 or size_cur == 0): + return cur_pts + + factor_x = -(smooth_x / w) + factor_y = -(smooth_y / w) + + for i in range(size_cur): + w_x = np.exp(factor_x * np.abs(cur_pts[i][0] - pre_pts[i][0])) + w_y = np.exp(factor_y * np.abs(cur_pts[i][1] - pre_pts[i][1])) + cur_pts[i][0] = (1.0 - w_x) * cur_pts[i][0] + w_x * pre_pts[i][0] + cur_pts[i][1] = (1.0 - w_y) * cur_pts[i][1] + w_y * pre_pts[i][1] + return cur_pts + + +def smoothing(lst_kps, lst_bboxes, smooth_x=15.0, smooth_y=15.0): + assert lst_kps.shape[0] == lst_bboxes.shape[0] + + lst_smoothed_kps = [] + prev_pts = None + for i in range(lst_kps.shape[0]): + smoothed_cur_kps = smooth_pts(lst_kps[i], prev_pts, + lst_bboxes[i][0:-1].reshape(2, 2), + smooth_x, smooth_y) + lst_smoothed_kps.append(smoothed_cur_kps) + prev_pts = smoothed_cur_kps + + return np.array(lst_smoothed_kps) + + +def convert_2_h36m_data(lst_kps, lst_bboxes, joints_nbr=15): + lst_kps = lst_kps.squeeze() + lst_bboxes = lst_bboxes.squeeze() + + assert lst_kps.shape[0] == lst_bboxes.shape[0] + + lst_kps = smoothing(lst_kps, lst_bboxes) + + keypoints = [] + for i in range(lst_kps.shape[0]): + h36m_joints_2d = convert_2_h36m(lst_kps[i], joints_nbr=joints_nbr) + keypoints.append(h36m_joints_2d) + return keypoints + + +@PIPELINES.register_module( + Tasks.body_3d_keypoints, module_name=Pipelines.body_3d_keypoints) +class Body3DKeypointsPipeline(Pipeline): + + def __init__(self, model: Union[str, BodyKeypointsDetection3D], **kwargs): + """Human body 3D pose estimation. + + Args: + model (Union[str, BodyKeypointsDetection3D]): model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + self.keypoint_model_3d = model if isinstance( + model, BodyKeypointsDetection3D) else Model.from_pretrained(model) + self.keypoint_model_3d.eval() + + # init human body 2D keypoints detection pipeline + self.human_body_2d_kps_det_pipeline = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' + self.human_body_2d_kps_detector = pipeline( + Tasks.body_2d_keypoints, + model=self.human_body_2d_kps_det_pipeline, + device='gpu' if torch.cuda.is_available() else 'cpu') + + def preprocess(self, input: Input) -> Dict[str, Any]: + self.video_url = input + video_frames = self.read_video_frames(self.video_url) + if 0 == len(video_frames): + res = {'success': False, 'msg': 'get video frame failed.'} + return res + + all_2d_poses = [] + all_boxes_with_socre = [] + max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints + for i, frame in enumerate(video_frames): + kps_2d = self.human_body_2d_kps_detector(frame) + if [] == kps_2d.get('boxes'): + res = { + 'success': False, + 'msg': f'fail to detect person at image frame {i}' + } + return res + + box = kps_2d['boxes'][ + 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox + pose = kps_2d['keypoints'][0] # keypoints: [15, 2] + score = kps_2d['scores'][0] # keypoints: [15, 2] + all_2d_poses.append(pose) + all_boxes_with_socre.append( + list(np.array(box).reshape( + (-1))) + [score]) # construct to list with shape [5] + if (i + 1) >= max_frame: + break + + all_2d_poses_np = np.array(all_2d_poses).reshape( + (len(all_2d_poses), 15, + 2)) # 15: 2d keypoints number, 2: keypoint coordinate (x, y) + all_boxes_np = np.array(all_boxes_with_socre).reshape( + (len(all_boxes_with_socre), 5)) # [x1, y1, x2, y2, score] + + kps_2d_h36m_17 = convert_2_h36m_data( + all_2d_poses_np, + all_boxes_np, + joints_nbr=self.keypoint_model_3d.cfg.model.MODEL.IN_NUM_JOINTS) + kps_2d_h36m_17 = np.array(kps_2d_h36m_17) + res = {'success': True, 'input_2d_pts': kps_2d_h36m_17} + return res + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if not input['success']: + res = {'success': False, 'msg': 'preprocess failed.'} + return res + + input_2d_pts = input['input_2d_pts'] + outputs = self.keypoint_model_3d.preprocess(input_2d_pts) + outputs = self.keypoint_model_3d.forward(outputs) + res = dict({'success': True}, **outputs) + return res + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + output_video_path = kwargs.get('output_video', None) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + + res = { + OutputKeys.KEYPOINTS: [], + OutputKeys.TIMESTAMPS: [], + OutputKeys.OUTPUT_VIDEO: output_video_path + } + + if not input['success']: + res[OutputKeys.OUTPUT_VIDEO] = self.video_url + else: + poses = input[KeypointsTypes.POSES_CAMERA] + pred_3d_pose = poses.data.cpu().numpy()[ + 0] # [frame_num, joint_num, joint_dim] + + if 'render' in self.keypoint_model_3d.cfg.keys(): + self.render_prediction(pred_3d_pose, output_video_path) + res[OutputKeys.OUTPUT_VIDEO] = output_video_path + + res[OutputKeys.KEYPOINTS] = pred_3d_pose + res[OutputKeys.TIMESTAMPS] = self.timestamps + return res + + def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]): + """Read video from local video file or from a video stream URL. + + Args: + video_url (str or cv2.VideoCapture): Video path or video stream. + + Raises: + Exception: Open video fail. + + Returns: + [nd.array]: List of video frames. + """ + + def timestamp_format(seconds): + m, s = divmod(seconds, 60) + h, m = divmod(m, 60) + time = '%02d:%02d:%06.3f' % (h, m, s) + return time + + frames = [] + self.timestamps = [] # for video render + if isinstance(video_url, str): + cap = cv2.VideoCapture(video_url) + if not cap.isOpened(): + raise Exception( + 'modelscope error: %s cannot be decoded by OpenCV.' % + (video_url)) + else: + cap = video_url + + self.fps = cap.get(cv2.CAP_PROP_FPS) + if self.fps is None or self.fps <= 0: + raise Exception('modelscope error: %s cannot get video fps info.' % + (video_url)) + + max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + self.timestamps.append( + timestamp_format(seconds=frame_idx / self.fps)) + frame_idx += 1 + frames.append(frame) + if frame_idx >= max_frame_num: + break + cap.release() + return frames + + def render_prediction(self, pose3d_cam_rr, output_video_path): + """render predict result 3d poses. + + Args: + pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints + output_video_path (str): output path for video + Returns: + """ + frame_num = pose3d_cam_rr.shape[0] + + left_points = [11, 12, 13, 4, 5, 6] # joints of left body + edges = [[0, 1], [0, 4], [0, 7], [1, 2], [4, 5], [5, 6], [2, + 3], [7, 8], + [8, 9], [8, 11], [8, 14], [14, 15], [15, 16], [11, 12], + [12, 13], [9, 10]] # connection between joints + + fig = plt.figure() + ax = p3.Axes3D(fig) + x_major_locator = MultipleLocator(0.5) + + ax.xaxis.set_major_locator(x_major_locator) + ax.yaxis.set_major_locator(x_major_locator) + ax.zaxis.set_major_locator(x_major_locator) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + ax.set_zlim(-1, 1) + # view direction + azim = self.keypoint_model_3d.cfg.render.azim + elev = self.keypoint_model_3d.cfg.render.elev + ax.view_init(elev, azim) + + # init plot, essentially + x = pose3d_cam_rr[0, :, 0] + y = pose3d_cam_rr[0, :, 1] + z = pose3d_cam_rr[0, :, 2] + points, = ax.plot(x, y, z, 'r.') + + def renderBones(xs, ys, zs): + """render bones in skeleton + + Args: + xs (nd.array): [joint_num, joint_channel] + ys (nd.array): [joint_num, joint_channel] + zs (nd.array): [joint_num, joint_channel] + """ + bones = {} + for idx, edge in enumerate(edges): + index1, index2 = edge[0], edge[1] + if index1 in left_points: + edge_color = 'red' + else: + edge_color = 'blue' + connect = ax.plot([xs[index1], xs[index2]], + [ys[index1], ys[index2]], + [zs[index1], zs[index2]], + linewidth=2, + color=edge_color) # plot edge + bones[idx] = connect[0] + return bones + + bones = renderBones(x, y, z) + + def update(frame_idx, points, bones): + """update animation + + Args: + frame_idx (int): frame index + points (mpl_toolkits.mplot3d.art3d.Line3D): skeleton points ploter + bones (dict[int, mpl_toolkits.mplot3d.art3d.Line3D]): connection ploter + + Returns: + tuple: points and bones ploter + """ + xs = pose3d_cam_rr[frame_idx, :, 0] + ys = pose3d_cam_rr[frame_idx, :, 1] + zs = pose3d_cam_rr[frame_idx, :, 2] + + # update bones + for idx, edge in enumerate(edges): + index1, index2 = edge[0], edge[1] + x1x2 = (xs[index1], xs[index2]) + y1y2 = (ys[index1], ys[index2]) + z1z2 = (zs[index1], zs[index2]) + bones[idx].set_xdata(x1x2) + bones[idx].set_ydata(y1y2) + bones[idx].set_3d_properties(z1z2, 'z') + + # update joints + points.set_data(xs, ys) + points.set_3d_properties(zs, 'z') + if 0 == frame_idx / 100: + logger.info(f'rendering {frame_idx}/{frame_num}') + return points, bones + + ani = animation.FuncAnimation( + fig=fig, + func=update, + frames=frame_num, + interval=self.fps, + fargs=(points, bones)) + + # save mp4 + Writer = writers['ffmpeg'] + writer = Writer(fps=self.fps, metadata={}, bitrate=4096) + ani.save(output_video_path, writer=writer) diff --git a/modelscope/pipelines/cv/card_detection_pipeline.py b/modelscope/pipelines/cv/card_detection_pipeline.py new file mode 100644 index 00000000..00b18024 --- /dev/null +++ b/modelscope/pipelines/cv/card_detection_pipeline.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Pipelines +from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.cv.face_detection_pipeline import \ + FaceDetectionPipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.card_detection, module_name=Pipelines.card_detection) +class CardDetectionPipeline(FaceDetectionPipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a card detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + thr = 0.45 # card/face detect use different threshold + super().__init__(model=model, score_thr=thr, **kwargs) diff --git a/modelscope/pipelines/cv/cmdssl_video_embedding_pipeline.py b/modelscope/pipelines/cv/cmdssl_video_embedding_pipeline.py new file mode 100644 index 00000000..deb17561 --- /dev/null +++ b/modelscope/pipelines/cv/cmdssl_video_embedding_pipeline.py @@ -0,0 +1,158 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os.path as osp +from typing import Any, Dict + +import decord +import numpy as np +import torch +import torchvision.transforms.functional as TF +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.cmdssl_video_embedding import resnet26_2p1d +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_embedding, module_name=Pipelines.cmdssl_video_embedding) +class CMDSSLVideoEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a CMDSSL Video Embedding pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.model = resnet26_2p1d(num_classes=None, last_pool=True) + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.model = self.model.to(self._device).eval().requires_grad_(False) + self.model.load_state_dict(torch.load(model_path)) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + decord.bridge.set_bridge('native') + + transforms = VCompose([ + VRescale(size=self.cfg.DATA.scale_size), + VCenterCrop(size=self.cfg.DATA.crop_size), + VToTensor(), + VNormalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std) + ]) + + clip_len = (self.cfg.DATA.video_frames + - 1) * self.cfg.DATA.video_stride + 1 + vr = decord.VideoReader(input, ctx=decord.cpu(0)) + if len(vr) <= clip_len: + init_frames = np.zeros(self.cfg.DATA.multi_crop, dtype=int) + else: + init_frames = np.linspace(0, + len(vr) - clip_len, + self.cfg.DATA.multi_crop + 1) + init_frames = ((init_frames[1:] + init_frames[:-1]) + / 2.).astype(int) + + indices = np.arange(0, clip_len, self.cfg.DATA.video_stride) + indices = (init_frames[:, None] + indices[None, :]).reshape(-1) + indices[indices >= len(vr)] = 0 + + frames = torch.from_numpy(vr.get_batch(indices).asnumpy()).chunk( + self.cfg.DATA.multi_crop, dim=0) + frames = [ + transforms([Image.fromarray(f) for f in u.numpy()]) for u in frames + ] + frames = torch.stack(frames, dim=0) + result = {'video_data': frames} + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + frames = input['video_data'].to(self._device) + feature = self.model(frames) + feature = feature.mean(0) + return {OutputKeys.VIDEO_EMBEDDING: feature.data.cpu().numpy()} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + +class VCompose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, item): + for t in self.transforms: + item = t(item) + return item + + +class VRescale(object): + + def __init__(self, size=128): + self.size = size + + def __call__(self, vclip): + w, h = vclip[0].size + scale = self.size / min(w, h) + out_w, out_h = int(round(w * scale)), int(round(h * scale)) + vclip = [u.resize((out_w, out_h), Image.BILINEAR) for u in vclip] + return vclip + + +class VCenterCrop(object): + + def __init__(self, size=112): + self.size = size + + def __call__(self, vclip): + w, h = vclip[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + vclip = [ + u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in vclip + ] + return vclip + + +class VToTensor(object): + + def __call__(self, vclip): + vclip = torch.stack([TF.to_tensor(u) for u in vclip], dim=1) + return vclip + + +class VNormalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, vclip): + assert vclip.min() > -0.1 and vclip.max() < 1.1, \ + 'vclip values should be in [0, 1]' + vclip = vclip.clone() + if not isinstance(self.mean, torch.Tensor): + self.mean = vclip.new_tensor(self.mean).view(-1, 1, 1, 1) + if not isinstance(self.std, torch.Tensor): + self.std = vclip.new_tensor(self.std).view(-1, 1, 1, 1) + vclip.sub_(self.mean).div_(self.std) + return vclip diff --git a/modelscope/pipelines/cv/crowd_counting_pipeline.py b/modelscope/pipelines/cv/crowd_counting_pipeline.py new file mode 100644 index 00000000..93fffdf2 --- /dev/null +++ b/modelscope/pipelines/cv/crowd_counting_pipeline.py @@ -0,0 +1,154 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +from typing import Any, Dict + +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.crowd_counting import HRNetCrowdCounting +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.crowd_counting, module_name=Pipelines.crowd_counting) +class CrowdCountingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + assert isinstance(model, str), 'model must be a single str' + super().__init__(model=model, auto_collate=False, **kwargs) + logger.info(f'loading model from dir {model}') + self.infer_model = HRNetCrowdCounting(model).to(self.device) + self.infer_model.eval() + logger.info('load model done') + + def resize(self, img): + height = img.size[1] + width = img.size[0] + resize_height = height + resize_width = width + if resize_width >= 2048: + tmp = resize_width + resize_width = 2048 + resize_height = (resize_width / tmp) * resize_height + + if resize_height >= 2048: + tmp = resize_height + resize_height = 2048 + resize_width = (resize_height / tmp) * resize_width + + if resize_height <= 416: + tmp = resize_height + resize_height = 416 + resize_width = (resize_height / tmp) * resize_width + if resize_width <= 416: + tmp = resize_width + resize_width = 416 + resize_height = (resize_width / tmp) * resize_height + + # other constraints + if resize_height < resize_width: + if resize_width / resize_height > 2048 / 416: # 1024/416=2.46 + resize_width = 2048 + resize_height = 416 + else: + if resize_height / resize_width > 2048 / 416: + resize_height = 2048 + resize_width = 416 + + resize_height = math.ceil(resize_height / 32) * 32 + resize_width = math.ceil(resize_width / 32) * 32 + img = transforms.Resize([resize_height, resize_width])(img) + return img + + def merge_crops(self, eval_shape, eval_p, pred_m): + for i in range(3): + for j in range(3): + start_h, start_w = math.floor(eval_shape[2] / 4), math.floor( + eval_shape[3] / 4) + valid_h, valid_w = eval_shape[2] // 2, eval_shape[3] // 2 + pred_h = math.floor( + 3 * eval_shape[2] / 4) + (eval_shape[2] // 2) * ( + i - 1) + pred_w = math.floor( + 3 * eval_shape[3] / 4) + (eval_shape[3] // 2) * ( + j - 1) + if i == 0: + valid_h = math.floor(3 * eval_shape[2] / 4) + start_h = 0 + pred_h = 0 + elif i == 2: + valid_h = math.ceil(3 * eval_shape[2] / 4) + + if j == 0: + valid_w = math.floor(3 * eval_shape[3] / 4) + start_w = 0 + pred_w = 0 + elif j == 2: + valid_w = math.ceil(3 * eval_shape[3] / 4) + pred_m[:, :, pred_h:pred_h + valid_h, pred_w:pred_w + + valid_w] += eval_p[i * 3 + j:i * 3 + j + 1, :, + start_h:start_h + valid_h, + start_w:start_w + valid_w] + return pred_m + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + img = self.resize(img) + img_ori_tensor = transforms.ToTensor()(img) + img_shape = img_ori_tensor.shape + img = transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))( + img_ori_tensor) + patch_height, patch_width = (img_shape[1]) // 2, (img_shape[2]) // 2 + imgs = [] + for i in range(3): + for j in range(3): + start_h, start_w = (patch_height // 2) * i, (patch_width + // 2) * j + imgs.append(img[:, start_h:start_h + patch_height, + start_w:start_w + patch_width]) + + imgs = torch.stack(imgs) + eval_img = imgs.to(self.device) + eval_patchs = torch.squeeze(eval_img) + prediction_map = torch.zeros( + (1, 1, img_shape[1] // 2, img_shape[2] // 2)).to(self.device) + result = { + 'img': eval_patchs, + 'map': prediction_map, + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + counts, img_data = self.perform_inference(input) + return {OutputKeys.SCORES: counts, OutputKeys.OUTPUT_IMG: img_data} + + @torch.no_grad() + def perform_inference(self, data): + eval_patchs = data['img'] + prediction_map = data['map'] + eval_prediction, _, _ = self.infer_model(eval_patchs) + eval_patchs_shape = eval_prediction.shape + prediction_map = self.merge_crops(eval_patchs_shape, eval_prediction, + prediction_map) + + return torch.sum( + prediction_map, dim=( + 1, 2, + 3)).data.cpu().numpy(), prediction_map.data.cpu().numpy()[0][0] + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/easycv_pipelines/__init__.py b/modelscope/pipelines/cv/easycv_pipelines/__init__.py new file mode 100644 index 00000000..e0209b85 --- /dev/null +++ b/modelscope/pipelines/cv/easycv_pipelines/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .detection_pipeline import EasyCVDetectionPipeline + from .segmentation_pipeline import EasyCVSegmentationPipeline + from .face_2d_keypoints_pipeline import Face2DKeypointsPipeline + from .human_wholebody_keypoint_pipeline import HumanWholebodyKeypointsPipeline +else: + _import_structure = { + 'detection_pipeline': ['EasyCVDetectionPipeline'], + 'segmentation_pipeline': ['EasyCVSegmentationPipeline'], + 'face_2d_keypoints_pipeline': ['Face2DKeypointsPipeline'], + 'human_wholebody_keypoint_pipeline': + ['HumanWholebodyKeypointsPipeline'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/cv/easycv_pipelines/base.py b/modelscope/pipelines/cv/easycv_pipelines/base.py new file mode 100644 index 00000000..c130aea0 --- /dev/null +++ b/modelscope/pipelines/cv/easycv_pipelines/base.py @@ -0,0 +1,114 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import os.path as osp +from typing import Any + +import numpy as np +from easycv.utils.ms_utils import EasyCVMeta +from PIL import ImageFile + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines.util import is_official_hub_path +from modelscope.utils.config import Config +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.device import create_device + + +class EasyCVPipeline(object): + """Base pipeline for EasyCV. + Loading configuration file of modelscope style by default, + but it is actually use the predictor api of easycv to predict. + So here we do some adaptation work for configuration and predict api. + """ + + def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): + """ + model (str): model id on modelscope hub or local model path. + model_file_pattern (str): model file pattern. + + """ + self.model_file_pattern = model_file_pattern + + assert isinstance(model, str) + if osp.exists(model): + model_dir = model + else: + assert is_official_hub_path( + model), 'Only support local model path and official hub path!' + model_dir = snapshot_download( + model_id=model, revision=DEFAULT_MODEL_REVISION) + + assert osp.isdir(model_dir) + model_files = glob.glob( + os.path.join(model_dir, self.model_file_pattern)) + assert len( + model_files + ) == 1, f'Need one model file, but find {len(model_files)}: {model_files}' + + model_path = model_files[0] + self.model_path = model_path + + # get configuration file from source model dir + self.config_file = os.path.join(model_dir, ModelFile.CONFIGURATION) + assert os.path.exists( + self.config_file + ), f'Not find "{ModelFile.CONFIGURATION}" in model directory!' + + self.cfg = Config.from_file(self.config_file) + if 'device' in kwargs: + kwargs['device'] = create_device(kwargs['device']) + self.predict_op = self._build_predict_op(**kwargs) + + def _build_predict_op(self, **kwargs): + """Build EasyCV predictor.""" + from easycv.predictors.builder import build_predictor + + easycv_config = self._to_easycv_config() + pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { + 'model_path': self.model_path, + 'config_file': easycv_config, + **kwargs + }) + return pipeline_op + + def _to_easycv_config(self): + """Adapt to EasyCV predictor.""" + # TODO: refine config compatibility problems + + easycv_arch = self.cfg.model.pop(EasyCVMeta.ARCH, None) + model_cfg = self.cfg.model + # Revert to the configuration of easycv + if easycv_arch is not None: + model_cfg.update(easycv_arch) + + easycv_config = Config(dict(model=model_cfg)) + + reserved_keys = [] + if hasattr(self.cfg, EasyCVMeta.META): + easycv_meta_cfg = getattr(self.cfg, EasyCVMeta.META) + reserved_keys = easycv_meta_cfg.get(EasyCVMeta.RESERVED_KEYS, []) + for key in reserved_keys: + easycv_config.merge_from_dict({key: getattr(self.cfg, key)}) + if 'test_pipeline' not in reserved_keys: + easycv_config.merge_from_dict( + {'test_pipeline': self.cfg.dataset.val.get('pipeline', [])}) + + return easycv_config + + def _is_single_inputs(self, inputs): + if isinstance(inputs, str) or (isinstance(inputs, list) + and len(inputs) == 1) or isinstance( + inputs, np.ndarray) or isinstance( + inputs, ImageFile.ImageFile): + return True + + return False + + def __call__(self, inputs) -> Any: + outputs = self.predict_op(inputs) + + if self._is_single_inputs(inputs): + outputs = outputs[0] + + return outputs diff --git a/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py new file mode 100644 index 00000000..a1173bc4 --- /dev/null +++ b/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.cv.image_utils import \ + show_image_object_detection_auto_result +from .base import EasyCVPipeline + + +@PIPELINES.register_module( + Tasks.image_object_detection, module_name=Pipelines.easycv_detection) +@PIPELINES.register_module( + Tasks.image_object_detection, + module_name=Pipelines.image_object_detection_auto) +class EasyCVDetectionPipeline(EasyCVPipeline): + """Pipeline for easycv detection task.""" + + def __init__(self, + model: str, + model_file_pattern=ModelFile.TORCH_MODEL_FILE, + *args, + **kwargs): + """ + model (str): model id on modelscope hub or local model path. + model_file_pattern (str): model file pattern. + """ + + super(EasyCVDetectionPipeline, self).__init__( + model=model, + model_file_pattern=model_file_pattern, + *args, + **kwargs) + + def show_result(self, img_path, result, save_path=None): + show_image_object_detection_auto_result(img_path, result, save_path) + + def __call__(self, inputs) -> Any: + outputs = self.predict_op(inputs) + + scores = [] + labels = [] + boxes = [] + for output in outputs: + for score, label, box in zip(output['detection_scores'], + output['detection_classes'], + output['detection_boxes']): + scores.append(score) + labels.append(self.cfg.CLASSES[label]) + boxes.append([b for b in box]) + + results = [{ + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.BOXES: boxes + } for output in outputs] + + if self._is_single_inputs(inputs): + results = results[0] + + return results diff --git a/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py new file mode 100644 index 00000000..29a96a5f --- /dev/null +++ b/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py @@ -0,0 +1,244 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import math +from typing import Any + +import cv2 +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .base import EasyCVPipeline + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_2d_keypoints, module_name=Pipelines.face_2d_keypoints) +class Face2DKeypointsPipeline(EasyCVPipeline): + """Pipeline for face 2d keypoints detection.""" + + def __init__(self, + model: str, + model_file_pattern=ModelFile.TORCH_MODEL_FILE, + *args, + **kwargs): + """ + model (str): model id on modelscope hub or local model path. + model_file_pattern (str): model file pattern. + """ + + super(Face2DKeypointsPipeline, self).__init__( + model=model, + model_file_pattern=model_file_pattern, + *args, + **kwargs) + + # face detect pipeline + det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + self.face_detection = pipeline( + Tasks.face_detection, model=det_model_id) + + def show_result(self, img, points, scale=2, save_path=None): + return self.predict_op.show_result(img, points, scale, save_path) + + def _choose_face(self, det_result, min_face=10): + """ + choose face with maximum area + Args: + det_result: output of face detection pipeline + min_face: minimum size of valid face w/h + """ + bboxes = np.array(det_result[OutputKeys.BOXES]) + landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) + if bboxes.shape[0] == 0: + logger.warn('No face detected!') + return None + # face idx with enough size + face_idx = [] + for i in range(bboxes.shape[0]): + box = bboxes[i] + if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: + face_idx += [i] + if len(face_idx) == 0: + logger.warn( + f'Face size not enough, less than {min_face}x{min_face}!') + return None + bboxes = bboxes[face_idx] + landmarks = landmarks[face_idx] + + return bboxes, landmarks + + def expend_box(self, box, w, h, scalex=0.3, scaley=0.5): + x1 = box[0] + y1 = box[1] + wb = box[2] - x1 + hb = box[3] - y1 + deltax = int(wb * scalex) + deltay1 = int(hb * scaley) + deltay2 = int(hb * scalex) + x1 = x1 - deltax + y1 = y1 - deltay1 + if x1 < 0: + deltax = deltax + x1 + x1 = 0 + if y1 < 0: + deltay1 = deltay1 + y1 + y1 = 0 + x2 = x1 + wb + 2 * deltax + y2 = y1 + hb + deltay1 + deltay2 + x2 = np.clip(x2, 0, w - 1) + y2 = np.clip(y2, 0, h - 1) + return [x1, y1, x2, y2] + + def rotate_point(self, angle, center, landmark): + rad = angle * np.pi / 180.0 + alpha = np.cos(rad) + beta = np.sin(rad) + M = np.zeros((2, 3), dtype=np.float32) + M[0, 0] = alpha + M[0, 1] = beta + M[0, 2] = (1 - alpha) * center[0] - beta * center[1] + M[1, 0] = -beta + M[1, 1] = alpha + M[1, 2] = beta * center[0] + (1 - alpha) * center[1] + + landmark_ = np.asarray([(M[0, 0] * x + M[0, 1] * y + M[0, 2], + M[1, 0] * x + M[1, 1] * y + M[1, 2]) + for (x, y) in landmark]) + return M, landmark_ + + def rotate_crop_img(self, img, pts, M): + imgT = cv2.warpAffine(img, M, (int(img.shape[1]), int(img.shape[0]))) + + x1 = pts[5][0] + x2 = pts[5][0] + y1 = pts[5][1] + y2 = pts[5][1] + for i in range(0, 9): + x1 = min(x1, pts[i][0]) + x2 = max(x2, pts[i][0]) + y1 = min(y1, pts[i][1]) + y2 = max(y2, pts[i][1]) + + height, width, _ = imgT.shape + x1 = min(max(0, int(x1)), width) + y1 = min(max(0, int(y1)), height) + x2 = min(max(0, int(x2)), width) + y2 = min(max(0, int(y2)), height) + sub_imgT = imgT[y1:y2, x1:x2] + + return sub_imgT, imgT, [x1, y1, x2, y2] + + def crop_img(self, imgT, pts): + enlarge_ratio = 1.1 + + x1 = np.min(pts[:, 0]) + x2 = np.max(pts[:, 0]) + y1 = np.min(pts[:, 1]) + y2 = np.max(pts[:, 1]) + w = x2 - x1 + 1 + h = y2 - y1 + 1 + x1 = int(x1 - (enlarge_ratio - 1.0) / 2.0 * w) + y1 = int(y1 - (enlarge_ratio - 1.0) / 2.0 * h) + x1 = max(0, x1) + y1 = max(0, y1) + + new_w = int(enlarge_ratio * w) + new_h = int(enlarge_ratio * h) + new_x1 = x1 + new_y1 = y1 + new_x2 = new_x1 + new_w + new_y2 = new_y1 + new_h + + height, width, _ = imgT.shape + + new_x1 = min(max(0, new_x1), width) + new_y1 = min(max(0, new_y1), height) + new_x2 = max(min(width, new_x2), 0) + new_y2 = max(min(height, new_y2), 0) + + sub_imgT = imgT[new_y1:new_y2, new_x1:new_x2] + + return sub_imgT, [new_x1, new_y1, new_x2, new_y2] + + def __call__(self, inputs) -> Any: + img = LoadImage.convert_to_ndarray(inputs) + h, w, c = img.shape + img_rgb = copy.deepcopy(img) + img_rgb = img_rgb[:, :, ::-1] + det_result = self.face_detection(img_rgb) + + bboxes = np.array(det_result[OutputKeys.BOXES]) + if bboxes.shape[0] == 0: + logger.warn('No face detected!') + results = { + OutputKeys.KEYPOINTS: [], + OutputKeys.POSES: [], + OutputKeys.BOXES: [] + } + return results + + boxes, keypoints = self._choose_face(det_result) + + output_boxes = [] + output_keypoints = [] + output_poses = [] + for index, box_ori in enumerate(boxes): + box = self.expend_box(box_ori, w, h, scalex=0.1, scaley=0.1) + y0 = int(box[1]) + y1 = int(box[3]) + x0 = int(box[0]) + x1 = int(box[2]) + sub_img = img[y0:y1, x0:x1] + + keypoint = keypoints[index] + pts = [[keypoint[0], keypoint[1]], [keypoint[2], keypoint[3]], + [keypoint[4], keypoint[5]], [keypoint[6], keypoint[7]], + [keypoint[8], keypoint[9]], [box[0], box[1]], + [box[2], box[1]], [box[0], box[3]], [box[2], box[3]]] + # radian + angle = math.atan2((pts[1][1] - pts[0][1]), + (pts[1][0] - pts[0][0])) + # angle + theta = angle * (180 / np.pi) + + center = [w // 2, h // 2] + cx, cy = center + M, landmark_ = self.rotate_point(theta, (cx, cy), pts) + sub_imgT, imgT, bbox = self.rotate_crop_img(img, landmark_, M) + + outputs = self.predict_op([sub_imgT])[0] + tmp_keypoints = outputs['point'] + + for idx in range(0, len(tmp_keypoints)): + tmp_keypoints[idx][0] += bbox[0] + tmp_keypoints[idx][1] += bbox[1] + + for idx in range(0, 6): + sub_img, bbox = self.crop_img(imgT, tmp_keypoints) + outputs = self.predict_op([sub_img])[0] + tmp_keypoints = outputs['point'] + for idx in range(0, len(tmp_keypoints)): + tmp_keypoints[idx][0] += bbox[0] + tmp_keypoints[idx][1] += bbox[1] + + M2, tmp_keypoints = self.rotate_point(-theta, (cx, cy), + tmp_keypoints) + + output_keypoints.append(np.array(tmp_keypoints)) + output_poses.append(np.array(outputs['pose'])) + output_boxes.append(np.array(box_ori)) + + results = { + OutputKeys.KEYPOINTS: output_keypoints, + OutputKeys.POSES: output_poses, + OutputKeys.BOXES: output_boxes + } + + return results diff --git a/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py new file mode 100644 index 00000000..936accbf --- /dev/null +++ b/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path +from typing import Any + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from .base import EasyCVPipeline + + +@PIPELINES.register_module( + Tasks.human_wholebody_keypoint, + module_name=Pipelines.human_wholebody_keypoint) +class HumanWholebodyKeypointsPipeline(EasyCVPipeline): + """Pipeline for human wholebody 2d keypoints detection.""" + + def __init__(self, + model: str, + model_file_pattern=ModelFile.TORCH_MODEL_FILE, + *args, + **kwargs): + """ + model (str): model id on modelscope hub or local model path. + model_file_pattern (str): model file pattern. + """ + self.model_dir = model + super(HumanWholebodyKeypointsPipeline, self).__init__( + model=model, + model_file_pattern=model_file_pattern, + *args, + **kwargs) + + def _build_predict_op(self, **kwargs): + """Build EasyCV predictor.""" + from easycv.predictors.builder import build_predictor + detection_predictor_type = self.cfg['DETECTION']['type'] + detection_model_path = os.path.join( + self.model_dir, self.cfg['DETECTION']['model_path']) + detection_cfg_file = os.path.join(self.model_dir, + self.cfg['DETECTION']['config_file']) + detection_score_threshold = self.cfg['DETECTION']['score_threshold'] + self.cfg.pipeline.predictor_config[ + 'detection_predictor_config'] = dict( + type=detection_predictor_type, + model_path=detection_model_path, + config_file=detection_cfg_file, + score_threshold=detection_score_threshold) + easycv_config = self._to_easycv_config() + pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { + 'model_path': self.model_path, + 'config_file': easycv_config, + **kwargs + }) + return pipeline_op + + def __call__(self, inputs) -> Any: + outputs = self.predict_op(inputs) + + results = [{ + OutputKeys.KEYPOINTS: output['keypoints'], + OutputKeys.BOXES: output['boxes'] + } for output in outputs] + + if self._is_single_inputs(inputs): + results = results[0] + + return results diff --git a/modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py new file mode 100644 index 00000000..bd09fc9b --- /dev/null +++ b/modelscope/pipelines/cv/easycv_pipelines/segmentation_pipeline.py @@ -0,0 +1,47 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from .base import EasyCVPipeline + + +@PIPELINES.register_module( + Tasks.image_segmentation, module_name=Pipelines.easycv_segmentation) +class EasyCVSegmentationPipeline(EasyCVPipeline): + """Pipeline for easycv segmentation task.""" + + def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): + """ + model (str): model id on modelscope hub or local model path. + model_file_pattern (str): model file pattern. + """ + + super(EasyCVSegmentationPipeline, self).__init__( + model=model, + model_file_pattern=model_file_pattern, + *args, + **kwargs) + + def __call__(self, inputs) -> Any: + outputs = self.predict_op(inputs) + + semantic_result = outputs[0]['seg_pred'] + + ids = np.unique(semantic_result)[::-1] + legal_indices = ids != len(self.predict_op.CLASSES) # for VOID label + ids = ids[legal_indices] + segms = (semantic_result[None] == ids[:, None, None]) + masks = [it.astype(np.int) for it in segms] + labels_txt = np.array(self.predict_op.CLASSES)[ids].tolist() + + results = { + OutputKeys.MASKS: masks, + OutputKeys.LABELS: labels_txt, + OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] + } + return results diff --git a/modelscope/pipelines/cv/face_detection_pipeline.py b/modelscope/pipelines/cv/face_detection_pipeline.py new file mode 100644 index 00000000..608567a4 --- /dev/null +++ b/modelscope/pipelines/cv/face_detection_pipeline.py @@ -0,0 +1,73 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import ScrfdDetect +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.face_detection) +class FaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + detector = ScrfdDetect(model_dir=model, **kwargs) + self.detector = detector + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float32) + pre_pipeline = [ + dict( + type='MultiScaleFlipAug', + img_scale=(640, 640), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.0), + dict( + type='Normalize', + mean=[127.5, 127.5, 127.5], + std=[128.0, 128.0, 128.0], + to_rgb=False), + dict(type='Pad', size=(640, 640), pad_val=0), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ] + from mmdet.datasets.pipelines import Compose + pipeline = Compose(pre_pipeline) + result = {} + result['filename'] = '' + result['ori_filename'] = '' + result['img'] = img + result['img_shape'] = img.shape + result['ori_shape'] = img.shape + result['img_fields'] = ['img'] + result = pipeline(result) + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.detector(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/face_emotion_pipeline.py b/modelscope/pipelines/cv/face_emotion_pipeline.py new file mode 100644 index 00000000..9d9aa6ee --- /dev/null +++ b/modelscope/pipelines/cv/face_emotion_pipeline.py @@ -0,0 +1,43 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_emotion import emotion_infer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_emotion, module_name=Pipelines.face_emotion) +class FaceEmotionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create face emotion pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + self.face_model = model + '/' + ModelFile.TF_GRAPH_FILE + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input['img_path']) + return img + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result, bbox = emotion_infer.inference(input, self.model, + self.face_model) + return {OutputKeys.OUTPUT: result, OutputKeys.BOXES: bbox} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py b/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py new file mode 100644 index 00000000..d41a14dd --- /dev/null +++ b/modelscope/pipelines/cv/face_human_hand_detection_pipeline.py @@ -0,0 +1,50 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_human_hand_detection import det_infer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_human_hand_detection, + module_name=Pipelines.face_human_hand_detection) +class NanoDettForFaceHumanHandDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create face-human-hand detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input['input_path']) + return img + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + cls_list, bbox_list, score_list = det_infer.inference( + self.model, self.device, input) + logger.info(cls_list, bbox_list, score_list) + return { + OutputKeys.LABELS: cls_list, + OutputKeys.BOXES: bbox_list, + OutputKeys.SCORES: score_list + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/face_image_generation_pipeline.py b/modelscope/pipelines/cv/face_image_generation_pipeline.py new file mode 100644 index 00000000..1b4e2e8a --- /dev/null +++ b/modelscope/pipelines/cv/face_image_generation_pipeline.py @@ -0,0 +1,85 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_generation import Generator +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_image_generation, module_name=Pipelines.face_image_generation) +class FaceImageGenerationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face image generation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + self.size = 1024 + self.latent = 512 + self.n_mlp = 8 + self.channel_multiplier = 2 + self.truncation = 0.7 + self.truncation_mean = 4096 + self.generator = Generator( + self.size, + self.latent, + self.n_mlp, + channel_multiplier=self.channel_multiplier).to(self.device) + + self.model_file = f'{model}/{ModelFile.TORCH_MODEL_FILE}' + + self.generator.load_state_dict(torch.load(self.model_file)['g_ema']) + logger.info('load model done') + + self.mean_latent = None + if self.truncation < 1: + with torch.no_grad(): + self.mean_latent = self.generator.mean_latent( + self.truncation_mean) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(input, str): + input = int(input) + assert isinstance(input, int) + torch.manual_seed(input) + torch.cuda.manual_seed(input) + torch.cuda.manual_seed_all(input) + self.generator.eval() + with torch.no_grad(): + sample_z = torch.randn(1, self.latent).to(self.device) + + sample, _ = self.generator([sample_z], + truncation=self.truncation, + truncation_latent=self.mean_latent) + + sample = sample * 0.5 + 0.5 + sample = sample.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + sample = np.clip(sample.float().cpu().numpy(), 0, 1) * 255 + + return {OutputKeys.OUTPUT_IMG: sample.astype(np.uint8)} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/face_recognition_pipeline.py b/modelscope/pipelines/cv/face_recognition_pipeline.py new file mode 100644 index 00000000..873e4a1f --- /dev/null +++ b/modelscope/pipelines/cv/face_recognition_pipeline.py @@ -0,0 +1,132 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_recognition.align_face import align_face +from modelscope.models.cv.face_recognition.torchkit.backbone import get_model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_recognition, module_name=Pipelines.face_recognition) +class FaceRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + # face recong model + super().__init__(model=model, **kwargs) + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + self.device = device + face_model = get_model('IR_101')([112, 112]) + face_model.load_state_dict( + torch.load( + osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device)) + face_model = face_model.to(device) + face_model.eval() + self.face_model = face_model + logger.info('face recognition model loaded!') + # face detect pipeline + det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + self.face_detection = pipeline( + Tasks.face_detection, model=det_model_id) + + def _choose_face(self, + det_result, + min_face=10, + top_face=1, + center_face=False): + ''' + choose face with maximum area + Args: + det_result: output of face detection pipeline + min_face: minimum size of valid face w/h + top_face: take faces with top max areas + center_face: choose the most centerd face from multi faces, only valid if top_face > 1 + ''' + bboxes = np.array(det_result[OutputKeys.BOXES]) + landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) + # scores = np.array(det_result[OutputKeys.SCORES]) + if bboxes.shape[0] == 0: + logger.info('No face detected!') + return None + # face idx with enough size + face_idx = [] + for i in range(bboxes.shape[0]): + box = bboxes[i] + if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: + face_idx += [i] + if len(face_idx) == 0: + logger.info( + f'Face size not enough, less than {min_face}x{min_face}!') + return None + bboxes = bboxes[face_idx] + landmarks = landmarks[face_idx] + # find max faces + boxes = np.array(bboxes) + area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + sort_idx = np.argsort(area)[-top_face:] + # find center face + if top_face > 1 and center_face and bboxes.shape[0] > 1: + img_center = [img.shape[1] // 2, img.shape[0] // 2] + min_dist = float('inf') + sel_idx = -1 + for _idx in sort_idx: + box = boxes[_idx] + dist = np.square( + np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( + np.abs((box[1] + box[3]) / 2 - img_center[1])) + if dist < min_dist: + min_dist = dist + sel_idx = _idx + sort_idx = [sel_idx] + main_idx = sort_idx[-1] + return bboxes[main_idx], landmarks[main_idx] + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img[:, :, ::-1] + det_result = self.face_detection(img.copy()) + rtn = self._choose_face(det_result) + face_img = None + if rtn is not None: + _, face_lmks = rtn + face_lmks = face_lmks.reshape(5, 2) + align_img, _ = align_face(img, (112, 112), face_lmks) + face_img = align_img[:, :, ::-1] # to rgb + face_img = np.transpose(face_img, axes=(2, 0, 1)) + face_img = (face_img / 255. - 0.5) / 0.5 + face_img = face_img.astype(np.float32) + result = {} + result['img'] = face_img + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + assert input['img'] is not None + img = input['img'].unsqueeze(0) + emb = self.face_model(img).detach().cpu().numpy() + emb /= np.sqrt(np.sum(emb**2, -1, keepdims=True)) # l2 norm + return {OutputKeys.IMG_EMBEDDING: emb} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py b/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py new file mode 100644 index 00000000..3c85ae62 --- /dev/null +++ b/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py @@ -0,0 +1,128 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_recognition.align_face import align_face +from modelscope.models.cv.facial_expression_recognition import \ + FacialExpressionRecognition +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.facial_expression_recognition, + module_name=Pipelines.facial_expression_recognition) +class FacialExpressionRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {ckpt_path}') + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + fer = FacialExpressionRecognition(model_path=ckpt_path, device=device) + self.fer = fer + self.device = device + logger.info('load model done') + + # face detect pipeline + det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + self.map_list = [ + 'Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral' + ] + self.face_detection = pipeline( + Tasks.face_detection, model=det_model_id) + + def _choose_face(self, + det_result, + min_face=10, + top_face=1, + center_face=False): + ''' + choose face with maximum area + Args: + det_result: output of face detection pipeline + min_face: minimum size of valid face w/h + top_face: take faces with top max areas + center_face: choose the most centerd face from multi faces, only valid if top_face > 1 + ''' + bboxes = np.array(det_result[OutputKeys.BOXES]) + landmarks = np.array(det_result[OutputKeys.KEYPOINTS]) + if bboxes.shape[0] == 0: + logger.info('Warning: No face detected!') + return None + # face idx with enough size + face_idx = [] + for i in range(bboxes.shape[0]): + box = bboxes[i] + if (box[2] - box[0]) >= min_face and (box[3] - box[1]) >= min_face: + face_idx += [i] + if len(face_idx) == 0: + logger.info( + f'Warning: Face size not enough, less than {min_face}x{min_face}!' + ) + return None + bboxes = bboxes[face_idx] + landmarks = landmarks[face_idx] + # find max faces + boxes = np.array(bboxes) + area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + sort_idx = np.argsort(area)[-top_face:] + # find center face + if top_face > 1 and center_face and bboxes.shape[0] > 1: + img_center = [img.shape[1] // 2, img.shape[0] // 2] + min_dist = float('inf') + sel_idx = -1 + for _idx in sort_idx: + box = boxes[_idx] + dist = np.square( + np.abs((box[0] + box[2]) / 2 - img_center[0])) + np.square( + np.abs((box[1] + box[3]) / 2 - img_center[1])) + if dist < min_dist: + min_dist = dist + sel_idx = _idx + sort_idx = [sel_idx] + main_idx = sort_idx[-1] + return bboxes[main_idx], landmarks[main_idx] + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img[:, :, ::-1] + det_result = self.face_detection(img.copy()) + rtn = self._choose_face(det_result) + face_img = None + if rtn is not None: + _, face_lmks = rtn + face_lmks = face_lmks.reshape(5, 2) + face_img, _ = align_face(img, (112, 112), face_lmks) + face_img = face_img.astype(np.float32) + result = {} + result['img'] = face_img + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.fer(input) + assert result is not None + scores = result[0].tolist() + return {OutputKeys.SCORES: scores, OutputKeys.LABELS: self.map_list} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/general_recognition_pipeline.py b/modelscope/pipelines/cv/general_recognition_pipeline.py new file mode 100644 index 00000000..80f6f88a --- /dev/null +++ b/modelscope/pipelines/cv/general_recognition_pipeline.py @@ -0,0 +1,121 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Pipelines +from modelscope.models.cv.animal_recognition import resnet +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage, load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.general_recognition, module_name=Pipelines.general_recognition) +class GeneralRecognitionPipeline(Pipeline): + + def __init__(self, model: str, device: str): + """ + use `model` to create a general recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + import torch + + def resnest101(**kwargs): + model = resnet.ResNet( + resnet.Bottleneck, [3, 4, 23, 3], + radix=2, + groups=1, + bottleneck_width=64, + deep_stem=True, + stem_width=64, + avg_down=True, + avd=True, + avd_first=False, + **kwargs) + return model + + def filter_param(src_params, own_state): + copied_keys = [] + for name, param in src_params.items(): + if 'module.' == name[0:7]: + name = name[7:] + if '.module.' not in list(own_state.keys())[0]: + name = name.replace('.module.', '.') + if (name in own_state) and (own_state[name].shape + == param.shape): + own_state[name].copy_(param) + copied_keys.append(name) + + def load_pretrained(model, src_params): + if 'state_dict' in src_params: + src_params = src_params['state_dict'] + own_state = model.state_dict() + filter_param(src_params, own_state) + model.load_state_dict(own_state) + + self.model = resnest101(num_classes=54092) + local_model_dir = model + device = 'cpu' + if osp.exists(model): + local_model_dir = model + else: + local_model_dir = snapshot_download(model) + self.local_path = local_model_dir + src_params = torch.load( + osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE), device) + load_pretrained(self.model, src_params) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), normalize + ]) + img = transform(img) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + def set_phase(model, is_train): + if is_train: + model.train() + else: + model.eval() + + is_train = False + set_phase(self.model, is_train) + img = input['img'] + input_img = torch.unsqueeze(img, 0) + outputs = self.model(input_img) + return {'outputs': outputs} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + label_mapping_path = osp.join(self.local_path, 'meta_info.txt') + with open(label_mapping_path, 'r') as f: + label_mapping = f.readlines() + score = torch.max(inputs['outputs']) + inputs = { + OutputKeys.SCORES: [score.item()], + OutputKeys.LABELS: + [label_mapping[inputs['outputs'].argmax()].split('\t')[1]] + } + return inputs diff --git a/modelscope/pipelines/cv/hand_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/hand_2d_keypoints_pipeline.py new file mode 100644 index 00000000..bad0c652 --- /dev/null +++ b/modelscope/pipelines/cv/hand_2d_keypoints_pipeline.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from .easycv_pipelines.base import EasyCVPipeline + + +@PIPELINES.register_module( + Tasks.hand_2d_keypoints, module_name=Pipelines.hand_2d_keypoints) +class Hand2DKeypointsPipeline(EasyCVPipeline): + """Pipeline for hand pose keypoint task.""" + + def __init__(self, + model: str, + model_file_pattern=ModelFile.TORCH_MODEL_FILE, + *args, + **kwargs): + """ + model (str): model id on modelscope hub or local model path. + model_file_pattern (str): model file pattern. + """ + self.model_dir = model + super(Hand2DKeypointsPipeline, self).__init__( + model=model, + model_file_pattern=model_file_pattern, + *args, + **kwargs) + + def _build_predict_op(self, **kwargs): + """Build EasyCV predictor.""" + from easycv.predictors.builder import build_predictor + detection_predictor_type = self.cfg['DETECTION']['type'] + detection_model_path = os.path.join( + self.model_dir, self.cfg['DETECTION']['model_path']) + detection_cfg_file = os.path.join(self.model_dir, + self.cfg['DETECTION']['config_file']) + detection_score_threshold = self.cfg['DETECTION']['score_threshold'] + self.cfg.pipeline.predictor_config[ + 'detection_predictor_config'] = dict( + type=detection_predictor_type, + model_path=detection_model_path, + config_file=detection_cfg_file, + score_threshold=detection_score_threshold) + easycv_config = self._to_easycv_config() + pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { + 'model_path': self.model_path, + 'config_file': easycv_config, + **kwargs + }) + return pipeline_op diff --git a/modelscope/pipelines/cv/hand_static_pipeline.py b/modelscope/pipelines/cv/hand_static_pipeline.py new file mode 100644 index 00000000..c020b7aa --- /dev/null +++ b/modelscope/pipelines/cv/hand_static_pipeline.py @@ -0,0 +1,41 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.hand_static import hand_model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.hand_static, module_name=Pipelines.hand_static) +class HandStaticPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create hand static pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input['img_path']) + return img + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = hand_model.infer(input, self.model, self.device) + return {OutputKeys.OUTPUT: result} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py b/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py new file mode 100644 index 00000000..21af2f75 --- /dev/null +++ b/modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os.path as osp +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.action_recognition import BaseVideoModel +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import ReadVideoData +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_embedding, module_name=Pipelines.hicossl_video_embedding) +class HICOSSLVideoEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a hicossl video embedding pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) + self.infer_model.eval() + self.infer_model.load_state_dict( + torch.load(model_path, map_location=self.device)['model_state'], + strict=False) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + video_input_data = ReadVideoData( + self.cfg, input, num_temporal_views_override=1).to(self.device) + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_data': video_input_data} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + feature = self.perform_inference(input['video_data']) + return {OutputKeys.VIDEO_EMBEDDING: feature.data.cpu().numpy()} + + @torch.no_grad() + def perform_inference(self, data, max_bsz=4): + """ Perform feature extracting for a given video + Args: + model (BaseVideoModel): video model with loadded state dict. + max_bsz (int): the maximum batch size, limited by GPU memory. + Returns: + pred (Tensor): the extracted features for input video clips. + """ + iter_num = math.ceil(data.size(0) / max_bsz) + preds_list = [] + for i in range(iter_num): + preds_list.append( + self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0]) + pred = torch.cat(preds_list, dim=0) + return pred + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_body_reshaping_pipeline.py b/modelscope/pipelines/cv/image_body_reshaping_pipeline.py new file mode 100644 index 00000000..c3600eb5 --- /dev/null +++ b/modelscope/pipelines/cv/image_body_reshaping_pipeline.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_body_reshaping, module_name=Pipelines.image_body_reshaping) +class ImageBodyReshapingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image body reshaping pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + logger.info('body reshaping model init done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + output = self.model.inference(input['img']) + result = {'outputs': output} + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_img = inputs['outputs'] + return {OutputKeys.OUTPUT_IMG: output_img} diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py new file mode 100644 index 00000000..8606915c --- /dev/null +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -0,0 +1,148 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import tensorflow as tf + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.cartoon import (FaceAna, get_f5p, + get_reference_facial_points, + padTo16x, resize_size, + warp_and_crop_face) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from ...utils.device import device_placement + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.disable_eager_execution() + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_portrait_stylization, + module_name=Pipelines.person_image_cartoon) +class ImageCartoonPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image cartoon pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.facer = FaceAna(self.model) + with tf.Graph().as_default(): + self.sess_anime_head = self.load_sess( + os.path.join(self.model, 'cartoon_h.pb'), 'model_anime_head') + self.sess_anime_bg = self.load_sess( + os.path.join(self.model, 'cartoon_bg.pb'), 'model_anime_bg') + + self.box_width = 288 + global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg')) + global_mask = cv2.resize( + global_mask, (self.box_width, self.box_width), + interpolation=cv2.INTER_AREA) + self.global_mask = cv2.cvtColor( + global_mask, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0 + + def load_sess(self, model_path, name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + sess = tf.Session(config=config) + logger.info(f'loading model from {model_path}') + with tf.gfile.FastGFile(model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + sess.graph.as_default() + tf.import_graph_def(graph_def, name=name) + sess.run(tf.global_variables_initializer()) + logger.info(f'load model {model_path} done.') + return sess + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float) + result = {'img': img} + return result + + def detect_face(self, img): + src_h, src_w, _ = img.shape + boxes, landmarks, _ = self.facer.run(img) + if boxes.shape[0] == 0: + return None + else: + return landmarks + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + img = input['img'].astype(np.uint8) + ori_h, ori_w, _ = img.shape + img = resize_size(img, size=720) + + img_brg = img[:, :, ::-1] + + # background process + pad_bg, pad_h, pad_w = padTo16x(img_brg) + + bg_res = self.sess_anime_bg.run( + self.sess_anime_bg.graph.get_tensor_by_name( + 'model_anime_bg/output_image:0'), + feed_dict={'model_anime_bg/input_image:0': pad_bg}) + res = bg_res[:pad_h, :pad_w, :] + + landmarks = self.detect_face(img) + if landmarks is None: + print('No face detected!') + return {OutputKeys.OUTPUT_IMG: res} + + for landmark in landmarks: + # get facial 5 points + f5p = get_f5p(landmark, img_brg) + + # face alignment + head_img, trans_inv = warp_and_crop_face( + img, + f5p, + ratio=0.75, + reference_pts=get_reference_facial_points(default_square=True), + crop_size=(self.box_width, self.box_width), + return_trans_inv=True) + + # head process + head_res = self.sess_anime_head.run( + self.sess_anime_head.graph.get_tensor_by_name( + 'model_anime_head/output_image:0'), + feed_dict={ + 'model_anime_head/input_image:0': head_img[:, :, ::-1] + }) + + # merge head and background + head_trans_inv = cv2.warpAffine( + head_res, + trans_inv, (np.size(img, 1), np.size(img, 0)), + borderValue=(0, 0, 0)) + + mask = self.global_mask + mask_trans_inv = cv2.warpAffine( + mask, + trans_inv, (np.size(img, 1), np.size(img, 0)), + borderValue=(0, 0, 0)) + mask_trans_inv = np.expand_dims(mask_trans_inv, 2) + + res = mask_trans_inv * head_trans_inv + (1 - mask_trans_inv) * res + + res = cv2.resize(res, (ori_w, ori_h), interpolation=cv2.INTER_AREA) + + return {OutputKeys.OUTPUT_IMG: res} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_classification_pipeline.py b/modelscope/pipelines/cv/image_classification_pipeline.py new file mode 100644 index 00000000..69dbd1fb --- /dev/null +++ b/modelscope/pipelines/cv/image_classification_pipeline.py @@ -0,0 +1,120 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import OfaForAllTasks +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor, load_image +from modelscope.utils.constant import Tasks +from modelscope.utils.device import get_device +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_classification, module_name=Pipelines.image_classification) +class ImageClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + pipe_model.to(get_device()) + if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + +@PIPELINES.register_module( + Tasks.image_classification, + module_name=Pipelines.general_image_classification) +@PIPELINES.register_module( + Tasks.image_classification, + module_name=Pipelines.daily_image_classification) +class GeneralImageClassificationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` and `preprocessor` to create a image classification pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + from mmcls.datasets.pipelines import Compose + from mmcv.parallel import collate, scatter + if isinstance(input, str): + img = np.array(load_image(input)) + elif isinstance(input, PIL.Image.Image): + img = np.array(input.convert('RGB')) + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] # in rgb order + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + mmcls_cfg = self.model.cfg + # build the data pipeline + if mmcls_cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': + mmcls_cfg.data.test.pipeline.pop(0) + data = dict(img=img) + test_pipeline = Compose(mmcls_cfg.data.test.pipeline) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + if next(self.model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [next(self.model.parameters()).device])[0] + + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + with torch.no_grad(): + input['return_loss'] = False + scores = self.model(input) + + return {'scores': scores} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + scores = inputs['scores'] + + pred_scores = np.sort(scores, axis=1)[0][::-1][:5] + pred_labels = np.argsort(scores, axis=1)[0][::-1][:5] + + result = {'pred_score': [score for score in pred_scores]} + result['pred_class'] = [ + self.model.CLASSES[lable] for lable in pred_labels + ] + + outputs = { + OutputKeys.SCORES: result['pred_score'], + OutputKeys.LABELS: result['pred_class'] + } + return outputs diff --git a/modelscope/pipelines/cv/image_color_enhance_pipeline.py b/modelscope/pipelines/cv/image_color_enhance_pipeline.py new file mode 100644 index 00000000..3a4cf8bc --- /dev/null +++ b/modelscope/pipelines/cv/image_color_enhance_pipeline.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.models.cv.image_color_enhance import ImageColorEnhance +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, + LoadImage) +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_color_enhancement, module_name=Pipelines.image_color_enhance) +class ImageColorEnhancePipeline(Pipeline): + + def __init__(self, + model: Union[ImageColorEnhance, str], + preprocessor: Optional[ + ImageColorEnhanceFinetunePreprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a image color enhance pipeline for prediction + Args: + model: model id on modelscope hub. + """ + model = model if isinstance( + model, ImageColorEnhance) else Model.from_pretrained(model) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + test_transforms = transforms.Compose([transforms.ToTensor()]) + img = test_transforms(img) + result = {'src': img.unsqueeze(0).to(self._device)} + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return super().forward(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_img = (inputs['outputs'].squeeze(0) * 255.).type( + torch.uint8).cpu().permute(1, 2, 0).numpy()[:, :, ::-1] + return {OutputKeys.OUTPUT_IMG: output_img} diff --git a/modelscope/pipelines/cv/image_colorization_pipeline.py b/modelscope/pipelines/cv/image_colorization_pipeline.py new file mode 100644 index 00000000..cd385024 --- /dev/null +++ b/modelscope/pipelines/cv/image_colorization_pipeline.py @@ -0,0 +1,126 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +from torchvision import models, transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_colorization import (DynamicUnetDeep, + DynamicUnetWide, NormType) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_colorization, module_name=Pipelines.image_colorization) +class ImageColorizationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image colorization pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.cut = 8 + self.size = 512 + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + + self.orig_img = None + self.model_type = 'stable' + self.norm = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.denorm = transforms.Normalize( + mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], + std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) + + if self.model_type == 'stable': + body = models.resnet101(pretrained=True) + body = torch.nn.Sequential(*list(body.children())[:self.cut]) + self.model = DynamicUnetWide( + body, + n_classes=3, + blur=True, + blur_final=True, + self_attention=True, + y_range=(-3.0, 3.0), + norm_type=NormType.Spectral, + last_cross=True, + bottle=False, + nf_factor=2, + ).to(self.device) + else: + body = models.resnet34(pretrained=True) + body = torch.nn.Sequential(*list(body.children())[:cut]) + self.model = DynamicUnetDeep( + body, + n_classes=3, + blur=True, + blur_final=True, + self_attention=True, + y_range=(-3.0, 3.0), + norm_type=NormType.Spectral, + last_cross=True, + bottle=False, + nf_factor=1.5, + ).to(self.device) + + model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' + self.model.load_state_dict( + torch.load(model_path, map_location=torch.device('cpu'))['model'], + strict=True) + + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input).convert('LA').convert('RGB') + + self.wide, self.height = img.size + if self.wide * self.height < 100000: + self.size = 256 + self.orig_img = img.copy() + img = img.resize((self.size, self.size), resample=PIL.Image.BILINEAR) + + img = self.norm(img).unsqueeze(0).to(self.device) + result = {'img': img} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + self.model.eval() + with torch.no_grad(): + out = self.model(input['img'])[0] + + out = self.denorm(out) + out = out.float().clamp(min=0, max=1) + out_img = (out.permute(1, 2, 0).flip(2).cpu().numpy() * 255).astype( + np.uint8) + + if self.orig_img is not None: + color_np = cv2.resize(out_img, self.orig_img.size) + orig_np = np.asarray(self.orig_img) + color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV) + orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV) + hires = np.copy(orig_yuv) + hires[:, :, 1:3] = color_yuv[:, :, 1:3] + out_img = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR) + + return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_denoise_pipeline.py b/modelscope/pipelines/cv/image_denoise_pipeline.py new file mode 100644 index 00000000..34ac1e81 --- /dev/null +++ b/modelscope/pipelines/cv/image_denoise_pipeline.py @@ -0,0 +1,108 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.cv.image_denoise import NAFNetForImageDenoise +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import ImageDenoisePreprocessor, LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ImageDenoisePipeline'] + + +@PIPELINES.register_module( + Tasks.image_denoising, module_name=Pipelines.image_denoise) +class ImageDenoisePipeline(Pipeline): + + def __init__(self, + model: Union[NAFNetForImageDenoise, str], + preprocessor: Optional[ImageDenoisePreprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a cv image denoise pipeline for prediction + Args: + model: model id on modelscope hub. + """ + model = model if isinstance( + model, NAFNetForImageDenoise) else Model.from_pretrained(model) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.config = model.config + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.model = model + logger.info('load image denoise model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + test_transforms = transforms.Compose([transforms.ToTensor()]) + img = test_transforms(img) + result = {'img': img.unsqueeze(0).to(self._device)} + return result + + def crop_process(self, input): + output = torch.zeros_like(input) # [1, C, H, W] + # determine crop_h and crop_w + ih, iw = input.shape[-2:] + crop_rows, crop_cols = max(ih // 512, 1), max(iw // 512, 1) + overlap = 16 + + step_h, step_w = ih // crop_rows, iw // crop_cols + for y in range(crop_rows): + for x in range(crop_cols): + crop_y = step_h * y + crop_x = step_w * x + + crop_h = step_h if y < crop_rows - 1 else ih - crop_y + crop_w = step_w if x < crop_cols - 1 else iw - crop_x + + crop_frames = input[:, :, + max(0, crop_y - overlap + ):min(crop_y + crop_h + overlap, ih), + max(0, crop_x - overlap + ):min(crop_x + crop_w + + overlap, iw)].contiguous() + h_start = overlap if max(0, crop_y - overlap) > 0 else 0 + w_start = overlap if max(0, crop_x - overlap) > 0 else 0 + h_end = h_start + crop_h if min(crop_y + crop_h + + overlap, ih) < ih else ih + w_end = w_start + crop_w if min(crop_x + crop_w + + overlap, iw) < iw else iw + + output[:, :, crop_y:crop_y + crop_h, + crop_x:crop_x + crop_w] = self.model._inference_forward( + crop_frames)['outputs'][:, :, h_start:h_end, + w_start:w_end] + return output + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + def set_phase(model, is_train): + if is_train: + model.train() + else: + model.eval() + + is_train = False + set_phase(self.model, is_train) + with torch.no_grad(): + output = self.crop_process(input['img']) # output Tensor + + return {'output_tensor': output} + + def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: + output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute( + 1, 2, 0).numpy().astype('uint8') + return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]} diff --git a/modelscope/pipelines/cv/image_detection_pipeline.py b/modelscope/pipelines/cv/image_detection_pipeline.py new file mode 100644 index 00000000..08633c35 --- /dev/null +++ b/modelscope/pipelines/cv/image_detection_pipeline.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + + +@PIPELINES.register_module( + Tasks.human_detection, module_name=Pipelines.human_detection) +@PIPELINES.register_module( + Tasks.image_object_detection, module_name=Pipelines.object_detection) +class ImageDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float) + img = self.model.preprocess(img) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + outputs = self.model.inference(input['img']) + result = {'data': outputs} + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + bboxes, scores, labels = self.model.postprocess(inputs['data']) + if bboxes is None: + outputs = { + OutputKeys.SCORES: [], + OutputKeys.LABELS: [], + OutputKeys.BOXES: [] + } + return outputs + outputs = { + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.BOXES: bboxes + } + return outputs diff --git a/modelscope/pipelines/cv/image_inpainting_pipeline.py b/modelscope/pipelines/cv/image_inpainting_pipeline.py new file mode 100644 index 00000000..aff9788d --- /dev/null +++ b/modelscope/pipelines/cv/image_inpainting_pipeline.py @@ -0,0 +1,147 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +import torch.nn as nn +from torch.utils.data._utils.collate import default_collate + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_inpainting import FFTInpainting +from modelscope.models.cv.image_inpainting.refinement import refine_predict +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_inpainting, module_name=Pipelines.image_inpainting) +class ImageInpaintingPipeline(Pipeline): + + def __init__(self, + model: str, + pad_out_to_modulo=8, + refine=False, + **kwargs): + """ + model: model id on modelscope hub. + """ + assert isinstance(model, str), 'model must be a single str' + super().__init__(model=model, auto_collate=False, **kwargs) + self.refine = refine + logger.info(f'loading model from dir {model}') + self.infer_model = FFTInpainting(model, predict_only=True) + if not self.refine: + self.infer_model.to(self.device) + self.infer_model.eval() + logger.info(f'loading model done, refinement is set to {self.refine}') + self.pad_out_to_modulo = pad_out_to_modulo + + def move_to_device(self, obj, device): + if isinstance(obj, nn.Module): + return obj.to(device) + if torch.is_tensor(obj): + return obj.to(device) + if isinstance(obj, (tuple, list)): + return [self.move_to_device(el, device) for el in obj] + if isinstance(obj, dict): + return { + name: self.move_to_device(val, device) + for name, val in obj.items() + } + raise ValueError(f'Unexpected type {type(obj)}') + + def transforms(self, img): + if img.ndim == 3: + img = np.transpose(img, (2, 0, 1)) + out_img = img.astype('float32') / 255 + return out_img + + def ceil_modulo(self, x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + def pad_img_to_modulo(self, img, mod): + channels, height, width = img.shape + out_height = self.ceil_modulo(height, mod) + out_width = self.ceil_modulo(width, mod) + return np.pad( + img, ((0, 0), (0, out_height - height), (0, out_width - width)), + mode='symmetric') + + def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(input['img'], str): + image_name, mask_name = input['img'], input['mask'] + img = LoadImage.convert_to_ndarray(image_name) + img = self.transforms(img) + mask = np.array(LoadImage(mode='L')(mask_name)['img']) + mask = self.transforms(mask) + elif isinstance(input['img'], PIL.Image.Image): + img = input['img'] + img = self.transforms(np.array(img)) + mask = input['mask'].convert('L') + mask = self.transforms(np.array(mask)) + else: + raise TypeError( + 'input should be either str or PIL.Image, and both inputs should have the same type' + ) + result = dict(image=img, mask=mask[None, ...]) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['unpad_to_size'] = result['image'].shape[1:] + result['image'] = self.pad_img_to_modulo(result['image'], + self.pad_out_to_modulo) + result['mask'] = self.pad_img_to_modulo(result['mask'], + self.pad_out_to_modulo) + + # Since Pipeline use default torch.no_grad() for performing forward func. + # We conduct inference here in case of doing training for refinement. + result = self.perform_inference(result) + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return {OutputKeys.OUTPUT_IMG: input} + + def perform_inference(self, data): + batch = default_collate([data]) + if self.refine: + assert 'unpad_to_size' in batch, 'Unpadded size is required for the refinement' + assert 'cuda' in str(self.device), 'GPU is required for refinement' + gpu_ids = str(self.device).split(':')[-1] + cur_res = refine_predict( + batch, + self.infer_model, + gpu_ids=gpu_ids, + modulo=self.pad_out_to_modulo, + n_iters=15, + lr=0.002, + min_side=512, + max_scales=3, + px_budget=900000) + cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() + else: + with torch.no_grad(): + batch = self.move_to_device(batch, self.device) + batch['mask'] = (batch['mask'] > 0) * 1 + batch = self.infer_model(batch) + cur_res = batch['inpainted'][0].permute( + 1, 2, 0).detach().cpu().numpy() + unpad_to_size = batch.get('unpad_to_size', None) + if unpad_to_size is not None: + orig_height, orig_width = unpad_to_size + cur_res = cur_res[:orig_height, :orig_width] + + cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + return cur_res + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_instance_segmentation_pipeline.py b/modelscope/pipelines/cv/image_instance_segmentation_pipeline.py new file mode 100644 index 00000000..5a0f0d7e --- /dev/null +++ b/modelscope/pipelines/cv/image_instance_segmentation_pipeline.py @@ -0,0 +1,104 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, Optional, Union + +import cv2 +import numpy as np +import torch +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_instance_segmentation import ( + CascadeMaskRCNNSwinModel, get_img_ins_seg_result) +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (ImageInstanceSegmentationPreprocessor, + build_preprocessor, load_image) +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_segmentation, + module_name=Pipelines.image_instance_segmentation) +class ImageInstanceSegmentationPipeline(Pipeline): + + def __init__(self, + model: Union[CascadeMaskRCNNSwinModel, str], + preprocessor: Optional[ + ImageInstanceSegmentationPreprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a image instance segmentation pipeline for prediction + + Args: + model (CascadeMaskRCNNSwinModel | str): a model instance + preprocessor (CascadeMaskRCNNSwinPreprocessor | None): a preprocessor instance + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + if preprocessor is None: + config_path = os.path.join(self.model.model_dir, + ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + self.preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) + else: + self.preprocessor = preprocessor + + self.preprocessor.eval() + self.model.eval() + + def _collate_fn(self, data): + # don't require collating + return data + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + filename = None + img = None + if isinstance(input, str): + filename = input + img = np.array(load_image(input)) + img = img[:, :, ::-1] # convert to bgr + elif isinstance(input, Image.Image): + img = np.array(input.convert('RGB')) + img = img[:, :, ::-1] # convert to bgr + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + result = { + 'img': img, + 'img_shape': img.shape, + 'ori_shape': img.shape, + 'img_fields': ['img'], + 'img_prefix': '', + 'img_info': { + 'filename': filename, + 'ann_file': None, + 'classes': None + }, + } + result = self.preprocessor(result) + + # stacked as a batch + result['img'] = torch.stack([result['img']], dim=0) + result['img_metas'] = [result['img_metas'].data] + + return result + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + output = self.model(input) + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + result = get_img_ins_seg_result( + img_seg_result=inputs['eval_result'][0], + class_names=self.model.model.classes) + return result diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py new file mode 100644 index 00000000..fb5d8f8b --- /dev/null +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import device_placement +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.portrait_matting, module_name=Pipelines.portrait_matting) +class ImageMattingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image matting pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + import tensorflow as tf + if tf.__version__ >= '2.0': + tf = tf.compat.v1 + model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) + + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + with self._session.as_default(): + logger.info(f'loading model from {model_path}') + with tf.gfile.FastGFile(model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + self.output = self._session.graph.get_tensor_by_name( + 'output_png:0') + self.input_name = 'input_image:0' + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with self._session.as_default(): + feed_dict = {self.input_name: input['img']} + output_img = self._session.run(self.output, feed_dict=feed_dict) + output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA) + return {OutputKeys.OUTPUT_IMG: output_img} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_panoptic_segmentation_pipeline.py b/modelscope/pipelines/cv/image_panoptic_segmentation_pipeline.py new file mode 100644 index 00000000..b96e709c --- /dev/null +++ b/modelscope/pipelines/cv/image_panoptic_segmentation_pipeline.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_segmentation, + module_name=Pipelines.image_panoptic_segmentation) +class ImagePanopticSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image panoptic segmentation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + logger.info('panoptic segmentation model, pipeline init') + + def preprocess(self, input: Input) -> Dict[str, Any]: + from mmdet.datasets.pipelines import Compose + from mmcv.parallel import collate, scatter + from mmdet.datasets import replace_ImageToTensor + + cfg = self.model.cfg + # build the data pipeline + + if isinstance(input, str): + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + img = np.array(load_image(input)) + img = img[:, :, ::-1] # convert to bgr + elif isinstance(input, PIL.Image.Image): + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + img = np.array(input.convert('RGB')) + elif isinstance(input, np.ndarray): + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + if len(input.shape) == 2: + img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + else: + img = input + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + # collect data + data = dict(img=img) + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + test_pipeline = Compose(cfg.data.test.pipeline) + + data = test_pipeline(data) + # copy from mmdet_model collect data + data = collate([data], samples_per_gpu=1) + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + data['img'] = [img.data[0] for img in data['img']] + if next(self.model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [next(self.model.parameters()).device])[0] + + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.inference(input) + + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + # bz=1, tcguo + pan_results = inputs[0]['pan_results'] + INSTANCE_OFFSET = 1000 + + ids = np.unique(pan_results)[::-1] + legal_indices = ids != self.model.num_classes # for VOID label + ids = ids[legal_indices] + labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) + segms = (pan_results[None] == ids[:, None, None]) + masks = [it.astype(np.int) for it in segms] + labels_txt = np.array(self.model.CLASSES)[labels].tolist() + + outputs = { + OutputKeys.MASKS: masks, + OutputKeys.LABELS: labels_txt, + OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] + } + return outputs diff --git a/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py b/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py new file mode 100644 index 00000000..3eec6526 --- /dev/null +++ b/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py @@ -0,0 +1,239 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from scipy.ndimage import gaussian_filter +from scipy.spatial.distance import pdist, squareform + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_portrait_enhancement import gpen +from modelscope.models.cv.image_portrait_enhancement.align_faces import ( + get_reference_facial_points, warp_and_crop_face) +from modelscope.models.cv.image_portrait_enhancement.eqface import fqa +from modelscope.models.cv.image_portrait_enhancement.retinaface import \ + detection +from modelscope.models.cv.super_resolution import rrdbnet_arch +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage, load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_portrait_enhancement, + module_name=Pipelines.image_portrait_enhancement) +class ImagePortraitEnhancementPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + self.use_sr = True + + self.size = 512 + self.n_mlp = 8 + self.channel_multiplier = 2 + self.narrow = 1 + self.face_enhancer = gpen.FullGenerator( + self.size, + 512, + self.n_mlp, + self.channel_multiplier, + narrow=self.narrow).to(self.device) + + gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' + self.face_enhancer.load_state_dict( + torch.load(gpen_model_path, map_location=torch.device('cpu')), + strict=True) + + logger.info('load face enhancer model done') + + self.threshold = 0.9 + detector_model_path = f'{model}/face_detection/RetinaFace-R50.pth' + self.face_detector = detection.RetinaFaceDetection( + detector_model_path, self.device) + + logger.info('load face detector model done') + + self.num_feat = 32 + self.num_block = 23 + self.scale = 2 + self.sr_model = rrdbnet_arch.RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=self.num_feat, + num_block=self.num_block, + num_grow_ch=32, + scale=self.scale).to(self.device) + + sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth' + self.sr_model.load_state_dict( + torch.load(sr_model_path, + map_location=torch.device('cpu'))['params_ema'], + strict=True) + + logger.info('load sr model done') + + self.fqa_thres = 0.1 + self.id_thres = 0.15 + self.alpha = 1.0 + backbone_model_path = f'{model}/face_quality/eqface_backbone.pth' + fqa_model_path = f'{model}/face_quality/eqface_quality.pth' + self.eqface = fqa.FQA(backbone_model_path, fqa_model_path, self.device) + + logger.info('load fqa model done') + + # the mask for pasting restored faces back + self.mask = np.zeros((512, 512, 3), np.float32) + cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, + cv2.LINE_AA) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) + + def enhance_face(self, img): + img = cv2.resize(img, (self.size, self.size)) + img_t = self.img2tensor(img) + + self.face_enhancer.eval() + with torch.no_grad(): + out, __ = self.face_enhancer(img_t) + del img_t + + out = self.tensor2img(out) + + return out + + def img2tensor(self, img, is_norm=True): + img_t = torch.from_numpy(img).to(self.device) / 255. + if is_norm: + img_t = (img_t - 0.5) / 0.5 + img_t = img_t.permute(2, 0, 1).unsqueeze(0) + return img_t + + def tensor2img(self, img_t, pmax=255.0, is_denorm=True, imtype=np.uint8): + if is_denorm: + img_t = img_t * 0.5 + 0.5 + img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax + + return img_np.astype(imtype) + + def sr_process(self, img): + img = img.astype(np.float32) / 255. + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + img = img.unsqueeze(0).to(self.device) + + if self.scale == 2: + mod_scale = 2 + elif self.scale == 1: + mod_scale = 4 + else: + mod_scale = None + if mod_scale is not None: + h_pad, w_pad = 0, 0 + _, _, h, w = img.size() + if (h % mod_scale != 0): + h_pad = (mod_scale - h % mod_scale) + if (w % mod_scale != 0): + w_pad = (mod_scale - w % mod_scale) + img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect') + + self.sr_model.eval() + with torch.no_grad(): + output = self.sr_model(img) + del img + # remove extra pad + if mod_scale is not None: + _, _, h, w = output.size() + output = output[:, :, 0:h - h_pad, 0:w - w_pad] + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) + + return output + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + + img_sr = img + if self.use_sr: + img_sr = self.sr_process(img) + + img = cv2.resize(img, img_sr.shape[:2][::-1]) + + result = {'img': img, 'img_sr': img_sr} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + img, img_sr = input['img'], input['img_sr'] + img, img_sr = img.cpu().numpy(), img_sr.cpu().numpy() + facebs, landms = self.face_detector.detect(img) + + height, width = img.shape[:2] + full_mask = np.zeros(img.shape, dtype=np.float32) + full_img = np.zeros(img.shape, dtype=np.uint8) + + for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): + if faceb[4] < self.threshold: + continue + + facial5points = np.reshape(facial5points, (2, 5)) + + of, of_112, tfm_inv = warp_and_crop_face( + img, facial5points, crop_size=(self.size, self.size)) + + # detect orig face quality + fq_o, fea_o = self.eqface.get_face_quality(of_112) + if fq_o < self.fqa_thres: + continue + + # enhance the face + ef = self.enhance_face(of) + + # detect enhanced face quality + ss = self.size // 256 + ef_112 = cv2.resize(ef[35 * ss:-33 * ss, 32 * ss:-36 * ss], + (112, 112)) # crop roi + fq_e, fea_e = self.eqface.get_face_quality(ef_112) + dist = squareform(pdist([fea_o, fea_e], 'cosine')).mean() + if dist > self.id_thres: + continue + + tmp_mask = self.mask + tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) + tmp_mask = cv2.warpAffine( + tmp_mask, tfm_inv, (width, height), flags=3) + + tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) + + mask = np.clip(tmp_mask - full_mask, 0, 1) + full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)] + full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)] + + if self.use_sr and img_sr is not None: + out_img = cv2.convertScaleAbs(img_sr * (1 - full_mask) + + full_img * full_mask) + else: + out_img = cv2.convertScaleAbs(img * (1 - full_mask) + + full_img * full_mask) + + return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_reid_person_pipeline.py b/modelscope/pipelines/cv/image_reid_person_pipeline.py new file mode 100644 index 00000000..9f60142a --- /dev/null +++ b/modelscope/pipelines/cv/image_reid_person_pipeline.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os +from typing import Any, Dict + +import torch +import torchvision.transforms as T +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_reid_person, module_name=Pipelines.image_reid_person) +class ImageReidPersonPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + assert isinstance(model, str), 'model must be a single str' + super().__init__(model=model, auto_collate=False, **kwargs) + logger.info(f'loading model config from dir {model}') + + cfg_path = os.path.join(model, ModelFile.CONFIGURATION) + cfg = Config.from_file(cfg_path) + cfg = cfg.model.cfg + self.model = self.model.to(self.device) + self.model.eval() + + self.val_transforms = T.Compose([ + T.Resize(cfg.INPUT.SIZE_TEST), + T.ToTensor(), + T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + img = self.val_transforms(img) + img = img.unsqueeze(0) + img = img.to(self.device) + return {'img': img} + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + img = input['img'] + img_embedding = self.model(img) + img_embedding = img_embedding.detach().cpu().numpy() + return {OutputKeys.IMG_EMBEDDING: img_embedding} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_salient_detection_pipeline.py b/modelscope/pipelines/cv/image_salient_detection_pipeline.py new file mode 100644 index 00000000..4a3eaa65 --- /dev/null +++ b/modelscope/pipelines/cv/image_salient_detection_pipeline.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.semantic_segmentation, module_name=Pipelines.salient_detection) +class ImageSalientDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + + img = LoadImage.convert_to_ndarray(input) + img_h, img_w, _ = img.shape + img = self.model.preprocess(img) + result = {'img': img, 'img_w': img_w, 'img_h': img_h} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + outputs = self.model.inference(input['img']) + result = { + 'data': outputs, + 'img_w': input['img_w'], + 'img_h': input['img_h'] + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + data = self.model.postprocess(inputs) + outputs = {OutputKeys.MASKS: data} + return outputs diff --git a/modelscope/pipelines/cv/image_semantic_segmentation_pipeline.py b/modelscope/pipelines/cv/image_semantic_segmentation_pipeline.py new file mode 100644 index 00000000..023d9712 --- /dev/null +++ b/modelscope/pipelines/cv/image_semantic_segmentation_pipeline.py @@ -0,0 +1,90 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_segmentation, + module_name=Pipelines.image_semantic_segmentation) +class ImageSemanticSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image semantic segmentation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + logger.info('semantic segmentation model, pipeline init') + + def preprocess(self, input: Input) -> Dict[str, Any]: + from mmdet.datasets.pipelines import Compose + from mmcv.parallel import collate, scatter + from mmdet.datasets import replace_ImageToTensor + + cfg = self.model.cfg + # build the data pipeline + + if isinstance(input, str): + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + img = np.array(load_image(input)) + img = img[:, :, ::-1] # convert to bgr + elif isinstance(input, PIL.Image.Image): # BGR + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + img = np.array(input)[:, :, ::-1] + elif isinstance(input, np.ndarray): + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + if len(input.shape) == 2: + img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + else: + img = input + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + # collect data + data = dict(img=img) + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + test_pipeline = Compose(cfg.data.test.pipeline) + + data = test_pipeline(data) + # copy from mmdet_model collect data + data = collate([data], samples_per_gpu=1) + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + data['img'] = [img.data[0] for img in data['img']] + if next(self.model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [next(self.model.parameters()).device])[0] + + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.inference(input) + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.postprocess(inputs) + outputs = { + OutputKeys.MASKS: results[OutputKeys.MASKS], + OutputKeys.LABELS: results[OutputKeys.LABELS], + OutputKeys.SCORES: results[OutputKeys.SCORES] + } + + return outputs diff --git a/modelscope/pipelines/cv/image_style_transfer_pipeline.py b/modelscope/pipelines/cv/image_style_transfer_pipeline.py new file mode 100644 index 00000000..e5fd0d48 --- /dev/null +++ b/modelscope/pipelines/cv/image_style_transfer_pipeline.py @@ -0,0 +1,117 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import device_placement +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_style_transfer, module_name=Pipelines.image_style_transfer) +class ImageStyleTransferPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a style transfer pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + import tensorflow as tf + if tf.__version__ >= '2.0': + tf = tf.compat.v1 + model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) + + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + self.max_length = 800 + with self._session.as_default(): + logger.info(f'loading model from {model_path}') + with tf.gfile.FastGFile(model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + self.content = tf.get_default_graph().get_tensor_by_name( + 'content:0') + self.style = tf.get_default_graph().get_tensor_by_name( + 'style:0') + self.output = tf.get_default_graph().get_tensor_by_name( + 'stylized_output:0') + self.attention = tf.get_default_graph().get_tensor_by_name( + 'attention_map:0') + self.inter_weight = tf.get_default_graph( + ).get_tensor_by_name('inter_weight:0') + self.centroids = tf.get_default_graph().get_tensor_by_name( + 'centroids:0') + logger.info('load model done') + + def _sanitize_parameters(self, **pipeline_parameters): + return pipeline_parameters, {}, {} + + def preprocess(self, + content: Input, + style: Input = None) -> Dict[str, Any]: + if type(content) is dict: # for demo service + style = content['style'] + content = content['content'] + + content = LoadImage.convert_to_ndarray(content) + if len(content.shape) == 2: + content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) + content_img = content.astype(np.float) + + style_img = LoadImage.convert_to_ndarray(style) + if len(style_img.shape) == 2: + style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR) + style_img = style_img.astype(np.float) + + result = {'content': content_img, 'style': style_img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + content_feed, style_feed = input['content'], input['style'] + h = np.shape(content_feed)[0] + w = np.shape(content_feed)[1] + if h > self.max_length or w > self.max_length: + if h > w: + content_feed = cv2.resize( + content_feed, + (int(self.max_length * w / h), self.max_length)) + else: + content_feed = cv2.resize( + content_feed, + (self.max_length, int(self.max_length * h / w))) + + with self._session.as_default(): + feed_dict = { + self.content: content_feed, + self.style: style_feed, + self.inter_weight: 1.0 + } + output_img = self._session.run(self.output, feed_dict=feed_dict) + + # print('out_img shape:{}'.format(output_img.shape)) + output_img = cv2.cvtColor(output_img[0], cv2.COLOR_RGB2BGR) + output_img = np.clip(output_img, 0, 255).astype(np.uint8) + + output_img = cv2.resize(output_img, (w, h)) + + return {OutputKeys.OUTPUT_IMG: output_img} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_super_resolution_pipeline.py b/modelscope/pipelines/cv/image_super_resolution_pipeline.py new file mode 100644 index 00000000..ca8f3209 --- /dev/null +++ b/modelscope/pipelines/cv/image_super_resolution_pipeline.py @@ -0,0 +1,95 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +import torch.nn.functional as F + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.super_resolution import RRDBNet +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_super_resolution, module_name=Pipelines.image_super_resolution) +class ImageSuperResolutionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image super resolution pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + + self.num_feat = 64 + self.num_block = 23 + self.scale = 4 + self.sr_model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=self.num_feat, + num_block=self.num_block, + num_grow_ch=32, + scale=self.scale).to(self.device) + + model_path = f'{self.model}/{ModelFile.TORCH_MODEL_FILE}' + self.sr_model.load_state_dict(torch.load(model_path), strict=True) + + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = torch.from_numpy(img).to(self.device).permute( + 2, 0, 1).unsqueeze(0) / 255. + result = {'img': img} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + self.sr_model.eval() + + img = input['img'] + if self.scale == 2: + mod_scale = 2 + elif self.scale == 1: + mod_scale = 4 + else: + mod_scale = None + if mod_scale is not None: + h_pad, w_pad = 0, 0 + _, _, h, w = img.size() + if (h % mod_scale != 0): + h_pad = (mod_scale - h % mod_scale) + if (w % mod_scale != 0): + w_pad = (mod_scale - w % mod_scale) + img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect') + + with torch.no_grad(): + output = self.sr_model(img) + del img + # remove extra pad + if mod_scale is not None: + _, _, h, w = output.size() + output = output[:, :, 0:h - h_pad, 0:w - w_pad] + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) + + return {OutputKeys.OUTPUT_IMG: output} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_to_image_generate_pipeline.py b/modelscope/pipelines/cv/image_to_image_generate_pipeline.py new file mode 100644 index 00000000..4f0121dd --- /dev/null +++ b/modelscope/pipelines/cv/image_to_image_generate_pipeline.py @@ -0,0 +1,251 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +import torch.nn.functional as F +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from PIL import Image +from torchvision.utils import save_image + +import modelscope.models.cv.image_to_image_generation.data as data +import modelscope.models.cv.image_to_image_generation.models as models +import modelscope.models.cv.image_to_image_generation.ops as ops +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_to_image_generation.model import UNet +from modelscope.models.cv.image_to_image_generation.models.clip import \ + VisionTransformer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_to_image_generation, + module_name=Pipelines.image_to_image_generation) +class Image2ImageGenerationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image-to-image generation pipeline + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.repetition = 4 + # load vit model + vit_model_path = osp.join(self.model, + self.cfg.ModelPath.vit_model_path) + logger.info(f'loading vit model from {vit_model_path}') + self.vit = VisionTransformer( + image_size=self.cfg.Params.vit.vit_image_size, + patch_size=self.cfg.Params.vit.vit_patch_size, + dim=self.cfg.Params.vit.vit_dim, + out_dim=self.cfg.Params.vit.vit_out_dim, + num_heads=self.cfg.Params.vit.vit_num_heads, + num_layers=self.cfg.Params.vit.vit_num_layers).eval( + ).requires_grad_(False).to(self._device) # noqa E123 + state = torch.load(vit_model_path) + state = { + k[len('visual.'):]: v + for k, v in state.items() if k.startswith('visual.') + } + self.vit.load_state_dict(state) + logger.info('load vit model done') + + # load autoencoder model + ae_model_path = osp.join(self.model, self.cfg.ModelPath.ae_model_path) + logger.info(f'loading autoencoder model from {ae_model_path}') + self.autoencoder = models.VQAutoencoder( + dim=self.cfg.Params.ae.ae_dim, + z_dim=self.cfg.Params.ae.ae_z_dim, + dim_mult=self.cfg.Params.ae.ae_dim_mult, + attn_scales=self.cfg.Params.ae.ae_attn_scales, + codebook_size=self.cfg.Params.ae.ae_codebook_size).eval( + ).requires_grad_(False).to(self._device) # noqa E123 + self.autoencoder.load_state_dict( + torch.load(ae_model_path, map_location=self._device)) + logger.info('load autoencoder model done') + + # load decoder model + decoder_model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading decoder model from {decoder_model_path}') + self.decoder = UNet( + resolution=self.cfg.Params.unet.unet_resolution, + in_dim=self.cfg.Params.unet.unet_in_dim, + dim=self.cfg.Params.unet.unet_dim, + label_dim=self.cfg.Params.vit.vit_out_dim, + context_dim=self.cfg.Params.unet.unet_context_dim, + out_dim=self.cfg.Params.unet.unet_out_dim, + dim_mult=self.cfg.Params.unet.unet_dim_mult, + num_heads=self.cfg.Params.unet.unet_num_heads, + head_dim=None, + num_res_blocks=self.cfg.Params.unet.unet_res_blocks, + attn_scales=self.cfg.Params.unet.unet_attn_scales, + dropout=self.cfg.Params.unet.unet_dropout).eval().requires_grad_( + False).to(self._device) + self.decoder.load_state_dict( + torch.load(decoder_model_path, map_location=self._device)) + logger.info('load decoder model done') + + # diffusion + logger.info('Initialization diffusion ...') + betas = ops.beta_schedule(self.cfg.Params.diffusion.schedule, + self.cfg.Params.diffusion.num_timesteps) + self.diffusion = ops.GaussianDiffusion( + betas=betas, + mean_type=self.cfg.Params.diffusion.mean_type, + var_type=self.cfg.Params.diffusion.var_type, + loss_type=self.cfg.Params.diffusion.loss_type, + rescale_timesteps=False) + + def preprocess(self, input: Input) -> Dict[str, Any]: + input_img_list = [] + if isinstance(input, str): + input_img_list = [input] + input_type = 0 + elif isinstance(input, tuple) and len(input) == 2: + input_img_list = list(input) + input_type = 1 + else: + raise TypeError( + 'modelscope error: Only support "str" or "tuple (img1, img2)" , but got {type(input)}' + ) + + if input_type == 0: + logger.info('Processing Similar Image Generation mode') + if input_type == 1: + logger.info('Processing Interpolation mode') + + img_list = [] + for i, input_img in enumerate(input_img_list): + img = LoadImage.convert_to_img(input_img) + logger.info(f'Load {i}-th image done') + img_list.append(img) + + transforms = T.Compose([ + data.PadToSquare(), + T.Resize( + self.cfg.DATA.scale_size, + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std) + ]) + + y_list = [] + for img in img_list: + img = transforms(img) + imgs = torch.unsqueeze(img, 0) + imgs = imgs.to(self._device) + imgs_x0 = self.autoencoder.encode(imgs) + b, c, h, w = imgs_x0.shape + aug_imgs = TF.normalize( + F.interpolate( + imgs.add(1).div(2), (self.cfg.Params.vit.vit_image_size, + self.cfg.Params.vit.vit_image_size), + mode='bilinear', + align_corners=True), self.cfg.Params.vit.vit_mean, + self.cfg.Params.vit.vit_std) + uy = self.vit(aug_imgs) + y = F.normalize(uy, p=2, dim=1) + y_list.append(y) + + if input_type == 0: + result = { + 'image_data': y_list[0], + 'c': c, + 'h': h, + 'w': w, + 'type': input_type + } + elif input_type == 1: + result = { + 'image_data': y_list[0], + 'image_data_s': y_list[1], + 'c': c, + 'h': h, + 'w': w, + 'type': input_type + } + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + type_ = input['type'] + if type_ == 0: + # Similar Image Generation # + y = input['image_data'] + + # fix seed + torch.manual_seed(1 * 8888) + torch.cuda.manual_seed(1 * 8888) + i_y = y.repeat(self.repetition, 1) + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn(self.repetition, input['c'], input['h'], + input['w']).to(self._device), + model=self.decoder, + model_kwargs=[{ + 'y': i_y + }, { + 'y': torch.zeros_like(i_y) + }], + guide_scale=1.0, + clamp=None, + ddim_timesteps=50, + eta=1.0) + i_gen_imgs = self.autoencoder.decode(x0) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + else: + # Interpolation # + # get content-style pairs + y = input['image_data'] + y_s = input['image_data_s'] + + # fix seed + torch.manual_seed(1 * 8888) + torch.cuda.manual_seed(1 * 8888) + noise = torch.randn(self.repetition, input['c'], input['h'], + input['w']).to(self._device) + + # interpolation between y_cid and y_sid + factors = torch.linspace(0, 1, self.repetition).unsqueeze(1).to( + self._device) + i_y = (1 - factors) * y + factors * y_s + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=noise, + model=self.decoder, + model_kwargs=[{ + 'y': i_y + }, { + 'y': torch.zeros_like(i_y) + }], + guide_scale=3.0, + clamp=None, + ddim_timesteps=50, + eta=0.0) + i_gen_imgs = self.autoencoder.decode(x0) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_to_image_translation_pipeline.py b/modelscope/pipelines/cv/image_to_image_translation_pipeline.py new file mode 100644 index 00000000..e5f853ca --- /dev/null +++ b/modelscope/pipelines/cv/image_to_image_translation_pipeline.py @@ -0,0 +1,327 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import io +import os.path as osp +import sys +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image +from torchvision.utils import save_image + +import modelscope.models.cv.image_to_image_translation.data as data +import modelscope.models.cv.image_to_image_translation.models as models +import modelscope.models.cv.image_to_image_translation.ops as ops +from modelscope.fileio import File +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_to_image_translation.model_translation import \ + UNet +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def save_grid(imgs, filename, nrow=5): + save_image( + imgs.clamp(-1, 1), filename, range=(-1, 1), normalize=True, nrow=nrow) + + +@PIPELINES.register_module( + Tasks.image_to_image_translation, + module_name=Pipelines.image2image_translation) +class Image2ImageTranslationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.repetition = 4 + # load autoencoder model + ae_model_path = osp.join(self.model, self.cfg.ModelPath.ae_model_path) + logger.info(f'loading autoencoder model from {ae_model_path}') + self.autoencoder = models.VQAutoencoder( + dim=self.cfg.Params.ae.ae_dim, + z_dim=self.cfg.Params.ae.ae_z_dim, + dim_mult=self.cfg.Params.ae.ae_dim_mult, + attn_scales=self.cfg.Params.ae.ae_attn_scales, + codebook_size=self.cfg.Params.ae.ae_codebook_size).eval( + ).requires_grad_(False).to(self._device) # noqa E123 + self.autoencoder.load_state_dict( + torch.load(ae_model_path, map_location=self._device)) + logger.info('load autoencoder model done') + + # load palette model + palette_model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading palette model from {palette_model_path}') + self.palette = UNet( + resolution=self.cfg.Params.unet.unet_resolution, + in_dim=self.cfg.Params.unet.unet_in_dim, + dim=self.cfg.Params.unet.unet_dim, + context_dim=self.cfg.Params.unet.unet_context_dim, + out_dim=self.cfg.Params.unet.unet_out_dim, + dim_mult=self.cfg.Params.unet.unet_dim_mult, + num_heads=self.cfg.Params.unet.unet_num_heads, + head_dim=None, + num_res_blocks=self.cfg.Params.unet.unet_res_blocks, + attn_scales=self.cfg.Params.unet.unet_attn_scales, + num_classes=self.cfg.Params.unet.unet_num_classes + 1, + dropout=self.cfg.Params.unet.unet_dropout).eval().requires_grad_( + False).to(self._device) + self.palette.load_state_dict( + torch.load(palette_model_path, map_location=self._device)) + logger.info('load palette model done') + + # diffusion + logger.info('Initialization diffusion ...') + betas = ops.beta_schedule(self.cfg.Params.diffusion.schedule, + self.cfg.Params.diffusion.num_timesteps) + self.diffusion = ops.GaussianDiffusion( + betas=betas, + mean_type=self.cfg.Params.diffusion.mean_type, + var_type=self.cfg.Params.diffusion.var_type, + loss_type=self.cfg.Params.diffusion.loss_type, + rescale_timesteps=False) + + self.transforms = T.Compose([ + data.PadToSquare(), + T.Resize( + self.cfg.DATA.scale_size, + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std) + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + if len(input) == 3: # colorization + _, input_type, save_path = input + elif len(input) == 4: # uncropping or in-painting + _, meta, input_type, save_path = input + if input_type == 0: # uncropping + assert meta in ['up', 'down', 'left', 'right'] + direction = meta + + list_ = [] + for i in range(len(input) - 2): + input_img = input[i] + if input_img in ['up', 'down', 'left', 'right']: + continue + if isinstance(input_img, str): + if input_type == 2 and i == 0: + logger.info('Loading image by origin way ... ') + bytes = File.read(input_img) + img = Image.open(io.BytesIO(bytes)) + assert len(img.split()) == 4 + else: + img = load_image(input_img) + elif isinstance(input_img, PIL.Image.Image): + img = input_img.convert('RGB') + elif isinstance(input_img, np.ndarray): + if len(input_img.shape) == 2: + input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR) + img = input_img[:, :, ::-1] + img = Image.fromarray(img.astype('uint8')).convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + list_.append(img) + img_list = [] + if input_type != 2: + for img in list_: + img = self.transforms(img) + imgs = torch.unsqueeze(img, 0) + imgs = imgs.to(self._device) + img_list.append(imgs) + elif input_type == 2: + mask, masked_img = list_[0], list_[1] + img = self.transforms(masked_img.convert('RGB')) + mask = torch.from_numpy( + np.array( + mask.resize((img.shape[2], img.shape[1])), + dtype=np.float32)[:, :, -1] / 255.0).unsqueeze(0) + img = (1 - mask) * img + mask * torch.randn_like(img).clamp_(-1, 1) + imgs = img.unsqueeze(0).to(self._device) + b, c, h, w = imgs.shape + y = torch.LongTensor([self.cfg.Classes.class_id]).to(self._device) + + if input_type == 0: + assert len(img_list) == 1 + result = { + 'image_data': img_list[0], + 'c': c, + 'h': h, + 'w': w, + 'direction': direction, + 'type': input_type, + 'y': y, + 'save_path': save_path + } + elif input_type == 1: + assert len(img_list) == 1 + result = { + 'image_data': img_list[0], + 'c': c, + 'h': h, + 'w': w, + 'type': input_type, + 'y': y, + 'save_path': save_path + } + elif input_type == 2: + result = { + 'image_data': imgs, + # 'image_mask': mask, + 'c': c, + 'h': h, + 'w': w, + 'type': input_type, + 'y': y, + 'save_path': save_path + } + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + type_ = input['type'] + if type_ == 0: + # Uncropping + img = input['image_data'] + direction = input['direction'] + y = input['y'] + + # fix seed + torch.manual_seed(1 * 8888) + torch.cuda.manual_seed(1 * 8888) + + logger.info(f'Processing {direction} uncropping') + img = img.clone() + i_y = y.repeat(self.repetition, 1) + if direction == 'up': + img[:, :, input['h'] // 2:, :] = torch.randn_like( + img[:, :, input['h'] // 2:, :]) + elif direction == 'down': + img[:, :, :input['h'] // 2, :] = torch.randn_like( + img[:, :, :input['h'] // 2, :]) + elif direction == 'left': + img[:, :, :, + input['w'] // 2:] = torch.randn_like(img[:, :, :, + input['w'] // 2:]) + elif direction == 'right': + img[:, :, :, :input['w'] // 2] = torch.randn_like( + img[:, :, :, :input['w'] // 2]) + i_concat = self.autoencoder.encode(img).repeat( + self.repetition, 1, 1, 1) + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn_like(i_concat), + model=self.palette, + model_kwargs=[{ + 'y': i_y, + 'concat': i_concat + }, { + 'y': + torch.full_like(i_y, + self.cfg.Params.unet.unet_num_classes), + 'concat': + i_concat + }], + guide_scale=1.0, + clamp=None, + ddim_timesteps=50, + eta=1.0) + i_gen_imgs = self.autoencoder.decode(x0) + save_grid(i_gen_imgs, input['save_path'], nrow=4) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + + elif type_ == 1: + # Colorization # + img = input['image_data'] + y = input['y'] + # fix seed + torch.manual_seed(1 * 8888) + torch.cuda.manual_seed(1 * 8888) + + logger.info('Processing Colorization') + img = img.clone() + img = img.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1) + i_concat = self.autoencoder.encode(img).repeat( + self.repetition, 1, 1, 1) + i_y = y.repeat(self.repetition, 1) + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn_like(i_concat), + model=self.palette, + model_kwargs=[{ + 'y': i_y, + 'concat': i_concat + }, { + 'y': + torch.full_like(i_y, + self.cfg.Params.unet.unet_num_classes), + 'concat': + i_concat + }], + guide_scale=1.0, + clamp=None, + ddim_timesteps=50, + eta=0.0) + i_gen_imgs = self.autoencoder.decode(x0) + save_grid(i_gen_imgs, input['save_path'], nrow=4) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + elif type_ == 2: + # Combination # + logger.info('Processing Combination') + + # prepare inputs + img = input['image_data'] + concat = self.autoencoder.encode(img).repeat( + self.repetition, 1, 1, 1) + y = torch.LongTensor([126]).unsqueeze(0).to(self._device).repeat( + self.repetition, 1) + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn_like(concat), + model=self.palette, + model_kwargs=[{ + 'y': y, + 'concat': concat + }, { + 'y': + torch.full_like(y, self.cfg.Params.unet.unet_num_classes), + 'concat': + concat + }], + guide_scale=1.0, + clamp=None, + ddim_timesteps=50, + eta=1.0) + i_gen_imgs = self.autoencoder.decode(x0) + save_grid(i_gen_imgs, input['save_path'], nrow=4) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + else: + raise TypeError( + f'input type should be 0 (Uncropping), 1 (Colorization), 2 (Combation)' + f' but got {type_}') + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/live_category_pipeline.py b/modelscope/pipelines/cv/live_category_pipeline.py new file mode 100644 index 00000000..715998cc --- /dev/null +++ b/modelscope/pipelines/cv/live_category_pipeline.py @@ -0,0 +1,156 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os.path as osp +from typing import Any, Dict + +import decord +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.transforms.functional as TF +from decord import VideoReader, cpu +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.live_category, module_name=Pipelines.live_category) +class LiveCategoryPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a live-category pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + self.infer_model = models.resnet50(pretrained=False) + self.infer_model.fc = nn.Linear(2048, 8613) + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.infer_model = self.infer_model.to(self.device).eval() + self.infer_model.load_state_dict( + torch.load(model_path, map_location=self.device)) + logger.info('load model done') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.label_mapping = self.cfg.label_mapping + logger.info('load config done') + self.transforms = VCompose([ + VRescale(size=256), + VCenterCrop(size=224), + VToTensor(), + VNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + decord.bridge.set_bridge('native') + vr = VideoReader(input, ctx=cpu(0)) + indices = np.linspace(0, len(vr) - 1, 4).astype(int) + frames = vr.get_batch(indices).asnumpy() + video_input_data = self.transforms( + [Image.fromarray(f) for f in frames]) + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_data': video_input_data} + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + logits = self.infer_model(input['video_data'].to(self.device)) + softmax_out = F.softmax(logits, dim=1).mean(dim=0).cpu() + scores, ids = softmax_out.topk(3, 0, True, True) + scores = scores.numpy() + ids = ids.numpy() + labels = [] + for i in ids: + label_info = self.label_mapping[str(i)] + label_keys = ['cate_level1_name', 'cate_level2_name', 'cate_name'] + label_str = [] + for label_key in label_keys: + if label_info[label_key] not in label_str: + label_str.append(label_info[label_key]) + labels.append(label_str[-1]) + return {OutputKeys.SCORES: list(scores), OutputKeys.LABELS: labels} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + +class VCompose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, item): + for t in self.transforms: + item = t(item) + return item + + +class VRescale(object): + + def __init__(self, size=128): + self.size = size + + def __call__(self, vclip): + vclip = [ + u.resize((self.size, self.size), Image.BILINEAR) for u in vclip + ] + return vclip + + +class VCenterCrop(object): + + def __init__(self, size=112): + self.size = size + + def __call__(self, vclip): + w, h = vclip[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + vclip = [ + u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in vclip + ] + return vclip + + +class VToTensor(object): + + def __call__(self, vclip): + vclip = torch.stack([TF.to_tensor(u) for u in vclip], dim=0) + return vclip + + +class VNormalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, vclip): + assert vclip.min() > -0.1 and vclip.max() < 1.1, \ + 'vclip values should be in [0, 1]' + vclip = vclip.clone() + if not isinstance(self.mean, torch.Tensor): + self.mean = vclip.new_tensor(self.mean).view(1, -1, 1, 1) + if not isinstance(self.std, torch.Tensor): + self.std = vclip.new_tensor(self.std).view(1, -1, 1, 1) + vclip.sub_(self.mean).div_(self.std) + return vclip diff --git a/modelscope/pipelines/cv/mog_face_detection_pipeline.py b/modelscope/pipelines/cv/mog_face_detection_pipeline.py new file mode 100644 index 00000000..124b605b --- /dev/null +++ b/modelscope/pipelines/cv/mog_face_detection_pipeline.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import MogFaceDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.mog_face_detection) +class MogFaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {ckpt_path}') + detector = MogFaceDetector(model_path=ckpt_path, device=self.device) + self.detector = detector + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float32) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + result = self.detector(input) + assert result is not None + bboxes = result[:, :4].tolist() + scores = result[:, 4].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: None, + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py b/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py new file mode 100644 index 00000000..3fffc546 --- /dev/null +++ b/modelscope/pipelines/cv/movie_scene_segmentation_pipeline.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.movie_scene_segmentation, + module_name=Pipelines.movie_scene_segmentation) +class MovieSceneSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """use `model` to create a movie scene segmentation pipeline for prediction + + Args: + model: model id on modelscope hub + """ + _device = kwargs.pop('device', 'gpu') + if torch.cuda.is_available() and _device == 'gpu': + device = 'gpu' + else: + device = 'cpu' + super().__init__(model=model, device=device, **kwargs) + + logger.info('Load model done!') + + def preprocess(self, input: Input) -> Dict[str, Any]: + """ use pyscenedetect to detect shot from the input video, and generate key-frame jpg, anno.ndjson, and shot-frame.txt + Then use shot-encoder to encoder feat of the detected key-frame + + Args: + input: path of the input video + + """ + self.input_video_pth = input + if isinstance(input, str): + shot_feat, sid = self.model.preprocess(input) + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + + result = {'sid': sid, 'shot_feat': shot_feat} + + return result + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + output = self.model.inference(input) + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + data = {'input_video_pth': self.input_video_pth, 'feat': inputs} + scene_num, scene_meta_lst, shot_num, shot_meta_lst = self.model.postprocess( + data) + result = { + OutputKeys.SHOT_NUM: shot_num, + OutputKeys.SHOT_META_LIST: shot_meta_lst, + OutputKeys.SCENE_NUM: scene_num, + OutputKeys.SCENE_META_LIST: scene_meta_lst + } + return result diff --git a/modelscope/pipelines/cv/mtcnn_face_detection_pipeline.py b/modelscope/pipelines/cv/mtcnn_face_detection_pipeline.py new file mode 100644 index 00000000..bda46a70 --- /dev/null +++ b/modelscope/pipelines/cv/mtcnn_face_detection_pipeline.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import MtcnnFaceDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.mtcnn_face_detection) +class MtcnnFaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + ckpt_path = osp.join(model, './weights') + logger.info(f'loading model from {ckpt_path}') + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + detector = MtcnnFaceDetector(model_path=ckpt_path, device=device) + self.detector = detector + self.device = device + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector(input) + assert result is not None + bboxes = result[0][:, :4].tolist() + scores = result[0][:, 4].tolist() + lms = result[1].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: lms, + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py new file mode 100644 index 00000000..292ec2c5 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -0,0 +1,184 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import tensorflow as tf + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import device_placement +from modelscope.utils.logger import get_logger +from .ocr_utils import (SegLinkDetector, cal_width, combine_segments_python, + decode_segments_links_python, nms_python, + rboxes_to_polygons) + +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 +tf.compat.v1.disable_eager_execution() + +logger = get_logger() + +# constant +RBOX_DIM = 5 +OFFSET_DIM = 6 +WORD_POLYGON_DIM = 8 +OFFSET_VARIANCE = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1] + +FLAGS = tf.app.flags.FLAGS +tf.app.flags.DEFINE_float('node_threshold', 0.4, + 'Confidence threshold for nodes') +tf.app.flags.DEFINE_float('link_threshold', 0.6, + 'Confidence threshold for links') + + +@PIPELINES.register_module( + Tasks.ocr_detection, module_name=Pipelines.ocr_detection) +class OCRDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a OCR detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + tf.reset_default_graph() + model_path = osp.join( + osp.join(self.model, ModelFile.TF_CHECKPOINT_FOLDER), + 'checkpoint-80000') + self._graph = tf.get_default_graph() + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + + with self._graph.as_default(): + with device_placement(self.framework, self.device_name): + self.input_images = tf.placeholder( + tf.float32, shape=[1, 1024, 1024, 3], name='input_images') + self.output = {} + + with tf.variable_scope('', reuse=tf.AUTO_REUSE): + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), + dtype=tf.int64, + trainable=False) + variable_averages = tf.train.ExponentialMovingAverage( + 0.997, global_step) + + # detector + detector = SegLinkDetector() + all_maps = detector.build_model( + self.input_images, is_training=False) + + # decode local predictions + all_nodes, all_links, all_reg = [], [], [] + for i, maps in enumerate(all_maps): + cls_maps, lnk_maps, reg_maps = maps[0], maps[1], maps[ + 2] + reg_maps = tf.multiply(reg_maps, OFFSET_VARIANCE) + + cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2])) + + lnk_prob_pos = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, :2]) + lnk_prob_mut = tf.nn.softmax( + tf.reshape(lnk_maps, [-1, 4])[:, 2:]) + lnk_prob = tf.concat([lnk_prob_pos, lnk_prob_mut], + axis=1) + + all_nodes.append(cls_prob) + all_links.append(lnk_prob) + all_reg.append(reg_maps) + + # decode segments and links + image_size = tf.shape(self.input_images)[1:3] + segments, group_indices, segment_counts, _ = decode_segments_links_python( + image_size, + all_nodes, + all_links, + all_reg, + anchor_sizes=list(detector.anchor_sizes)) + + # combine segments + combined_rboxes, combined_counts = combine_segments_python( + segments, group_indices, segment_counts) + self.output['combined_rboxes'] = combined_rboxes + self.output['combined_counts'] = combined_counts + + with self._session.as_default() as sess: + logger.info(f'loading model from {model_path}') + # load model + model_loader = tf.train.Saver( + variable_averages.variables_to_restore()) + model_loader.restore(sess, model_path) + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + + h, w, c = img.shape + img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) + img_pad[:h, :w, :] = img + + resize_size = 1024 + img_pad_resize = cv2.resize(img_pad, (resize_size, resize_size)) + img_pad_resize = cv2.cvtColor(img_pad_resize, cv2.COLOR_RGB2BGR) + img_pad_resize = img_pad_resize - np.array([123.68, 116.78, 103.94], + dtype=np.float32) + + with self._graph.as_default(): + resize_size = tf.stack([resize_size, resize_size]) + orig_size = tf.stack([max(h, w), max(h, w)]) + self.output['orig_size'] = orig_size + self.output['resize_size'] = resize_size + + result = {'img': np.expand_dims(img_pad_resize, axis=0)} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with self._graph.as_default(): + with self._session.as_default(): + feed_dict = {self.input_images: input['img']} + sess_outputs = self._session.run( + self.output, feed_dict=feed_dict) + return sess_outputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + rboxes = inputs['combined_rboxes'][0] + count = inputs['combined_counts'][0] + if count == 0 or count < rboxes.shape[0]: + raise Exception('modelscope error: No text detected') + rboxes = rboxes[:count, :] + + # convert rboxes to polygons and find its coordinates on the original image + orig_h, orig_w = inputs['orig_size'] + resize_h, resize_w = inputs['resize_size'] + polygons = rboxes_to_polygons(rboxes) + scale_y = float(orig_h) / float(resize_h) + scale_x = float(orig_w) / float(resize_w) + + # confine polygons inside image + polygons[:, ::2] = np.maximum( + 0, np.minimum(polygons[:, ::2] * scale_x, orig_w - 1)) + polygons[:, 1::2] = np.maximum( + 0, np.minimum(polygons[:, 1::2] * scale_y, orig_h - 1)) + polygons = np.round(polygons).astype(np.int32) + + # nms + dt_n9 = [o + [cal_width(o)] for o in polygons.tolist()] + dt_nms = nms_python(dt_n9) + dt_polygons = np.array([o[:8] for o in dt_nms]) + + result = {OutputKeys.POLYGONS: dt_polygons} + return result diff --git a/modelscope/pipelines/cv/ocr_recognition_pipeline.py b/modelscope/pipelines/cv/ocr_recognition_pipeline.py new file mode 100644 index 00000000..e81467a1 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_recognition_pipeline.py @@ -0,0 +1,133 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.cv.ocr_utils.model_convnext_transformer import \ + OCRRecModel +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +# constant +NUM_CLASSES = 7644 +IMG_HEIGHT = 32 +IMG_WIDTH = 300 +PRED_LENTH = 75 +PRED_PAD = 6 + + +@PIPELINES.register_module( + Tasks.ocr_recognition, module_name=Pipelines.ocr_recognition) +class OCRRecognitionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + label_path = osp.join(self.model, 'label_dict.txt') + logger.info(f'loading model from {model_path}') + + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.infer_model = OCRRecModel(NUM_CLASSES).to(self.device) + self.infer_model.eval() + self.infer_model.load_state_dict( + torch.load(model_path, map_location=self.device)) + self.labelMapping = dict() + with open(label_path, 'r') as f: + lines = f.readlines() + cnt = 2 + for line in lines: + line = line.strip('\n') + self.labelMapping[cnt] = line + cnt += 1 + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + img = np.array(load_image(input).convert('L')) + elif isinstance(input, PIL.Image.Image): + img = np.array(input.convert('L')) + elif isinstance(input, np.ndarray): + if len(input.shape) == 3: + img = cv2.cvtColor(input, cv2.COLOR_RGB2GRAY) + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + data = [] + img_h, img_w = img.shape + wh_ratio = img_w / img_h + true_w = int(IMG_HEIGHT * wh_ratio) + split_batch_cnt = 1 + if true_w < IMG_WIDTH * 1.2: + img = cv2.resize(img, (min(true_w, IMG_WIDTH), IMG_HEIGHT)) + else: + split_batch_cnt = math.ceil((true_w - 48) * 1.0 / 252) + img = cv2.resize(img, (true_w, IMG_HEIGHT)) + + if split_batch_cnt == 1: + mask = np.zeros((IMG_HEIGHT, IMG_WIDTH)) + mask[:, :img.shape[1]] = img + data.append(mask) + else: + for idx in range(split_batch_cnt): + mask = np.zeros((IMG_HEIGHT, IMG_WIDTH)) + left = (PRED_LENTH * 4 - PRED_PAD * 4) * idx + trunk_img = img[:, left:min(left + PRED_LENTH * 4, true_w)] + mask[:, :trunk_img.shape[1]] = trunk_img + data.append(mask) + + data = torch.FloatTensor(data).view( + len(data), 1, IMG_HEIGHT, IMG_WIDTH) / 255. + data = data.to(self.device) + + result = {'img': data} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred = self.infer_model(input['img']) + return {'results': pred} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + preds = inputs['results'] + batchSize, length = preds.shape + pred_idx = [] + if batchSize == 1: + pred_idx = preds[0].cpu().data.tolist() + else: + for idx in range(batchSize): + if idx == 0: + pred_idx.extend(preds[idx].cpu().data[:PRED_LENTH + - PRED_PAD].tolist()) + elif idx == batchSize - 1: + pred_idx.extend(preds[idx].cpu().data[PRED_PAD:].tolist()) + else: + pred_idx.extend(preds[idx].cpu().data[PRED_PAD:PRED_LENTH + - PRED_PAD].tolist()) + + # ctc decoder + last_p = 0 + str_pred = [] + for p in pred_idx: + if p != last_p and p != 0: + str_pred.append(self.labelMapping[p]) + last_p = p + + final_str = ''.join(str_pred) + result = {OutputKeys.TEXT: final_str} + return result diff --git a/modelscope/pipelines/cv/ocr_utils/__init__.py b/modelscope/pipelines/cv/ocr_utils/__init__.py new file mode 100644 index 00000000..312445a9 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model_resnet_mutex_v4_linewithchar import SegLinkDetector + from .ops import decode_segments_links_python, combine_segments_python + from .utils import rboxes_to_polygons, cal_width, nms_python +else: + _import_structure = { + 'model_resnet_mutex_v4_linewithchar': ['SegLinkDetector'], + 'ops': ['decode_segments_links_python', 'combine_segments_python'], + 'utils': ['rboxes_to_polygons', 'cal_width', 'nms_python'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py b/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py new file mode 100644 index 00000000..6ecff7ef --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/model_convnext_transformer.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn + +from .ocr_modules.convnext import convnext_tiny +from .ocr_modules.vitstr import vitstr_tiny + + +class OCRRecModel(nn.Module): + + def __init__(self, num_classes): + super(OCRRecModel, self).__init__() + self.cnn_model = convnext_tiny() + self.num_classes = num_classes + self.vitstr = vitstr_tiny(num_tokens=num_classes) + + def forward(self, input): + """ Transformation stage """ + features = self.cnn_model(input) + prediction = self.vitstr(features) + prediction = torch.nn.functional.softmax(prediction, dim=-1) + + output = torch.argmax(prediction, -1) + return output diff --git a/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py new file mode 100644 index 00000000..2c2d5b00 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/model_resnet_mutex_v4_linewithchar.py @@ -0,0 +1,164 @@ +# Part of the implementation is borrowed and modified from SegLink, +# publicly available at https://github.com/bgshih/seglink +import tensorflow as tf + +from . import ops, resnet18_v1, resnet_utils + +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + +# constants +OFFSET_DIM = 6 + +N_LOCAL_LINKS = 8 +N_CROSS_LINKS = 4 +N_SEG_CLASSES = 2 +N_LNK_CLASSES = 4 + +POS_LABEL = 1 +NEG_LABEL = 0 + + +class SegLinkDetector(): + + def __init__(self): + self.anchor_sizes = [6., 11.84210526, 23.68421053, 45., 90., 150.] + + def _detection_classifier(self, + maps, + ksize, + weight_decay, + cross_links=False, + scope=None): + + with tf.variable_scope(scope): + seg_depth = N_SEG_CLASSES + if cross_links: + lnk_depth = N_LNK_CLASSES * (N_LOCAL_LINKS + N_CROSS_LINKS) + else: + lnk_depth = N_LNK_CLASSES * N_LOCAL_LINKS + reg_depth = OFFSET_DIM + map_depth = maps.get_shape()[3] + inter_maps, inter_relu = ops.conv2d( + maps, map_depth, 256, 1, 1, 'SAME', scope='conv_inter') + + dir_maps, dir_relu = ops.conv2d( + inter_relu, 256, 2, ksize, 1, 'SAME', scope='conv_dir') + cen_maps, cen_relu = ops.conv2d( + inter_relu, 256, 2, ksize, 1, 'SAME', scope='conv_cen') + pol_maps, pol_relu = ops.conv2d( + inter_relu, 256, 8, ksize, 1, 'SAME', scope='conv_pol') + concat_relu = tf.concat([dir_relu, cen_relu, pol_relu], axis=-1) + _, lnk_embedding = ops.conv_relu( + concat_relu, 12, 256, 1, 1, scope='lnk_embedding') + lnk_maps, lnk_relu = ops.conv2d( + inter_relu + lnk_embedding, + 256, + lnk_depth, + ksize, + 1, + 'SAME', + scope='conv_lnk') + + char_seg_maps, char_seg_relu = ops.conv2d( + inter_relu, + 256, + seg_depth, + ksize, + 1, + 'SAME', + scope='conv_char_cls') + char_reg_maps, char_reg_relu = ops.conv2d( + inter_relu, + 256, + reg_depth, + ksize, + 1, + 'SAME', + scope='conv_char_reg') + concat_char_relu = tf.concat([char_seg_relu, char_reg_relu], + axis=-1) + _, char_embedding = ops.conv_relu( + concat_char_relu, 8, 256, 1, 1, scope='conv_char_embedding') + seg_maps, seg_relu = ops.conv2d( + inter_relu + char_embedding, + 256, + seg_depth, + ksize, + 1, + 'SAME', + scope='conv_cls') + reg_maps, reg_relu = ops.conv2d( + inter_relu + char_embedding, + 256, + reg_depth, + ksize, + 1, + 'SAME', + scope='conv_reg') + + return seg_relu, lnk_relu, reg_relu + + def _build_cnn(self, images, weight_decay, is_training): + with slim.arg_scope( + resnet18_v1.resnet_arg_scope(weight_decay=weight_decay)): + logits, end_points = resnet18_v1.resnet_v1_18( + images, is_training=is_training, scope='resnet_v1_18') + + outputs = { + 'conv3_3': end_points['pool1'], + 'conv4_3': end_points['pool2'], + 'fc7': end_points['pool3'], + 'conv8_2': end_points['pool4'], + 'conv9_2': end_points['pool5'], + 'conv10_2': end_points['pool6'], + } + return outputs + + def build_model(self, images, is_training=True, scope=None): + + weight_decay = 5e-4 # FLAGS.weight_decay + cnn_outputs = self._build_cnn(images, weight_decay, is_training) + det_0 = self._detection_classifier( + cnn_outputs['conv3_3'], + 3, + weight_decay, + cross_links=False, + scope='dete_0') + det_1 = self._detection_classifier( + cnn_outputs['conv4_3'], + 3, + weight_decay, + cross_links=True, + scope='dete_1') + det_2 = self._detection_classifier( + cnn_outputs['fc7'], + 3, + weight_decay, + cross_links=True, + scope='dete_2') + det_3 = self._detection_classifier( + cnn_outputs['conv8_2'], + 3, + weight_decay, + cross_links=True, + scope='dete_3') + det_4 = self._detection_classifier( + cnn_outputs['conv9_2'], + 3, + weight_decay, + cross_links=True, + scope='dete_4') + det_5 = self._detection_classifier( + cnn_outputs['conv10_2'], + 3, + weight_decay, + cross_links=True, + scope='dete_5') + outputs = [det_0, det_1, det_2, det_3, det_4, det_5] + return outputs diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/__init__.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/__init__.py new file mode 100644 index 00000000..7799c34f --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .convnext import convnext_tiny + from .vitstr import vitstr_tiny +else: + _import_structure = { + 'convnext': ['convnext_tiny'], + 'vitstr': ['vitstr_tiny'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py new file mode 100644 index 00000000..c0e30616 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/convnext.py @@ -0,0 +1,163 @@ +# Part of the implementation is borrowed and modified from ConvNext, +# publicly available at https://github.com/facebookresearch/ConvNeXt +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .timm_tinyc import DropPath + + +class Block(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d( + dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, + 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, + in_chans=1, + num_classes=1000, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + drop_path_rate=0., + layer_scale_init_value=1e-6, + head_init_scale=1., + ): + super().__init__() + + self.downsample_layers = nn.ModuleList( + ) # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format='channels_first')) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format='channels_first'), + nn.Conv2d( + dims[i], dims[i + 1], kernel_size=(2, 1), stride=(2, 1)), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList( + ) # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + cur = 0 + for i in range(4): + stage = nn.Sequential(*[ + Block( + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) + for j in range(depths[i]) + ]) + self.stages.append(stage) + cur += depths[i] + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x.contiguous()) + x = self.stages[i](x.contiguous()) + return x # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + x = self.forward_features(x.contiguous()) + + return x.contiguous() + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, + normalized_shape, + eps=1e-6, + data_format='channels_last'): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ['channels_last', 'channels_first']: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == 'channels_last': + return F.layer_norm(x, self.normalized_shape, self.weight, + self.bias, self.eps) + elif self.data_format == 'channels_first': + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def convnext_tiny(): + model = ConvNeXt(depths=[3, 3, 8, 3], dims=[96, 192, 256, 512]) + return model diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py new file mode 100644 index 00000000..555b1e42 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/timm_tinyc.py @@ -0,0 +1,332 @@ +# Part of the implementation is borrowed and modified from timm, +# publicly available at https://github.com/rwightman/pytorch-image-models +import collections.abc +import logging +import math +from collections import OrderedDict +from copy import deepcopy +from functools import partial +from itertools import repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True): + super().__init__() + img_size = (1, 75) + to_2tuple = _ntuple(2) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + x = x.permute(0, 1, 3, 2) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0.1, + proj_drop=0.1): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + representation_size=None, + distilled=False, + drop_rate=0.1, + attn_drop_rate=0.1, + drop_path_rate=0., + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + weight_init=''): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros( + 1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer) for i in range(depth) + ]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential( + OrderedDict([('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh())])) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s) + self.head = nn.Linear( + self.num_features, + num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear( + self.embed_dim, + self.num_classes) if num_classes > 0 else nn.Identity() + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear( + self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.num_tokens == 2: + self.head_dist = nn.Linear( + self.embed_dim, + self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand( + x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat( + (cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), + dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) + x = self.norm(x) + if self.dist_token is None: + return self.pre_logits(x[:, 0]) + else: + return x[:, 0], x[:, 1] + + def forward(self, x): + x = self.forward_features(x) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist( + x[1]) # x must be a tuple + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 + else: + x = self.head(x) + return x diff --git a/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py b/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py new file mode 100644 index 00000000..5ce3aeca --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ocr_modules/vitstr.py @@ -0,0 +1,58 @@ +# Part of the implementation is borrowed and modified from ViTSTR, +# publicly available at https://github.com/roatienza/deep-text-recognition-benchmark +from __future__ import absolute_import, division, print_function +import logging +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +from .timm_tinyc import VisionTransformer + + +class ViTSTR(VisionTransformer): + ''' + ViTSTR is basically a ViT that uses DeiT weights. + Modified head to support a sequence of characters prediction for STR. + ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + self.head = nn.Linear( + self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + b, s, e = x.size() + x = x.reshape(b * s, e) + x = self.head(x).view(b, s, self.num_classes) + return x + + +def vitstr_tiny(num_tokens): + vitstr = ViTSTR( + patch_size=1, + in_chans=512, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4, + qkv_bias=True) + vitstr.reset_classifier(num_classes=num_tokens) + return vitstr diff --git a/modelscope/pipelines/cv/ocr_utils/ops.py b/modelscope/pipelines/cv/ocr_utils/ops.py new file mode 100644 index 00000000..a36838a6 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/ops.py @@ -0,0 +1,1106 @@ +# Part of the implementation is borrowed and modified from SegLink, +# publicly available at https://github.com/bgshih/seglink +import math +import os +import shutil +import sys +import uuid + +import absl.flags as absl_flags +import cv2 +import numpy as np +import tensorflow as tf + +from . import utils + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + +# skip parse sys.argv in tf, so fix bug: +# absl.flags._exceptions.UnrecognizedFlagError: +# Unknown command line flag 'OCRDetectionPipeline: Unknown command line flag +absl_flags.FLAGS(sys.argv, known_only=True) +FLAGS = tf.app.flags.FLAGS +tf.app.flags.DEFINE_string('weight_init_method', 'xavier', + 'Weight initialization method') + +# constants +OFFSET_DIM = 6 +RBOX_DIM = 5 + +N_LOCAL_LINKS = 8 +N_CROSS_LINKS = 4 +N_SEG_CLASSES = 2 +N_LNK_CLASSES = 4 + +MATCH_STATUS_POS = 1 +MATCH_STATUS_NEG = -1 +MATCH_STATUS_IGNORE = 0 +MUT_LABEL = 3 +POS_LABEL = 1 +NEG_LABEL = 0 + +N_DET_LAYERS = 6 + + +def load_oplib(lib_name): + """ + Load TensorFlow operator library. + """ + # use absolute path so that ops.py can be called from other directory + lib_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'lib{0}.so'.format(lib_name)) + # duplicate library with a random new name so that + # a running program will not be interrupted when the original library is updated + lib_copy_path = '/tmp/lib{0}_{1}.so'.format( + str(uuid.uuid4())[:8], LIB_NAME) + shutil.copyfile(lib_path, lib_copy_path) + oplib = tf.load_op_library(lib_copy_path) + return oplib + + +def _nn_variable(name, shape, init_method, collection=None, **kwargs): + """ + Create or reuse a variable + ARGS + name: variable name + shape: variable shape + init_method: 'zero', 'kaiming', 'xavier', or (mean, std) + collection: if not none, add variable to this collection + kwargs: extra paramters passed to tf.get_variable + RETURN + var: a new or existing variable + """ + if init_method == 'zero': + initializer = tf.constant_initializer(0.0) + elif init_method == 'kaiming': + if len(shape) == 4: # convolutional filters + kh, kw, n_in = shape[:3] + init_std = math.sqrt(2.0 / (kh * kw * n_in)) + elif len(shape) == 2: # linear weights + n_in, n_out = shape + init_std = math.sqrt(1.0 / n_out) + else: + raise 'Unsupported shape' + initializer = tf.truncated_normal_initializer(0.0, init_std) + elif init_method == 'xavier': + if len(shape) == 4: + initializer = tf.keras.initializers.glorot_normal() + else: + initializer = tf.keras.initializers.glorot_normal() + elif isinstance(init_method, tuple): + assert (len(init_method) == 2) + initializer = tf.truncated_normal_initializer(init_method[0], + init_method[1]) + else: + raise 'Unsupported weight initialization method: ' + init_method + + var = tf.get_variable(name, shape=shape, initializer=initializer) + if collection is not None: + tf.add_to_collection(collection, var) + + return var + + +def conv2d(x, + n_in, + n_out, + ksize, + stride=1, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + scope=None, + **kwargs): + weight_init = weight_init or FLAGS.weight_init_method + trainable = kwargs.get('trainable', True) + # input_dim = n_in + if (padding == 'SAME'): + in_height = x.get_shape()[1] + in_width = x.get_shape()[2] + if (in_height % stride == 0): + pad_along_height = max(ksize - stride, 0) + else: + pad_along_height = max(ksize - (in_height % stride), 0) + if (in_width % stride == 0): + pad_along_width = max(ksize - stride, 0) + else: + pad_along_width = max(ksize - (in_width % stride), 0) + pad_bottom = pad_along_height // 2 + pad_top = pad_along_height - pad_bottom + pad_right = pad_along_width // 2 + pad_left = pad_along_width - pad_right + paddings = tf.constant([[0, 0], [pad_top, pad_bottom], + [pad_left, pad_right], [0, 0]]) + input_padded = tf.pad(x, paddings, 'CONSTANT') + else: + input_padded = x + + with tf.variable_scope(scope or 'conv2d'): + # convolution + kernel = _nn_variable( + 'weight', [ksize, ksize, n_in, n_out], + weight_init, + collection='weights' if trainable else None, + **kwargs) + yc = tf.nn.conv2d( + input_padded, kernel, [1, stride, stride, 1], padding='VALID') + # add bias + if bias is True: + bias = _nn_variable( + 'bias', [n_out], + 'zero', + collection='biases' if trainable else None, + **kwargs) + yb = tf.nn.bias_add(yc, bias) + # apply ReLU + y = yb + if relu is True: + y = tf.nn.relu(yb) + return yb, y + + +def group_conv2d_relu(x, + n_in, + n_out, + ksize, + stride=1, + group=4, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + name='group_conv2d', + **kwargs): + group_axis = len(x.get_shape()) - 1 + splits = tf.split(x, [int(n_in / group)] * group, group_axis) + + conv_list = [] + for i in range(group): + conv_split, relu_split = conv2d( + splits[i], + n_in / group, + n_out / group, + ksize=ksize, + stride=stride, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope='%s_%d' % (name, i)) + conv_list.append(conv_split) + conv = tf.concat(values=conv_list, axis=group_axis, name=name + '_concat') + relu = tf.nn.relu(conv) + return conv, relu + + +def group_conv2d_bn_relu(x, + n_in, + n_out, + ksize, + stride=1, + group=4, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + name='group_conv2d', + **kwargs): + group_axis = len(x.get_shape()) - 1 + splits = tf.split(x, [int(n_in / group)] * group, group_axis) + + conv_list = [] + for i in range(group): + conv_split, relu_split = conv2d( + splits[i], + n_in / group, + n_out / group, + ksize=ksize, + stride=stride, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope='%s_%d' % (name, i)) + conv_list.append(conv_split) + conv = tf.concat(values=conv_list, axis=group_axis, name=name + '_concat') + with tf.variable_scope(name + '_bn'): + bn = tf.layers.batch_normalization( + conv, momentum=0.9, epsilon=1e-5, scale=True, training=True) + relu = tf.nn.relu(bn) + return conv, relu + + +def next_conv(x, + n_in, + n_out, + ksize, + stride=1, + group=4, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + name='next_conv2d', + **kwargs): + conv_a, relu_a = conv_relu( + x, + n_in, + n_in / 2, + ksize=1, + stride=1, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope=name + '_a', + **kwargs) + + conv_b, relu_b = group_conv2d_relu( + relu_a, + n_in / 2, + n_out / 2, + ksize=ksize, + stride=stride, + group=group, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + name=name + '_b', + **kwargs) + + conv_c, relu_c = conv_relu( + relu_b, + n_out / 2, + n_out, + ksize=1, + stride=1, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope=name + '_c', + **kwargs) + + return conv_c, relu_c + + +def next_conv_bn(x, + n_in, + n_out, + ksize, + stride=1, + group=4, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + name='next_conv2d', + **kwargs): + conv_a, relu_a = conv_bn_relu( + x, + n_in, + n_in / 2, + ksize=1, + stride=1, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope=name + '_a', + **kwargs) + + conv_b, relu_b = group_conv2d_bn_relu( + relu_a, + n_in / 2, + n_out / 2, + ksize=ksize, + stride=stride, + group=group, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + name=name + '_b', + **kwargs) + + conv_c, relu_c = conv_bn_relu( + relu_b, + n_out / 2, + n_out, + ksize=1, + stride=1, + padding=padding, + weight_init=weight_init, + bias=bias, + relu=relu, + scope=name + '_c', + **kwargs) + + return conv_c, relu_c + + +def conv2d_ori(x, + n_in, + n_out, + ksize, + stride=1, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + scope=None, + **kwargs): + weight_init = weight_init or FLAGS.weight_init_method + trainable = kwargs.get('trainable', True) + + with tf.variable_scope(scope or 'conv2d'): + # convolution + kernel = _nn_variable( + 'weight', [ksize, ksize, n_in, n_out], + weight_init, + collection='weights' if trainable else None, + **kwargs) + y = tf.nn.conv2d(x, kernel, [1, stride, stride, 1], padding=padding) + # add bias + if bias is True: + bias = _nn_variable( + 'bias', [n_out], + 'zero', + collection='biases' if trainable else None, + **kwargs) + y = tf.nn.bias_add(y, bias) + # apply ReLU + if relu is True: + y = tf.nn.relu(y) + return y + + +def conv_relu(*args, **kwargs): + kwargs['relu'] = True + if 'scope' not in kwargs: + kwargs['scope'] = 'conv_relu' + return conv2d(*args, **kwargs) + + +def conv_bn_relu(*args, **kwargs): + kwargs['relu'] = True + if 'scope' not in kwargs: + kwargs['scope'] = 'conv_relu' + conv, relu = conv2d(*args, **kwargs) + with tf.variable_scope(kwargs['scope'] + '_bn'): + bn = tf.layers.batch_normalization( + conv, momentum=0.9, epsilon=1e-5, scale=True, training=True) + bn_relu = tf.nn.relu(bn) + return bn, bn_relu + + +def conv_relu_ori(*args, **kwargs): + kwargs['relu'] = True + if 'scope' not in kwargs: + kwargs['scope'] = 'conv_relu' + return conv2d_ori(*args, **kwargs) + + +def atrous_conv2d(x, + n_in, + n_out, + ksize, + dilation, + padding='SAME', + weight_init=None, + bias=True, + relu=False, + scope=None, + **kwargs): + weight_init = weight_init or FLAGS.weight_init_method + trainable = kwargs.get('trainable', True) + with tf.variable_scope(scope or 'atrous_conv2d'): + # atrous convolution + kernel = _nn_variable( + 'weight', [ksize, ksize, n_in, n_out], + weight_init, + collection='weights' if trainable else None, + **kwargs) + y = tf.nn.atrous_conv2d(x, kernel, dilation, padding=padding) + # add bias + if bias is True: + bias = _nn_variable( + 'bias', [n_out], + 'zero', + collection='biases' if trainable else None, + **kwargs) + y = tf.nn.bias_add(y, bias) + # apply ReLU + if relu is True: + y = tf.nn.relu(y) + return y + + +def avg_pool(x, ksize, stride, padding='SAME', scope=None): + with tf.variable_scope(scope or 'avg_pool'): + y = tf.nn.avg_pool(x, [1, ksize, ksize, 1], [1, stride, stride, 1], + padding) + return y + + +def max_pool(x, ksize, stride, padding='SAME', scope=None): + with tf.variable_scope(scope or 'max_pool'): + y = tf.nn.max_pool(x, [1, ksize, ksize, 1], [1, stride, stride, 1], + padding) + return y + + +def score_loss(gt_labels, match_scores, n_classes): + """ + Classification loss + ARGS + gt_labels: int32 [n] + match_scores: [n, n_classes] + RETURN + loss + """ + embeddings = tf.one_hot(tf.cast(gt_labels, tf.int64), n_classes, 1.0, 0.0) + losses = tf.nn.softmax_cross_entropy_with_logits(match_scores, embeddings) + return tf.reduce_sum(losses) + + +def smooth_l1_loss(offsets, gt_offsets, scope=None): + """ + Smooth L1 loss between offsets and encoded_gt + ARGS + offsets: [m?, 5], predicted offsets for one example + gt_offsets: [m?, 5], correponding groundtruth offsets + RETURN + loss: scalar + """ + with tf.variable_scope(scope or 'smooth_l1_loss'): + gt_offsets = tf.stop_gradient(gt_offsets) + diff = tf.abs(offsets - gt_offsets) + lesser_mask = tf.cast(tf.less(diff, 1.0), tf.float32) + larger_mask = 1.0 - lesser_mask + losses1 = (0.5 * tf.square(diff)) * lesser_mask + losses2 = (diff - 0.5) * larger_mask + return tf.reduce_sum(losses1 + losses2, 1) + + +def polygon_to_rboxe(polygon): + x1 = polygon[0] + y1 = polygon[1] + x2 = polygon[2] + y2 = polygon[3] + x3 = polygon[4] + y3 = polygon[5] + x4 = polygon[6] + y4 = polygon[7] + c_x = (x1 + x2 + x3 + x4) / 4 + c_y = (y1 + y2 + y3 + y4) / 4 + w1 = point_dist(x1, y1, x2, y2) + w2 = point_dist(x3, y3, x4, y4) + h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2) + h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4) + h = h1 + h2 + w = (w1 + w2) / 2 + theta1 = np.arctan2(y2 - y1, x2 - x1) + theta2 = np.arctan2(y3 - y4, x3 - x4) + theta = (theta1 + theta2) / 2 + return np.array([c_x, c_y, w, h, theta]) + + +def point_dist(x1, y1, x2, y2): + return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)) + + +def point_line_dist(px, py, x1, y1, x2, y2): + eps = 1e-6 + dx = x2 - x1 + dy = y2 - y1 + div = np.sqrt(dx * dx + dy * dy) + eps + dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div + return dist + + +def get_combined_polygon(rboxes, resize_size): + image_w = resize_size[1] + image_h = resize_size[0] + img = np.zeros((image_h, image_w, 3), np.uint8) + for i in range(rboxes.shape[0]): + segment = np.reshape( + np.array(utils.rboxes_to_polygons(rboxes)[i, :], np.int32), + (-1, 1, 2)) + cv2.drawContours(img, [segment], 0, (255, 255, 255), -1) + img2gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + ret, thresh = cv2.threshold(img2gray, 127, 255, cv2.THRESH_BINARY) + im2, contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) + if len(contours) > 0: + cnt = contours[0] + max_area = cv2.contourArea(cnt) + # get max_area + for cont in contours: + if cv2.contourArea(cont) > max_area: + cnt = cont + max_area = cv2.contourArea(cont) + rect = cv2.minAreaRect(cnt) + combined_polygon = np.array(cv2.boxPoints(rect)).reshape(-1) + else: + combined_polygon = np.array([0, 0, 0, 0, 0, 0, 0, 0]) + + return combined_polygon + + +def combine_segs(segs): + segs = np.asarray(segs) + assert segs.ndim == 2, 'invalid segs ndim' + assert segs.shape[-1] == 6, 'invalid segs shape' + + if len(segs) == 1: + cx = segs[0, 0] + cy = segs[0, 1] + w = segs[0, 2] + h = segs[0, 3] + theta_sin = segs[0, 4] + theta_cos = segs[0, 5] + theta = np.arctan2(theta_sin, theta_cos) + return np.array([cx, cy, w, h, theta]) + + # find the best straight line fitting all center points: y = kx + b + cxs = segs[:, 0] + cys = segs[:, 1] + + theta_coss = segs[:, 4] + theta_sins = segs[:, 5] + + bar_theta = np.arctan2(theta_sins.sum(), theta_coss.sum()) + k = np.tan(bar_theta) + b = np.mean(cys - k * cxs) + + proj_xs = (k * cys + cxs - k * b) / (k**2 + 1) + proj_ys = (k * k * cys + k * cxs + b) / (k**2 + 1) + proj_points = np.stack((proj_xs, proj_ys), -1) + + # find the max distance + max_dist = -1 + idx1 = -1 + idx2 = -1 + + for i in range(len(proj_points)): + point1 = proj_points[i, :] + for j in range(i + 1, len(proj_points)): + point2 = proj_points[j, :] + dist = np.sqrt(np.sum((point1 - point2)**2)) + if dist > max_dist: + idx1 = i + idx2 = j + max_dist = dist + assert idx1 >= 0 and idx2 >= 0 + # the bbox: bcx, bcy, bw, bh, average_theta + seg1 = segs[idx1, :] + seg2 = segs[idx2, :] + bcx, bcy = (seg1[:2] + seg2[:2]) / 2.0 + bh = np.mean(segs[:, 3]) + bw = max_dist + (seg1[2] + seg2[2]) / 2.0 + return bcx, bcy, bw, bh, bar_theta + + +def combine_segments_batch(segments_batch, group_indices_batch, + segment_counts_batch): + batch_size = 1 + combined_rboxes_batch = [] + combined_counts_batch = [] + for image_id in range(batch_size): + group_count = segment_counts_batch[image_id] + segments = segments_batch[image_id, :, :] + group_indices = group_indices_batch[image_id, :] + combined_rboxes = [] + for i in range(group_count): + segments_group = segments[np.where(group_indices == i)[0], :] + if segments_group.shape[0] > 0: + combined_rbox = combine_segs(segments_group) + combined_rboxes.append(combined_rbox) + combined_rboxes_batch.append(combined_rboxes) + combined_counts_batch.append(len(combined_rboxes)) + + max_count = np.max(combined_counts_batch) + for image_id in range(batch_size): + if not combined_counts_batch[image_id] == max_count: + combined_rboxes_pad = (max_count - combined_counts_batch[image_id] + ) * [RBOX_DIM * [0.0]] + combined_rboxes_batch[image_id] = np.vstack( + (combined_rboxes_batch[image_id], + np.array(combined_rboxes_pad))) + + return np.asarray(combined_rboxes_batch, + np.float32), np.asarray(combined_counts_batch, np.int32) + + +# combine_segments rewrite in python version +def combine_segments_python(segments, group_indices, segment_counts): + combined_rboxes, combined_counts = tf.py_func( + combine_segments_batch, [segments, group_indices, segment_counts], + [tf.float32, tf.int32]) + return combined_rboxes, combined_counts + + +# decode_segments_links rewrite in python version +def get_coord(offsets, map_size, offsets_defaults): + if offsets < offsets_defaults[1][0]: + l_idx = 0 + x = offsets % map_size[0][1] + y = offsets // map_size[0][1] + elif offsets < offsets_defaults[2][0]: + l_idx = 1 + x = (offsets - offsets_defaults[1][0]) % map_size[1][1] + y = (offsets - offsets_defaults[1][0]) // map_size[1][1] + elif offsets < offsets_defaults[3][0]: + l_idx = 2 + x = (offsets - offsets_defaults[2][0]) % map_size[2][1] + y = (offsets - offsets_defaults[2][0]) // map_size[2][1] + elif offsets < offsets_defaults[4][0]: + l_idx = 3 + x = (offsets - offsets_defaults[3][0]) % map_size[3][1] + y = (offsets - offsets_defaults[3][0]) // map_size[3][1] + elif offsets < offsets_defaults[5][0]: + l_idx = 4 + x = (offsets - offsets_defaults[4][0]) % map_size[4][1] + y = (offsets - offsets_defaults[4][0]) // map_size[4][1] + else: + l_idx = 5 + x = (offsets - offsets_defaults[5][0]) % map_size[5][1] + y = (offsets - offsets_defaults[5][0]) // map_size[5][1] + + return l_idx, x, y + + +def get_coord_link(offsets, map_size, offsets_defaults): + if offsets < offsets_defaults[1][1]: + offsets_node = offsets // N_LOCAL_LINKS + link_idx = offsets % N_LOCAL_LINKS + else: + offsets_node = (offsets - offsets_defaults[1][1]) // ( + N_LOCAL_LINKS + N_CROSS_LINKS) + offsets_defaults[1][0] + link_idx = (offsets - offsets_defaults[1][1]) % ( + N_LOCAL_LINKS + N_CROSS_LINKS) + l_idx, x, y = get_coord(offsets_node, map_size, offsets_defaults) + return l_idx, x, y, link_idx + + +def is_valid_coord(l_idx, x, y, map_size): + w = map_size[l_idx][1] + h = map_size[l_idx][0] + return x >= 0 and x < w and y >= 0 and y < h + + +def get_neighbours(l_idx, x, y, map_size, offsets_defaults): + if l_idx == 0: + coord = [(0, x - 1, y - 1), (0, x, y - 1), (0, x + 1, y - 1), + (0, x - 1, y), (0, x + 1, y), (0, x - 1, y + 1), + (0, x, y + 1), (0, x + 1, y + 1)] + else: + coord = [(l_idx, x - 1, y - 1), + (l_idx, x, y - 1), (l_idx, x + 1, y - 1), (l_idx, x - 1, y), + (l_idx, x + 1, y), (l_idx, x - 1, y + 1), (l_idx, x, y + 1), + (l_idx, x + 1, y + 1), (l_idx - 1, 2 * x, 2 * y), + (l_idx - 1, 2 * x + 1, 2 * y), (l_idx - 1, 2 * x, 2 * y + 1), + (l_idx - 1, 2 * x + 1, 2 * y + 1)] + neighbours_offsets = [] + link_idx = 0 + for nl_idx, nx, ny in coord: + if is_valid_coord(nl_idx, nx, ny, map_size): + neighbours_offset_node = offsets_defaults[nl_idx][ + 0] + map_size[nl_idx][1] * ny + nx + if l_idx == 0: + neighbours_offset_link = offsets_defaults[l_idx][1] + ( + map_size[l_idx][1] * y + x) * N_LOCAL_LINKS + link_idx + else: + off_tmp = (map_size[l_idx][1] * y + x) * ( + N_LOCAL_LINKS + N_CROSS_LINKS) + neighbours_offset_link = offsets_defaults[l_idx][ + 1] + off_tmp + link_idx + neighbours_offsets.append( + [neighbours_offset_node, neighbours_offset_link, link_idx]) + link_idx += 1 + # [node_offsets, link_offsets, link_idx(0-7/11)] + return neighbours_offsets + + +def decode_segments_links_python(image_size, all_nodes, all_links, all_reg, + anchor_sizes): + batch_size = 1 # FLAGS.test_batch_size + # offsets = 12285 #768 + all_nodes_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, N_SEG_CLASSES]) for o in all_nodes], + axis=1) + all_links_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, N_LNK_CLASSES]) for o in all_links], + axis=1) + all_reg_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, OFFSET_DIM]) for o in all_reg], axis=1) + segments, group_indices, segment_counts, group_indices_all = tf.py_func( + decode_batch, [ + all_nodes_flat, all_links_flat, all_reg_flat, image_size, + tf.constant(anchor_sizes) + ], [tf.float32, tf.int32, tf.int32, tf.int32]) + return segments, group_indices, segment_counts, group_indices_all + + +def decode_segments_links_train(image_size, all_nodes, all_links, all_reg, + anchor_sizes): + batch_size = FLAGS.train_batch_size + # offsets = 12285 #768 + all_nodes_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, N_SEG_CLASSES]) for o in all_nodes], + axis=1) + all_links_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, N_LNK_CLASSES]) for o in all_links], + axis=1) + all_reg_flat = tf.concat( + [tf.reshape(o, [batch_size, -1, OFFSET_DIM]) for o in all_reg], axis=1) + segments, group_indices, segment_counts, group_indices_all = tf.py_func( + decode_batch, [ + all_nodes_flat, all_links_flat, all_reg_flat, image_size, + tf.constant(anchor_sizes) + ], [tf.float32, tf.int32, tf.int32, tf.int32]) + return segments, group_indices, segment_counts, group_indices_all + + +def decode_batch(all_nodes, all_links, all_reg, image_size, anchor_sizes): + batch_size = all_nodes.shape[0] + batch_segments = [] + batch_group_indices = [] + batch_segments_counts = [] + batch_group_indices_all = [] + for image_id in range(batch_size): + image_node_scores = all_nodes[image_id, :, :] + image_link_scores = all_links[image_id, :, :] + image_reg = all_reg[image_id, :, :] + image_segments, image_group_indices, image_segments_counts, image_group_indices_all = decode_image( + image_node_scores, image_link_scores, image_reg, image_size, + anchor_sizes) + batch_segments.append(image_segments) + batch_group_indices.append(image_group_indices) + batch_segments_counts.append(image_segments_counts) + batch_group_indices_all.append(image_group_indices_all) + max_count = np.max(batch_segments_counts) + for image_id in range(batch_size): + if not batch_segments_counts[image_id] == max_count: + batch_segments_pad = (max_count - batch_segments_counts[image_id] + ) * [OFFSET_DIM * [0.0]] + batch_segments[image_id] = np.vstack( + (batch_segments[image_id], np.array(batch_segments_pad))) + batch_group_indices[image_id] = np.hstack( + (batch_group_indices[image_id], + np.array( + (max_count - batch_segments_counts[image_id]) * [-1]))) + return np.asarray(batch_segments, np.float32), np.asarray( + batch_group_indices, + np.int32), np.asarray(batch_segments_counts, + np.int32), np.asarray(batch_group_indices_all, + np.int32) + + +def decode_image(image_node_scores, image_link_scores, image_reg, image_size, + anchor_sizes): + map_size = [] + offsets_defaults = [] + offsets_default_node = 0 + offsets_default_link = 0 + for i in range(N_DET_LAYERS): + offsets_defaults.append([offsets_default_node, offsets_default_link]) + map_size.append(image_size // (2**(2 + i))) + offsets_default_node += map_size[i][0] * map_size[i][1] + if i == 0: + offsets_default_link += map_size[i][0] * map_size[i][ + 1] * N_LOCAL_LINKS + else: + offsets_default_link += map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + N_CROSS_LINKS) + + image_group_indices_all = decode_image_by_join(image_node_scores, + image_link_scores, + FLAGS.node_threshold, + FLAGS.link_threshold, + map_size, offsets_defaults) + image_group_indices_all -= 1 + image_group_indices = image_group_indices_all[np.where( + image_group_indices_all >= 0)[0]] + image_segments_counts = len(image_group_indices) + # convert image_reg to segments with scores(OFFSET_DIM+1) + image_segments = np.zeros((image_segments_counts, OFFSET_DIM), + dtype=np.float32) + for i, offsets in enumerate(np.where(image_group_indices_all >= 0)[0]): + encoded_cx = image_reg[offsets, 0] + encoded_cy = image_reg[offsets, 1] + encoded_width = image_reg[offsets, 2] + encoded_height = image_reg[offsets, 3] + encoded_theta_cos = image_reg[offsets, 4] + encoded_theta_sin = image_reg[offsets, 5] + + l_idx, x, y = get_coord(offsets, map_size, offsets_defaults) + rs = anchor_sizes[l_idx] + eps = 1e-6 + image_segments[i, 0] = encoded_cx * rs + (2**(2 + l_idx)) * (x + 0.5) + image_segments[i, 1] = encoded_cy * rs + (2**(2 + l_idx)) * (y + 0.5) + image_segments[i, 2] = np.exp(encoded_width) * rs - eps + image_segments[i, 3] = np.exp(encoded_height) * rs - eps + image_segments[i, 4] = encoded_theta_cos + image_segments[i, 5] = encoded_theta_sin + + return image_segments, image_group_indices, image_segments_counts, image_group_indices_all + + +def decode_image_by_join(node_scores, link_scores, node_threshold, + link_threshold, map_size, offsets_defaults): + node_mask = node_scores[:, POS_LABEL] >= node_threshold + link_mask = link_scores[:, POS_LABEL] >= link_threshold + group_mask = np.zeros_like(node_mask, np.int32) - 1 + offsets_pos = np.where(node_mask == 1)[0] + + def find_parent(point): + return group_mask[point] + + def set_parent(point, parent): + group_mask[point] = parent + + def is_root(point): + return find_parent(point) == -1 + + def find_root(point): + root = point + update_parent = False + while not is_root(root): + root = find_parent(root) + update_parent = True + + # for acceleration of find_root + if update_parent: + set_parent(point, root) + + return root + + def join(p1, p2): + root1 = find_root(p1) + root2 = find_root(p2) + + if root1 != root2: + set_parent(root1, root2) + + def get_all(): + root_map = {} + + def get_index(root): + if root not in root_map: + root_map[root] = len(root_map) + 1 + return root_map[root] + + mask = np.zeros_like(node_mask, dtype=np.int32) + for i, point in enumerate(offsets_pos): + point_root = find_root(point) + bbox_idx = get_index(point_root) + mask[point] = bbox_idx + return mask + + # join by link + pos_link = 0 + for i, offsets in enumerate(offsets_pos): + l_idx, x, y = get_coord(offsets, map_size, offsets_defaults) + neighbours = get_neighbours(l_idx, x, y, map_size, offsets_defaults) + for n_idx, noffsets in enumerate(neighbours): + link_value = link_mask[noffsets[1]] + node_cls = node_mask[noffsets[0]] + if link_value and node_cls: + pos_link += 1 + join(offsets, noffsets[0]) + # print(pos_link) + mask = get_all() + return mask + + +def get_link_mask(node_mask, offsets_defaults, link_max): + link_mask = np.zeros_like(link_max) + link_mask[0:offsets_defaults[1][1]] = np.tile( + node_mask[0:offsets_defaults[1][0]], + (N_LOCAL_LINKS, 1)).transpose().reshape(offsets_defaults[1][1]) + link_mask[offsets_defaults[1][1]:offsets_defaults[2][1]] = np.tile( + node_mask[offsets_defaults[1][0]:offsets_defaults[2][0]], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (offsets_defaults[2][1] - offsets_defaults[1][1])) + link_mask[offsets_defaults[2][1]:offsets_defaults[3][1]] = np.tile( + node_mask[offsets_defaults[2][0]:offsets_defaults[3][0]], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (offsets_defaults[3][1] - offsets_defaults[2][1])) + link_mask[offsets_defaults[3][1]:offsets_defaults[4][1]] = np.tile( + node_mask[offsets_defaults[3][0]:offsets_defaults[4][0]], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (offsets_defaults[4][1] - offsets_defaults[3][1])) + link_mask[offsets_defaults[4][1]:offsets_defaults[5][1]] = np.tile( + node_mask[offsets_defaults[4][0]:offsets_defaults[5][0]], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (offsets_defaults[5][1] - offsets_defaults[4][1])) + link_mask[offsets_defaults[5][1]:] = np.tile( + node_mask[offsets_defaults[5][0]:], + (N_LOCAL_LINKS + N_CROSS_LINKS, 1)).transpose().reshape( + (len(link_mask) - offsets_defaults[5][1])) + + return link_mask + + +def get_link8(link_scores_raw, map_size): + # link[i-1] -local- start -16- end -cross- link[i] + link8_mask = np.zeros((link_scores_raw.shape[0])) + for i in range(N_DET_LAYERS): + if i == 0: + offsets_start = map_size[i][0] * map_size[i][1] * N_LOCAL_LINKS + offsets_end = map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + 16) + offsets_link = map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + 16) + link8_mask[:offsets_start] = 1 + else: + offsets_start = offsets_link + map_size[i][0] * map_size[i][ + 1] * N_LOCAL_LINKS + offsets_end = offsets_link + map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + 16) + offsets_link_pre = offsets_link + offsets_link += map_size[i][0] * map_size[i][1] * ( + N_LOCAL_LINKS + 16 + N_CROSS_LINKS) + link8_mask[offsets_link_pre:offsets_start] = 1 + link8_mask[offsets_end:offsets_link] = 1 + return link_scores_raw[np.where(link8_mask > 0)[0], :] + + +def decode_image_by_mutex(node_scores, link_scores, node_threshold, + link_threshold, map_size, offsets_defaults): + node_mask = node_scores[:, POS_LABEL] >= node_threshold + link_pos = link_scores[:, POS_LABEL] + link_mut = link_scores[:, MUT_LABEL] + link_max = np.max(np.vstack((link_pos, link_mut)), axis=0) + + offsets_pos_list = np.where(node_mask == 1)[0].tolist() + + link_mask_th = link_max >= link_threshold + link_mask = get_link_mask(node_mask, offsets_defaults, link_max) + offsets_link_max = np.argsort(-(link_max * link_mask * link_mask_th)) + offsets_link_max = offsets_link_max[:len(offsets_pos_list) * 8] + + group_mask = np.zeros_like(node_mask, dtype=np.int32) - 1 + mutex_mask = len(node_mask) * [[]] + + def find_parent(point): + return group_mask[point] + + def set_parent(point, parent): + group_mask[point] = parent + + def set_mutex_constraint(point, mutex_point_list): + mutex_mask[point] = mutex_point_list + + def find_mutex_constraint(point): + mutex_point_list = mutex_mask[point] + # update mutex_point_list + mutex_point_list_new = [] + if not mutex_point_list == []: + for mutex_point in mutex_point_list: + if not is_root(mutex_point): + mutex_point = find_root(mutex_point) + if mutex_point not in mutex_point_list_new: + mutex_point_list_new.append(mutex_point) + set_mutex_constraint(point, mutex_point_list_new) + return mutex_point_list_new + + def combine_mutex_constraint(point, parent): + mutex_point_list = find_mutex_constraint(point) + mutex_parent_list = find_mutex_constraint(parent) + for mutex_point in mutex_point_list: + if not is_root(mutex_point): + mutex_point = find_root(mutex_point) + if mutex_point not in mutex_parent_list: + mutex_parent_list.append(mutex_point) + set_mutex_constraint(parent, mutex_parent_list) + + def add_mutex_constraint(p1, p2): + mutex_point_list1 = find_mutex_constraint(p1) + mutex_point_list2 = find_mutex_constraint(p2) + + if p1 not in mutex_point_list2: + mutex_point_list2.append(p1) + if p2 not in mutex_point_list1: + mutex_point_list1.append(p2) + set_mutex_constraint(p1, mutex_point_list1) + set_mutex_constraint(p2, mutex_point_list2) + + def is_root(point): + return find_parent(point) == -1 + + def find_root(point): + root = point + update_parent = False + while not is_root(root): + root = find_parent(root) + update_parent = True + + # for acceleration of find_root + if update_parent: + set_parent(point, root) + + return root + + def join(p1, p2): + root1 = find_root(p1) + root2 = find_root(p2) + + if root1 != root2 and (root1 not in find_mutex_constraint(root2)): + set_parent(root1, root2) + combine_mutex_constraint(root1, root2) + + def disjoin(p1, p2): + root1 = find_root(p1) + root2 = find_root(p2) + + if root1 != root2: + add_mutex_constraint(root1, root2) + + def get_all(): + root_map = {} + + def get_index(root): + if root not in root_map: + root_map[root] = len(root_map) + 1 + return root_map[root] + + mask = np.zeros_like(node_mask, dtype=np.int32) + for _, point in enumerate(offsets_pos_list): + point_root = find_root(point) + bbox_idx = get_index(point_root) + mask[point] = bbox_idx + return mask + + # join by link + pos_link = 0 + mut_link = 0 + for _, offsets_link in enumerate(offsets_link_max): + l_idx, x, y, link_idx = get_coord_link(offsets_link, map_size, + offsets_defaults) + offsets = offsets_defaults[l_idx][0] + map_size[l_idx][1] * y + x + if offsets in offsets_pos_list: + neighbours = get_neighbours(l_idx, x, y, map_size, + offsets_defaults) + if not len(np.where(np.array(neighbours)[:, + 2] == link_idx)[0]) == 0: + noffsets = neighbours[np.where( + np.array(neighbours)[:, 2] == link_idx)[0][0]] + link_pos_value = link_pos[noffsets[1]] + link_mut_value = link_mut[noffsets[1]] + node_cls = node_mask[noffsets[0]] + if node_cls and (link_pos_value > link_mut_value): + pos_link += 1 + join(offsets, noffsets[0]) + elif node_cls and (link_pos_value < link_mut_value): + mut_link += 1 + disjoin(offsets, noffsets[0]) + + mask = get_all() + return mask diff --git a/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py new file mode 100644 index 00000000..85f9faca --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/resnet18_v1.py @@ -0,0 +1,450 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains definitions for the original form of Residual Networks. +The 'v1' residual networks (ResNets) implemented in this module were proposed +by: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +Other variants were introduced in: +[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 +The networks defined in this module utilize the bottleneck building block of +[1] with projection shortcuts only for increasing depths. They employ batch +normalization *after* every weight layer. This is the architecture used by +MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and +ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' +architecture and the alternative 'v2' architecture of [2] which uses batch +normalization *before* every weight layer in the so-called full pre-activation +units. +Typical use: + from tensorflow.contrib.slim.nets import resnet_v1 +ResNet-101 for image classification into 1000 classes: + # inputs has shape [batch, 224, 224, 3] + with slim.arg_scope(resnet_v1.resnet_arg_scope()): + net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) +ResNet-101 for semantic segmentation into 21 classes: + # inputs has shape [batch, 513, 513, 3] + with slim.arg_scope(resnet_v1.resnet_arg_scope()): + net, end_points = resnet_v1.resnet_v1_101(inputs, + 21, + is_training=False, + global_pool=False, + output_stride=16) +""" +import tensorflow as tf + +from . import resnet_utils + +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + +resnet_arg_scope = resnet_utils.resnet_arg_scope + + +@slim.add_arg_scope +def basicblock(inputs, + depth, + depth_bottleneck, + stride, + rate=1, + outputs_collections=None, + scope=None): + """Bottleneck residual unit variant with BN after convolutions. + This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for + its definition. Note that we use here the bottleneck variant which has an + extra bottleneck layer. + When putting together two consecutive ResNet blocks that use this unit, one + should use stride = 2 in the last unit of the first block. + Args: + inputs: A tensor of size [batch, height, width, channels]. + depth: The depth of the ResNet unit output. + depth_bottleneck: The depth of the bottleneck layers. + stride: The ResNet unit's stride. Determines the amount of downsampling of + the units output compared to its input. + rate: An integer, rate for atrous convolution. + outputs_collections: Collection to add the ResNet unit output. + scope: Optional variable_scope. + Returns: + The ResNet unit's output. + """ + with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: + depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) + if depth == depth_in: + shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') + else: + shortcut = slim.conv2d( + inputs, + depth, [1, 1], + stride=stride, + activation_fn=None, + scope='shortcut') + + residual = resnet_utils.conv2d_same( + inputs, depth, 3, stride, rate=rate, scope='conv1') + residual = resnet_utils.conv2d_same( + residual, depth, 3, 1, rate=rate, scope='conv2') + + output = tf.nn.relu(residual + shortcut) + + return slim.utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, output) + + +@slim.add_arg_scope +def bottleneck(inputs, + depth, + depth_bottleneck, + stride, + rate=1, + outputs_collections=None, + scope=None): + """Bottleneck residual unit variant with BN after convolutions. + This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for + its definition. Note that we use here the bottleneck variant which has an + extra bottleneck layer. + When putting together two consecutive ResNet blocks that use this unit, one + should use stride = 2 in the last unit of the first block. + Args: + inputs: A tensor of size [batch, height, width, channels]. + depth: The depth of the ResNet unit output. + depth_bottleneck: The depth of the bottleneck layers. + stride: The ResNet unit's stride. Determines the amount of downsampling of + the units output compared to its input. + rate: An integer, rate for atrous convolution. + outputs_collections: Collection to add the ResNet unit output. + scope: Optional variable_scope. + Returns: + The ResNet unit's output. + """ + with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: + depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) + if depth == depth_in: + shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') + else: + shortcut = slim.conv2d( + inputs, + depth, [1, 1], + stride=stride, + activation_fn=None, + scope='shortcut') + + residual = slim.conv2d( + inputs, depth_bottleneck, [1, 1], stride=1, scope='conv1') + residual = resnet_utils.conv2d_same( + residual, depth_bottleneck, 3, stride, rate=rate, scope='conv2') + residual = slim.conv2d( + residual, + depth, [1, 1], + stride=1, + activation_fn=None, + scope='conv3') + + output = tf.nn.relu(shortcut + residual) + + return slim.utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, output) + + +def resnet_v1(inputs, + blocks, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + include_root_block=True, + spatial_squeeze=True, + reuse=None, + scope=None): + """Generator for v1 ResNet models. + This function generates a family of ResNet v1 models. See the resnet_v1_*() + methods for specific model instantiations, obtained by selecting different + block instantiations that produce ResNets of various depths. + Training for image classification on Imagenet is usually done with [224, 224] + inputs, resulting in [7, 7] feature maps at the output of the last ResNet + block for the ResNets defined in [1] that have nominal stride equal to 32. + However, for dense prediction tasks we advise that one uses inputs with + spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In + this case the feature maps at the ResNet output will have spatial shape + [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] + and corners exactly aligned with the input image corners, which greatly + facilitates alignment of the features to the image. Using as input [225, 225] + images results in [8, 8] feature maps at the output of the last ResNet block. + For dense prediction tasks, the ResNet needs to run in fully-convolutional + (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all + have nominal stride equal to 32 and a good choice in FCN mode is to use + output_stride=16 in order to increase the density of the computed features at + small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. + Args: + inputs: A tensor of size [batch, height_in, width_in, channels]. + blocks: A list of length equal to the number of ResNet blocks. Each element + is a resnet_utils.Block object describing the units in the block. + num_classes: Number of predicted classes for classification tasks. If None + we return the features before the logit layer. + is_training: whether is training or not. + global_pool: If True, we perform global average pooling before computing the + logits. Set to True for image classification, False for dense prediction. + output_stride: If None, then the output will be computed at the nominal + network stride. If output_stride is not None, it specifies the requested + ratio of input to output spatial resolution. + include_root_block: If True, include the initial convolution followed by + max-pooling, if False excludes it. + spatial_squeeze: if True, logits is of shape [B, C], if false logits is + of shape [B, 1, 1, C], where B is batch_size and C is number of classes. + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + scope: Optional variable_scope. + Returns: + net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. + If global_pool is False, then height_out and width_out are reduced by a + factor of output_stride compared to the respective height_in and width_in, + else both height_out and width_out equal one. If num_classes is None, then + net is the output of the last ResNet block, potentially after global + average pooling. If num_classes is not None, net contains the pre-softmax + activations. + end_points: A dictionary from components of the network to the corresponding + activation. + Raises: + ValueError: If the target output_stride is not valid. + """ + with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: + end_points_collection = sc.name + '_end_points' + with slim.arg_scope( + [slim.conv2d, bottleneck, resnet_utils.stack_blocks_dense], + outputs_collections=end_points_collection): + with slim.arg_scope([slim.batch_norm], is_training=is_training): + net = inputs + if include_root_block: + if output_stride is not None: + if output_stride % 4 != 0: + raise ValueError( + 'The output_stride needs to be a multiple of 4.' + ) + output_stride /= 4 + net = resnet_utils.conv2d_same( + net, 64, 7, stride=2, scope='conv1') + net = tf.pad(net, [[0, 0], [1, 1], [1, 1], [0, 0]]) + net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') + + net = slim.utils.collect_named_outputs( + end_points_collection, 'pool2', net) + + net = resnet_utils.stack_blocks_dense(net, blocks, + output_stride) + + end_points = slim.utils.convert_collection_to_dict( + end_points_collection) + + end_points['pool1'] = end_points['resnet_v1_18/block2/unit_2'] + end_points['pool2'] = end_points['resnet_v1_18/block3/unit_2'] + end_points['pool3'] = end_points['resnet_v1_18/block4/unit_2'] + end_points['pool4'] = end_points['resnet_v1_18/block5/unit_2'] + end_points['pool5'] = end_points['resnet_v1_18/block6/unit_2'] + end_points['pool6'] = net + + return net, end_points + + +resnet_v1.default_image_size = 224 + + +def resnet_v1_18(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_18'): + """ResNet-18 model of [1]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', basicblock, + [(64, 64, 1)] + [(64, 64, 1)]), + resnet_utils.Block('block2', basicblock, + [(128, 128, 1)] + [(128, 128, 1)]), + resnet_utils.Block('block3', basicblock, + [(256, 256, 2)] + [(256, 256, 1)]), + resnet_utils.Block('block4', basicblock, + [(512, 512, 2)] + [(512, 512, 1)]), + resnet_utils.Block('block5', basicblock, + [(256, 256, 2)] + [(256, 256, 1)]), + resnet_utils.Block('block6', basicblock, + [(256, 256, 2)] + [(256, 256, 1)]), + resnet_utils.Block('block7', basicblock, + [(256, 256, 2)] + [(256, 256, 1)]), + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_18.default_image_size = resnet_v1.default_image_size + + +def resnet_v1_50(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_50'): + """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', bottleneck, + [(256, 64, 1)] * 2 + [(256, 64, 2)]), + resnet_utils.Block('block2', bottleneck, + [(512, 128, 1)] * 3 + [(512, 128, 2)]), + resnet_utils.Block('block3', bottleneck, + [(1024, 256, 1)] * 5 + [(1024, 256, 2)]), + resnet_utils.Block('block4', bottleneck, + [(2048, 512, 1)] * 3 + [(2048, 512, 2)]), + resnet_utils.Block('block5', bottleneck, + [(1024, 256, 1)] * 2 + [(1024, 256, 2)]), + resnet_utils.Block('block6', bottleneck, [(1024, 256, 1)] * 2), + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_50.default_image_size = resnet_v1.default_image_size + + +def resnet_v1_101(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_101'): + """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', bottleneck, + [(256, 64, 1)] * 2 + [(256, 64, 2)]), + resnet_utils.Block('block2', bottleneck, + [(512, 128, 1)] * 3 + [(512, 128, 2)]), + resnet_utils.Block('block3', bottleneck, + [(1024, 256, 1)] * 22 + [(1024, 256, 2)]), + resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3) + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_101.default_image_size = resnet_v1.default_image_size + + +def resnet_v1_152(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_152'): + """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', bottleneck, + [(256, 64, 1)] * 2 + [(256, 64, 2)]), + resnet_utils.Block('block2', bottleneck, + [(512, 128, 1)] * 7 + [(512, 128, 2)]), + resnet_utils.Block('block3', bottleneck, + [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), + resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3) + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_152.default_image_size = resnet_v1.default_image_size + + +def resnet_v1_200(inputs, + num_classes=None, + is_training=True, + global_pool=True, + output_stride=None, + spatial_squeeze=True, + reuse=None, + scope='resnet_v1_200'): + """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" + blocks = [ + resnet_utils.Block('block1', bottleneck, + [(256, 64, 1)] * 2 + [(256, 64, 2)]), + resnet_utils.Block('block2', bottleneck, + [(512, 128, 1)] * 23 + [(512, 128, 2)]), + resnet_utils.Block('block3', bottleneck, + [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), + resnet_utils.Block('block4', bottleneck, [(2048, 512, 1)] * 3) + ] + return resnet_v1( + inputs, + blocks, + num_classes, + is_training, + global_pool=global_pool, + output_stride=output_stride, + include_root_block=True, + spatial_squeeze=spatial_squeeze, + reuse=reuse, + scope=scope) + + +resnet_v1_200.default_image_size = resnet_v1.default_image_size + +if __name__ == '__main__': + input = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input') + with slim.arg_scope(resnet_arg_scope()) as sc: + logits = resnet_v1_50(input) diff --git a/modelscope/pipelines/cv/ocr_utils/resnet_utils.py b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py new file mode 100644 index 00000000..2ccbd038 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/resnet_utils.py @@ -0,0 +1,249 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains building blocks for various versions of Residual Networks. +Residual networks (ResNets) were proposed in: + Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 +More variants were introduced in: + Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 +We can obtain different ResNet variants by changing the network depth, width, +and form of residual unit. This module implements the infrastructure for +building them. Concrete ResNet units and full ResNet networks are implemented in +the accompanying resnet_v1.py and resnet_v2.py modules. +Compared to https://github.com/KaimingHe/deep-residual-networks, in the current +implementation we subsample the output activations in the last residual unit of +each block, instead of subsampling the input activations in the first residual +unit of each block. The two implementations give identical results but our +implementation is more memory efficient. +""" + +import collections + +import tensorflow as tf + +if tf.__version__ >= '2.0': + import tf_slim as slim +else: + from tensorflow.contrib import slim + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): + """A named tuple describing a ResNet block. + Its parts are: + scope: The scope of the `Block`. + unit_fn: The ResNet unit function which takes as input a `Tensor` and + returns another `Tensor` with the output of the ResNet unit. + args: A list of length equal to the number of units in the `Block`. The list + contains one (depth, depth_bottleneck, stride) tuple for each unit in the + block to serve as argument to unit_fn. + """ + + +def subsample(inputs, factor, scope=None): + """Subsamples the input along the spatial dimensions. + Args: + inputs: A `Tensor` of size [batch, height_in, width_in, channels]. + factor: The subsampling factor. + scope: Optional variable_scope. + Returns: + output: A `Tensor` of size [batch, height_out, width_out, channels] with the + input, either intact (if factor == 1) or subsampled (if factor > 1). + """ + if factor == 1: + return inputs + else: + return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) + + +def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): + """Strided 2-D convolution with 'SAME' padding. + When stride > 1, then we do explicit zero-padding, followed by conv2d with + 'VALID' padding. + Note that + net = conv2d_same(inputs, num_outputs, 3, stride=stride) + is equivalent to + net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') + net = subsample(net, factor=stride) + whereas + net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') + is different when the input's height or width is even, which is why we add the + current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). + Args: + inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. + num_outputs: An integer, the number of output filters. + kernel_size: An int with the kernel_size of the filters. + stride: An integer, the output stride. + rate: An integer, rate for atrous convolution. + scope: Scope. + Returns: + output: A 4-D tensor of size [batch, height_out, width_out, channels] with + the convolution output. + """ + if stride == 1: + return slim.conv2d( + inputs, + num_outputs, + kernel_size, + stride=1, + rate=rate, + padding='SAME', + scope=scope) + else: + kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) + pad_total = kernel_size_effective - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + inputs = tf.pad( + inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) + return slim.conv2d( + inputs, + num_outputs, + kernel_size, + stride=stride, + rate=rate, + padding='VALID', + scope=scope) + + +@slim.add_arg_scope +def stack_blocks_dense(net, + blocks, + output_stride=None, + outputs_collections=None): + """Stacks ResNet `Blocks` and controls output feature density. + First, this function creates scopes for the ResNet in the form of + 'block_name/unit_1', 'block_name/unit_2', etc. + Second, this function allows the user to explicitly control the ResNet + output_stride, which is the ratio of the input to output spatial resolution. + This is useful for dense prediction tasks such as semantic segmentation or + object detection. + Most ResNets consist of 4 ResNet blocks and subsample the activations by a + factor of 2 when transitioning between consecutive ResNet blocks. This results + to a nominal ResNet output_stride equal to 8. If we set the output_stride to + half the nominal network stride (e.g., output_stride=4), then we compute + responses twice. + Control of the output feature density is implemented by atrous convolution. + Args: + net: A `Tensor` of size [batch, height, width, channels]. + blocks: A list of length equal to the number of ResNet `Blocks`. Each + element is a ResNet `Block` object describing the units in the `Block`. + output_stride: If `None`, then the output will be computed at the nominal + network stride. If output_stride is not `None`, it specifies the requested + ratio of input to output spatial resolution, which needs to be equal to + the product of unit strides from the start up to some level of the ResNet. + For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, + then valid values for the output_stride are 1, 2, 6, 24 or None (which + is equivalent to output_stride=24). + outputs_collections: Collection to add the ResNet block outputs. + Returns: + net: Output tensor with stride equal to the specified output_stride. + Raises: + ValueError: If the target output_stride is not valid. + """ + # The current_stride variable keeps track of the effective stride of the + # activations. This allows us to invoke atrous convolution whenever applying + # the next residual unit would result in the activations having stride larger + # than the target output_stride. + current_stride = 1 + + # The atrous convolution rate parameter. + rate = 1 + + for block in blocks: + with tf.variable_scope(block.scope, 'block', [net]): + for i, unit in enumerate(block.args): + if output_stride is not None and current_stride > output_stride: + raise ValueError( + 'The target output_stride cannot be reached.') + + with tf.variable_scope( + 'unit_%d' % (i + 1), values=[net]) as sc: + unit_depth, unit_depth_bottleneck, unit_stride = unit + # If we have reached the target output_stride, then we need to employ + # atrous convolution with stride=1 and multiply the atrous rate by the + # current unit's stride for use in subsequent layers. + if output_stride is not None and current_stride == output_stride: + net = block.unit_fn( + net, + depth=unit_depth, + depth_bottleneck=unit_depth_bottleneck, + stride=1, + rate=rate) + rate *= unit_stride + + else: + net = block.unit_fn( + net, + depth=unit_depth, + depth_bottleneck=unit_depth_bottleneck, + stride=unit_stride, + rate=1) + current_stride *= unit_stride + net = slim.utils.collect_named_outputs( + outputs_collections, sc.name, net) + + if output_stride is not None and current_stride != output_stride: + raise ValueError('The target output_stride cannot be reached.') + + return net + + +def resnet_arg_scope(weight_decay=0.0001, + batch_norm_decay=0.997, + batch_norm_epsilon=1e-5, + batch_norm_scale=True): + """Defines the default ResNet arg scope. + TODO(gpapan): The batch-normalization related default values above are + appropriate for use in conjunction with the reference ResNet models + released at https://github.com/KaimingHe/deep-residual-networks. When + training ResNets from scratch, they might need to be tuned. + Args: + weight_decay: The weight decay to use for regularizing the model. + batch_norm_decay: The moving average decay when estimating layer activation + statistics in batch normalization. + batch_norm_epsilon: Small constant to prevent division by zero when + normalizing activations by their variance in batch normalization. + batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the + activations in the batch normalization layer. + Returns: + An `arg_scope` to use for the resnet models. + """ + batch_norm_params = { + 'decay': batch_norm_decay, + 'epsilon': batch_norm_epsilon, + 'scale': batch_norm_scale, + 'updates_collections': tf.GraphKeys.UPDATE_OPS, + } + + with slim.arg_scope( + [slim.conv2d], + weights_regularizer=slim.l2_regularizer(weight_decay), + weights_initializer=slim.variance_scaling_initializer(), + activation_fn=tf.nn.relu, + normalizer_fn=slim.batch_norm, + normalizer_params=batch_norm_params): + with slim.arg_scope([slim.batch_norm], **batch_norm_params): + # The following implies padding='SAME' for pool1, which makes feature + # alignment easier for dense prediction tasks. This is also used in + # https://github.com/facebook/fb.resnet.torch. However the accompanying + # code of 'Deep Residual Learning for Image Recognition' uses + # padding='VALID' for pool1. You can switch to that choice by setting + # slim.arg_scope([slim.max_pool2d], padding='VALID'). + with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: + return arg_sc diff --git a/modelscope/pipelines/cv/ocr_utils/utils.py b/modelscope/pipelines/cv/ocr_utils/utils.py new file mode 100644 index 00000000..1d0fb297 --- /dev/null +++ b/modelscope/pipelines/cv/ocr_utils/utils.py @@ -0,0 +1,109 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import cv2 +import numpy as np + + +def rboxes_to_polygons(rboxes): + """ + Convert rboxes to polygons + ARGS + `rboxes`: [n, 5] + RETURN + `polygons`: [n, 8] + """ + + theta = rboxes[:, 4:5] + cxcy = rboxes[:, :2] + half_w = rboxes[:, 2:3] / 2. + half_h = rboxes[:, 3:4] / 2. + v1 = np.hstack([np.cos(theta) * half_w, np.sin(theta) * half_w]) + v2 = np.hstack([-np.sin(theta) * half_h, np.cos(theta) * half_h]) + p1 = cxcy - v1 - v2 + p2 = cxcy + v1 - v2 + p3 = cxcy + v1 + v2 + p4 = cxcy - v1 + v2 + polygons = np.hstack([p1, p2, p3, p4]) + return polygons + + +def cal_width(box): + pd1 = point_dist(box[0], box[1], box[2], box[3]) + pd2 = point_dist(box[4], box[5], box[6], box[7]) + return (pd1 + pd2) / 2 + + +def point_dist(x1, y1, x2, y2): + return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)) + + +def draw_polygons(img, polygons): + for p in polygons.tolist(): + p = [int(o) for o in p] + cv2.line(img, (p[0], p[1]), (p[2], p[3]), (0, 255, 0), 1) + cv2.line(img, (p[2], p[3]), (p[4], p[5]), (0, 255, 0), 1) + cv2.line(img, (p[4], p[5]), (p[6], p[7]), (0, 255, 0), 1) + cv2.line(img, (p[6], p[7]), (p[0], p[1]), (0, 255, 0), 1) + return img + + +def nms_python(boxes): + boxes = sorted(boxes, key=lambda x: -x[8]) + nms_flag = [True] * len(boxes) + for i, a in enumerate(boxes): + if not nms_flag[i]: + continue + else: + for j, b in enumerate(boxes): + if not j > i: + continue + if not nms_flag[j]: + continue + score_a = a[8] + score_b = b[8] + rbox_a = polygon2rbox(a[:8]) + rbox_b = polygon2rbox(b[:8]) + if point_in_rbox(rbox_a[:2], rbox_b) or point_in_rbox( + rbox_b[:2], rbox_a): + if score_a > score_b: + nms_flag[j] = False + boxes_nms = [] + for i, box in enumerate(boxes): + if nms_flag[i]: + boxes_nms.append(box) + return boxes_nms + + +def point_in_rbox(c, rbox): + cx0, cy0 = c[0], c[1] + cx1, cy1 = rbox[0], rbox[1] + w, h = rbox[2], rbox[3] + theta = rbox[4] + dist_x = np.abs((cx1 - cx0) * np.cos(theta) + (cy1 - cy0) * np.sin(theta)) + dist_y = np.abs(-(cx1 - cx0) * np.sin(theta) + (cy1 - cy0) * np.cos(theta)) + return ((dist_x < w / 2.0) and (dist_y < h / 2.0)) + + +def polygon2rbox(polygon): + x1, x2, x3, x4 = polygon[0], polygon[2], polygon[4], polygon[6] + y1, y2, y3, y4 = polygon[1], polygon[3], polygon[5], polygon[7] + c_x = (x1 + x2 + x3 + x4) / 4 + c_y = (y1 + y2 + y3 + y4) / 4 + w1 = point_dist(x1, y1, x2, y2) + w2 = point_dist(x3, y3, x4, y4) + h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2) + h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4) + h = h1 + h2 + w = (w1 + w2) / 2 + theta1 = np.arctan2(y2 - y1, x2 - x1) + theta2 = np.arctan2(y3 - y4, x3 - x4) + theta = (theta1 + theta2) / 2.0 + return [c_x, c_y, w, h, theta] + + +def point_line_dist(px, py, x1, y1, x2, y2): + eps = 1e-6 + dx = x2 - x1 + dy = y2 - y1 + div = np.sqrt(dx * dx + dy * dy) + eps + dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div + return dist diff --git a/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py b/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py new file mode 100644 index 00000000..0164a998 --- /dev/null +++ b/modelscope/pipelines/cv/product_retrieval_embedding_pipeline.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.product_retrieval_embedding, + module_name=Pipelines.product_retrieval_embedding) +class ProductRetrievalEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """use `model` to create a pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + """ + preprocess the input image to cv2-bgr style + """ + img = LoadImage.convert_to_ndarray(input) # array with rgb + img = np.ascontiguousarray(img[:, :, ::-1]) # array with bgr + result = {'img': img} # only for detection + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/product_segmentation_pipeline.py b/modelscope/pipelines/cv/product_segmentation_pipeline.py new file mode 100644 index 00000000..3b1b2381 --- /dev/null +++ b/modelscope/pipelines/cv/product_segmentation_pipeline.py @@ -0,0 +1,44 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import Any, Dict + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.product_segmentation import seg_infer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.product_segmentation, module_name=Pipelines.product_segmentation) +class F3NetForProductSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create product segmentation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input['input_path']) + img = img.astype(np.float32) + return img + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + mask = seg_infer.inference(self.model, self.device, input) + return {OutputKeys.MASKS: mask} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/realtime_object_detection_pipeline.py b/modelscope/pipelines/cv/realtime_object_detection_pipeline.py new file mode 100644 index 00000000..9f558f88 --- /dev/null +++ b/modelscope/pipelines/cv/realtime_object_detection_pipeline.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict, List, Union + +import cv2 +import json +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.realtime_object_detection import RealtimeDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Model, Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_object_detection, + module_name=Pipelines.realtime_object_detection) +class RealtimeObjectDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + super().__init__(model=model, **kwargs) + self.model = RealtimeDetector(model) + + def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: + output = self.model.preprocess(input) + return {'pre_output': output} + + def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]: + pre_output = input['pre_output'] + forward_output = self.model(pre_output) + return {'forward_output': forward_output} + + def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], + **kwargs) -> str: + forward_output = input['forward_output'] + bboxes, scores, labels = forward_output + return { + OutputKeys.BOXES: bboxes, + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + } diff --git a/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py b/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py new file mode 100644 index 00000000..073fad66 --- /dev/null +++ b/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict, List, Union + +import cv2 +import json +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.realtime_object_detection import \ + RealtimeVideoDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Model, Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_object_detection, + module_name=Pipelines.realtime_video_object_detection) +class RealtimeVideoObjectDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + super().__init__(model=model, **kwargs) + self.model = RealtimeVideoDetector(model) + + def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: + return input + + def forward(self, input: Input) -> Dict[Tensor, Dict[str, np.ndarray]]: + self.video_path = input + # Processing the whole video and return results for each frame + forward_output = self.model.inference_video(self.video_path) + return {'forward_output': forward_output} + + def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], + **kwargs) -> str: + forward_output = input['forward_output'] + + scores, boxes, labels, timestamps = [], [], [], [] + for result in forward_output: + box, score, label, timestamp = result + scores.append(score) + boxes.append(box) + labels.append(label) + timestamps.append(timestamp) + + return { + OutputKeys.BOXES: boxes, + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.TIMESTAMPS: timestamps, + } diff --git a/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py new file mode 100644 index 00000000..cfbf2607 --- /dev/null +++ b/modelscope/pipelines/cv/referring_video_object_segmentation_pipeline.py @@ -0,0 +1,199 @@ +# The implementation here is modified based on MTTR, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/mttr2021/MTTR +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as F +from einops import rearrange +from moviepy.editor import AudioFileClip, ImageSequenceClip, VideoFileClip +from PIL import Image, ImageDraw, ImageFont, ImageOps +from tqdm import tqdm + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.referring_video_object_segmentation, + module_name=Pipelines.referring_video_object_segmentation) +class ReferringVideoObjectSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """use `model` to create a referring video object segmentation pipeline for prediction + + Args: + model: model id on modelscope hub + """ + _device = kwargs.pop('device', 'gpu') + if torch.cuda.is_available() and _device == 'gpu': + self.device = 'gpu' + else: + self.device = 'cpu' + super().__init__(model=model, device=self.device, **kwargs) + + logger.info('Load model done!') + + def preprocess(self, input: Input) -> Dict[str, Any]: + """ + + Args: + input: path of the input video + + """ + assert isinstance(input, tuple) and len( + input + ) == 4, 'error - input type must be tuple and input length must be 4' + self.input_video_pth, text_queries, start_pt, end_pt = input + + assert 0 < end_pt - start_pt <= 10, 'error - the subclip length must be 0-10 seconds long' + assert 1 <= len( + text_queries) <= 2, 'error - 1-2 input text queries are expected' + + # extract the relevant subclip: + self.input_clip_pth = 'input_clip.mp4' + with VideoFileClip(self.input_video_pth) as video: + subclip = video.subclip(start_pt, end_pt) + subclip.write_videofile(self.input_clip_pth) + + self.window_length = 24 # length of window during inference + self.window_overlap = 6 # overlap (in frames) between consecutive windows + + self.video, audio, self.meta = torchvision.io.read_video( + filename=self.input_clip_pth) + self.video = rearrange(self.video, 't h w c -> t c h w') + + input_video = F.resize(self.video, size=360, max_size=640) + if self.device_name == 'gpu': + input_video = input_video.cuda() + + input_video = input_video.to(torch.float).div_(255) + input_video = F.normalize( + input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + video_metadata = { + 'resized_frame_size': input_video.shape[-2:], + 'original_frame_size': self.video.shape[-2:] + } + + # partition the clip into overlapping windows of frames: + windows = [ + input_video[i:i + self.window_length] + for i in range(0, len(input_video), self.window_length + - self.window_overlap) + ] + # clean up the text queries: + self.text_queries = [' '.join(q.lower().split()) for q in text_queries] + + result = { + 'text_queries': self.text_queries, + 'windows': windows, + 'video_metadata': video_metadata + } + + return result + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + pred_masks_per_query = [] + t, _, h, w = self.video.shape + for text_query in tqdm(input['text_queries'], desc='text queries'): + pred_masks = torch.zeros(size=(t, 1, h, w)) + for i, window in enumerate( + tqdm(input['windows'], desc='windows')): + + window_masks = self.model.inference( + window=window, + text_query=text_query, + metadata=input['video_metadata']) + + win_start_idx = i * ( + self.window_length - self.window_overlap) + pred_masks[win_start_idx:win_start_idx + + self.window_length] = window_masks + pred_masks_per_query.append(pred_masks) + return pred_masks_per_query + + def postprocess(self, inputs) -> Dict[str, Any]: + if self.model.cfg.pipeline.save_masked_video: + # RGB colors for instance masks: + light_blue = (41, 171, 226) + purple = (237, 30, 121) + dark_green = (35, 161, 90) + orange = (255, 148, 59) + colors = np.array([light_blue, purple, dark_green, orange]) + + # width (in pixels) of the black strip above the video on which the text queries will be displayed: + text_border_height_per_query = 36 + + video_np = rearrange(self.video, + 't c h w -> t h w c').numpy() / 255.0 + + # del video + pred_masks_per_frame = rearrange( + torch.stack(inputs), 'q t 1 h w -> t q h w').numpy() + masked_video = [] + for vid_frame, frame_masks in tqdm( + zip(video_np, pred_masks_per_frame), + total=len(video_np), + desc='applying masks...'): + # apply the masks: + for inst_mask, color in zip(frame_masks, colors): + vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0) + vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8)) + # visualize the text queries: + vid_frame = ImageOps.expand( + vid_frame, + border=(0, len(self.text_queries) + * text_border_height_per_query, 0, 0)) + W, H = vid_frame.size + draw = ImageDraw.Draw(vid_frame) + + if self.model.cfg.pipeline.output_font: + font = ImageFont.truetype( + font=self.model.cfg.pipeline.output_font, + size=self.model.cfg.pipeline.output_font_size) + else: + font = ImageFont.load_default() + for i, (text_query, color) in enumerate( + zip(self.text_queries, colors), start=1): + w, h = draw.textsize(text_query, font=font) + draw.text(((W - w) / 2, + (text_border_height_per_query * i) - h - 3), + text_query, + fill=tuple(color) + (255, ), + font=font) + masked_video.append(np.array(vid_frame)) + print(type(vid_frame)) + print(type(masked_video[0])) + print(masked_video[0].shape) + # generate and save the output clip: + + assert self.model.cfg.pipeline.output_path + output_clip_path = self.model.cfg.pipeline.output_path + clip = ImageSequenceClip( + sequence=masked_video, fps=self.meta['video_fps']) + clip = clip.set_audio(AudioFileClip(self.input_clip_pth)) + clip.write_videofile( + output_clip_path, fps=self.meta['video_fps'], audio=True) + del masked_video + + result = {OutputKeys.MASKS: inputs} + return result + + +def apply_mask(image, mask, color, transparency=0.7): + mask = mask[..., np.newaxis].repeat(repeats=3, axis=2) + mask = mask * transparency + color_matrix = np.ones(image.shape, dtype=np.float) * color + out_image = color_matrix * mask + image * (1.0 - mask) + return out_image diff --git a/modelscope/pipelines/cv/retina_face_detection_pipeline.py b/modelscope/pipelines/cv/retina_face_detection_pipeline.py new file mode 100644 index 00000000..40f2336a --- /dev/null +++ b/modelscope/pipelines/cv/retina_face_detection_pipeline.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import RetinaFaceDetection +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.retina_face_detection) +class RetinaFaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {ckpt_path}') + detector = RetinaFaceDetection( + model_path=ckpt_path, device=self.device) + self.detector = detector + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float32) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector(input) + assert result is not None + bboxes = result[0][:, :4].tolist() + scores = result[0][:, 4].tolist() + lms = result[1].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: lms, + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/shop_segmentation_pipleline.py b/modelscope/pipelines/cv/shop_segmentation_pipleline.py new file mode 100644 index 00000000..d08058c3 --- /dev/null +++ b/modelscope/pipelines/cv/shop_segmentation_pipleline.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.shop_segmentation, module_name=Pipelines.shop_segmentation) +class ShopSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img_tensor, ori_h, ori_w, crop_h, crop_w = self.model.preprocess(img) + result = { + 'img': img_tensor, + 'ori_h': ori_h, + 'ori_w': ori_w, + 'crop_h': crop_h, + 'crop_w': crop_w + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + outputs = self.model.inference(input['img']) + result = { + 'data': outputs, + 'ori_h': input['ori_h'], + 'ori_w': input['ori_w'], + 'crop_h': input['crop_h'], + 'crop_w': input['crop_w'], + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + data = self.model.postprocess(inputs['data'], inputs['crop_h'], + inputs['crop_w'], inputs['ori_h'], + inputs['ori_w']) + outputs = {OutputKeys.MASKS: data} + return outputs diff --git a/modelscope/pipelines/cv/skin_retouching_pipeline.py b/modelscope/pipelines/cv/skin_retouching_pipeline.py new file mode 100644 index 00000000..c6571bef --- /dev/null +++ b/modelscope/pipelines/cv/skin_retouching_pipeline.py @@ -0,0 +1,304 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import tensorflow as tf +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.skin_retouching.detection_model.detection_unet_in import \ + DetectionUNet +from modelscope.models.cv.skin_retouching.inpainting_model.inpainting_unet import \ + RetouchingNet +from modelscope.models.cv.skin_retouching.retinaface.predict_single import \ + Model +from modelscope.models.cv.skin_retouching.unet_deploy import UNet +from modelscope.models.cv.skin_retouching.utils import * # noqa F403 +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import create_device, device_placement +from modelscope.utils.logger import get_logger + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.disable_eager_execution() + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.skin_retouching, module_name=Pipelines.skin_retouching) +class SkinRetouchingPipeline(Pipeline): + + def __init__(self, model: str, device: str): + """ + use `model` to create a skin retouching pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, device=device) + + device = create_device(self.device_name) + model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE) + detector_model_path = os.path.join( + self.model, 'retinaface_resnet50_2020-07-20_old_torch.pth') + local_model_path = os.path.join(self.model, 'joint_20210926.pth') + skin_model_path = os.path.join(self.model, ModelFile.TF_GRAPH_FILE) + + self.generator = UNet(3, 3).to(device) + self.generator.load_state_dict( + torch.load(model_path, map_location='cpu')['generator']) + self.generator.eval() + + self.detector = Model(max_size=512, device=device) + state_dict = torch.load(detector_model_path, map_location='cpu') + self.detector.load_state_dict(state_dict) + self.detector.eval() + + self.local_model_path = local_model_path + ckpt_dict_load = torch.load(self.local_model_path, map_location='cpu') + self.inpainting_net = RetouchingNet( + in_channels=4, out_channels=3).to(device) + self.detection_net = DetectionUNet( + n_channels=3, n_classes=1).to(device) + + self.inpainting_net.load_state_dict(ckpt_dict_load['inpainting_net']) + self.detection_net.load_state_dict(ckpt_dict_load['detection_net']) + + self.inpainting_net.eval() + self.detection_net.eval() + + self.patch_size = 512 + + self.skin_model_path = skin_model_path + if self.skin_model_path is not None: + with device_placement(self.framework, self.device_name): + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + config.gpu_options.allow_growth = True + self.sess = tf.Session(config=config) + with tf.gfile.FastGFile(self.skin_model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + self.sess.graph.as_default() + tf.import_graph_def(graph_def, name='') + self.sess.run(tf.global_variables_initializer()) + + self.image_files_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + self.diffuse_mask = gen_diffuse_mask() + self.diffuse_mask = torch.from_numpy( + self.diffuse_mask).to(device).float() + self.diffuse_mask = self.diffuse_mask.permute(2, 0, 1)[None, ...] + + self.input_size = 512 + self.device = device + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = img.astype(np.float) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + rgb_image = input['img'].astype(np.uint8) + + retouch_local = True + whitening = True + degree = 1.0 + whitening_degree = 0.8 + return_mg = False + + with torch.no_grad(): + if whitening and whitening_degree > 0 and self.skin_model_path is not None: + rgb_image_small, resize_scale = resize_on_long_side( + rgb_image, 800) + skin_mask = self.sess.run( + self.sess.graph.get_tensor_by_name('output_png:0'), + feed_dict={'input_image:0': rgb_image_small}) + + output_pred = torch.from_numpy(rgb_image).to(self.device) + if return_mg: + output_mg = np.ones( + (rgb_image.shape[0], rgb_image.shape[1], 3), + dtype=np.float32) * 0.5 + + results = self.detector.predict_jsons( + rgb_image + ) # list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...] + + crop_bboxes = get_crop_bbox(results) + + face_num = len(crop_bboxes) + if face_num == 0: + output = { + 'pred': output_pred.cpu().numpy()[:, :, ::-1], + 'face_num': face_num + } + return output + + flag_bigKernal = False + for bbox in crop_bboxes: + roi, expand, crop_tblr = get_roi_without_padding( + rgb_image, bbox) + roi = roi_to_tensor(roi) # bgr -> rgb + + if roi.shape[2] > 0.4 * rgb_image.shape[0]: + flag_bigKernal = True + + roi = roi.to(self.device) + + roi = preprocess_roi(roi) + + if retouch_local and self.local_model_path is not None: + roi = self.retouch_local(roi) + + roi_output = self.predict_roi( + roi, + degree=degree, + smooth_border=True, + return_mg=return_mg) + + roi_pred = roi_output['pred'] + output_pred[crop_tblr[0]:crop_tblr[1], + crop_tblr[2]:crop_tblr[3]] = roi_pred + + if return_mg: + roi_mg = roi_output['pred_mg'] + output_mg[crop_tblr[0]:crop_tblr[1], + crop_tblr[2]:crop_tblr[3]] = roi_mg + + if whitening and whitening_degree > 0 and self.skin_model_path is not None: + output_pred = whiten_img( + output_pred, + skin_mask, + whitening_degree, + flag_bigKernal=flag_bigKernal) + + if not isinstance(output_pred, np.ndarray): + output_pred = output_pred.cpu().numpy() + + output_pred = output_pred[:, :, ::-1] + + return {OutputKeys.OUTPUT_IMG: output_pred} + + def retouch_local(self, image): + """ + image: rgb + """ + with torch.no_grad(): + sub_H, sub_W = image.shape[2:] + + sub_image_standard = F.interpolate( + image, size=(768, 768), mode='bilinear', align_corners=True) + sub_mask_pred = torch.sigmoid( + self.detection_net(sub_image_standard)) + sub_mask_pred = F.interpolate( + sub_mask_pred, size=(sub_H, sub_W), mode='nearest') + + sub_mask_pred_hard_low = (sub_mask_pred >= 0.35).float() + sub_mask_pred_hard_high = (sub_mask_pred >= 0.5).float() + sub_mask_pred = sub_mask_pred * ( + 1 - sub_mask_pred_hard_high) + sub_mask_pred_hard_high + sub_mask_pred = sub_mask_pred * sub_mask_pred_hard_low + sub_mask_pred = 1 - sub_mask_pred + + sub_H_standard = sub_H if sub_H % self.patch_size == 0 else ( + sub_H // self.patch_size + 1) * self.patch_size + sub_W_standard = sub_W if sub_W % self.patch_size == 0 else ( + sub_W // self.patch_size + 1) * self.patch_size + + sub_image_padding = F.pad( + image, + pad=(0, sub_W_standard - sub_W, 0, sub_H_standard - sub_H, 0, + 0), + mode='constant', + value=0) + sub_mask_pred_padding = F.pad( + sub_mask_pred, + pad=(0, sub_W_standard - sub_W, 0, sub_H_standard - sub_H, 0, + 0), + mode='constant', + value=0) + + sub_image_padding = patch_partition_overlap( + sub_image_padding, p1=self.patch_size, p2=self.patch_size) + sub_mask_pred_padding = patch_partition_overlap( + sub_mask_pred_padding, p1=self.patch_size, p2=self.patch_size) + B_padding, C_padding, _, _ = sub_image_padding.size() + + sub_comp_padding_list = [] + for window_item in range(B_padding): + sub_image_padding_window = sub_image_padding[ + window_item:window_item + 1] + sub_mask_pred_padding_window = sub_mask_pred_padding[ + window_item:window_item + 1] + + sub_input_image_padding_window = sub_image_padding_window * sub_mask_pred_padding_window + + sub_output_padding_window = self.inpainting_net( + sub_input_image_padding_window, + sub_mask_pred_padding_window) + sub_comp_padding_window = sub_input_image_padding_window + ( + 1 + - sub_mask_pred_padding_window) * sub_output_padding_window + + sub_comp_padding_list.append(sub_comp_padding_window) + + sub_comp_padding = torch.cat(sub_comp_padding_list, dim=0) + sub_comp = patch_aggregation_overlap( + sub_comp_padding, + h=int(round(sub_H_standard / self.patch_size)), + w=int(round(sub_W_standard + / self.patch_size)))[:, :, :sub_H, :sub_W] + + return sub_comp + + def predict_roi(self, + roi, + degree=1.0, + smooth_border=False, + return_mg=False): + with torch.no_grad(): + image = F.interpolate( + roi, (self.input_size, self.input_size), mode='bilinear') + + pred_mg = self.generator(image) # value: 0~1 + pred_mg = (pred_mg - 0.5) * degree + 0.5 + pred_mg = pred_mg.clamp(0.0, 1.0) + pred_mg = F.interpolate(pred_mg, roi.shape[2:], mode='bilinear') + pred_mg = pred_mg[0].permute( + 1, 2, 0) # ndarray, (h, w, 1) or (h0, w0, 3) + if len(pred_mg.shape) == 2: + pred_mg = pred_mg[..., None] + + if smooth_border: + pred_mg = smooth_border_mg(self.diffuse_mask, pred_mg) + + image = (roi[0].permute(1, 2, 0) + 1.0) / 2 + + pred = (1 - 2 * pred_mg + ) * image * image + 2 * pred_mg * image # value: 0~1 + + pred = (pred * 255.0).byte() # ndarray, (h, w, 3), rgb + + output = {'pred': pred} + if return_mg: + output['pred_mg'] = pred_mg.cpu().numpy() + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/text_driven_segmentation_pipleline.py b/modelscope/pipelines/cv/text_driven_segmentation_pipleline.py new file mode 100644 index 00000000..c7f9d4c2 --- /dev/null +++ b/modelscope/pipelines/cv/text_driven_segmentation_pipleline.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.text_driven_segmentation, + module_name=Pipelines.text_driven_segmentation) +class TextDrivenSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + + def preprocess(self, input: Dict) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input['image']) + img_tensor, ori_h, ori_w, crop_h, crop_w = self.model.preprocess(img) + result = { + 'img': img_tensor, + 'ori_h': ori_h, + 'ori_w': ori_w, + 'crop_h': crop_h, + 'crop_w': crop_w, + 'text': input['text'], + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + outputs = self.model.inference(input['img'], input['text']) + result = { + 'data': outputs, + 'ori_h': input['ori_h'], + 'ori_w': input['ori_w'], + 'crop_h': input['crop_h'], + 'crop_w': input['crop_w'], + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + data = self.model.postprocess(inputs['data'], inputs['crop_h'], + inputs['crop_w'], inputs['ori_h'], + inputs['ori_w']) + outputs = {OutputKeys.MASKS: data} + return outputs diff --git a/modelscope/pipelines/cv/tinynas_classification_pipeline.py b/modelscope/pipelines/cv/tinynas_classification_pipeline.py new file mode 100644 index 00000000..a470e58b --- /dev/null +++ b/modelscope/pipelines/cv/tinynas_classification_pipeline.py @@ -0,0 +1,96 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os.path as osp +from typing import Any, Dict + +import torch +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.tinynas_classfication import get_zennet +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_classification, module_name=Pipelines.tinynas_classification) +class TinynasClassificationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a tinynas classification pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.path = model + self.model = get_zennet() + + model_pth_path = osp.join(self.path, ModelFile.TORCH_MODEL_FILE) + + checkpoint = torch.load(model_pth_path, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + self.model.load_state_dict(state_dict, strict=True) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + + input_image_size = 224 + crop_image_size = 380 + input_image_crop = 0.875 + resize_image_size = int(math.ceil(crop_image_size / input_image_crop)) + transforms_normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transform_list = [ + transforms.Resize( + resize_image_size, + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(crop_image_size), + transforms.ToTensor(), transforms_normalize + ] + transformer = transforms.Compose(transform_list) + + img = transformer(img) + img = torch.unsqueeze(img, 0) + img = torch.nn.functional.interpolate( + img, input_image_size, mode='bilinear') + result = {'img': img} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + is_train = False + if is_train: + self.model.train() + else: + self.model.eval() + + outputs = self.model(input['img']) + return {'outputs': outputs} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + label_mapping_path = osp.join(self.path, 'label_map.txt') + f = open(label_mapping_path) + content = f.read() + f.close() + label_dict = eval(content) + + output_prob = torch.nn.functional.softmax(inputs['outputs'], dim=-1) + score = torch.max(output_prob) + output_dict = { + OutputKeys.SCORES: [score.item()], + OutputKeys.LABELS: [label_dict[inputs['outputs'].argmax().item()]] + } + return output_dict diff --git a/modelscope/pipelines/cv/tinynas_detection_pipeline.py b/modelscope/pipelines/cv/tinynas_detection_pipeline.py new file mode 100644 index 00000000..d35d4d36 --- /dev/null +++ b/modelscope/pipelines/cv/tinynas_detection_pipeline.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import \ + show_image_object_detection_auto_result +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_object_detection, module_name=Pipelines.tinynas_detection) +class TinynasDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + if torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + self.model.to(self.device) + self.model.eval() + + def preprocess(self, input: Input) -> Dict[str, Any]: + + img = LoadImage.convert_to_ndarray(input) + self.img = img + img = img.astype(np.float) + img = self.model.preprocess(img) + result = {'img': img.to(self.device)} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + outputs = self.model.inference(input['img']) + result = {'data': outputs} + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + bboxes, scores, labels = self.model.postprocess(inputs['data']) + if bboxes is None: + outputs = { + OutputKeys.SCORES: [], + OutputKeys.LABELS: [], + OutputKeys.BOXES: [] + } + else: + outputs = { + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.BOXES: bboxes + } + return outputs + + def show_result(self, img_path, result, save_path=None): + show_image_object_detection_auto_result(img_path, result, save_path) diff --git a/modelscope/pipelines/cv/ulfd_face_detection_pipeline.py b/modelscope/pipelines/cv/ulfd_face_detection_pipeline.py new file mode 100644 index 00000000..e9901d64 --- /dev/null +++ b/modelscope/pipelines/cv/ulfd_face_detection_pipeline.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_detection import UlfdFaceDetector +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.face_detection, module_name=Pipelines.ulfd_face_detection) +class UlfdFaceDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a face detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {ckpt_path}') + detector = UlfdFaceDetector(model_path=ckpt_path, device=self.device) + self.detector = detector + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + img = img.astype(np.float32) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + result = self.detector(input) + assert result is not None + bboxes = result[0].tolist() + scores = result[1].tolist() + return { + OutputKeys.SCORES: scores, + OutputKeys.BOXES: bboxes, + OutputKeys.KEYPOINTS: None, + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/video_category_pipeline.py b/modelscope/pipelines/cv/video_category_pipeline.py new file mode 100644 index 00000000..e4c73649 --- /dev/null +++ b/modelscope/pipelines/cv/video_category_pipeline.py @@ -0,0 +1,397 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os.path as osp +from typing import Any, Dict + +import decord +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.transforms.functional as TF +from decord import VideoReader, cpu +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_category, module_name=Pipelines.video_category) +class VideoCategoryPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a video-category pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading configuration from {config_path}') + with open(config_path, 'r') as f: + config = json.load(f) + self.frame_num = config['frame_num'] + self.level_1_num = config['level_1_num'] + self.level_2_num = config['level_2_num'] + self.resize = config['resize'] + self.crop = config['crop'] + self.mean = config['mean'] + self.std = config['std'] + self.cateproj_v3 = config['cateproj_v3'] + self.class_name = config['class_name'] + self.subclass_name = config['subclass_name'] + logger.info('load configuration done') + + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading model from {model_path}') + self.infer_model = ModelWrapper(self.level_1_num, self.level_2_num, + self.frame_num) + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.infer_model = self.infer_model.to(self.device).eval() + self.infer_model.load_state_dict( + torch.load(model_path, map_location=self.device)) + logger.info('load model done') + self.transforms = VCompose([ + VRescale(size=self.resize), + VCenterCrop(size=self.crop), + VToTensor(), + VNormalize(mean=self.mean, std=self.std) + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + decord.bridge.set_bridge('native') + vr = VideoReader(input, ctx=cpu(0)) + indices = np.linspace(0, len(vr) - 1, 16).astype(int) + frames = vr.get_batch(indices).asnumpy() + video_input_data = self.transforms( + [Image.fromarray(f) for f in frames]) + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_data': video_input_data} + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred1, pred2 = self.infer_model(input['video_data'].to(self.device)) + + pred1 = F.softmax(pred1, dim=1) + pred2 = F.softmax(pred2, dim=1) + + vals_2, preds_2 = pred2.cpu().topk(10, 1, True, True) + vals_2 = vals_2.detach().numpy() + preds_2 = preds_2.detach().numpy() + + if vals_2[0][0] >= 0.3: + c2 = int(preds_2[0][0]) + c1 = self.cateproj_v3[c2] + + tag1 = self.class_name[c1] + tag2 = self.subclass_name[c2] + + prob = float(vals_2[0][0]) + else: + vals_1, preds_1 = pred1.cpu().topk(10, 1, True, True) + vals_1 = vals_1.detach().numpy() + preds_1 = preds_1.detach().numpy() + + c1 = int(preds_1[0][0]) + + tag1 = self.class_name[c1] + tag2 = '其他' + + prob = float(vals_1[0][0]) + + return { + OutputKeys.SCORES: [prob], + OutputKeys.LABELS: [tag1 + '>>' + tag2] + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + +class TimeFirstBatchNorm1d(nn.Module): + + def __init__(self, dim, groups=None): + super().__init__() + self.groups = groups + self.bn = nn.BatchNorm1d(dim) + + def forward(self, tensor): + _, length, dim = tensor.size() + if self.groups: + dim = dim // self.groups + tensor = tensor.view(-1, dim) + tensor = self.bn(tensor) + if self.groups: + return tensor.view(-1, length, self.groups, dim) + else: + return tensor.view(-1, length, dim) + + +class NeXtVLAD(nn.Module): + """NeXtVLAD layer implementation + Adapted from https://github.com/linrongc/youtube-8m/blob/master/nextvlad.py + """ + + def __init__(self, + num_clusters=64, + dim=128, + alpha=100.0, + groups=8, + expansion=2, + normalize_input=True, + p_drop=0.25, + add_batchnorm=False): + """ + Args: + num_clusters : int + The number of clusters + dim : int + Dimension of descriptors + alpha : float + Parameter of initialization. Larger value is harder assignment. + normalize_input : bool + If true, descriptor-wise L2 normalization is applied to input. + """ + super(NeXtVLAD, self).__init__() + assert dim % groups == 0, '`dim` must be divisible by `groups`' + assert expansion > 1 + self.p_drop = p_drop + self.cluster_dropout = nn.Dropout2d(p_drop) + self.num_clusters = num_clusters + self.dim = dim + self.expansion = expansion + self.grouped_dim = dim * expansion // groups + self.groups = groups + self.alpha = alpha + self.normalize_input = normalize_input + self.add_batchnorm = add_batchnorm + self.expansion_mapper = nn.Linear(dim, dim * expansion) + if add_batchnorm: + self.soft_assignment_mapper = nn.Sequential( + nn.Linear(dim * expansion, num_clusters * groups, bias=False), + TimeFirstBatchNorm1d(num_clusters, groups=groups)) + else: + self.soft_assignment_mapper = nn.Linear( + dim * expansion, num_clusters * groups, bias=True) + self.attention_mapper = nn.Linear(dim * expansion, groups) + self.centroids = nn.Parameter( + torch.rand(num_clusters, self.grouped_dim)) + self.final_bn = nn.BatchNorm1d(num_clusters * self.grouped_dim) + self._init_params() + + def _init_params(self): + for component in (self.soft_assignment_mapper, self.attention_mapper, + self.expansion_mapper): + for module in component.modules(): + self.general_weight_initialization(module) + if self.add_batchnorm: + self.soft_assignment_mapper[0].weight = nn.Parameter( + (2.0 * self.alpha * self.centroids).repeat( + (self.groups, self.groups))) + nn.init.constant_(self.soft_assignment_mapper[1].bn.weight, 1) + nn.init.constant_(self.soft_assignment_mapper[1].bn.bias, 0) + else: + self.soft_assignment_mapper.weight = nn.Parameter( + (2.0 * self.alpha * self.centroids).repeat( + (self.groups, self.groups))) + self.soft_assignment_mapper.bias = nn.Parameter( + (-self.alpha * self.centroids.norm(dim=1)).repeat( + (self.groups, ))) + + def general_weight_initialization(self, module): + if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)): + if module.weight is not None: + nn.init.uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.Linear): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def forward(self, x, masks=None): + """NeXtVlad Adaptive Pooling + Arguments: + x {torch.Tensor} -- shape: (n_batch, len, dim) + Returns: + torch.Tensor -- shape (n_batch, n_cluster * dim / groups) + """ + if self.normalize_input: + x = F.normalize(x, p=2, dim=2) # across descriptor dim + + # expansion + # shape: (n_batch, len, dim * expansion) + x = self.expansion_mapper(x) + + # soft-assignment + # shape: (n_batch, len, n_cluster, groups) + soft_assign = self.soft_assignment_mapper(x).view( + x.size(0), x.size(1), self.num_clusters, self.groups) + soft_assign = F.softmax(soft_assign, dim=2) + + # attention + # shape: (n_batch, len, groups) + attention = torch.sigmoid(self.attention_mapper(x)) + if masks is not None: + # shape: (n_batch, len, groups) + attention = attention * masks[:, :, None] + + # (n_batch, len, n_cluster, groups, dim / groups) + activation = ( + attention[:, :, None, :, None] * soft_assign[:, :, :, :, None]) + + # calculate residuals to each clusters + # (n_batch, n_cluster, dim / groups) + second_term = ( + activation.sum(dim=3).sum(dim=1) * self.centroids[None, :, :]) + # (n_batch, n_cluster, dim / groups) + first_term = ( + # (n_batch, len, n_cluster, groups, dim / groups) + activation + * x.view(x.size(0), x.size(1), 1, self.groups, + self.grouped_dim)).sum(dim=3).sum(dim=1) + + # vlad shape (n_batch, n_cluster, dim / groups) + vlad = first_term - second_term + vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization + # flatten shape (n_batch, n_cluster * dim / groups) + vlad = vlad.view(x.size(0), -1) # flatten + # vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize + vlad = self.final_bn(vlad) + if self.p_drop: + vlad = self.cluster_dropout( + vlad.view(x.size(0), self.num_clusters, self.grouped_dim, + 1)).view(x.size(0), -1) + return vlad + + +class ModelWrapper(nn.Module): + + def __init__(self, class_num, subclass_num, frame_num): + super(ModelWrapper, self).__init__() + cnn = models.resnet50(pretrained=False) + cnn.fc = nn.Sequential() + self.model = cnn + # Use NextVlad + # output size: (n_batch, n_cluster * dim / groups) + nv_group = 2 + expand = int(2 * frame_num / nv_group) + self.nextvlad = NeXtVLAD( + num_clusters=frame_num, dim=2048, groups=nv_group) + self.fc = nn.Linear(2048 * expand, 2048) + self.head1_p1 = nn.Sequential( + nn.Linear(2048, 2048), + nn.ReLU(), + nn.Linear(2048, 1024), + ) + self.head1_p2 = nn.Sequential( + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Linear(1024, class_num), + ) + self.head2_p1 = nn.Sequential( + nn.Linear(2048, 2048), + nn.ReLU(), + nn.Linear(2048, 1024), + ) + self.head2_p2 = nn.Sequential( + nn.Linear(2048, 1024), + nn.ReLU(), + nn.Linear(1024, subclass_num), + ) + self.fn = frame_num + + def forward(self, x): + x = x.view(-1, 3, 224, 224) + x = self.model(x) + + x = x.view(-1, self.fn, 2048) + x = self.nextvlad(x) + + x = self.fc(x) + + x1 = self.head1_p1(x) + c1 = self.head1_p2(x1) + + x2 = self.head2_p1(x) + c2 = self.head2_p2(torch.cat((x1, x2), dim=1)) + + return c1, c2 + + +class VCompose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, item): + for t in self.transforms: + item = t(item) + return item + + +class VRescale(object): + + def __init__(self, size=128): + self.size = size + + def __call__(self, vclip): + w, h = vclip[0].size + scale = self.size / min(w, h) + out_w, out_h = int(round(w * scale)), int(round(h * scale)) + vclip = [u.resize((out_w, out_h), Image.BILINEAR) for u in vclip] + return vclip + + +class VCenterCrop(object): + + def __init__(self, size=112): + self.size = size + + def __call__(self, vclip): + w, h = vclip[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + vclip = [ + u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in vclip + ] + return vclip + + +class VToTensor(object): + + def __call__(self, vclip): + vclip = torch.stack([TF.to_tensor(u) for u in vclip], dim=0) + return vclip + + +class VNormalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, vclip): + assert vclip.min() > -0.1 and vclip.max() < 1.1, \ + 'vclip values should be in [0, 1]' + vclip = vclip.clone() + if not isinstance(self.mean, torch.Tensor): + self.mean = vclip.new_tensor(self.mean).view(1, -1, 1, 1) + if not isinstance(self.std, torch.Tensor): + self.std = vclip.new_tensor(self.std).view(1, -1, 1, 1) + vclip.sub_(self.mean).div_(self.std) + return vclip diff --git a/modelscope/pipelines/cv/video_inpainting_pipeline.py b/modelscope/pipelines/cv/video_inpainting_pipeline.py new file mode 100644 index 00000000..85133474 --- /dev/null +++ b/modelscope/pipelines/cv/video_inpainting_pipeline.py @@ -0,0 +1,48 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.video_inpainting import inpainting +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_inpainting, module_name=Pipelines.video_inpainting) +class VideoInpaintingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create video inpainting pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + decode_error, fps, w, h = inpainting.video_process( + input['video_input_path']) + + if decode_error is not None: + return {OutputKeys.OUTPUT: 'decode_error'} + + inpainting.inpainting_by_model_balance(self.model, + input['video_input_path'], + input['mask_path'], + input['video_output_path'], fps, + w, h) + + return {OutputKeys.OUTPUT: 'Done'} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py b/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py new file mode 100644 index 00000000..4169def7 --- /dev/null +++ b/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py @@ -0,0 +1,88 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Any, Dict + +import cv2 + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.video_single_object_tracking.config.ostrack import \ + cfg +from modelscope.models.cv.video_single_object_tracking.tracker.ostrack import \ + OSTrack +from modelscope.models.cv.video_single_object_tracking.utils.utils import ( + check_box, timestamp_format) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_single_object_tracking, + module_name=Pipelines.video_single_object_tracking) +class VideoSingleObjectTrackingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a single object tracking pipeline + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.cfg = cfg + ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) + logger.info(f'loading model from {ckpt_path}') + self.tracker = OSTrack(ckpt_path, self.device) + logger.info('init tracker done') + + def preprocess(self, input) -> Input: + self.video_path = input[0] + self.init_bbox = input[1] + return input + + def forward(self, input: Input) -> Dict[str, Any]: + output_boxes = [] + output_timestamps = [] + cap = cv2.VideoCapture(self.video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + frame_idx = 0 + success, frame = cap.read() + if success is False: + raise Exception( + 'modelscope error: %s can not be decoded by OpenCV.' % + (self.video_path)) + + init_box = self.init_bbox + frame_h, frame_w = frame.shape[0:2] + if not check_box(init_box, frame_h, frame_w): + raise Exception('modelscope error: init_box out of image range ', + init_box) + output_boxes.append(init_box.copy()) + output_timestamps.append(timestamp_format(seconds=frame_idx / fps)) + init_box[2] = init_box[2] - init_box[0] + init_box[3] = init_box[3] - init_box[1] + self.tracker.initialize(frame, {'init_bbox': init_box}) + logger.info('init bbox done') + + while True: + ret, frame = cap.read() + if frame is None: + break + frame_idx += 1 + out = self.tracker.track(frame) + state = [int(s) for s in out['target_bbox']] + output_boxes.append(state) + output_timestamps.append(timestamp_format(seconds=frame_idx / fps)) + cap.release() + logger.info('tracking process done') + + return { + OutputKeys.BOXES: output_boxes, + OutputKeys.TIMESTAMPS: output_timestamps + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/video_summarization_pipeline.py b/modelscope/pipelines/cv/video_summarization_pipeline.py new file mode 100644 index 00000000..e4fe206d --- /dev/null +++ b/modelscope/pipelines/cv/video_summarization_pipeline.py @@ -0,0 +1,117 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM + +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +from tqdm import tqdm + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.video_summarization import (PGLVideoSummarization, + summary_format) +from modelscope.models.cv.video_summarization.base_model import bvlc_googlenet +from modelscope.models.cv.video_summarization.summarizer import ( + generate_summary, get_change_points) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_summarization, module_name=Pipelines.video_summarization) +class VideoSummarizationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a video summarization pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + logger.info(f'loading model from {model}') + googlenet_model_path = osp.join(model, 'bvlc_googlenet.pt') + config_path = osp.join(model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + + self.googlenet_model = bvlc_googlenet() + self.googlenet_model.model.load_state_dict( + torch.load( + googlenet_model_path, map_location=torch.device(self.device))) + self.googlenet_model = self.googlenet_model.to(self.device).eval() + + self.pgl_model = PGLVideoSummarization(model) + self.pgl_model = self.pgl_model.to(self.device).eval() + + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if not isinstance(input, str): + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + frames = [] + picks = [] + cap = cv2.VideoCapture(input) + self.fps = cap.get(cv2.CAP_PROP_FPS) + self.frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) + frame_idx = 0 + while (cap.isOpened()): + ret, frame = cap.read() + if not ret: + break + if frame_idx % 15 == 0: + frames.append(frame) + picks.append(frame_idx) + frame_idx += 1 + n_frame = frame_idx + + result = { + 'video_name': input, + 'video_frames': np.array(frames), + 'n_frame': n_frame, + 'picks': np.array(picks) + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + frame_features = [] + for frame in tqdm(input['video_frames']): + feat = self.googlenet_model(frame) + frame_features.append(feat) + + change_points, n_frame_per_seg = get_change_points( + frame_features, input['n_frame']) + + summary = self.inference(frame_features, input['n_frame'], + input['picks'], change_points) + + output = summary_format(summary, self.fps) + + return {OutputKeys.OUTPUT: output} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + def inference(self, frame_features, n_frames, picks, change_points): + frame_features = torch.from_numpy(np.array(frame_features, np.float32)) + picks = np.array(picks, np.int32) + + with torch.no_grad(): + results = self.pgl_model(dict(frame_features=frame_features)) + scores = results['scores'] + if not scores.device.type == 'cpu': + scores = scores.cpu() + scores = scores.squeeze(0).numpy().tolist() + summary = generate_summary([change_points], [scores], [n_frames], + [picks])[0] + + return summary.tolist() diff --git a/modelscope/pipelines/cv/virtual_try_on_pipeline.py b/modelscope/pipelines/cv/virtual_try_on_pipeline.py new file mode 100644 index 00000000..cd6e7046 --- /dev/null +++ b/modelscope/pipelines/cv/virtual_try_on_pipeline.py @@ -0,0 +1,133 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict, Union + +import cv2 +import numpy as np +import PIL +import torch +from PIL import Image + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Pipelines +from modelscope.models.cv.virual_tryon import SDAFNet_Tryon +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks + + +@PIPELINES.register_module( + Tasks.virtual_try_on, module_name=Pipelines.virtual_try_on) +class VirtualTryonPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a virtual tryon pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + + def filter_param(src_params, own_state): + copied_keys = [] + for name, param in src_params.items(): + if 'module.' == name[0:7]: + name = name[7:] + if '.module.' not in list(own_state.keys())[0]: + name = name.replace('.module.', '.') + if (name in own_state) and (own_state[name].shape + == param.shape): + own_state[name].copy_(param) + copied_keys.append(name) + + def load_pretrained(model, src_params): + if 'state_dict' in src_params: + src_params = src_params['state_dict'] + own_state = model.state_dict() + filter_param(src_params, own_state) + model.load_state_dict(own_state) + + self.model = SDAFNet_Tryon(ref_in_channel=6).to(self.device) + local_model_dir = model + if osp.exists(model): + local_model_dir = model + else: + local_model_dir = snapshot_download(model) + self.local_path = local_model_dir + src_params = torch.load( + osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE), 'cpu') + load_pretrained(self.model, src_params) + self.model = self.model.eval() + self.size = 192 + from torchvision import transforms + self.test_transforms = transforms.Compose([ + transforms.Resize(self.size, interpolation=2), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + def preprocess(self, input: Union[Dict[str, Any], + tuple]) -> Dict[str, Any]: + if isinstance(input, tuple): + index_model = 0 + index_pose = 1 + index_cloth = 2 + else: + index_model = 'masked_model' + index_pose = 'pose' + index_cloth = 'cloth' + if isinstance(input[index_model], str): + img_agnostic = load_image(input[index_model]) + pose = load_image(input[index_pose]) + cloth_img = load_image(input[index_cloth]) + elif isinstance(input[index_model], PIL.Image.Image): + img_agnostic = input[index_model].convert('RGB') + pose = input[index_pose].convert('RGB') + cloth_img = input[index_cloth].convert('RGB') + elif isinstance(input[index_model], np.ndarray): + if len(input.shape) == 2: + img_agnostic = cv2.cvtColor(input[index_model], + cv2.COLOR_GRAY2BGR) + pose = cv2.cvtColor(input[index_pose], cv2.COLOR_GRAY2BGR) + cloth_img = cv2.cvtColor(input[index_cloth], + cv2.COLOR_GRAY2BGR) + img_agnostic = Image.fromarray( + img_agnostic[:, :, ::-1].astype('uint8')).convert('RGB') + pose = Image.fromarray( + pose[:, :, ::-1].astype('uint8')).convert('RGB') + cloth_img = Image.fromarray( + cloth_img[:, :, ::-1].astype('uint8')).convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + img_agnostic = self.test_transforms(img_agnostic) + pose = self.test_transforms(pose) + cloth_img = self.test_transforms(cloth_img) + inputs = { + 'masked_model': img_agnostic.unsqueeze(0), + 'pose': pose.unsqueeze(0), + 'cloth': cloth_img.unsqueeze(0) + } + return inputs + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + img_agnostic = inputs['masked_model'].to(self.device) + pose = inputs['pose'].to(self.device) + cloth_img = inputs['cloth'].to(self.device) + ref_input = torch.cat((pose, img_agnostic), dim=1) + tryon_result = self.model(ref_input, cloth_img, img_agnostic) + return {OutputKeys.OUTPUT_IMG: tryon_result} + + def postprocess(self, outputs: Dict[str, Any]) -> Dict[str, Any]: + tryon_result = outputs[OutputKeys.OUTPUT_IMG].permute(0, 2, 3, + 1).squeeze(0) + tryon_result = tryon_result.add(1.).div(2.).mul(255).data.cpu().numpy() + outputs[OutputKeys.OUTPUT_IMG] = tryon_result + return outputs diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py new file mode 100644 index 00000000..55906e43 --- /dev/null +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline + from .image_captioning_pipeline import ImageCaptioningPipeline + from .visual_entailment_pipeline import VisualEntailmentPipeline + from .visual_grounding_pipeline import VisualGroundingPipeline + from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline + from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline + from .video_multi_modal_embedding_pipeline import \ + VideoMultiModalEmbeddingPipeline + from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline + +else: + _import_structure = { + 'image_captioning_pipeline': ['ImageCaptioningPipeline'], + 'visual_entailment_pipeline': ['VisualEntailmentPipeline'], + 'visual_grounding_pipeline': ['VisualGroundingPipeline'], + 'multi_modal_embedding_pipeline': ['MultiModalEmbeddingPipeline'], + 'text_to_image_synthesis_pipeline': ['TextToImageSynthesisPipeline'], + 'visual_question_answering_pipeline': + ['VisualQuestionAnsweringPipeline'], + 'video_multi_modal_embedding_pipeline': + ['VideoMultiModalEmbeddingPipeline'], + 'generative_multi_modal_embedding_pipeline': + ['GEMMMultiModalEmbeddingPipeline'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/multi_modal/generative_multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/generative_multi_modal_embedding_pipeline.py new file mode 100644 index 00000000..13032314 --- /dev/null +++ b/modelscope/pipelines/multi_modal/generative_multi_modal_embedding_pipeline.py @@ -0,0 +1,34 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.generative_multi_modal_embedding, + module_name=Pipelines.generative_multi_modal_embedding) +class GEMMMultiModalEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a generative multimodal embedding pipeline + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py new file mode 100644 index 00000000..81a5f8cd --- /dev/null +++ b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, + Preprocessor) +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_captioning, module_name=Pipelines.image_captioning) +class ImageCaptioningPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a image captioning pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None: + if isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(pipe_model.model_dir) + elif isinstance(pipe_model, MPlugForAllTasks): + preprocessor = MPlugPreprocessor(pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py b/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py new file mode 100644 index 00000000..329d79bf --- /dev/null +++ b/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import MPlugPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_text_retrieval, module_name=Pipelines.image_text_retrieval) +class ImageTextRetrievalPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a + image text retrieval pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + f'model must be a single str or Model, but got {type(model)}' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None: + preprocessor = MPlugPreprocessor(pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py new file mode 100644 index 00000000..18ee1dbf --- /dev/null +++ b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal.clip.model import CLIPForMultiModalEmbedding +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.multi_modal import CLIPPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_text_retrieval, module_name=Pipelines.multi_modal_embedding) +@PIPELINES.register_module( + Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) +class MultiModalEmbeddingPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError('model must be a single str') + pipe_model.eval() + if preprocessor is None: + if isinstance(pipe_model, CLIPForMultiModalEmbedding): + preprocessor = CLIPPreprocessor(pipe_model.model_dir) + else: + raise NotImplementedError + + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(self.preprocess(input)) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/ocr_recognition_pipeline.py b/modelscope/pipelines/multi_modal/ocr_recognition_pipeline.py new file mode 100644 index 00000000..c61b38f3 --- /dev/null +++ b/modelscope/pipelines/multi_modal/ocr_recognition_pipeline.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import OfaForAllTasks +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.ocr_recognition, module_name=Pipelines.ofa_ocr_recognition) +class OcrRecognitionPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a ocr recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None: + if isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/team_multi_modal_similarity_pipeline.py b/modelscope/pipelines/multi_modal/team_multi_modal_similarity_pipeline.py new file mode 100644 index 00000000..cafd6555 --- /dev/null +++ b/modelscope/pipelines/multi_modal/team_multi_modal_similarity_pipeline.py @@ -0,0 +1,32 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.multi_modal_similarity, module_name=Pipelines.multi_modal_similarity) +class TEAMMultiModalSimilarityPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a multimodal similarity pipeline + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py new file mode 100644 index 00000000..7516c5be --- /dev/null +++ b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import ( + MultiStageDiffusionForTextToImageSynthesis, OfaForTextToImageSynthesis) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.text_to_image_synthesis, + module_name=Pipelines.text_to_image_synthesis) +class TextToImageSynthesisPipeline(Pipeline): + + def __init__(self, + model: str, + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + device_id = 0 if torch.cuda.is_available() else -1 + if isinstance(model, str): + pipe_model = Model.from_pretrained(model, device_id=device_id) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError( + f'expecting a Model instance or str, but get {type(model)}.') + if preprocessor is None and isinstance(pipe_model, + OfaForTextToImageSynthesis): + preprocessor = OfaPreprocessor(pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + if self.preprocessor is not None: + return self.preprocessor(input, **preprocess_params) + else: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(self.model, + (OfaForTextToImageSynthesis, + MultiStageDiffusionForTextToImageSynthesis)): + return self.model(input) + return self.model.generate(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return {OutputKeys.OUTPUT_IMG: inputs} diff --git a/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py new file mode 100644 index 00000000..3a9284f1 --- /dev/null +++ b/modelscope/pipelines/multi_modal/video_multi_modal_embedding_pipeline.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.device import device_placement +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_multi_modal_embedding, + module_name=Pipelines.video_multi_modal_embedding) +class VideoMultiModalEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a video_multi_modal_embedding pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: + with device_placement(self.framework, self.device_name): + out = self.forward(input) + + self._check_output(out) + return out + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/visual_entailment_pipeline.py b/modelscope/pipelines/multi_modal/visual_entailment_pipeline.py new file mode 100644 index 00000000..2a7bd1d0 --- /dev/null +++ b/modelscope/pipelines/multi_modal/visual_entailment_pipeline.py @@ -0,0 +1,43 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import OfaForAllTasks +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.visual_entailment, module_name=Pipelines.visual_entailment) +class VisualEntailmentPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a visual entailment pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/visual_grounding_pipeline.py b/modelscope/pipelines/multi_modal/visual_grounding_pipeline.py new file mode 100644 index 00000000..651109d9 --- /dev/null +++ b/modelscope/pipelines/multi_modal/visual_grounding_pipeline.py @@ -0,0 +1,43 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import OfaForAllTasks +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.visual_grounding, module_name=Pipelines.visual_grounding) +class VisualGroundingPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a visual grounding pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py new file mode 100644 index 00000000..86177074 --- /dev/null +++ b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, + Preprocessor) +from modelscope.utils.constant import Tasks + +__all__ = ['VisualQuestionAnsweringPipeline'] + + +@PIPELINES.register_module( + Tasks.visual_question_answering, + module_name=Pipelines.visual_question_answering) +class VisualQuestionAnsweringPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a visual question answering pipeline for prediction + + Args: + model (MPlugForVisualQuestionAnswering): a model instance + preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance + """ + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + if isinstance(model, OfaForAllTasks): + preprocessor = OfaPreprocessor(model.model_dir) + elif isinstance(model, MPlugForAllTasks): + preprocessor = MPlugPreprocessor(model.model_dir) + model.model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Tensor], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + return inputs diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py new file mode 100644 index 00000000..1206ae08 --- /dev/null +++ b/modelscope/pipelines/nlp/__init__.py @@ -0,0 +1,90 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .automatic_post_editing_pipeline import AutomaticPostEditingPipeline + from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline + from .table_question_answering_pipeline import TableQuestionAnsweringPipeline + from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline + from .dialog_modeling_pipeline import DialogModelingPipeline + from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline + from .document_segmentation_pipeline import DocumentSegmentationPipeline + from .fasttext_sequence_classification_pipeline import FasttextSequenceClassificationPipeline + from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline + from .feature_extraction_pipeline import FeatureExtractionPipeline + from .fill_mask_pipeline import FillMaskPipeline + from .information_extraction_pipeline import InformationExtractionPipeline + from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline, \ + NamedEntityRecognitionThaiPipeline, \ + NamedEntityRecognitionVietPipeline + from .text_ranking_pipeline import TextRankingPipeline + from .sentence_embedding_pipeline import SentenceEmbeddingPipeline + from .text_classification_pipeline import TextClassificationPipeline + from .summarization_pipeline import SummarizationPipeline + from .translation_quality_estimation_pipeline import TranslationQualityEstimationPipeline + from .text_error_correction_pipeline import TextErrorCorrectionPipeline + from .text_generation_pipeline import TextGenerationPipeline + from .text2text_generation_pipeline import Text2TextGenerationPipeline + from .token_classification_pipeline import TokenClassificationPipeline + from .translation_pipeline import TranslationPipeline + from .word_segmentation_pipeline import WordSegmentationPipeline + from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline + from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline + from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ + WordSegmentationThaiPipeline + +else: + _import_structure = { + 'automatic_post_editing_pipeline': ['AutomaticPostEditingPipeline'], + 'conversational_text_to_sql_pipeline': + ['ConversationalTextToSqlPipeline'], + 'dialog_intent_prediction_pipeline': + ['DialogIntentPredictionPipeline'], + 'dialog_modeling_pipeline': ['DialogModelingPipeline'], + 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], + 'domain_classification_pipeline': + ['FasttextSequenceClassificationPipeline'], + 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], + 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], + 'feature_extraction_pipeline': ['FeatureExtractionPipeline'], + 'fill_mask_pipeline': ['FillMaskPipeline'], + 'information_extraction_pipeline': ['InformationExtractionPipeline'], + 'named_entity_recognition_pipeline': [ + 'NamedEntityRecognitionPipeline', + 'NamedEntityRecognitionThaiPipeline', + 'NamedEntityRecognitionVietPipeline' + ], + 'text_ranking_pipeline': ['TextRankingPipeline'], + 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], + 'summarization_pipeline': ['SummarizationPipeline'], + 'table_question_answering_pipeline': + ['TableQuestionAnsweringPipeline'], + 'text_classification_pipeline': ['TextClassificationPipeline'], + 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], + 'text_generation_pipeline': ['TextGenerationPipeline'], + 'text2text_generation_pipeline': ['Text2TextGenerationPipeline'], + 'token_classification_pipeline': ['TokenClassificationPipeline'], + 'translation_pipeline': ['TranslationPipeline'], + 'translation_quality_estimation_pipeline': + ['TranslationQualityEstimationPipeline'], + 'word_segmentation_pipeline': ['WordSegmentationPipeline'], + 'zero_shot_classification_pipeline': + ['ZeroShotClassificationPipeline'], + 'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], + 'multilingual_word_segmentation_pipeline': [ + 'MultilingualWordSegmentationPipeline', + 'WordSegmentationThaiPipeline' + ], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/nlp/automatic_post_editing_pipeline.py b/modelscope/pipelines/nlp/automatic_post_editing_pipeline.py new file mode 100644 index 00000000..83968586 --- /dev/null +++ b/modelscope/pipelines/nlp/automatic_post_editing_pipeline.py @@ -0,0 +1,158 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from html import unescape +from typing import Any, Dict + +import jieba +import numpy as np +import tensorflow as tf +from sacremoses import (MosesDetokenizer, MosesDetruecaser, + MosesPunctNormalizer, MosesTokenizer, MosesTruecaser) +from sentencepiece import SentencePieceProcessor +from tensorflow.contrib.seq2seq.python.ops import beam_search_ops + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config, ConfigFields +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.disable_eager_execution() + +logger = get_logger() + +__all__ = ['AutomaticPostEditingPipeline'] + + +@PIPELINES.register_module( + Tasks.translation, module_name=Pipelines.automatic_post_editing) +class AutomaticPostEditingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """Build an automatic post editing pipeline with a model dir. + + @param model: Model path for saved pb file + """ + super().__init__(model=model, **kwargs) + export_dir = model + self.cfg = Config.from_file( + os.path.join(export_dir, ModelFile.CONFIGURATION)) + joint_vocab_file = os.path.join( + export_dir, self.cfg[ConfigFields.preprocessor]['vocab']) + self.vocab = dict([(w.strip(), i) for i, w in enumerate( + open(joint_vocab_file, 'r', encoding='utf8'))]) + self.vocab_reverse = dict([(i, w.strip()) for i, w in enumerate( + open(joint_vocab_file, 'r', encoding='utf8'))]) + self.unk_id = self.cfg[ConfigFields.preprocessor].get('unk_id', -1) + strip_unk = self.cfg.get(ConfigFields.postprocessor, + {}).get('strip_unk', True) + self.unk_token = '' if strip_unk else self.cfg.get( + ConfigFields.postprocessor, {}).get('unk_token', '') + if self.unk_id == -1: + self.unk_id = len(self.vocab) - 1 + tf.reset_default_graph() + tf_config = tf.ConfigProto(allow_soft_placement=True) + tf_config.gpu_options.allow_growth = True + self._session = tf.Session(config=tf_config) + tf.saved_model.loader.load( + self._session, [tf.python.saved_model.tag_constants.SERVING], + export_dir) + default_graph = tf.get_default_graph() + self.input_src_id_placeholder = default_graph.get_tensor_by_name( + 'Placeholder:0') + self.input_src_len_placeholder = default_graph.get_tensor_by_name( + 'Placeholder_1:0') + self.input_mt_id_placeholder = default_graph.get_tensor_by_name( + 'Placeholder_2:0') + self.input_mt_len_placeholder = default_graph.get_tensor_by_name( + 'Placeholder_3:0') + output_id_beam = default_graph.get_tensor_by_name( + 'enc2enc/decoder/transpose:0') + output_len_beam = default_graph.get_tensor_by_name( + 'enc2enc/decoder/Minimum:0') + output_id = tf.cast( + tf.map_fn(lambda x: x[0], output_id_beam), dtype=tf.int64) + output_len = tf.map_fn(lambda x: x[0], output_len_beam) + self.output = {'output_ids': output_id, 'output_lens': output_len} + init = tf.global_variables_initializer() + local_init = tf.local_variables_initializer() + self._session.run([init, local_init]) + tf.saved_model.loader.load( + self._session, [tf.python.saved_model.tag_constants.SERVING], + export_dir) + + # preprocess + self._src_lang = self.cfg[ConfigFields.preprocessor]['src_lang'] + self._tgt_lang = self.cfg[ConfigFields.preprocessor]['tgt_lang'] + tok_escape = self.cfg[ConfigFields.preprocessor].get( + 'tokenize_escape', False) + src_tokenizer = MosesTokenizer(lang=self._src_lang) + mt_tokenizer = MosesTokenizer(lang=self._tgt_lang) + truecase_model = os.path.join( + export_dir, self.cfg[ConfigFields.preprocessor]['truecaser']) + truecaser = MosesTruecaser(load_from=truecase_model) + sp_model = os.path.join( + export_dir, self.cfg[ConfigFields.preprocessor]['sentencepiece']) + sp = SentencePieceProcessor() + sp.load(sp_model) + + self.src_preprocess = lambda x: ' '.join( + sp.encode_as_pieces( + truecaser.truecase( + src_tokenizer.tokenize( + x, return_str=True, escape=tok_escape), + return_str=True))) + self.mt_preprocess = lambda x: ' '.join( + sp.encode_as_pieces( + truecaser.truecase( + mt_tokenizer.tokenize( + x, return_str=True, escape=tok_escape), + return_str=True))) + + # post process, de-bpe, de-truecase, detok + detruecaser = MosesDetruecaser() + detokenizer = MosesDetokenizer(lang=self._tgt_lang) + self.postprocess_fun = lambda x: detokenizer.detokenize( + detruecaser.detruecase( + x.replace(' ▁', '@@').replace(' ', '').replace('@@', ' '). + strip()[1:], + return_str=True).split()) + + def preprocess(self, input: str) -> Dict[str, Any]: + src, mt = input.split('\005', 1) + src_sp, mt_sp = self.src_preprocess(src), self.mt_preprocess(mt) + input_src_ids = np.array( + [[self.vocab.get(w, self.unk_id) for w in src_sp.strip().split()]]) + input_mt_ids = np.array( + [[self.vocab.get(w, self.unk_id) for w in mt_sp.strip().split()]]) + input_src_lens = [len(x) for x in input_src_ids] + input_mt_lens = [len(x) for x in input_mt_ids] + feed_dict = { + self.input_src_id_placeholder: input_src_ids, + self.input_mt_id_placeholder: input_mt_ids, + self.input_src_len_placeholder: input_src_lens, + self.input_mt_len_placeholder: input_mt_lens + } + return feed_dict + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with self._session.as_default(): + sess_outputs = self._session.run(self.output, feed_dict=input) + return sess_outputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_ids, output_len = inputs['output_ids'][0], inputs[ + 'output_lens'][0] + output_ids = output_ids[:output_len - 1] # -1 for + output_tokens = ' '.join([ + self.vocab_reverse.get(wid, self.unk_token) for wid in output_ids + ]) + post_editing_output = self.postprocess_fun(output_tokens) + result = {OutputKeys.TRANSLATION: post_editing_output} + return result diff --git a/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py new file mode 100644 index 00000000..48df0c40 --- /dev/null +++ b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import torch +from text2sql_lgesql.utils.example import Example + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import StarForTextToSql +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import ConversationalTextToSqlPreprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['ConversationalTextToSqlPipeline'] + + +@PIPELINES.register_module( + Tasks.table_question_answering, + module_name=Pipelines.conversational_text_to_sql) +class ConversationalTextToSqlPipeline(Pipeline): + + def __init__(self, + model: Union[StarForTextToSql, str], + preprocessor: ConversationalTextToSqlPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a conversational text-to-sql prediction pipeline + + Args: + model (StarForTextToSql): a model instance + preprocessor (ConversationalTextToSqlPreprocessor): + a preprocessor instance + """ + model = model if isinstance( + model, StarForTextToSql) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = ConversationalTextToSqlPreprocessor(model.model_dir) + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) + result = {OutputKeys.OUTPUT: {OutputKeys.TEXT: sql}} + return result + + def _collate_fn(self, data): + return data diff --git a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py new file mode 100644 index 00000000..70374c50 --- /dev/null +++ b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import SpaceForDialogIntent +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import DialogIntentPredictionPreprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['DialogIntentPredictionPipeline'] + + +@PIPELINES.register_module( + Tasks.task_oriented_conversation, + module_name=Pipelines.dialog_intent_prediction) +class DialogIntentPredictionPipeline(Pipeline): + + def __init__(self, + model: Union[SpaceForDialogIntent, str], + preprocessor: DialogIntentPredictionPreprocessor = None, + **kwargs): + """Use `model` and `preprocessor` to create a dialog intent prediction pipeline + + Args: + model (str or SpaceForDialogIntent): Supply either a local model dir or a model id from the model hub, + or a SpaceForDialogIntent instance. + preprocessor (DialogIntentPredictionPreprocessor): An optional preprocessor instance. + """ + model = model if isinstance( + model, SpaceForDialogIntent) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = DialogIntentPredictionPreprocessor(model.model_dir) + self.model = model + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.categories = preprocessor.categories + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + import numpy as np + pred = inputs['pred'] + pos = np.where(pred == np.max(pred)) + + return { + OutputKeys.OUTPUT: { + OutputKeys.PREDICTION: pred, + OutputKeys.LABEL_POS: pos[0], + OutputKeys.LABEL: self.categories[pos[0][0]] + } + } diff --git a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py new file mode 100644 index 00000000..3215d765 --- /dev/null +++ b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import SpaceForDialogModeling +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import DialogModelingPreprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['DialogModelingPipeline'] + + +@PIPELINES.register_module( + Tasks.task_oriented_conversation, module_name=Pipelines.dialog_modeling) +class DialogModelingPipeline(Pipeline): + + def __init__(self, + model: Union[SpaceForDialogModeling, str], + preprocessor: DialogModelingPreprocessor = None, + **kwargs): + """Use `model` and `preprocessor` to create a dialog modeling pipeline for dialog response generation + + Args: + model (str or SpaceForDialogModeling): Supply either a local model dir or a model id from the model hub, + or a SpaceForDialogModeling instance. + preprocessor (DialogModelingPreprocessor): An optional preprocessor instance. + """ + model = model if isinstance( + model, SpaceForDialogModeling) else Model.from_pretrained(model) + self.model = model + if preprocessor is None: + preprocessor = DialogModelingPreprocessor(model.model_dir) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.preprocessor = preprocessor + + def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens( + inputs['resp']) + assert len(sys_rsp) > 2 + sys_rsp = sys_rsp[1:len(sys_rsp) - 1] + inputs[OutputKeys.OUTPUT] = sys_rsp + + return inputs diff --git a/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py new file mode 100644 index 00000000..9520c06f --- /dev/null +++ b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py @@ -0,0 +1,163 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import SpaceForDST +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import DialogStateTrackingPreprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['DialogStateTrackingPipeline'] + + +@PIPELINES.register_module( + Tasks.task_oriented_conversation, + module_name=Pipelines.dialog_state_tracking) +class DialogStateTrackingPipeline(Pipeline): + + def __init__(self, + model: Union[SpaceForDST, str], + preprocessor: DialogStateTrackingPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a dialog state tracking pipeline for + observation of dialog states tracking after many turns of open domain dialogue + + Args: + model (str or SpaceForDialogStateTracking): Supply either a local model dir or a model id + from the model hub, or a SpaceForDialogStateTracking instance. + preprocessor (DialogStateTrackingPreprocessor): An optional preprocessor instance. + """ + + model = model if isinstance( + model, SpaceForDST) else Model.from_pretrained(model) + self.model = model + if preprocessor is None: + preprocessor = DialogStateTrackingPreprocessor(model.model_dir) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + self.tokenizer = preprocessor.tokenizer + self.config = preprocessor.config + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + _inputs = inputs['inputs'] + _outputs = inputs['outputs'] + unique_ids = inputs['unique_ids'] + input_ids_unmasked = inputs['input_ids_unmasked'] + values = inputs['values'] + inform = inputs['inform'] + prefix = inputs['prefix'] + ds = inputs['ds'] + ds = predict_and_format(self.config, self.tokenizer, _inputs, + _outputs[2], _outputs[3], _outputs[4], + _outputs[5], unique_ids, input_ids_unmasked, + values, inform, prefix, ds) + + return {OutputKeys.OUTPUT: ds} + + +def predict_and_format(config, tokenizer, features, per_slot_class_logits, + per_slot_start_logits, per_slot_end_logits, + per_slot_refer_logits, ids, input_ids_unmasked, values, + inform, prefix, ds): + import re + + prediction_list = [] + dialog_state = ds + for i in range(len(ids)): + if int(ids[i].split('-')[2]) == 0: + dialog_state = {slot: 'none' for slot in config.dst_slot_list} + + prediction = {} + prediction_addendum = {} + for slot in config.dst_slot_list: + class_logits = per_slot_class_logits[slot][i] + start_logits = per_slot_start_logits[slot][i] + end_logits = per_slot_end_logits[slot][i] + refer_logits = per_slot_refer_logits[slot][i] + + input_ids = features['input_ids'][i].tolist() + class_label_id = int(features['class_label_id'][slot][i]) + start_pos = int(features['start_pos'][slot][i]) + end_pos = int(features['end_pos'][slot][i]) + refer_id = int(features['refer_id'][slot][i]) + + class_prediction = int(class_logits.argmax()) + start_prediction = int(start_logits.argmax()) + end_prediction = int(end_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + prediction['guid'] = ids[i].split('-') + prediction['class_prediction_%s' % slot] = class_prediction + prediction['class_label_id_%s' % slot] = class_label_id + prediction['start_prediction_%s' % slot] = start_prediction + prediction['start_pos_%s' % slot] = start_pos + prediction['end_prediction_%s' % slot] = end_prediction + prediction['end_pos_%s' % slot] = end_pos + prediction['refer_prediction_%s' % slot] = refer_prediction + prediction['refer_id_%s' % slot] = refer_id + prediction['input_ids_%s' % slot] = input_ids + + if class_prediction == config.dst_class_types.index('dontcare'): + dialog_state[slot] = 'dontcare' + elif class_prediction == config.dst_class_types.index( + 'copy_value'): + input_tokens = tokenizer.convert_ids_to_tokens( + input_ids_unmasked[i]) + dialog_state[slot] = ' '.join( + input_tokens[start_prediction:end_prediction + 1]) + dialog_state[slot] = re.sub('(^| )##', '', dialog_state[slot]) + elif 'true' in config.dst_class_types and class_prediction == config.dst_class_types.index( + 'true'): + dialog_state[slot] = 'true' + elif 'false' in config.dst_class_types and class_prediction == config.dst_class_types.index( + 'false'): + dialog_state[slot] = 'false' + elif class_prediction == config.dst_class_types.index('inform'): + # dialog_state[slot] = '§§' + inform[i][slot] + if isinstance(inform[i][slot], str): + dialog_state[slot] = inform[i][slot] + elif isinstance(inform[i][slot], list): + dialog_state[slot] = inform[i][slot][0] + # Referral case is handled below + + prediction_addendum['slot_prediction_%s' + % slot] = dialog_state[slot] + prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot] + + # Referral case. All other slot values need to be seen first in order + # to be able to do this correctly. + for slot in config.dst_slot_list: + class_logits = per_slot_class_logits[slot][i] + refer_logits = per_slot_refer_logits[slot][i] + + class_prediction = int(class_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + if 'refer' in config.dst_class_types and class_prediction == config.dst_class_types.index( + 'refer'): + # Only slots that have been mentioned before can be referred to. + # One can think of a situation where one slot is referred to in the same utterance. + # This phenomenon is however currently not properly covered in the training data + # label generation process. + dialog_state[slot] = dialog_state[config.dst_slot_list[ + refer_prediction - 1]] + prediction_addendum['slot_prediction_%s' % + slot] = dialog_state[slot] # Value update + + prediction.update(prediction_addendum) + prediction_list.append(prediction) + + return dialog_state diff --git a/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py new file mode 100644 index 00000000..325d3303 --- /dev/null +++ b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.nlp.gpt3.distributed_gpt3 import DistributedGPT3 +from modelscope.pipelines.base import DistributedPipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TextGenerationJiebaPreprocessor +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.text_generation, module_name=Pipelines.gpt3_generation) +class DistributedGPT3Pipeline(DistributedPipeline): + """This class is used to instantiate the gpt3 model. + """ + + model = None + + def __init__(self, model, preprocessor=None, **kwargs): + if preprocessor is None: + preprocessor = TextGenerationJiebaPreprocessor(model) + super().__init__(model, preprocessor=preprocessor, **kwargs) + assert hasattr(preprocessor, 'tokenizer') + + @classmethod + def _instantiate_one(cls, rank, model_dir, **kwargs): + cls.model = DistributedGPT3(model_dir, rank, **kwargs) + cls.model.eval() + + @classmethod + def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: + tokens = inputs['inputs']['input_ids'].cuda( + torch.cuda.current_device()) + return cls.model.generate(tokens) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + from modelscope.outputs import OutputKeys + return { + OutputKeys.TEXT: + self.preprocessor.tokenizer.detokenize(inputs[0].tolist()) + } diff --git a/modelscope/pipelines/nlp/distributed_plug_pipeline.py b/modelscope/pipelines/nlp/distributed_plug_pipeline.py new file mode 100644 index 00000000..8499f7ff --- /dev/null +++ b/modelscope/pipelines/nlp/distributed_plug_pipeline.py @@ -0,0 +1,110 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.nlp.plug import DistributedPlug +from modelscope.pipelines.base import DistributedPipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TextGenerationPreprocessor +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.text_generation, module_name=Pipelines.plug_generation) +class DistributedPlugPipeline(DistributedPipeline): + """This class is used to instantiate the plug model. + """ + + model = None + + def __init__(self, + model, + preprocessor=None, + first_sequence='sentence', + **kwargs): + """Create a plug pipeline instance. + + Args: + model: The model_id of plug(damo/nlp_plug_text-generation_27B). + The default path to damo/nlp_plug_text-generation_27B can be obtained by function + get_cache_dir("damo/nlp_plug_text-generation_27B"), the model should be downloaded to + this path before calling this class by model_id. + The model can be downloaded from the link on + https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. + After downloading, you should have a plug model structure like this: + /your/path/to/damo/nlp_plug_text-generation_27B + |_ config.json + |_ configuration.json + |_ ds_zero-offload_10B_config.json + |_ vocab.txt + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + preprocessor: The optional preprocessor, if not passed in, a TextGenerationPreprocessor will + be used as default. + first_sequence: The first_sequence key name if the input format is a dict. + kwargs: + sequence_length: The input sequence_length. + """ + if preprocessor is None: + preprocessor = TextGenerationPreprocessor( + model, + first_sequence=first_sequence, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model, preprocessor=preprocessor, **kwargs) + assert hasattr(preprocessor, 'tokenizer') + self.cls_token_id = preprocessor.tokenizer.cls_token_id + + @classmethod + def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: + with torch.no_grad(): + return cls.model.generate(inputs['inputs'], + **inputs['forward_params']) + + def _sanitize_parameters(self, **pipeline_parameters): + return {}, pipeline_parameters, {} + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + batch_size = inputs['input_ids'].shape[0] + dec_input_ids = torch.full([batch_size, 1], + self.cls_token_id, + dtype=torch.long) + inputs['dec_input_ids'] = dec_input_ids + res = super().forward(inputs, **forward_params) + return res + + @classmethod + def _instantiate_one(cls, rank, model_dir, **kwargs): + cls.model = DistributedPlug(model_dir, rank, **kwargs) + cls.model.eval() + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + from modelscope.outputs import OutputKeys + generate_context = inputs['generate_context'] + generate_context = ''.join( + self.preprocessor.tokenizer.convert_ids_to_tokens( + generate_context)).replace('[UNK]', '“').replace('##', '') + return {OutputKeys.TEXT: generate_context} diff --git a/modelscope/pipelines/nlp/document_segmentation_pipeline.py b/modelscope/pipelines/nlp/document_segmentation_pipeline.py new file mode 100644 index 00000000..00837bf3 --- /dev/null +++ b/modelscope/pipelines/nlp/document_segmentation_pipeline.py @@ -0,0 +1,175 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import re +from typing import Any, Dict, List, Union + +import numpy as np +import torch +from datasets import Dataset +from transformers.models.bert.modeling_bert import BertConfig + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import DocumentSegmentationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['DocumentSegmentationPipeline'] + + +@PIPELINES.register_module( + Tasks.document_segmentation, module_name=Pipelines.document_segmentation) +class DocumentSegmentationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: DocumentSegmentationPreprocessor = None, + **kwargs): + + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + + self.model_dir = model.model_dir + config = BertConfig.from_pretrained(model.model_dir, num_labels=2) + + self.document_segmentation_model = model.build_with_config( + config=config) + + if preprocessor is None: + preprocessor = DocumentSegmentationPreprocessor( + self.model_dir, config) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + self.preprocessor = preprocessor + + def __call__(self, documents: Union[List[str], str]) -> Dict[str, Any]: + output = self.predict(documents) + output = self.postprocess(output) + return output + + def predict(self, documents: Union[List[str], str]) -> Dict[str, Any]: + pred_samples = self.cut_documents(documents) + predict_examples = Dataset.from_dict(pred_samples) + + # Predict Feature Creation + predict_dataset = self.preprocessor(predict_examples) + num_examples = len( + predict_examples[self.preprocessor.context_column_name]) + num_samples = len( + predict_dataset[self.preprocessor.context_column_name]) + + predict_dataset.pop('segment_ids') + labels = predict_dataset.pop('labels') + sentences = predict_dataset.pop('sentences') + example_ids = predict_dataset.pop( + self.preprocessor.example_id_column_name) + + with torch.no_grad(): + input = { + key: torch.tensor(val) + for key, val in predict_dataset.items() + } + predictions = self.document_segmentation_model.forward( + **input).logits + + predictions = np.argmax(predictions, axis=2) + assert len(sentences) == len( + predictions), 'sample {} infer_sample {} prediction {}'.format( + num_samples, len(sentences), len(predictions)) + # Remove ignored index (special tokens) + true_predictions = [ + [ + self.preprocessor.label_list[p] + for (p, l) in zip(prediction, label) if l != -100 # noqa * + ] for prediction, label in zip(predictions, labels) + ] + + true_labels = [ + [ + self.preprocessor.label_list[l] + for (p, l) in zip(prediction, label) if l != -100 # noqa * + ] for prediction, label in zip(predictions, labels) + ] + + # Save predictions + out = [] + for i in range(num_examples): + out.append({'sentences': [], 'labels': [], 'predictions': []}) + + for prediction, sentence_list, label, example_id in zip( + true_predictions, sentences, true_labels, example_ids): + if len(label) < len(sentence_list): + label.append('B-EOP') + prediction.append('B-EOP') + assert len(sentence_list) == len(prediction), '{} {}'.format( + len(sentence_list), len(prediction)) + assert len(sentence_list) == len(label), '{} {}'.format( + len(sentence_list), len(label)) + out[example_id]['sentences'].extend(sentence_list) + out[example_id]['labels'].extend(label) + out[example_id]['predictions'].extend(prediction) + + return out + + def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + result = [] + list_count = len(inputs) + for num in range(list_count): + res = [] + for s, p in zip(inputs[num]['sentences'], + inputs[num]['predictions']): + s = s.strip() + if p == 'B-EOP': + s = ''.join([s, '\n\t']) + res.append(s) + + document = ('\t' + ''.join(res)) + result.append(document) + + if list_count == 1: + return {OutputKeys.TEXT: result[0]} + else: + return {OutputKeys.TEXT: result} + + def cut_documents(self, para: Union[List[str], str]): + document_list = para + if isinstance(para, str): + document_list = [para] + sentences = [] + labels = [] + example_id = [] + id = 0 + for document in document_list: + sentence = self.cut_sentence(document) + label = ['O'] * (len(sentence) - 1) + ['B-EOP'] + sentences.append(sentence) + labels.append(label) + example_id.append(id) + id += 1 + + return { + 'example_id': example_id, + 'sentences': sentences, + 'labels': labels + } + + def cut_sentence(self, para): + para = re.sub(r'([。!.!?\?])([^”’])', r'\1\n\2', para) # noqa * + para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa * + para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa * + para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para) # noqa * + para = para.rstrip() + return [_ for _ in para.split('\n') if _] diff --git a/modelscope/pipelines/nlp/faq_question_answering_pipeline.py b/modelscope/pipelines/nlp/faq_question_answering_pipeline.py new file mode 100644 index 00000000..3917f20c --- /dev/null +++ b/modelscope/pipelines/nlp/faq_question_answering_pipeline.py @@ -0,0 +1,70 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['FaqQuestionAnsweringPipeline'] + + +@PIPELINES.register_module( + Tasks.faq_question_answering, module_name=Pipelines.faq_question_answering) +class FaqQuestionAnsweringPipeline(Pipeline): + + def __init__(self, + model: Union[str, Model], + preprocessor: Preprocessor = None, + **kwargs): + model = Model.from_pretrained(model) if isinstance(model, + str) else model + if preprocessor is None: + preprocessor = Preprocessor.from_pretrained( + model.model_dir, **kwargs) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **pipeline_parameters): + return pipeline_parameters, pipeline_parameters, pipeline_parameters + + def get_sentence_embedding(self, inputs, max_len=None): + inputs = self.preprocessor.batch_encode(inputs, max_length=max_len) + sentence_vecs = self.model.forward_sentence_embedding(inputs) + sentence_vecs = sentence_vecs.detach().tolist() + return sentence_vecs + + def forward(self, inputs: [list, Dict[str, Any]], + **forward_params) -> Dict[str, Any]: + return self.model(inputs) + + def postprocess(self, inputs: [list, Dict[str, Any]], + **postprocess_params) -> Dict[str, Any]: + scores = inputs['scores'] + labels = [] + for item in scores: + tmplabels = [ + self.preprocessor.get_label(label_id) + for label_id in range(len(item)) + ] + labels.append(tmplabels) + + predictions = [] + for tmp_scores, tmp_labels in zip(scores.tolist(), labels): + prediction = [] + for score, label in zip(tmp_scores, tmp_labels): + prediction.append({ + OutputKeys.LABEL: label, + OutputKeys.SCORE: score + }) + predictions.append( + list( + sorted( + prediction, + key=lambda d: d[OutputKeys.SCORE], + reverse=True))) + + return {OutputKeys.OUTPUT: predictions} diff --git a/modelscope/pipelines/nlp/fasttext_sequence_classification_pipeline.py b/modelscope/pipelines/nlp/fasttext_sequence_classification_pipeline.py new file mode 100644 index 00000000..f10af88f --- /dev/null +++ b/modelscope/pipelines/nlp/fasttext_sequence_classification_pipeline.py @@ -0,0 +1,69 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict, Union + +import numpy as np +import sentencepiece +from fasttext import load_model +from fasttext.FastText import _FastText + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['FasttextSequenceClassificationPipeline'] + + +def sentencepiece_tokenize(sp_model, sent): + tokens = [] + for t in sp_model.EncodeAsPieces(sent): + s = t.strip() + if s: + tokens.append(s) + return ' '.join(tokens) + + +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.domain_classification) +class FasttextSequenceClassificationPipeline(Pipeline): + + def __init__(self, model: Union[str, _FastText], **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model: a model directory including model.bin and spm.model + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + super().__init__(model=model) + model_file = os.path.join(model, ModelFile.TORCH_MODEL_BIN_FILE) + spm_file = os.path.join(model, 'sentencepiece.model') + assert os.path.isdir(model) and os.path.exists(model_file) and os.path.exists(spm_file), \ + '`model` should be a directory contains `model.bin` and `sentencepiece.model`' + self.model = load_model(model_file) + self.spm = sentencepiece.SentencePieceProcessor() + self.spm.Load(spm_file) + + def preprocess(self, inputs: str) -> Dict[str, Any]: + text = inputs.strip() + text_sp = sentencepiece_tokenize(self.spm, text) + return {'text_sp': text_sp, 'text': text} + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + topk = inputs.get('topk', -1) + label, probs = self.model.predict(inputs['text_sp'], k=topk) + label = [x.replace('__label__', '') for x in label] + result = { + OutputKeys.LABEL: label[0], + OutputKeys.SCORE: probs[0], + OutputKeys.LABELS: label, + OutputKeys.SCORES: probs + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/nlp/feature_extraction_pipeline.py b/modelscope/pipelines/nlp/feature_extraction_pipeline.py new file mode 100644 index 00000000..e94e4337 --- /dev/null +++ b/modelscope/pipelines/nlp/feature_extraction_pipeline.py @@ -0,0 +1,83 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import NLPPreprocessor, Preprocessor +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['FeatureExtractionPipeline'] + + +@PIPELINES.register_module( + Tasks.feature_extraction, module_name=Pipelines.feature_extraction) +class FeatureExtractionPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + first_sequence='sentence', + **kwargs): + """Use `model` and `preprocessor` to create a nlp feature extraction pipeline for prediction + + Args: + model (str or Model): Supply either a local model dir which supported feature extraction task, or a + no-head model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + first_sequence: The key to read the sentence in. + sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. + + NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' + param will have no effect. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipe_ins = pipeline('feature_extraction', model='damo/nlp_structbert_feature-extraction_english-large') + >>> input = 'Everything you love is treasure' + >>> print(pipe_ins(input)) + + + """ + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + + if preprocessor is None: + preprocessor = NLPPreprocessor( + model.model_dir, + padding=kwargs.pop('padding', False), + sequence_length=kwargs.pop('sequence_length', 128)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + self.preprocessor = preprocessor + self.config = Config.from_file( + os.path.join(model.model_dir, ModelFile.CONFIGURATION)) + self.tokenizer = preprocessor.tokenizer + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return self.model(**inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + return { + OutputKeys.TEXT_EMBEDDING: + inputs[OutputKeys.TEXT_EMBEDDING].tolist() + } diff --git a/modelscope/pipelines/nlp/fill_mask_pipeline.py b/modelscope/pipelines/nlp/fill_mask_pipeline.py new file mode 100644 index 00000000..0f3446e6 --- /dev/null +++ b/modelscope/pipelines/nlp/fill_mask_pipeline.py @@ -0,0 +1,103 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['FillMaskPipeline'] + + +@PIPELINES.register_module(Tasks.fill_mask, module_name=Pipelines.fill_mask) +@PIPELINES.register_module( + Tasks.fill_mask, module_name=Pipelines.fill_mask_ponet) +class FillMaskPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + first_sequence: str = 'sentence', + **kwargs): + """The inference pipeline for all the fill mask sub-tasks. + + Args: + model (`str` or `Model` or module instance): A model instance or a model local dir + or a model id in the model hub. + preprocessor (`Preprocessor`, `optional`): A Preprocessor instance. + first_sequence (`str`, `optional`): The key to read the sentence in. + sequence_length (`int`, `optional`): Max sequence length in the user's custom scenario, default 128. + + NOTE1: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' + param will have no effect. + + Example1: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model='damo/nlp_structbert_fill-mask_english-large') + >>> input = 'Everything in [MASK] you call reality is really [MASK] a reflection of your [MASK].' + >>> print(pipeline_ins(input)) + Example2: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('fill-mask', model='damo/nlp_ponet_fill-mask_english-base') + >>> input = 'Everything in [MASK] you call reality is really [MASK] a reflection of your [MASK].' + >>> print(pipeline_ins(input)) + + NOTE2: Please pay attention to the model's special tokens. + If bert based model(bert, structbert, etc.) is used, the mask token is '[MASK]'. + If the xlm-roberta(xlm-roberta, veco, etc.) based model is used, the mask token is ''. + To view other examples plese check the tests/pipelines/test_fill_mask.py. + """ + + fill_mask_model = Model.from_pretrained(model) if isinstance( + model, str) else model + + if preprocessor is None: + preprocessor = Preprocessor.from_pretrained( + fill_mask_model.model_dir, + first_sequence=first_sequence, + second_sequence=None, + sequence_length=kwargs.pop('sequence_length', 128)) + fill_mask_model.eval() + assert hasattr( + preprocessor, 'mask_id' + ), 'The input preprocessor should have the mask_id attribute.' + super().__init__( + model=fill_mask_model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + return self.model(**inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): The model outputs. + The output should follow some rules: + 1. Values can be retrieved by keys(dict-like, or the __getitem__ method is overriden) + 2. 'logits' and 'input_ids' key exists. + Models in modelscope will return the output dataclass `modelscope.outputs.FillMaskModelOutput`. + Returns: + Dict[str, str]: the prediction results + """ + logits = inputs[OutputKeys.LOGITS].detach().cpu().numpy() + input_ids = inputs[OutputKeys.INPUT_IDS].detach().cpu().numpy() + pred_ids = np.argmax(logits, axis=-1) + rst_ids = np.where(input_ids == self.preprocessor.mask_id, pred_ids, + input_ids) + + pred_strings = [] + for ids in rst_ids: # batch + pred_string = self.preprocessor.decode( + ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=True) + pred_strings.append(pred_string) + + return {OutputKeys.TEXT: pred_strings} diff --git a/modelscope/pipelines/nlp/information_extraction_pipeline.py b/modelscope/pipelines/nlp/information_extraction_pipeline.py new file mode 100644 index 00000000..8ac85f43 --- /dev/null +++ b/modelscope/pipelines/nlp/information_extraction_pipeline.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (Preprocessor, + RelationExtractionPreprocessor) +from modelscope.utils.constant import Tasks + +__all__ = ['InformationExtractionPipeline'] + + +@PIPELINES.register_module( + Tasks.information_extraction, module_name=Pipelines.relation_extraction) +@PIPELINES.register_module( + Tasks.relation_extraction, module_name=Pipelines.relation_extraction) +class InformationExtractionPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = RelationExtractionPreprocessor( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return self.model(**inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + return inputs diff --git a/modelscope/pipelines/nlp/mglm_text_summarization_pipeline.py b/modelscope/pipelines/nlp/mglm_text_summarization_pipeline.py new file mode 100644 index 00000000..c6d03077 --- /dev/null +++ b/modelscope/pipelines/nlp/mglm_text_summarization_pipeline.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022 Zhipu.AI + +from typing import Any, Dict, Optional, Union + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.models.nlp import MGLMForTextSummarization +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (MGLMSummarizationPreprocessor, + Preprocessor) +from modelscope.utils.constant import Tasks + +__all__ = ['MGLMTextSummarizationPipeline'] + + +@PIPELINES.register_module( + group_key=Tasks.text_summarization, + module_name=Pipelines.mglm_text_summarization) +class MGLMTextSummarizationPipeline(Pipeline): + + def __init__(self, + model: Union[MGLMForTextSummarization, str], + preprocessor: [Preprocessor] = None, + *args, + **kwargs): + model = MGLMForTextSummarization(model) if isinstance(model, + str) else model + self.model = model + self.model.eval() + if preprocessor is None: + preprocessor = MGLMSummarizationPreprocessor() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + # define the forward pass + def forward(self, inputs: Union[Dict, str], + **forward_params) -> Dict[str, Any]: + inputs = {'text': inputs} if isinstance(inputs, str) else inputs + return self.model.generate(inputs) + + # format the outputs from pipeline + def postprocess(self, input, **kwargs) -> Dict[str, Any]: + return input diff --git a/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py b/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py new file mode 100644 index 00000000..56c3a041 --- /dev/null +++ b/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py @@ -0,0 +1,125 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (Preprocessor, + TokenClassificationPreprocessor, + WordSegmentationPreprocessorThai) +from modelscope.utils.constant import Tasks + +__all__ = [ + 'MultilingualWordSegmentationPipeline', 'WordSegmentationThaiPipeline' +] + + +@PIPELINES.register_module( + Tasks.word_segmentation, + module_name=Pipelines.multilingual_word_segmentation) +class MultilingualWordSegmentationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction + + Args: + model (str or Model): Supply either a local model dir which supported word segmentation task, or a + model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. + + To view other examples plese check the tests/pipelines/test_multilingual_word_segmentation.py. + """ + + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TokenClassificationPreprocessor( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.tokenizer = preprocessor.tokenizer + self.config = model.config + assert len(self.config.id2label) > 0 + self.id2label = self.config.id2label + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + text = inputs.pop(OutputKeys.TEXT) + with torch.no_grad(): + return { + **super().forward(inputs, **forward_params), OutputKeys.TEXT: + text + } + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + text = inputs['text'] + offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] + labels = [ + self.id2label[x] + for x in inputs['predictions'].squeeze(0).cpu().numpy() + ] + entities = [] + entity = {} + for label, offsets in zip(labels, offset_mapping): + if label[0] in 'BS': + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + entity = { + 'type': label[2:], + 'start': offsets[0], + 'end': offsets[1] + } + if label[0] in 'IES': + if entity: + entity['end'] = offsets[1] + if label[0] in 'ES': + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + entity = {} + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + + word_segments = [entity['span'] for entity in entities] + outputs = {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} + + return outputs + + +@PIPELINES.register_module( + Tasks.word_segmentation, module_name=Pipelines.word_segmentation_thai) +class WordSegmentationThaiPipeline(MultilingualWordSegmentationPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = WordSegmentationPreprocessorThai( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + outputs = super().postprocess(inputs, **postprocess_params) + word_segments = outputs[OutputKeys.OUTPUT] + word_segments = [seg.replace(' ', '') for seg in word_segments] + + return {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} diff --git a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py new file mode 100644 index 00000000..fdcf9e0f --- /dev/null +++ b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py @@ -0,0 +1,168 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (NERPreprocessorThai, NERPreprocessorViet, + Preprocessor, + TokenClassificationPreprocessor) +from modelscope.utils.constant import Tasks +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) + +__all__ = [ + 'NamedEntityRecognitionPipeline', 'NamedEntityRecognitionThaiPipeline', + 'NamedEntityRecognitionVietPipeline' +] + + +@PIPELINES.register_module( + Tasks.named_entity_recognition, + module_name=Pipelines.named_entity_recognition) +class NamedEntityRecognitionPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a nlp NER pipeline for prediction + + Args: + model (str or Model): Supply either a local model dir which supported NER task, or a + model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline(task='named-entity-recognition', + >>> model='damo/nlp_raner_named-entity-recognition_chinese-base-news') + >>> input = '这与温岭市新河镇的一个神秘的传说有关。' + >>> print(pipeline_ins(input)) + + To view other examples plese check the tests/pipelines/test_named_entity_recognition.py. + """ + + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TokenClassificationPreprocessor( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.tokenizer = preprocessor.tokenizer + self.config = model.config + assert len(self.config.id2label) > 0 + self.id2label = self.config.id2label + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + text = inputs.pop(OutputKeys.TEXT) + with torch.no_grad(): + return { + **self.model(**inputs, **forward_params), OutputKeys.TEXT: text + } + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): should be tensors from model + + Returns: + Dict[str, str]: the prediction results + """ + text = inputs['text'] + if OutputKeys.PREDICTIONS not in inputs: + logits = inputs[OutputKeys.LOGITS] + predictions = torch.argmax(logits[0], dim=-1) + else: + predictions = inputs[OutputKeys.PREDICTIONS].squeeze( + 0).cpu().numpy() + predictions = torch_nested_numpify(torch_nested_detach(predictions)) + offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] + + labels = [self.id2label[x] for x in predictions] + chunks = [] + chunk = {} + for label, offsets in zip(labels, offset_mapping): + if label[0] in 'BS': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = { + 'type': label[2:], + 'start': offsets[0], + 'end': offsets[1] + } + if label[0] in 'IES': + if chunk: + chunk['end'] = offsets[1] + + if label[0] in 'ES': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = {} + + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + + # for cws output + if len(chunks) > 0 and chunks[0]['type'] == 'cws': + spans = [ + chunk['span'] for chunk in chunks if chunk['span'].strip() + ] + seg_result = ' '.join(spans) + outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} + + # for ner outpus + else: + outputs = {OutputKeys.OUTPUT: chunks} + return outputs + + +@PIPELINES.register_module( + Tasks.named_entity_recognition, + module_name=Pipelines.named_entity_recognition_thai) +class NamedEntityRecognitionThaiPipeline(NamedEntityRecognitionPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = NERPreprocessorThai( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + +@PIPELINES.register_module( + Tasks.named_entity_recognition, + module_name=Pipelines.named_entity_recognition_viet) +class NamedEntityRecognitionVietPipeline(NamedEntityRecognitionPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = NERPreprocessorViet( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) diff --git a/modelscope/pipelines/nlp/sentence_embedding_pipeline.py b/modelscope/pipelines/nlp/sentence_embedding_pipeline.py new file mode 100644 index 00000000..cfa5c2f1 --- /dev/null +++ b/modelscope/pipelines/nlp/sentence_embedding_pipeline.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['SentenceEmbeddingPipeline'] + + +@PIPELINES.register_module( + Tasks.sentence_embedding, module_name=Pipelines.sentence_embedding) +class SentenceEmbeddingPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + first_sequence='first_sequence', + **kwargs): + """Use `model` and `preprocessor` to create a nlp text dual encoder then generates the text representation. + Args: + model (str or Model): Supply either a local model dir which supported the WS task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. + """ + model = Model.from_pretrained(model) if isinstance(model, + str) else model + if preprocessor is None: + preprocessor = Preprocessor.from_pretrained( + model.model_dir if isinstance(model, Model) else model, + first_sequence=first_sequence, + sequence_length=kwargs.pop('sequence_length', 128)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + return self.model(**inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, Any]: the predicted text representation + """ + embs = inputs['last_hidden_state'][:, 0].cpu().numpy() + num_sent = embs.shape[0] + if num_sent >= 2: + scores = np.dot(embs[0:1, ], np.transpose(embs[1:, ], + (1, 0))).tolist()[0] + else: + scores = [] + return {OutputKeys.TEXT_EMBEDDING: embs, OutputKeys.SCORES: scores} diff --git a/modelscope/pipelines/nlp/summarization_pipeline.py b/modelscope/pipelines/nlp/summarization_pipeline.py new file mode 100644 index 00000000..30dd4b30 --- /dev/null +++ b/modelscope/pipelines/nlp/summarization_pipeline.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import OfaForAllTasks +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.text_summarization, module_name=Pipelines.text_generation) +class SummarizationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: [Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a Summarization pipeline for prediction. + + Args: + model (str or Model): Supply either a local model dir which supported the summarization task, + or a model id from the model hub, or a model instance. + preprocessor (Preprocessor): An optional preprocessor instance. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/nlp/table_question_answering_pipeline.py b/modelscope/pipelines/nlp/table_question_answering_pipeline.py new file mode 100644 index 00000000..b75a8153 --- /dev/null +++ b/modelscope/pipelines/nlp/table_question_answering_pipeline.py @@ -0,0 +1,373 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, Union + +import json +import torch +from transformers import BertTokenizer + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import TableQuestionAnswering +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TableQuestionAnsweringPreprocessor +from modelscope.preprocessors.nlp.space_T_cn.fields.database import Database +from modelscope.preprocessors.nlp.space_T_cn.fields.struct import (Constant, + SQLQuery) +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['TableQuestionAnsweringPipeline'] + + +@PIPELINES.register_module( + Tasks.table_question_answering, + module_name=Pipelines.table_question_answering_pipeline) +class TableQuestionAnsweringPipeline(Pipeline): + + def __init__(self, + model: Union[TableQuestionAnswering, str], + preprocessor: TableQuestionAnsweringPreprocessor = None, + db: Database = None, + **kwargs): + """use `model` and `preprocessor` to create a table question answering prediction pipeline + + Args: + model (TableQuestionAnswering): a model instance + preprocessor (TableQuestionAnsweringPreprocessor): a preprocessor instance + db (Database): a database to store tables in the database + """ + model = model if isinstance( + model, TableQuestionAnswering) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TableQuestionAnsweringPreprocessor(model.model_dir) + + # initilize tokenizer + self.tokenizer = BertTokenizer( + os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) + + # initialize database + if db is None: + self.db = Database( + tokenizer=self.tokenizer, + table_file_path=os.path.join(model.model_dir, 'table.json'), + syn_dict_file_path=os.path.join(model.model_dir, + 'synonym.txt')) + else: + self.db = db + + constant = Constant() + self.agg_ops = constant.agg_ops + self.cond_ops = constant.cond_ops + self.cond_conn_ops = constant.cond_conn_ops + self.action_ops = constant.action_ops + self.max_select_num = constant.max_select_num + self.max_where_num = constant.max_where_num + self.col_type_dict = constant.col_type_dict + self.schema_link_dict = constant.schema_link_dict + self.limit_dict = constant.limit_dict + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def post_process_multi_turn(self, history_sql, result, table): + action = self.action_ops[result['action']] + headers = table['header_name'] + current_sql = result['sql'] + current_sql['from'] = [table['table_id']] + + if history_sql is None: + return current_sql + + if action == 'out_of_scripts': + return history_sql + + elif action == 'switch_table': + return current_sql + + elif action == 'restart': + return current_sql + + elif action == 'firstTurn': + return current_sql + + elif action == 'del_focus': + pre_final_sql = history_sql + pre_sels = [] + pre_aggs = [] + for idx, seli in enumerate(pre_final_sql['sel']): + if seli not in current_sql['sel']: + pre_sels.append(seli) + pre_aggs.append(pre_final_sql['agg'][idx]) + + if len(pre_sels) < 1: + pre_sels.append(len(headers)) + pre_aggs.append(0) + pre_final_sql['sel'] = pre_sels + pre_final_sql['agg'] = pre_aggs + + final_conds = [] + for condi in pre_final_sql['conds']: + if condi[0] < len(headers): + final_conds.append(condi) + if len(final_conds) < 1: + final_conds.append([len(headers), 2, 'Null']) + pre_final_sql['conds'] = final_conds + + return pre_final_sql + + elif action == 'change_agg_only': + pre_final_sql = history_sql + pre_sels = [] + pre_aggs = [] + for idx, seli in enumerate(pre_final_sql['sel']): + if seli in current_sql['sel']: + pre_sels.append(seli) + changed_aggi = -1 + for idx_single, aggi in enumerate(current_sql['agg']): + if current_sql['sel'][idx_single] == seli: + changed_aggi = aggi + pre_aggs.append(changed_aggi) + else: + pre_sels.append(seli) + pre_aggs.append(pre_final_sql['agg'][idx]) + pre_final_sql['sel'] = pre_sels + pre_final_sql['agg'] = pre_aggs + + return pre_final_sql + + elif action == 'change_focus_total': + pre_final_sql = history_sql + pre_sels = current_sql['sel'] + pre_aggs = current_sql['agg'] + + pre_final_sql['sel'] = pre_sels + pre_final_sql['agg'] = pre_aggs + for pre_condi in current_sql['conds']: + if pre_condi[0] < len(headers): + in_flag = False + for history_condi in history_sql['conds']: + if pre_condi[0] == history_condi[0]: + in_flag = True + if not in_flag: + pre_final_sql['conds'].append(pre_condi) + + return pre_final_sql + + elif action == 'del_cond': + pre_final_sql = history_sql + + final_conds = [] + + for idx, condi in enumerate(pre_final_sql['conds']): + if condi[0] not in current_sql['sel']: + final_conds.append(condi) + pre_final_sql['conds'] = final_conds + + final_conds = [] + for condi in pre_final_sql['conds']: + if condi[0] < len(headers): + final_conds.append(condi) + if len(final_conds) < 1: + final_conds.append([len(headers), 2, 'Null']) + pre_final_sql['conds'] = final_conds + + return pre_final_sql + + elif action == 'change_cond': + pre_final_sql = history_sql + final_conds = [] + + for idx, condi in enumerate(pre_final_sql['conds']): + in_single_flag = False + for single_condi in current_sql['conds']: + if condi[0] == single_condi[0]: + in_single_flag = True + final_conds.append(single_condi) + if not in_single_flag: + final_conds.append(condi) + pre_final_sql['conds'] = final_conds + + final_conds = [] + for condi in pre_final_sql['conds']: + if condi[0] < len(headers): + final_conds.append(condi) + if len(final_conds) < 1: + final_conds.append([len(headers), 2, 'Null', 'Null']) + pre_final_sql['conds'] = final_conds + + return pre_final_sql + + elif action == 'add_cond': + pre_final_sql = history_sql + final_conds = pre_final_sql['conds'] + for idx, condi in enumerate(current_sql['conds']): + if condi[0] < len(headers): + final_conds.append(condi) + pre_final_sql['conds'] = final_conds + + final_conds = [] + for condi in pre_final_sql['conds']: + if condi[0] < len(headers): + final_conds.append(condi) + if len(final_conds) < 1: + final_conds.append([len(headers), 2, 'Null']) + pre_final_sql['conds'] = final_conds + + return pre_final_sql + + else: + return current_sql + + def sql_dict_to_str(self, result, tables): + """ + convert sql struct to string + """ + table = tables[result['sql']['from'][0]] + header_names = table['header_name'] + ['空列'] + header_ids = table['header_id'] + ['null'] + sql = result['sql'] + + str_sel_list, sql_sel_list = [], [] + for idx, sel in enumerate(sql['sel']): + header_name = header_names[sel] + header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) + if sql['agg'][idx] == 0: + str_sel_list.append(header_name) + sql_sel_list.append(header_id) + else: + str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' + + header_name + ')') + sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' + + header_id + ')') + + str_cond_list, sql_cond_list = [], [] + where_conds, orderby_conds = [], [] + for cond in sql['conds']: + if cond[1] in [4, 5]: + orderby_conds.append(cond) + else: + where_conds.append(cond) + for cond in where_conds: + header_name = header_names[cond[0]] + if header_name == '空列': + continue + header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) + op = self.cond_ops[cond[1]] + value = cond[2] + str_cond_list.append('( ' + header_name + ' ' + op + ' "' + value + + '" )') + sql_cond_list.append('( ' + header_id + ' ' + op + ' "' + value + + '" )') + cond_str = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' + str_where_conds = cond_str.join(str_cond_list) + sql_where_conds = cond_str.join(sql_cond_list) + if len(orderby_conds) != 0: + str_orderby_column = ', '.join( + [header_names[cond[0]] for cond in orderby_conds]) + sql_orderby_column = ', '.join([ + '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) + for cond in orderby_conds + ]) + str_orderby_op = self.cond_ops[orderby_conds[0][1]] + str_orderby = '%s %s' % (str_orderby_column, str_orderby_op) + sql_orderby = '%s %s' % (sql_orderby_column, str_orderby_op) + limit_key = orderby_conds[0][2] + is_in, limit_num = False, -1 + for key in self.limit_dict: + if key in limit_key: + is_in = True + limit_num = self.limit_dict[key] + break + if is_in: + str_orderby += ' LIMIT %d' % (limit_num) + sql_orderby += ' LIMIT %d' % (limit_num) + else: + str_orderby = '' + + if len(str_cond_list) != 0 and len(str_orderby) != 0: + final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( + ', '.join(str_sel_list), table['table_name'], str_where_conds, + str_orderby) + final_sql = 'SELECT %s FROM `%s` WHERE %s ORDER BY %s' % ( + ', '.join(sql_sel_list), table['table_id'], sql_where_conds, + sql_orderby) + elif len(str_cond_list) != 0: + final_str = 'SELECT %s FROM %s WHERE %s' % ( + ', '.join(str_sel_list), table['table_name'], str_where_conds) + final_sql = 'SELECT %s FROM `%s` WHERE %s' % ( + ', '.join(sql_sel_list), table['table_id'], sql_where_conds) + elif len(str_orderby) != 0: + final_str = 'SELECT %s FROM %s ORDER BY %s' % ( + ', '.join(str_sel_list), table['table_name'], str_orderby) + final_sql = 'SELECT %s FROM `%s` ORDER BY %s' % ( + ', '.join(sql_sel_list), table['table_id'], sql_orderby) + else: + final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list), + table['table_name']) + final_sql = 'SELECT %s FROM `%s`' % (', '.join(sql_sel_list), + table['table_id']) + + sql = SQLQuery( + string=final_str, query=final_sql, sql_result=result['sql']) + + return sql + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + result = inputs['result'] + history_sql = inputs['history_sql'] + try: + result['sql'] = self.post_process_multi_turn( + history_sql=history_sql, + result=result, + table=self.db.tables[result['table_id']]) + except Exception: + result['sql'] = history_sql + sql = self.sql_dict_to_str(result=result, tables=self.db.tables) + + # add sqlite + if self.db.is_use_sqlite: + try: + cursor = self.db.connection_obj.cursor().execute(sql.query) + header_ids, header_names = [], [] + for description in cursor.description: + header_names.append(self.db.tables[result['table_id']] + ['headerid2name'].get( + description[0], description[0])) + header_ids.append(description[0]) + rows = [] + for res in cursor.fetchall(): + rows.append(list(res)) + tabledata = { + 'header_id': header_ids, + 'header_name': header_names, + 'rows': rows + } + except Exception as e: + logger.error(e) + tabledata = {'header_id': [], 'header_name': [], 'rows': []} + else: + tabledata = {'header_id': [], 'header_name': [], 'rows': []} + + output = { + OutputKeys.SQL_STRING: sql.string, + OutputKeys.SQL_QUERY: sql.query, + OutputKeys.HISTORY: result['sql'], + OutputKeys.QUERT_RESULT: tabledata, + } + + return {OutputKeys.OUTPUT: output} + + def _collate_fn(self, data): + return data diff --git a/modelscope/pipelines/nlp/text2text_generation_pipeline.py b/modelscope/pipelines/nlp/text2text_generation_pipeline.py new file mode 100644 index 00000000..a739df69 --- /dev/null +++ b/modelscope/pipelines/nlp/text2text_generation_pipeline.py @@ -0,0 +1,115 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Optional, Union + +import torch +from numpy import isin + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Text2TextGenerationPreprocessor +from modelscope.utils.config import use_task_specific_params +from modelscope.utils.constant import Tasks + +__all__ = ['Text2TextGenerationPipeline'] + +TRANSLATE_PIPELINES = [ + Pipelines.translation_en_to_de, + Pipelines.translation_en_to_ro, + Pipelines.translation_en_to_fr, +] + + +@PIPELINES.register_module( + Tasks.text2text_generation, module_name=Pipelines.text2text_generation) +@PIPELINES.register_module( + Tasks.text2text_generation, module_name=Pipelines.translation_en_to_de) +@PIPELINES.register_module( + Tasks.text2text_generation, module_name=Pipelines.translation_en_to_ro) +@PIPELINES.register_module( + Tasks.text2text_generation, module_name=Pipelines.translation_en_to_fr) +class Text2TextGenerationPipeline(Pipeline): + + def __init__( + self, + model: Union[Model, str], + preprocessor: Optional[Text2TextGenerationPreprocessor] = None, + first_sequence='sentence', + **kwargs): + """Use `model` and `preprocessor` to create a text to text generation pipeline for prediction. + + Args: + model (str or Model): Supply either a local model dir which supported the text generation task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + first_sequence: The key to read the first sentence in. + sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. + + NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' + param will have no effect. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline(task='text2text-generation', + >>> model='damo/nlp_t5_text2text-generation_chinese-base') + >>> sentence1 = '中国的首都位于。' + >>> print(pipeline_ins(sentence1)) + >>> # Or use the dict input: + >>> print(pipeline_ins({'sentence': sentence1})) + >>> # 北京 + + To view other examples plese check the tests/pipelines/test_text_generation.py. + """ + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = Text2TextGenerationPreprocessor( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 128)) + self.tokenizer = preprocessor.tokenizer + self.pipeline = model.pipeline.type + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + """ Provide specific preprocess for text2text generation pipeline in order to handl multi tasks + """ + if not isinstance(inputs, str): + raise ValueError(f'Not supported input type: {type(inputs)}') + + if self.pipeline in TRANSLATE_PIPELINES: + use_task_specific_params(self.model, self.pipeline) + inputs = self.model.config.prefix + inputs + + return super().preprocess(inputs, **preprocess_params) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + + forward_params['min_length'] = forward_params.get( + 'min_length', self.model.config.min_length) + forward_params['max_length'] = forward_params.get( + 'max_length', self.model.config.max_length) + + with torch.no_grad(): + output_ids = self.model.generate(**inputs, **forward_params) + return {'output_ids': output_ids} + + def postprocess(self, inputs: Dict[str, Tensor], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + output = self.tokenizer.decode( + inputs['output_ids'][0], + skip_special_tokens=True, + ) + return {OutputKeys.TEXT: output} diff --git a/modelscope/pipelines/nlp/text_classification_pipeline.py b/modelscope/pipelines/nlp/text_classification_pipeline.py new file mode 100644 index 00000000..771660a5 --- /dev/null +++ b/modelscope/pipelines/nlp/text_classification_pipeline.py @@ -0,0 +1,126 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import numpy as np + +from modelscope.metainfo import Pipelines, Preprocessors +from modelscope.models.base import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor +from modelscope.utils.constant import Fields, Tasks + + +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.sentiment_analysis) +@PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli) +@PIPELINES.register_module( + Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity) +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.text_classification) +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.sentiment_classification) +@PIPELINES.register_module( + Tasks.text_classification, module_name=Pipelines.sentence_similarity) +@PIPELINES.register_module( + Tasks.sentiment_classification, + module_name=Pipelines.sentiment_classification) +class TextClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Preprocessor = None, + **kwargs): + """The inference pipeline for all the text classification sub-tasks. + + Args: + model (`str` or `Model` or module instance): A model instance or a model local dir + or a model id in the model hub. + preprocessor (`Preprocessor`, `optional`): A Preprocessor instance. + first_sequence (`str`, `optional`): The key of the first sentence. + second_sequence (`str`, `optional`): The key of the second sentence. + sequence_length (`int`, `optional`): The sequence length. + id2label (`dict`, `optional`): The id-label mapping. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline('text-classification', + model='damo/nlp_structbert_sentence-similarity_chinese-base') + >>> input = ('这是个测试', '这也是个测试') + >>> print(pipeline_ins(input)) + + NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' and 'second_sequence' + param will have no affection. + """ + model = Model.from_pretrained(model) if isinstance(model, + str) else model + + if preprocessor is None: + if model.__class__.__name__ == 'OfaForAllTasks': + preprocessor = Preprocessor.from_pretrained( + model_name_or_path=model.model_dir, + type=Preprocessors.ofa_tasks_preprocessor, + field=Fields.multi_modal) + else: + first_sequence = kwargs.pop('first_sequence', 'first_sequence') + second_sequence = kwargs.pop('second_sequence', None) + preprocessor = Preprocessor.from_pretrained( + model if isinstance(model, str) else model.model_dir, + first_sequence=first_sequence, + second_sequence=second_sequence, + sequence_length=kwargs.pop('sequence_length', 512)) + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.id2label = kwargs.get('id2label') + if self.id2label is None and hasattr(self.preprocessor, 'id2label'): + self.id2label = self.preprocessor.id2label + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + if self.model.__class__.__name__ == 'OfaForAllTasks': + return super().forward(inputs, **forward_params) + return self.model(**inputs, **forward_params) + + def postprocess(self, + inputs: Dict[str, Any], + topk: int = 5) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (`Dict[str, Any]` or `TextClassificationModelOutput`): The model output, please check + the `TextClassificationModelOutput` class for details. + topk (int): The topk probs to take + Returns: + Dict[str, str]: the prediction results. + scores: The probabilities of each label. + labels: The real labels. + Label at index 0 is the smallest probability. + """ + if self.model.__class__.__name__ == 'OfaForAllTasks': + return inputs + else: + assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ + 'as a parameter or make sure the preprocessor has the attribute.' + logits = inputs[OutputKeys.LOGITS].cpu().numpy() + if logits.shape[0] == 1: + logits = logits[0] + + def softmax(logits): + exp = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) + return exp / exp.sum(axis=-1, keepdims=True) + + probs = softmax(logits) + num_classes = probs.shape[-1] + topk = min(topk, num_classes) + top_indices = np.argpartition(probs, -topk)[-topk:] + probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() + + def map_to_label(id): + return self.id2label[id] + + v_func = np.vectorize(map_to_label) + return { + OutputKeys.SCORES: probs, + OutputKeys.LABELS: v_func(top_indices).tolist() + } diff --git a/modelscope/pipelines/nlp/text_error_correction_pipeline.py b/modelscope/pipelines/nlp/text_error_correction_pipeline.py new file mode 100644 index 00000000..8e9bf85d --- /dev/null +++ b/modelscope/pipelines/nlp/text_error_correction_pipeline.py @@ -0,0 +1,81 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import BartForTextErrorCorrection +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TextErrorCorrectionPreprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['TextErrorCorrectionPipeline'] + + +@PIPELINES.register_module( + Tasks.text_error_correction, module_name=Pipelines.text_error_correction) +class TextErrorCorrectionPipeline(Pipeline): + + def __init__( + self, + model: Union[BartForTextErrorCorrection, str], + preprocessor: Optional[TextErrorCorrectionPreprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp text correction pipeline. + + Args: + model (BartForTextErrorCorrection): A model instance, or a model local dir, or a model id in the model hub. + preprocessor (TextErrorCorrectionPreprocessor): An optional preprocessor instance. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline( + >>> task='text-error-correction', model='damo/nlp_bart_text-error-correction_chinese') + >>> sentence1 = '随着中国经济突飞猛近,建造工业与日俱增' + >>> print(pipeline_ins(sentence1)) + + To view other examples plese check the tests/pipelines/test_text_error_correction.py. + """ + + model = model if isinstance( + model, + BartForTextErrorCorrection) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) + self.vocab = preprocessor.vocab + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Tensor], + **postprocess_params) -> Dict[str, str]: + """ + Args: + inputs (Dict[str, Tensor]) + Example: + { + 'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer + } + Returns: + Dict[str, str] + Example: + { + 'output': '随着中国经济突飞猛进,建造工业与日俱增' + } + + + """ + + pred_str = self.vocab.string( + inputs['predictions'], + '@@', + extra_symbols_to_ignore={self.vocab.pad()}) + + return {OutputKeys.OUTPUT: ''.join(pred_str.split())} diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py new file mode 100644 index 00000000..0490c8e7 --- /dev/null +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -0,0 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor, build_preprocessor +from modelscope.utils.chinese_utils import remove_space_between_chinese_chars +from modelscope.utils.constant import Fields, Tasks +from modelscope.utils.hub import read_config + +__all__ = ['TextGenerationPipeline'] + + +@PIPELINES.register_module( + Tasks.text_generation, module_name=Pipelines.text_generation) +class TextGenerationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + first_sequence='sentence', + **kwargs): + """Use `model` and `preprocessor` to create a generation pipeline for prediction. + + Args: + model (str or Model): Supply either a local model dir which supported the text generation task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + first_sequence: The key to read the first sentence in. + sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. + + NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' + param will have no effect. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline(task='text-generation', + >>> model='damo/nlp_palm2.0_text-generation_chinese-base') + >>> sentence1 = '本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:' + >>> '1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代' + >>> print(pipeline_ins(sentence1)) + >>> # Or use the dict input: + >>> print(pipeline_ins({'sentence': sentence1})) + + To view other examples plese check the tests/pipelines/test_text_generation.py. + """ + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + cfg = read_config(model.model_dir) + self.postprocessor = cfg.pop('postprocessor', 'decode') + if preprocessor is None: + preprocessor_cfg = cfg.preprocessor + preprocessor_cfg.update({ + 'model_dir': + model.model_dir, + 'first_sequence': + first_sequence, + 'second_sequence': + None, + 'sequence_length': + kwargs.pop('sequence_length', 128) + }) + preprocessor = build_preprocessor(preprocessor_cfg, Fields.nlp) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **pipeline_parameters): + return {}, pipeline_parameters, {} + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return self.model.generate(inputs, **forward_params) + + def decode(self, inputs) -> str: + tokenizer = self.preprocessor.tokenizer + return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) + + def sentence_piece(self, inputs) -> str: + tokenizer = self.preprocessor.tokenizer + return tokenizer.decode(inputs.tolist()) + + def roberta(self, inputs) -> str: + tokenizer = self.preprocessor.tokenizer + decoded = tokenizer.decode(inputs.tolist()) + return decoded.replace('', '. ').replace('', + '. ').replace('', '') + + def postprocess(self, inputs: Dict[str, Tensor], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + inputs = inputs['sequences'] + if isinstance(inputs, list) or len(inputs.shape) > 1: + inputs = inputs[0] + decoded = getattr(self, self.postprocessor)(inputs) + text = remove_space_between_chinese_chars(decoded) + return {OutputKeys.TEXT: text} diff --git a/modelscope/pipelines/nlp/text_ranking_pipeline.py b/modelscope/pipelines/nlp/text_ranking_pipeline.py new file mode 100644 index 00000000..9cee327b --- /dev/null +++ b/modelscope/pipelines/nlp/text_ranking_pipeline.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import numpy as np + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor, TextRankingPreprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['TextRankingPipeline'] + + +@PIPELINES.register_module( + Tasks.text_ranking, module_name=Pipelines.text_ranking) +class TextRankingPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a nlp word segment pipeline for prediction. + + Args: + model (str or Model): Supply either a local model dir which supported the WS task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. + """ + model = Model.from_pretrained(model) if isinstance(model, + str) else model + + if preprocessor is None: + preprocessor = Preprocessor.from_pretrained( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 128)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + return self.model(**inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the prediction results + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, Any]: the predicted text representation + """ + + def sigmoid(logits): + return np.exp(logits) / (1 + np.exp(logits)) + + logits = inputs[OutputKeys.LOGITS].squeeze(-1).detach().cpu().numpy() + pred_list = sigmoid(logits).tolist() + return {OutputKeys.SCORES: pred_list} diff --git a/modelscope/pipelines/nlp/token_classification_pipeline.py b/modelscope/pipelines/nlp/token_classification_pipeline.py new file mode 100644 index 00000000..4af187ee --- /dev/null +++ b/modelscope/pipelines/nlp/token_classification_pipeline.py @@ -0,0 +1,123 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) + +__all__ = ['TokenClassificationPipeline'] + + +@PIPELINES.register_module( + Tasks.token_classification, module_name=Pipelines.token_classification) +@PIPELINES.register_module( + Tasks.token_classification, module_name=Pipelines.part_of_speech) +@PIPELINES.register_module( + Tasks.token_classification, module_name=Pipelines.word_segmentation) +@PIPELINES.register_module( + Tasks.token_classification, module_name=Pipelines.named_entity_recognition) +@PIPELINES.register_module( + Tasks.part_of_speech, module_name=Pipelines.part_of_speech) +class TokenClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a token classification pipeline for prediction + + Args: + model (str or Model): A model instance or a model local dir or a model id in the model hub. + preprocessor (Preprocessor): a preprocessor instance, must not be None. + """ + model = Model.from_pretrained(model) if isinstance(model, + str) else model + + if preprocessor is None: + preprocessor = Preprocessor.from_pretrained( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 128)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.id2label = kwargs.get('id2label') + if self.id2label is None and hasattr(self.preprocessor, 'id2label'): + self.id2label = self.preprocessor.id2label + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + text = inputs.pop(OutputKeys.TEXT) + with torch.no_grad(): + return { + **self.model(**inputs, **forward_params), OutputKeys.TEXT: text + } + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): should be tensors from model + + Returns: + Dict[str, str]: the prediction results + """ + text = inputs['text'] + if not hasattr(inputs, 'predictions'): + logits = inputs[OutputKeys.LOGITS] + predictions = torch.argmax(logits[0], dim=-1) + else: + predictions = inputs[OutputKeys.PREDICTIONS].squeeze( + 0).cpu().numpy() + predictions = torch_nested_numpify(torch_nested_detach(predictions)) + offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] + + labels = [self.id2label[x] for x in predictions] + if len(labels) > len(offset_mapping): + labels = labels[1:-1] + chunks = [] + chunk = {} + for label, offsets in zip(labels, offset_mapping): + if label[0] in 'BS': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = { + 'type': label[2:], + 'start': offsets[0], + 'end': offsets[1] + } + if label[0] in 'IES': + if chunk: + chunk['end'] = offsets[1] + + if label[0] in 'ES': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = {} + + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + + # for cws outputs + if len(chunks) > 0 and chunks[0]['type'] == 'cws': + spans = [ + chunk['span'] for chunk in chunks if chunk['span'].strip() + ] + seg_result = ' '.join(spans) + outputs = {OutputKeys.OUTPUT: seg_result} + + # for ner outputs + else: + outputs = {OutputKeys.OUTPUT: chunks} + return outputs diff --git a/modelscope/pipelines/nlp/translation_pipeline.py b/modelscope/pipelines/nlp/translation_pipeline.py new file mode 100644 index 00000000..68a03631 --- /dev/null +++ b/modelscope/pipelines/nlp/translation_pipeline.py @@ -0,0 +1,127 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +import jieba +import numpy as np +import tensorflow as tf +from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer +from subword_nmt import apply_bpe + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.disable_eager_execution() + +logger = get_logger() + +__all__ = ['TranslationPipeline'] + + +@PIPELINES.register_module( + Tasks.translation, module_name=Pipelines.csanmt_translation) +class TranslationPipeline(Pipeline): + + def __init__(self, model: Model, **kwargs): + """Build a translation pipeline with a model dir or a model id in the model hub. + + Args: + model: A Model instance. + """ + super().__init__(model=model, **kwargs) + model = self.model.model_dir + tf.reset_default_graph() + + model_path = osp.join( + osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0') + + self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION)) + + self._src_vocab_path = osp.join( + model, self.cfg['dataset']['src_vocab']['file']) + self._src_vocab = dict([ + (w.strip(), i) for i, w in enumerate(open(self._src_vocab_path)) + ]) + self._trg_vocab_path = osp.join( + model, self.cfg['dataset']['trg_vocab']['file']) + self._trg_rvocab = dict([ + (i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path)) + ]) + + tf_config = tf.ConfigProto(allow_soft_placement=True) + tf_config.gpu_options.allow_growth = True + self._session = tf.Session(config=tf_config) + + self.input_wids = tf.placeholder( + dtype=tf.int64, shape=[None, None], name='input_wids') + self.output = {} + + # preprocess + self._src_lang = self.cfg['preprocessor']['src_lang'] + self._tgt_lang = self.cfg['preprocessor']['tgt_lang'] + self._src_bpe_path = osp.join( + model, self.cfg['preprocessor']['src_bpe']['file']) + + if self._src_lang == 'zh': + self._tok = jieba + else: + self._punct_normalizer = MosesPunctNormalizer(lang=self._src_lang) + self._tok = MosesTokenizer(lang=self._src_lang) + self._detok = MosesDetokenizer(lang=self._tgt_lang) + + self._bpe = apply_bpe.BPE(open(self._src_bpe_path)) + + # model + output = self.model(self.input_wids) + self.output.update(output) + + with self._session.as_default() as sess: + logger.info(f'loading model from {model_path}') + # load model + model_loader = tf.train.Saver(tf.global_variables()) + model_loader.restore(sess, model_path) + + def preprocess(self, input: str) -> Dict[str, Any]: + if self._src_lang == 'zh': + input_tok = self._tok.cut(input) + input_tok = ' '.join(list(input_tok)) + else: + input = self._punct_normalizer.normalize(input) + input_tok = self._tok.tokenize( + input, return_str=True, aggressive_dash_splits=True) + + input_bpe = self._bpe.process_line(input_tok) + input_ids = np.array([[ + self._src_vocab[w] + if w in self._src_vocab else self.cfg['model']['src_vocab_size'] + for w in input_bpe.strip().split() + ]]) + result = {'input_ids': input_ids} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with self._session.as_default(): + feed_dict = {self.input_wids: input['input_ids']} + sess_outputs = self._session.run(self.output, feed_dict=feed_dict) + return sess_outputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_seqs = inputs['output_seqs'][0] + wids = list(output_seqs[0]) + [0] + wids = wids[:wids.index(0)] + translation_out = ' '.join([ + self._trg_rvocab[wid] if wid in self._trg_rvocab else '' + for wid in wids + ]).replace('@@ ', '').replace('@@', '') + translation_out = self._detok.detokenize(translation_out.split()) + result = {OutputKeys.TRANSLATION: translation_out} + return result diff --git a/modelscope/pipelines/nlp/translation_quality_estimation_pipeline.py b/modelscope/pipelines/nlp/translation_quality_estimation_pipeline.py new file mode 100644 index 00000000..6ef203b9 --- /dev/null +++ b/modelscope/pipelines/nlp/translation_quality_estimation_pipeline.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +import os +from typing import Any, Dict, Union + +import numpy as np +import torch +from transformers import XLMRobertaTokenizer + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import BertForSequenceClassification +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['TranslationQualityEstimationPipeline'] + + +@PIPELINES.register_module( + Tasks.sentence_similarity, + module_name=Pipelines.translation_quality_estimation) +class TranslationQualityEstimationPipeline(Pipeline): + + def __init__(self, model: str, device: str = 'gpu', **kwargs): + super().__init__(model=model, device=device) + model_file = os.path.join(model, ModelFile.TORCH_MODEL_FILE) + with open(model_file, 'rb') as f: + buffer = io.BytesIO(f.read()) + self.tokenizer = XLMRobertaTokenizer.from_pretrained(model) + self.model = torch.jit.load( + buffer, map_location=self.device).to(self.device) + + def preprocess(self, inputs: Dict[str, Any]): + src_text = inputs['source_text'].strip() + tgt_text = inputs['target_text'].strip() + encoded_inputs = self.tokenizer.batch_encode_plus( + [[src_text, tgt_text]], + return_tensors='pt', + padding=True, + truncation=True) + input_ids = encoded_inputs['input_ids'].to(self.device) + attention_mask = encoded_inputs['attention_mask'].to(self.device) + inputs.update({ + 'input_ids': input_ids, + 'attention_mask': attention_mask + }) + return inputs + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + if 'input_ids' not in inputs: + inputs = self.preprocess(inputs) + res = self.model(inputs['input_ids'], inputs['attention_mask']) + result = { + OutputKeys.LABELS: '-1', + OutputKeys.SCORES: res[0].detach().squeeze().tolist() + } + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): input data dict + + Returns: + Dict[str, str]: the prediction results + """ + return inputs diff --git a/modelscope/pipelines/nlp/word_segmentation_pipeline.py b/modelscope/pipelines/nlp/word_segmentation_pipeline.py new file mode 100644 index 00000000..c57f6b93 --- /dev/null +++ b/modelscope/pipelines/nlp/word_segmentation_pipeline.py @@ -0,0 +1,129 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (Preprocessor, + TokenClassificationPreprocessor) +from modelscope.utils.constant import Tasks +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) + +__all__ = ['WordSegmentationPipeline'] + + +@PIPELINES.register_module( + Tasks.word_segmentation, module_name=Pipelines.word_segmentation) +class WordSegmentationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a nlp word segment pipeline for prediction. + + Args: + model (str or Model): Supply either a local model dir which supported the WS task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. + + NOTE: The preprocessor will first split the sentence into single characters, + then feed them into the tokenizer with the parameter is_split_into_words=True. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline(task='word-segmentation', + >>> model='damo/nlp_structbert_word-segmentation_chinese-base') + >>> sentence1 = '今天天气不错,适合出去游玩' + >>> print(pipeline_ins(sentence1)) + + To view other examples plese check the tests/pipelines/test_word_segmentation.py. + """ + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TokenClassificationPreprocessor( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 128)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.id2label = kwargs.get('id2label') + if self.id2label is None and hasattr(self.preprocessor, 'id2label'): + self.id2label = self.preprocessor.id2label + assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ + 'as a parameter or make sure the preprocessor has the attribute.' + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + text = inputs.pop(OutputKeys.TEXT) + with torch.no_grad(): + return { + **self.model(**inputs, **forward_params), OutputKeys.TEXT: text + } + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): should be tensors from model + + Returns: + Dict[str, str]: the prediction results + """ + text = inputs['text'] + logits = inputs[OutputKeys.LOGITS] + predictions = torch.argmax(logits[0], dim=-1) + logits = torch_nested_numpify(torch_nested_detach(logits)) + predictions = torch_nested_numpify(torch_nested_detach(predictions)) + offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] + + labels = [self.id2label[x] for x in predictions] + if len(labels) > len(offset_mapping): + labels = labels[1:-1] + chunks = [] + chunk = {} + for label, offsets in zip(labels, offset_mapping): + if label[0] in 'BS': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = { + 'type': label[2:], + 'start': offsets[0], + 'end': offsets[1] + } + if label[0] in 'IES': + if chunk: + chunk['end'] = offsets[1] + + if label[0] in 'ES': + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + chunk = {} + + if chunk: + chunk['span'] = text[chunk['start']:chunk['end']] + chunks.append(chunk) + + # for cws outputs + if len(chunks) > 0 and chunks[0]['type'] == 'cws': + spans = [ + chunk['span'] for chunk in chunks if chunk['span'].strip() + ] + seg_result = ' '.join(spans) + outputs = {OutputKeys.OUTPUT: seg_result} + + # for ner output + else: + outputs = {OutputKeys.OUTPUT: chunks} + return outputs diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py new file mode 100644 index 00000000..ecd538b9 --- /dev/null +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -0,0 +1,122 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +import torch +from scipy.special import softmax + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (Preprocessor, + ZeroShotClassificationPreprocessor) +from modelscope.utils.constant import Tasks + +__all__ = ['ZeroShotClassificationPipeline'] + + +@PIPELINES.register_module( + Tasks.zero_shot_classification, + module_name=Pipelines.zero_shot_classification) +class ZeroShotClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Preprocessor = None, + **kwargs): + """Use `model` and `preprocessor` to create a nlp zero shot classifiction for prediction. + + A zero-shot classification task is used to classify texts by prompts. + In a normal classification task, model may produce a positive label by the input text + like 'The ice cream is made of the high quality milk, it is so delicious' + In a zero-shot task, the sentence is converted to: + ['The ice cream is made of the high quality milk, it is so delicious', 'This means it is good'] + And: + ['The ice cream is made of the high quality milk, it is so delicious', 'This means it is bad'] + Then feed these sentences into the model and turn the task to a NLI task(entailment, contradiction), + and compare the output logits to give the original classification label. + + + Args: + model (str or Model): Supply either a local model dir which supported the task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline(task='zero-shot-classification', + >>> model='damo/nlp_structbert_zero-shot-classification_chinese-base') + >>> sentence1 = '全新突破 解放军运20版空中加油机曝光' + >>> labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] + >>> template = '这篇文章的标题是{}' + >>> print(pipeline_ins(sentence1, candidate_labels=labels, hypothesis_template=template)) + + To view other examples plese check the tests/pipelines/test_zero_shot_classification.py. + """ + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or Model' + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + self.entailment_id = 0 + self.contradiction_id = 2 + if preprocessor is None: + preprocessor = ZeroShotClassificationPreprocessor( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **kwargs): + preprocess_params = {} + postprocess_params = {} + if 'candidate_labels' in kwargs: + candidate_labels = self._parse_labels( + kwargs.pop('candidate_labels')) + preprocess_params['candidate_labels'] = candidate_labels + postprocess_params['candidate_labels'] = candidate_labels + else: + raise ValueError('You must include at least one label.') + preprocess_params['hypothesis_template'] = kwargs.pop( + 'hypothesis_template', '{}') + postprocess_params['multi_label'] = kwargs.pop('multi_label', False) + return preprocess_params, {}, postprocess_params + + def _parse_labels(self, labels): + if isinstance(labels, str): + labels = labels.replace(',', ',') # replace cn comma to en comma + labels = [ + label.strip() for label in labels.split(',') if label.strip() + ] + return labels + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + return self.model(**inputs, **forward_params) + + def postprocess(self, + inputs: Dict[str, Any], + candidate_labels, + multi_label=False) -> Dict[str, Any]: + """process the prediction results + Args: + inputs (Dict[str, Any]): _description_ + Returns: + Dict[str, Any]: the prediction results + """ + logits = inputs[OutputKeys.LOGITS].cpu().numpy() + if multi_label or len(candidate_labels) == 1: + logits = logits[..., [self.contradiction_id, self.entailment_id]] + scores = softmax(logits, axis=-1)[..., 1] + else: + logits = logits[..., self.entailment_id] + scores = softmax(logits, axis=-1) + reversed_index = list(reversed(scores.argsort())) + result = { + OutputKeys.LABELS: [candidate_labels[i] for i in reversed_index], + OutputKeys.SCORES: [scores[i].item() for i in reversed_index], + } + return result diff --git a/modelscope/pipelines/science/__init__.py b/modelscope/pipelines/science/__init__.py new file mode 100644 index 00000000..1f81809b --- /dev/null +++ b/modelscope/pipelines/science/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .protein_structure_pipeline import ProteinStructurePipeline + +else: + _import_structure = { + 'protein_structure_pipeline': ['ProteinStructurePipeline'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/science/protein_structure_pipeline.py b/modelscope/pipelines/science/protein_structure_pipeline.py new file mode 100644 index 00000000..3dc51c72 --- /dev/null +++ b/modelscope/pipelines/science/protein_structure_pipeline.py @@ -0,0 +1,215 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import time +from typing import Any, Dict, List, Optional, Union + +import json +import numpy as np +import torch +from unicore.utils import tensor_tree_map + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.models.science.unifold.config import model_config +from modelscope.models.science.unifold.data import protein, residue_constants +from modelscope.models.science.unifold.dataset import (UnifoldDataset, + load_and_process) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor, build_preprocessor +from modelscope.utils.constant import Fields, Frameworks, Tasks +from modelscope.utils.device import device_placement +from modelscope.utils.hub import read_config +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ProteinStructurePipeline'] + + +def automatic_chunk_size(seq_len): + if seq_len < 512: + chunk_size = 256 + elif seq_len < 1024: + chunk_size = 128 + elif seq_len < 2048: + chunk_size = 32 + elif seq_len < 3072: + chunk_size = 16 + else: + chunk_size = 1 + return chunk_size + + +def load_feature_for_one_target( + config, + data_folder, + seed=0, + is_multimer=False, + use_uniprot=False, + symmetry_group=None, +): + if not is_multimer: + uniprot_msa_dir = None + sequence_ids = ['A'] + if use_uniprot: + uniprot_msa_dir = data_folder + + else: + uniprot_msa_dir = data_folder + sequence_ids = open(os.path.join(data_folder, + 'chains.txt')).readline().split() + + if symmetry_group is None: + batch, _ = load_and_process( + config=config.data, + mode='predict', + seed=seed, + batch_idx=None, + data_idx=0, + is_distillation=False, + sequence_ids=sequence_ids, + monomer_feature_dir=data_folder, + uniprot_msa_dir=uniprot_msa_dir, + ) + else: + raise NotImplementedError + batch = UnifoldDataset.collater([batch]) + return batch + + +@PIPELINES.register_module( + Tasks.protein_structure, module_name=Pipelines.protein_structure) +class ProteinStructurePipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a protein structure pipeline for prediction. + + Args: + model (str or Model): Supply either a local model dir which supported the protein structure task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline(task='protein-structure', + >>> model='DPTech/uni-fold-monomer') + >>> protein = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + >>> print(pipeline_ins(protein)) + + """ + import copy + model_path = copy.deepcopy(model) if isinstance(model, str) else None + cfg = read_config(model_path) # only model is str + self.cfg = cfg + self.config = model_config( + cfg['pipeline']['model_name']) # alphafold config + model = model if isinstance( + model, Model) else Model.from_pretrained(model_path) + self.postprocessor = cfg.pop('postprocessor', None) + if preprocessor is None: + preprocessor_cfg = cfg.preprocessor + preprocessor = build_preprocessor(preprocessor_cfg, Fields.science) + model.eval() + model.model.inference_mode() + model.model_dir = model_path + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **pipeline_parameters): + return pipeline_parameters, pipeline_parameters, pipeline_parameters + + def _process_single(self, input, *args, **kwargs) -> Dict[str, Any]: + preprocess_params = kwargs.get('preprocess_params', {}) + forward_params = kwargs.get('forward_params', {}) + postprocess_params = kwargs.get('postprocess_params', {}) + out = self.preprocess(input, **preprocess_params) + with device_placement(self.framework, self.device_name): + with torch.no_grad(): + out = self.forward(out, **forward_params) + + out = self.postprocess(out, **postprocess_params) + return out + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + plddts = {} + ptms = {} + + output_dir = os.path.join(self.preprocessor.output_dir_base, + inputs['target_id']) + + pdbs = [] + for seed in range(self.cfg['pipeline']['times']): + cur_seed = hash((42, seed)) % 100000 + batch = load_feature_for_one_target( + self.config, + output_dir, + cur_seed, + is_multimer=inputs['is_multimer'], + use_uniprot=inputs['is_multimer'], + symmetry_group=self.preprocessor.symmetry_group, + ) + seq_len = batch['aatype'].shape[-1] + self.model.model.globals.chunk_size = automatic_chunk_size(seq_len) + + with torch.no_grad(): + batch = { + k: torch.as_tensor(v, device='cuda:0') + for k, v in batch.items() + } + out = self.model(batch) + + def to_float(x): + if x.dtype == torch.bfloat16 or x.dtype == torch.half: + return x.float() + else: + return x + + # Toss out the recycling dimensions --- we don't need them anymore + batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch) + batch = tensor_tree_map(to_float, batch) + out = tensor_tree_map(lambda t: t[0, ...], out[0]) + out = tensor_tree_map(to_float, out) + batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch) + out = tensor_tree_map(lambda x: np.array(x.cpu()), out) + + plddt = out['plddt'] + mean_plddt = np.mean(plddt) + plddt_b_factors = np.repeat( + plddt[..., None], residue_constants.atom_type_num, axis=-1) + # TODO: , may need to reorder chains, based on entity_ids + cur_protein = protein.from_prediction( + features=batch, result=out, b_factors=plddt_b_factors) + cur_save_name = (f'{cur_seed}') + plddts[cur_save_name] = str(mean_plddt) + if inputs[ + 'is_multimer'] and self.preprocessor.symmetry_group is None: + ptms[cur_save_name] = str(np.mean(out['iptm+ptm'])) + with open(os.path.join(output_dir, cur_save_name + '.pdb'), + 'w') as f: + f.write(protein.to_pdb(cur_protein)) + pdbs.append(protein.to_pdb(cur_protein)) + + logger.info('plddts:' + str(plddts)) + model_name = self.cfg['pipeline']['model_name'] + score_name = f'{model_name}' + plddt_fname = score_name + '_plddt.json' + + with open(os.path.join(output_dir, plddt_fname), 'w') as f: + json.dump(plddts, f, indent=4) + if ptms: + logger.info('ptms' + str(ptms)) + ptm_fname = score_name + '_ptm.json' + with open(os.path.join(output_dir, ptm_fname), 'w') as f: + json.dump(ptms, f, indent=4) + + return pdbs + + def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params): + return inputs diff --git a/modelscope/pipelines/util.py b/modelscope/pipelines/util.py new file mode 100644 index 00000000..2c2c7751 --- /dev/null +++ b/modelscope/pipelines/util.py @@ -0,0 +1,84 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import List, Optional, Union + +from modelscope.hub.api import HubApi +from modelscope.hub.file_download import model_file_download +from modelscope.utils.config import Config +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def is_config_has_model(cfg_file): + try: + cfg = Config.from_file(cfg_file) + return hasattr(cfg, 'model') + except Exception as e: + logger.error(f'parse config file {cfg_file} failed: {e}') + return False + + +def is_official_hub_path(path: Union[str, List], + revision: Optional[str] = DEFAULT_MODEL_REVISION): + """ Whether path is an official hub name or a valid local + path to official hub directory. + """ + + def is_official_hub_impl(path): + if osp.exists(path): + cfg_file = osp.join(path, ModelFile.CONFIGURATION) + return osp.exists(cfg_file) + else: + try: + _ = HubApi().get_model(path, revision=revision) + return True + except Exception as e: + logger.warning(f'get model exception: {e}') + return False + + if isinstance(path, str): + return is_official_hub_impl(path) + else: + results = [is_official_hub_impl(m) for m in path] + all_true = all(results) + any_true = any(results) + if any_true and not all_true: + raise ValueError( + f'some model are hub address, some are not, model list: {path}' + ) + + return all_true + + +def is_model(path: Union[str, List]): + """ whether path is a valid modelhub path and containing model config + """ + + def is_modelhub_path_impl(path): + if osp.exists(path): + cfg_file = osp.join(path, ModelFile.CONFIGURATION) + if osp.exists(cfg_file): + return is_config_has_model(cfg_file) + else: + return False + else: + try: + cfg_file = model_file_download(path, ModelFile.CONFIGURATION) + return is_config_has_model(cfg_file) + except Exception: + return False + + if isinstance(path, str): + return is_modelhub_path_impl(path) + else: + results = [is_modelhub_path_impl(m) for m in path] + all_true = all(results) + any_true = any(results) + if any_true and not all_true: + raise ValueError( + f'some models are hub address, some are not, model list: {path}' + ) + + return all_true diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py new file mode 100644 index 00000000..0db1c7e0 --- /dev/null +++ b/modelscope/preprocessors/__init__.py @@ -0,0 +1,80 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .base import Preprocessor + from .builder import PREPROCESSORS, build_preprocessor + from .common import Compose, ToTensor, Filter + from .asr import WavToScp + from .audio import LinearAECAndFbank + from .image import (LoadImage, load_image, + ImageColorEnhanceFinetunePreprocessor, + ImageInstanceSegmentationPreprocessor, + ImageDenoisePreprocessor) + from .kws import WavToLists + from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) + from .nlp import ( + DocumentSegmentationPreprocessor, FaqQuestionAnsweringPreprocessor, + FillMaskPoNetPreprocessor, NLPPreprocessor, + NLPTokenizerPreprocessorBase, PassageRankingPreprocessor, + TextRankingPreprocessor, RelationExtractionPreprocessor, + SentenceEmbeddingPreprocessor, SequenceClassificationPreprocessor, + TokenClassificationPreprocessor, TextErrorCorrectionPreprocessor, + TextGenerationPreprocessor, Text2TextGenerationPreprocessor, Tokenize, + WordSegmentationBlankSetToLabelPreprocessor, + MGLMSummarizationPreprocessor, ZeroShotClassificationPreprocessor, + TextGenerationJiebaPreprocessor, SentencePiecePreprocessor, + DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, + DialogStateTrackingPreprocessor, ConversationalTextToSqlPreprocessor, + TableQuestionAnsweringPreprocessor, NERPreprocessorViet, + NERPreprocessorThai, WordSegmentationPreprocessorThai) + from .video import ReadVideoData, MovieSceneSegmentationPreprocessor + +else: + _import_structure = { + 'base': ['Preprocessor'], + 'builder': ['PREPROCESSORS', 'build_preprocessor'], + 'common': ['Compose', 'ToTensor', 'Filter'], + 'audio': ['LinearAECAndFbank'], + 'asr': ['WavToScp'], + 'video': ['ReadVideoData', 'MovieSceneSegmentationPreprocessor'], + 'image': [ + 'LoadImage', 'load_image', 'ImageColorEnhanceFinetunePreprocessor', + 'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor' + ], + 'kws': ['WavToLists'], + 'multi_modal': ['OfaPreprocessor', 'MPlugPreprocessor'], + 'nlp': [ + 'DocumentSegmentationPreprocessor', + 'FaqQuestionAnsweringPreprocessor', 'FillMaskPoNetPreprocessor', + 'NLPPreprocessor', 'NLPTokenizerPreprocessorBase', + 'TextRankingPreprocessor', 'RelationExtractionPreprocessor', + 'SentenceEmbeddingPreprocessor', + 'SequenceClassificationPreprocessor', + 'TokenClassificationPreprocessor', + 'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor', + 'Tokenize', 'Text2TextGenerationPreprocessor', + 'WordSegmentationBlankSetToLabelPreprocessor', + 'MGLMSummarizationPreprocessor', + 'ZeroShotClassificationPreprocessor', + 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', + 'NERPreprocessorViet', 'NERPreprocessorThai', + 'WordSegmentationPreprocessorThai', + 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', + 'DialogStateTrackingPreprocessor', + 'ConversationalTextToSqlPreprocessor', + 'TableQuestionAnsweringPreprocessor' + ], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/asr.py b/modelscope/preprocessors/asr.py new file mode 100644 index 00000000..91bf5860 --- /dev/null +++ b/modelscope/preprocessors/asr.py @@ -0,0 +1,264 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict, List, Union + +from modelscope.metainfo import Preprocessors +from modelscope.models.base import Model +from modelscope.utils.constant import Fields, Frameworks +from .base import Preprocessor +from .builder import PREPROCESSORS + +__all__ = ['WavToScp'] + + +@PREPROCESSORS.register_module( + Fields.audio, module_name=Preprocessors.wav_to_scp) +class WavToScp(Preprocessor): + """generate audio scp from wave or ark + """ + + def __init__(self): + pass + + def __call__(self, + model: Model = None, + recog_type: str = None, + audio_format: str = None, + audio_in: Union[str, bytes] = None, + audio_fs: int = None) -> Dict[str, Any]: + assert model is not None, 'preprocess model is empty' + assert recog_type is not None and len( + recog_type) > 0, 'preprocess recog_type is empty' + assert audio_format is not None, 'preprocess audio_format is empty' + assert audio_in is not None, 'preprocess audio_in is empty' + + self.am_model = model + out = self.forward(self.am_model.forward(), recog_type, audio_format, + audio_in, audio_fs) + return out + + def forward(self, model: Dict[str, + Any], recog_type: str, audio_format: str, + audio_in: Union[str, bytes], audio_fs: int) -> Dict[str, Any]: + assert len(recog_type) > 0, 'preprocess recog_type is empty' + assert len(audio_format) > 0, 'preprocess audio_format is empty' + assert len( + model['am_model']) > 0, 'preprocess model[am_model] is empty' + assert len(model['am_model_path'] + ) > 0, 'preprocess model[am_model_path] is empty' + assert os.path.exists( + model['am_model_path']), 'preprocess am_model_path does not exist' + assert len(model['model_workspace'] + ) > 0, 'preprocess model[model_workspace] is empty' + assert os.path.exists(model['model_workspace'] + ), 'preprocess model_workspace does not exist' + assert len(model['model_config'] + ) > 0, 'preprocess model[model_config] is empty' + + rst = { + # the recognition model dir path + 'model_workspace': model['model_workspace'], + # the am model name + 'am_model': model['am_model'], + # the am model file path + 'am_model_path': model['am_model_path'], + # the asr type setting, eg: test dev train wav + 'recog_type': recog_type, + # the asr audio format setting, eg: wav, pcm, kaldi_ark, tfrecord + 'audio_format': audio_format, + # the recognition model config dict + 'model_config': model['model_config'], + # the sample rate of audio_in + 'audio_fs': audio_fs + } + + if isinstance(audio_in, str): + # wav file path or the dataset path + rst['wav_path'] = audio_in + + out = self.config_checking(rst) + out = self.env_setting(out) + if audio_format == 'wav': + out['audio_lists'] = self.scp_generation_from_wav(out) + elif audio_format == 'kaldi_ark': + out['audio_lists'] = self.scp_generation_from_ark(out) + elif audio_format == 'tfrecord': + out['audio_lists'] = os.path.join(out['wav_path'], 'data.records') + elif audio_format == 'pcm': + out['audio_lists'] = audio_in + + return out + + def config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """config checking + """ + + assert inputs['model_config'].__contains__( + 'type'), 'model type does not exist' + inputs['model_type'] = inputs['model_config']['type'] + + if inputs['model_type'] == Frameworks.torch: + assert inputs['model_config'].__contains__( + 'batch_size'), 'batch_size does not exist' + assert inputs['model_config'].__contains__( + 'am_model_config'), 'am_model_config does not exist' + assert inputs['model_config'].__contains__( + 'asr_model_config'), 'asr_model_config does not exist' + assert inputs['model_config'].__contains__( + 'asr_model_wav_config'), 'asr_model_wav_config does not exist' + + am_model_config: str = os.path.join( + inputs['model_workspace'], + inputs['model_config']['am_model_config']) + assert os.path.exists( + am_model_config), 'am_model_config does not exist' + inputs['am_model_config'] = am_model_config + + asr_model_config: str = os.path.join( + inputs['model_workspace'], + inputs['model_config']['asr_model_config']) + assert os.path.exists( + asr_model_config), 'asr_model_config does not exist' + + asr_model_wav_config: str = os.path.join( + inputs['model_workspace'], + inputs['model_config']['asr_model_wav_config']) + assert os.path.exists( + asr_model_wav_config), 'asr_model_wav_config does not exist' + + if inputs['audio_format'] == 'wav' or inputs[ + 'audio_format'] == 'pcm': + inputs['asr_model_config'] = asr_model_wav_config + else: + inputs['asr_model_config'] = asr_model_config + + if inputs['model_config'].__contains__('mvn_file'): + mvn_file = os.path.join(inputs['model_workspace'], + inputs['model_config']['mvn_file']) + assert os.path.exists(mvn_file), 'mvn_file does not exist' + inputs['mvn_file'] = mvn_file + + elif inputs['model_type'] == Frameworks.tf: + assert inputs['model_config'].__contains__( + 'vocab_file'), 'vocab_file does not exist' + vocab_file: str = os.path.join( + inputs['model_workspace'], + inputs['model_config']['vocab_file']) + assert os.path.exists(vocab_file), 'vocab file does not exist' + inputs['vocab_file'] = vocab_file + + assert inputs['model_config'].__contains__( + 'am_mvn_file'), 'am_mvn_file does not exist' + am_mvn_file: str = os.path.join( + inputs['model_workspace'], + inputs['model_config']['am_mvn_file']) + assert os.path.exists(am_mvn_file), 'am mvn file does not exist' + inputs['am_mvn_file'] = am_mvn_file + + else: + raise ValueError('model type is mismatched') + + return inputs + + def env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + # run with datasets, should set datasets_path and text_path + if inputs['recog_type'] != 'wav': + inputs['datasets_path'] = inputs['wav_path'] + + # run with datasets, and audio format is waveform + if inputs['audio_format'] == 'wav': + inputs['wav_path'] = os.path.join(inputs['datasets_path'], + 'wav', inputs['recog_type']) + inputs['reference_text'] = os.path.join( + inputs['datasets_path'], 'transcript', 'data.text') + assert os.path.exists( + inputs['reference_text']), 'reference text does not exist' + + # run with datasets, and audio format is kaldi_ark + elif inputs['audio_format'] == 'kaldi_ark': + inputs['wav_path'] = os.path.join(inputs['datasets_path'], + inputs['recog_type']) + inputs['reference_text'] = os.path.join( + inputs['wav_path'], 'data.text') + assert os.path.exists( + inputs['reference_text']), 'reference text does not exist' + + # run with datasets, and audio format is tfrecord + elif inputs['audio_format'] == 'tfrecord': + inputs['wav_path'] = os.path.join(inputs['datasets_path'], + inputs['recog_type']) + inputs['reference_text'] = os.path.join( + inputs['wav_path'], 'data.txt') + assert os.path.exists( + inputs['reference_text']), 'reference text does not exist' + inputs['idx_text'] = os.path.join(inputs['wav_path'], + 'data.idx') + assert os.path.exists( + inputs['idx_text']), 'idx text does not exist' + + # set asr model language + if 'lang' in inputs['model_config']: + inputs['model_lang'] = inputs['model_config']['lang'] + else: + inputs['model_lang'] = 'zh-cn' + + return inputs + + def scp_generation_from_wav(self, inputs: Dict[str, Any]) -> List[Any]: + """scp generation from waveform files + """ + from easyasr.common import asr_utils + + # find all waveform files + wav_list = [] + if inputs['recog_type'] == 'wav': + file_path = inputs['wav_path'] + if os.path.isfile(file_path): + if file_path.endswith('.wav') or file_path.endswith('.WAV'): + wav_list.append(file_path) + else: + wav_dir: str = inputs['wav_path'] + wav_list = asr_utils.recursion_dir_all_wav(wav_list, wav_dir) + + list_count: int = len(wav_list) + inputs['wav_count'] = list_count + + # store all wav into audio list + audio_lists = [] + j: int = 0 + while j < list_count: + wav_file = wav_list[j] + wave_key: str = os.path.splitext(os.path.basename(wav_file))[0] + item = {'key': wave_key, 'file': wav_file} + audio_lists.append(item) + j += 1 + + return audio_lists + + def scp_generation_from_ark(self, inputs: Dict[str, Any]) -> List[Any]: + """scp generation from kaldi ark file + """ + + ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp') + ark_file_path = os.path.join(inputs['wav_path'], 'data.ark') + assert os.path.exists(ark_scp_path), 'data.scp does not exist' + assert os.path.exists(ark_file_path), 'data.ark does not exist' + + with open(ark_scp_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + + # store all ark item into audio list + audio_lists = [] + for line in lines: + outs = line.strip().split(' ') + if len(outs) == 2: + key = outs[0] + sub = outs[1].split(':') + if len(sub) == 2: + nums = sub[1] + content = ark_file_path + ':' + nums + item = {'key': key, 'file': content} + audio_lists.append(item) + + return audio_lists diff --git a/modelscope/preprocessors/audio.py b/modelscope/preprocessors/audio.py new file mode 100644 index 00000000..1e659218 --- /dev/null +++ b/modelscope/preprocessors/audio.py @@ -0,0 +1,217 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +import os +from typing import Any, Dict, Tuple, Union + +import numpy as np +import scipy.io.wavfile as wav +import torch + +from modelscope.fileio import File +from modelscope.preprocessors import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields + + +def load_kaldi_feature_transform(filename): + fp = open(filename, 'r') + all_str = fp.read() + pos1 = all_str.find('AddShift') + pos2 = all_str.find('[', pos1) + pos3 = all_str.find(']', pos2) + mean = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ') + pos1 = all_str.find('Rescale') + pos2 = all_str.find('[', pos1) + pos3 = all_str.find(']', pos2) + scale = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ') + fp.close() + return mean, scale + + +class Feature: + r"""Extract feat from one utterance. + """ + + def __init__(self, + fbank_config, + feat_type='spec', + mvn_file=None, + cuda=False): + r""" + + Args: + fbank_config (dict): + feat_type (str): + raw: do nothing + fbank: use kaldi.fbank + spec: Real/Imag + logpow: log(1+|x|^2) + mvn_file (str): the path of data file for mean variance normalization + cuda: + """ + self.fbank_config = fbank_config + self.feat_type = feat_type + self.n_fft = fbank_config['frame_length'] * fbank_config[ + 'sample_frequency'] // 1000 + self.hop_length = fbank_config['frame_shift'] * fbank_config[ + 'sample_frequency'] // 1000 + self.window = torch.hamming_window(self.n_fft, periodic=False) + + self.mvn = False + if mvn_file is not None and os.path.exists(mvn_file): + print(f'loading mvn file: {mvn_file}') + shift, scale = load_kaldi_feature_transform(mvn_file) + self.shift = torch.from_numpy(shift) + self.scale = torch.from_numpy(scale) + self.mvn = True + if cuda: + self.window = self.window.cuda() + if self.mvn: + self.shift = self.shift.cuda() + self.scale = self.scale.cuda() + + def compute(self, utt): + r""" + + Args: + utt: in [-32768, 32767] range + + Returns: + [..., T, F] + """ + if self.feat_type == 'raw': + return utt + elif self.feat_type == 'fbank': + # have to use local import before modelscope framework supoort lazy loading + import torchaudio.compliance.kaldi as kaldi + if len(utt.shape) == 1: + utt = utt.unsqueeze(0) + feat = kaldi.fbank(utt, **self.fbank_config) + elif self.feat_type == 'spec': + spec = torch.stft( + utt / 32768, + self.n_fft, + self.hop_length, + self.n_fft, + self.window, + center=False, + return_complex=True) + feat = torch.cat([spec.real, spec.imag], dim=-2).permute(-1, -2) + elif self.feat_type == 'logpow': + spec = torch.stft( + utt, + self.n_fft, + self.hop_length, + self.n_fft, + self.window, + center=False, + return_complex=True) + abspow = torch.abs(spec)**2 + feat = torch.log(1 + abspow).permute(-1, -2) + return feat + + def normalize(self, feat): + if self.mvn: + feat = feat + self.shift + feat = feat * self.scale + return feat + + +@PREPROCESSORS.register_module(Fields.audio) +class LinearAECAndFbank(Preprocessor): + SAMPLE_RATE = 16000 + + def __init__(self, io_config): + import MinDAEC + self.trunc_length = 7200 * self.SAMPLE_RATE + self.linear_aec_delay = io_config['linear_aec_delay'] + self.feature = Feature(io_config['fbank_config'], + io_config['feat_type'], io_config['mvn']) + self.mitaec = MinDAEC.load() + self.mask_on_mic = io_config['mask_on'] == 'nearend_mic' + + def __call__(self, data: Union[Tuple, Dict[str, Any]]) -> Dict[str, Any]: + """ Linear filtering the near end mic and far end audio, then extract the feature. + + Args: + data: Dict with two keys and correspond audios: "nearend_mic" and "farend_speech". + + Returns: + Dict with two keys and Tensor values: "base" linear filtered audio,and "feature" + """ + if isinstance(data, tuple): + nearend_mic, fs = self.load_wav(data[0]) + farend_speech, fs = self.load_wav(data[1]) + nearend_speech = np.zeros_like(nearend_mic) + else: + # read files + nearend_mic, fs = self.load_wav(data['nearend_mic']) + farend_speech, fs = self.load_wav(data['farend_speech']) + if 'nearend_speech' in data: + nearend_speech, fs = self.load_wav(data['nearend_speech']) + else: + nearend_speech = np.zeros_like(nearend_mic) + + out_mic, out_ref, out_linear, out_echo = self.mitaec.do_linear_aec( + nearend_mic, farend_speech) + # fix 20ms linear aec delay by delaying the target speech + extra_zeros = np.zeros([int(self.linear_aec_delay * fs)]) + nearend_speech = np.concatenate([extra_zeros, nearend_speech]) + # truncate files to the same length + flen = min( + len(out_mic), len(out_ref), len(out_linear), len(out_echo), + len(nearend_speech)) + fstart = 0 + flen = min(flen, self.trunc_length) + nearend_mic, out_ref, out_linear, out_echo, nearend_speech = ( + out_mic[fstart:flen], out_ref[fstart:flen], + out_linear[fstart:flen], out_echo[fstart:flen], + nearend_speech[fstart:flen]) + + # extract features (frames, [mic, linear, ref, aes?]) + feat = torch.FloatTensor() + + nearend_mic = torch.from_numpy(np.float32(nearend_mic)) + fbank_nearend_mic = self.feature.compute(nearend_mic) + feat = torch.cat([feat, fbank_nearend_mic], dim=1) + + out_linear = torch.from_numpy(np.float32(out_linear)) + fbank_out_linear = self.feature.compute(out_linear) + feat = torch.cat([feat, fbank_out_linear], dim=1) + + out_echo = torch.from_numpy(np.float32(out_echo)) + fbank_out_echo = self.feature.compute(out_echo) + feat = torch.cat([feat, fbank_out_echo], dim=1) + + # feature transform + feat = self.feature.normalize(feat) + + # prepare target + if nearend_speech is not None: + nearend_speech = torch.from_numpy(np.float32(nearend_speech)) + + if self.mask_on_mic: + base = nearend_mic + else: + base = out_linear + out_data = {'base': base, 'target': nearend_speech, 'feature': feat} + return out_data + + @staticmethod + def load_wav(inputs): + import librosa + if isinstance(inputs, bytes): + inputs = io.BytesIO(inputs) + elif isinstance(inputs, str): + file_bytes = File.read(inputs) + inputs = io.BytesIO(file_bytes) + else: + raise TypeError(f'Unsupported input type: {type(inputs)}.') + sample_rate, data = wav.read(inputs) + if len(data.shape) > 1: + raise ValueError('modelscope error:The audio must be mono.') + if sample_rate != LinearAECAndFbank.SAMPLE_RATE: + data = librosa.resample(data, sample_rate, + LinearAECAndFbank.SAMPLE_RATE) + return data.astype(np.float32), LinearAECAndFbank.SAMPLE_RATE diff --git a/modelscope/preprocessors/base.py b/modelscope/preprocessors/base.py new file mode 100644 index 00000000..38500561 --- /dev/null +++ b/modelscope/preprocessors/base.py @@ -0,0 +1,264 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, Dict, Optional, Sequence + +from modelscope.metainfo import Models, Preprocessors +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks +from modelscope.utils.hub import read_config, snapshot_download +from modelscope.utils.logger import get_logger +from .builder import build_preprocessor + +logger = get_logger(__name__) + +PREPROCESSOR_MAP = { + # nlp + # bart + (Models.bart, Tasks.text_error_correction): + Preprocessors.text_error_correction, + + # bert + (Models.bert, Tasks.backbone): + Preprocessors.sen_cls_tokenizer, + (Models.bert, Tasks.document_segmentation): + Preprocessors.document_segmentation, + (Models.bert, Tasks.fill_mask): + Preprocessors.fill_mask, + (Models.bert, Tasks.sentence_embedding): + Preprocessors.sentence_embedding, + (Models.bert, Tasks.text_classification): + Preprocessors.sen_cls_tokenizer, + (Models.bert, Tasks.nli): + Preprocessors.sen_cls_tokenizer, + (Models.bert, Tasks.sentiment_classification): + Preprocessors.sen_cls_tokenizer, + (Models.bert, Tasks.sentence_similarity): + Preprocessors.sen_cls_tokenizer, + (Models.bert, Tasks.zero_shot_classification): + Preprocessors.sen_cls_tokenizer, + (Models.bert, Tasks.text_ranking): + Preprocessors.text_ranking, + (Models.bert, Tasks.part_of_speech): + Preprocessors.token_cls_tokenizer, + (Models.bert, Tasks.token_classification): + Preprocessors.token_cls_tokenizer, + (Models.bert, Tasks.word_segmentation): + Preprocessors.token_cls_tokenizer, + + # bloom + (Models.bloom, Tasks.backbone): + Preprocessors.text_gen_tokenizer, + + # gpt_neo + # gpt_neo may have different preprocessors, but now only one + (Models.gpt_neo, Tasks.backbone): + Preprocessors.sentence_piece, + + # gpt3 has different preprocessors by different sizes of models, so they are not listed here. + + # palm_v2 + (Models.palm, Tasks.backbone): + Preprocessors.text_gen_tokenizer, + + # T5 + (Models.T5, Tasks.backbone): + Preprocessors.text2text_gen_preprocessor, + (Models.T5, Tasks.text2text_generation): + Preprocessors.text2text_gen_preprocessor, + + # deberta_v2 + (Models.deberta_v2, Tasks.backbone): + Preprocessors.sen_cls_tokenizer, + (Models.deberta_v2, Tasks.fill_mask): + Preprocessors.fill_mask, + + # ponet + (Models.ponet, Tasks.fill_mask): + Preprocessors.fill_mask_ponet, + + # structbert + (Models.structbert, Tasks.backbone): + Preprocessors.sen_cls_tokenizer, + (Models.structbert, Tasks.fill_mask): + Preprocessors.fill_mask, + (Models.structbert, Tasks.faq_question_answering): + Preprocessors.faq_question_answering_preprocessor, + (Models.structbert, Tasks.text_classification): + Preprocessors.sen_cls_tokenizer, + (Models.structbert, Tasks.nli): + Preprocessors.sen_cls_tokenizer, + (Models.structbert, Tasks.sentiment_classification): + Preprocessors.sen_cls_tokenizer, + (Models.structbert, Tasks.sentence_similarity): + Preprocessors.sen_cls_tokenizer, + (Models.structbert, Tasks.zero_shot_classification): + Preprocessors.sen_cls_tokenizer, + (Models.structbert, Tasks.part_of_speech): + Preprocessors.token_cls_tokenizer, + (Models.structbert, Tasks.token_classification): + Preprocessors.token_cls_tokenizer, + (Models.structbert, Tasks.word_segmentation): + Preprocessors.token_cls_tokenizer, + + # veco + (Models.veco, Tasks.backbone): + Preprocessors.sen_cls_tokenizer, + (Models.veco, Tasks.fill_mask): + Preprocessors.fill_mask, + (Models.veco, Tasks.text_classification): + Preprocessors.sen_cls_tokenizer, + (Models.veco, Tasks.nli): + Preprocessors.sen_cls_tokenizer, + (Models.veco, Tasks.sentiment_classification): + Preprocessors.sen_cls_tokenizer, + (Models.veco, Tasks.sentence_similarity): + Preprocessors.sen_cls_tokenizer, + + # space +} + + +class Preprocessor(ABC): + + def __init__(self, mode=ModeKeys.INFERENCE, *args, **kwargs): + self._mode = mode + self.device = int( + os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else None + pass + + @abstractmethod + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + pass + + @property + def mode(self): + return self._mode + + @mode.setter + def mode(self, value): + self._mode = value + + @classmethod + def from_pretrained(cls, + model_name_or_path: str, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + cfg_dict: Config = None, + preprocessor_mode=ModeKeys.INFERENCE, + **kwargs): + """Instantiate a preprocessor from local directory or remote model repo. Note + that when loading from remote, the model revision can be specified. + + Args: + model_name_or_path(str): A model dir or a model id used to load the preprocessor out. + revision(str, `optional`): The revision used when the model_name_or_path is + a model id of the remote hub. default `master`. + cfg_dict(Config, `optional`): An optional config. If provided, it will replace + the config read out of the `model_name_or_path` + preprocessor_mode(str, `optional`): Specify the working mode of the preprocessor, can be `train`, `eval`, + or `inference`. Default value `inference`. + The preprocessor field in the config may contain two sub preprocessors: + >>> { + >>> "train": { + >>> "type": "some-train-preprocessor" + >>> }, + >>> "val": { + >>> "type": "some-eval-preprocessor" + >>> } + >>> } + In this scenario, the `train` preprocessor will be loaded in the `train` mode, the `val` preprocessor + will be loaded in the `eval` or `inference` mode. The `mode` field in the preprocessor class + will be assigned in all the modes. + Or just one: + >>> { + >>> "type": "some-train-preprocessor" + >>> } + In this scenario, the sole preprocessor will be loaded in all the modes, + and the `mode` field in the preprocessor class will be assigned. + + **kwargs: + task(str, `optional`): The `Tasks` enumeration value to replace the task value + read out of config in the `model_name_or_path`. + This is useful when the preprocessor does not have a `type` field and the task to be used is not + equal to the task of which the model is saved. + Other kwargs will be directly fed into the preprocessor, to replace the default configs. + + Returns: + The preprocessor instance. + + Examples: + >>> from modelscope.preprocessors import Preprocessor + >>> Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-base') + + """ + if not os.path.exists(model_name_or_path): + model_dir = snapshot_download( + model_name_or_path, revision=revision) + else: + model_dir = model_name_or_path + if cfg_dict is None: + cfg = read_config(model_dir) + else: + cfg = cfg_dict + task = cfg.task + if 'task' in kwargs: + task = kwargs.pop('task') + field_name = Tasks.find_field_by_task(task) + if 'field' in kwargs: + field_name = kwargs.pop('field') + sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val' + + if not hasattr(cfg, 'preprocessor') or len(cfg.preprocessor) == 0: + logger.warn('No preprocessor field found in cfg.') + preprocessor_cfg = ConfigDict() + else: + preprocessor_cfg = cfg.preprocessor + + if 'type' not in preprocessor_cfg: + if sub_key in preprocessor_cfg: + sub_cfg = getattr(preprocessor_cfg, sub_key) + else: + logger.warn(f'No {sub_key} key and type key found in ' + f'preprocessor domain of configuration.json file.') + sub_cfg = preprocessor_cfg + else: + sub_cfg = preprocessor_cfg + + sub_cfg.update({'model_dir': model_dir}) + sub_cfg.update(kwargs) + if 'type' in sub_cfg: + if isinstance(sub_cfg, Sequence): + # TODO: for Sequence, need adapt to `mode` and `mode_dir` args, + # and add mode for Compose or other plans + raise NotImplementedError('Not supported yet!') + sub_cfg = deepcopy(sub_cfg) + + preprocessor = build_preprocessor(sub_cfg, field_name) + else: + logger.warn( + f'Cannot find available config to build preprocessor at mode {preprocessor_mode}, ' + f'current config: {sub_cfg}. trying to build by task and model information.' + ) + model_cfg = getattr(cfg, 'model', ConfigDict()) + model_type = model_cfg.type if hasattr( + model_cfg, 'type') else getattr(model_cfg, 'model_type', None) + if task is None or model_type is None: + logger.warn( + f'Find task: {task}, model type: {model_type}. ' + f'Insufficient information to build preprocessor, skip building preprocessor' + ) + return None + if (model_type, task) not in PREPROCESSOR_MAP: + logger.warn( + f'No preprocessor key {(model_type, task)} found in PREPROCESSOR_MAP, ' + f'skip building preprocessor.') + return None + + sub_cfg = ConfigDict({ + 'type': PREPROCESSOR_MAP[(model_type, task)], + **sub_cfg + }) + preprocessor = build_preprocessor(sub_cfg, field_name) + preprocessor.mode = preprocessor_mode + return preprocessor diff --git a/modelscope/preprocessors/builder.py b/modelscope/preprocessors/builder.py new file mode 100644 index 00000000..918f8d17 --- /dev/null +++ b/modelscope/preprocessors/builder.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.config import ConfigDict +from modelscope.utils.constant import Fields +from modelscope.utils.registry import Registry, build_from_cfg + +PREPROCESSORS = Registry('preprocessors') + + +def build_preprocessor(cfg: ConfigDict, + field_name: str = None, + default_args: dict = None): + """ build preprocessor given model config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for model object. + field_name (str, optional): application field name, refer to + :obj:`Fields` for more details + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, PREPROCESSORS, group_key=field_name, default_args=default_args) diff --git a/modelscope/preprocessors/common.py b/modelscope/preprocessors/common.py new file mode 100644 index 00000000..aa1db84c --- /dev/null +++ b/modelscope/preprocessors/common.py @@ -0,0 +1,143 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import time +from collections.abc import Sequence +from typing import Mapping + +import numpy as np +import torch + +from .builder import PREPROCESSORS, build_preprocessor + + +@PREPROCESSORS.register_module() +class Compose(object): + """Compose a data pipeline with a sequence of transforms. + Args: + transforms (list[dict | callable]): + Either config dicts of transforms or transform objects. + profiling (bool, optional): If set True, will profile and + print preprocess time for each step. + """ + + def __init__(self, transforms, field_name=None, profiling=False): + assert isinstance(transforms, Sequence) + self.profiling = profiling + self.transforms = [] + self.field_name = field_name + for transform in transforms: + if isinstance(transform, dict): + if self.field_name is None: + transform = build_preprocessor(transform, field_name) + else: + # if not found key in field_name, try field_name=None(default_group) + try: + transform = build_preprocessor(transform, field_name) + except KeyError: + transform = build_preprocessor(transform, None) + elif callable(transform): + pass + else: + raise TypeError('transform must be callable or a dict, but got' + f' {type(transform)}') + self.transforms.append(transform) + + def __call__(self, data): + for t in self.transforms: + if self.profiling: + start = time.time() + + data = t(data) + + if self.profiling: + print(f'{t} time {time.time()-start}') + + if data is None: + return None + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += f'\n {t}' + format_string += '\n)' + return format_string + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not isinstance(data, str): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PREPROCESSORS.register_module() +class ToTensor(object): + """Convert target object to tensor. + + Args: + keys (Sequence[str]): Key of data to be converted to Tensor. + Only valid when data is type of `Mapping`. If `keys` is None, + all values of keys ​​will be converted to tensor by default. + """ + + def __init__(self, keys=None): + self.keys = keys + + def __call__(self, data): + if isinstance(data, Mapping): + if self.keys is None: + self.keys = list(data.keys()) + + for key in self.keys: + data[key] = to_tensor(data[key]) + else: + data = to_tensor(data) + + return data + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PREPROCESSORS.register_module() +class Filter(object): + """This is usually the last stage of the dataloader transform. + Only data of reserved keys will be kept and passed directly to the model, others will be removed. + + Args: + keys (Sequence[str]): Keys of data to be reserved, others will be removed. + """ + + def __init__(self, reserved_keys): + self.reserved_keys = reserved_keys + + def __call__(self, data): + assert isinstance(data, Mapping) + + reserved_data = {} + for key in self.reserved_keys: + reserved_data[key] = data[key] + + return reserved_data + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.reserved_keys})' diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py new file mode 100644 index 00000000..60f6e0eb --- /dev/null +++ b/modelscope/preprocessors/image.py @@ -0,0 +1,291 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import io +from typing import Any, Dict, Union + +import cv2 +import numpy as np +import PIL +from numpy import ndarray +from PIL import Image, ImageOps + +from modelscope.fileio import File +from modelscope.metainfo import Preprocessors +from modelscope.utils.constant import Fields +from modelscope.utils.type_assert import type_assert +from .base import Preprocessor +from .builder import PREPROCESSORS + + +@PREPROCESSORS.register_module(Fields.cv, Preprocessors.load_image) +class LoadImage: + """Load an image from file or url. + Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + Args: + mode (str): See :ref:`PIL.Mode`. + """ + + def __init__(self, mode='rgb'): + self.mode = mode.upper() + + def __call__(self, input: Union[str, Dict[str, str]]): + """Call functions to load image and get image meta information. + Args: + input (str or dict): input image path or input dict with + a key `filename`. + Returns: + dict: The dict contains loaded image. + """ + if isinstance(input, dict): + image_path_or_url = input['filename'] + else: + image_path_or_url = input + + bytes = File.read(image_path_or_url) + # TODO @wenmeng.zwm add opencv decode as optional + # we should also look at the input format which is the most commonly + # used in Mind' image related models + with io.BytesIO(bytes) as infile: + img = Image.open(infile) + img = ImageOps.exif_transpose(img) + img = img.convert(self.mode) + + results = { + 'filename': image_path_or_url, + 'img': img, + 'img_shape': (img.size[1], img.size[0], 3), + 'img_field': 'img', + } + return results + + def __repr__(self): + repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})' + return repr_str + + @staticmethod + def convert_to_ndarray(input) -> ndarray: + if isinstance(input, str): + img = np.array(load_image(input)) + elif isinstance(input, PIL.Image.Image): + img = np.array(input.convert('RGB')) + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + return img + + @staticmethod + def convert_to_img(input) -> ndarray: + if isinstance(input, str): + img = load_image(input) + elif isinstance(input, PIL.Image.Image): + img = input.convert('RGB') + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] + img = Image.fromarray(img.astype('uint8')).convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + return img + + +def load_image(image_path_or_url: str) -> Image.Image: + """ simple interface to load an image from file or url + + Args: + image_path_or_url (str): image file path or http url + """ + loader = LoadImage() + return loader(image_path_or_url)['img'] + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.image_color_enhance_preprocessor) +class ImageColorEnhanceFinetunePreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + self.model_dir: str = model_dir + + @type_assert(object, object) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (tuple): [sentence1, sentence2] + sentence1 (str): a sentence + Example: + 'you are so handsome.' + sentence2 (str): a sentence + Example: + 'you are so beautiful.' + Returns: + Dict[str, Any]: the preprocessed data + """ + + return data + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.image_denoie_preprocessor) +class ImageDenoisePreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """ + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + self.model_dir: str = model_dir + + from .common import Filter + + # TODO: `Filter` should be moved to configurarion file of each model + self._transforms = [Filter(reserved_keys=['input', 'target'])] + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data Dict[str, Any] + + Returns: + Dict[str, Any]: the preprocessed data + """ + for t in self._transforms: + data = t(data) + + return data + + +@PREPROCESSORS.register_module( + Fields.cv, + module_name=Preprocessors.image_portrait_enhancement_preprocessor) +class ImagePortraitEnhancementPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """ + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + self.model_dir: str = model_dir + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data Dict[str, Any] + + Returns: + Dict[str, Any]: the preprocessed data + """ + return data + + +@PREPROCESSORS.register_module( + Fields.cv, + module_name=Preprocessors.image_instance_segmentation_preprocessor) +class ImageInstanceSegmentationPreprocessor(Preprocessor): + + def __init__(self, *args, **kwargs): + """image instance segmentation preprocessor in the fine-tune scenario + """ + + super().__init__(*args, **kwargs) + + self.training = kwargs.pop('training', True) + self.preprocessor_train_cfg = kwargs.pop('train', None) + self.preprocessor_test_cfg = kwargs.pop('val', None) + + self.train_transforms = [] + self.test_transforms = [] + + from modelscope.models.cv.image_instance_segmentation.datasets import \ + build_preprocess_transform + + if self.preprocessor_train_cfg is not None: + if isinstance(self.preprocessor_train_cfg, dict): + self.preprocessor_train_cfg = [self.preprocessor_train_cfg] + for cfg in self.preprocessor_train_cfg: + transform = build_preprocess_transform(cfg) + self.train_transforms.append(transform) + + if self.preprocessor_test_cfg is not None: + if isinstance(self.preprocessor_test_cfg, dict): + self.preprocessor_test_cfg = [self.preprocessor_test_cfg] + for cfg in self.preprocessor_test_cfg: + transform = build_preprocess_transform(cfg) + self.test_transforms.append(transform) + + def train(self): + self.training = True + return + + def eval(self): + self.training = False + return + + @type_assert(object, object) + def __call__(self, results: Dict[str, Any]): + """process the raw input data + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + Dict[str, Any] | None: the preprocessed data + """ + + if self.training: + transforms = self.train_transforms + else: + transforms = self.test_transforms + + for t in transforms: + + results = t(results) + + if results is None: + return None + + return results + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.video_summarization_preprocessor) +class VideoSummarizationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """ + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + self.model_dir: str = model_dir + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data Dict[str, Any] + + Returns: + Dict[str, Any]: the preprocessed data + """ + return data diff --git a/modelscope/preprocessors/kws.py b/modelscope/preprocessors/kws.py new file mode 100644 index 00000000..6f09d545 --- /dev/null +++ b/modelscope/preprocessors/kws.py @@ -0,0 +1,143 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict, List, Union + +import yaml + +from modelscope.metainfo import Preprocessors +from modelscope.models.base import Model +from modelscope.utils.constant import Fields +from .base import Preprocessor +from .builder import PREPROCESSORS + +__all__ = ['WavToLists'] + + +@PREPROCESSORS.register_module( + Fields.audio, module_name=Preprocessors.wav_to_lists) +class WavToLists(Preprocessor): + """generate audio lists file from wav + """ + + def __init__(self): + pass + + def __call__(self, model: Model, audio_in: Union[List[str], str, + bytes]) -> Dict[str, Any]: + """Call functions to load model and wav. + + Args: + model (Model): model should be provided + audio_in (Union[List[str], str, bytes]): + audio_in[0] is positive wav path, audio_in[1] is negative wav path; + audio_in (str) is positive wav path; + audio_in (bytes) is audio pcm data; + Returns: + Dict[str, Any]: the kws result + """ + + self.model = model + out = self.forward(self.model.forward(), audio_in) + return out + + def forward(self, model: Dict[str, Any], + audio_in: Union[List[str], str, bytes]) -> Dict[str, Any]: + assert len( + model['config_path']) > 0, 'preprocess model[config_path] is empty' + assert os.path.exists( + model['config_path']), 'model config.yaml is absent' + + inputs = model.copy() + + import kws_util.common + kws_type = kws_util.common.type_checking(audio_in) + assert kws_type in [ + 'wav', 'pcm', 'pos_testsets', 'neg_testsets', 'roc' + ], f'kws_type {kws_type} is invalid, please check audio data' + + inputs['kws_type'] = kws_type + if kws_type == 'wav': + inputs['pos_wav_path'] = audio_in + elif kws_type == 'pcm': + inputs['pos_data'] = audio_in + if kws_type in ['pos_testsets', 'roc']: + inputs['pos_wav_path'] = audio_in[0] + if kws_type in ['neg_testsets', 'roc']: + inputs['neg_wav_path'] = audio_in[1] + + out = self.read_config(inputs) + out = self.generate_wav_lists(out) + + return out + + def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """read and parse config.yaml to get all model files + """ + + assert os.path.exists( + inputs['config_path']), 'model config yaml file does not exist' + + config_file = open(inputs['config_path']) + root = yaml.full_load(config_file) + config_file.close() + + inputs['cfg_file'] = root['cfg_file'] + inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'], + root['cfg_file']) + inputs['keyword_grammar'] = root['keyword_grammar'] + inputs['keyword_grammar_path'] = os.path.join( + inputs['model_workspace'], root['keyword_grammar']) + inputs['sample_rate'] = root['sample_rate'] + + return inputs + + def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """assemble wav lists + """ + import kws_util.common + + if inputs['kws_type'] == 'wav': + wav_list = [] + wave_scp_content: str = inputs['pos_wav_path'] + wav_list.append(wave_scp_content) + inputs['pos_wav_list'] = wav_list + inputs['pos_wav_count'] = 1 + inputs['pos_num_thread'] = 1 + + if inputs['kws_type'] == 'pcm': + inputs['pos_wav_list'] = ['pcm_data'] + inputs['pos_wav_count'] = 1 + inputs['pos_num_thread'] = 1 + + if inputs['kws_type'] in ['pos_testsets', 'roc']: + # find all positive wave + wav_list = [] + wav_dir = inputs['pos_wav_path'] + wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) + inputs['pos_wav_list'] = wav_list + + list_count: int = len(wav_list) + inputs['pos_wav_count'] = list_count + + if list_count <= 128: + inputs['pos_num_thread'] = list_count + else: + inputs['pos_num_thread'] = 128 + + if inputs['kws_type'] in ['neg_testsets', 'roc']: + # find all negative wave + wav_list = [] + wav_dir = inputs['neg_wav_path'] + wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) + inputs['neg_wav_list'] = wav_list + + list_count: int = len(wav_list) + inputs['neg_wav_count'] = list_count + + if list_count <= 128: + inputs['neg_num_thread'] = list_count + else: + inputs['neg_num_thread'] = 128 + + return inputs diff --git a/modelscope/preprocessors/movie_scene_segmentation/__init__.py b/modelscope/preprocessors/movie_scene_segmentation/__init__.py new file mode 100644 index 00000000..b28ccabc --- /dev/null +++ b/modelscope/preprocessors/movie_scene_segmentation/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .transforms import get_transform +else: + _import_structure = { + 'transforms': ['get_transform'], + } + + import sys + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/movie_scene_segmentation/transforms.py b/modelscope/preprocessors/movie_scene_segmentation/transforms.py new file mode 100644 index 00000000..5b84003c --- /dev/null +++ b/modelscope/preprocessors/movie_scene_segmentation/transforms.py @@ -0,0 +1,308 @@ +# The implementation here is modified based on BaSSL, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/kakaobrain/bassl +import numbers +import os.path as osp +import random +from typing import List + +import numpy as np +import torch +import torchvision.transforms as TF +import torchvision.transforms.functional as F +from PIL import Image, ImageFilter + + +def get_transform(lst): + assert len(lst) > 0 + transform_lst = [] + for item in lst: + transform_lst.append(build_transform(item)) + transform = TF.Compose(transform_lst) + return transform + + +def build_transform(cfg): + assert isinstance(cfg, dict) + cfg = cfg.copy() + type = cfg.pop('type') + + if type == 'VideoResizedCenterCrop': + return VideoResizedCenterCrop(**cfg) + elif type == 'VideoToTensor': + return VideoToTensor(**cfg) + elif type == 'VideoRandomResizedCrop': + return VideoRandomResizedCrop(**cfg) + elif type == 'VideoRandomHFlip': + return VideoRandomHFlip() + elif type == 'VideoRandomColorJitter': + return VideoRandomColorJitter(**cfg) + elif type == 'VideoRandomGaussianBlur': + return VideoRandomGaussianBlur(**cfg) + else: + raise NotImplementedError + + +class VideoResizedCenterCrop(torch.nn.Module): + + def __init__(self, image_size, crop_size): + self.tfm = TF.Compose([ + TF.Resize(size=image_size, interpolation=Image.BICUBIC), + TF.CenterCrop(crop_size), + ]) + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + return [self.tfm(img) for img in imgmap] + + +class VideoToTensor(torch.nn.Module): + + def __init__(self, mean=None, std=None, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + assert self.mean is not None + assert self.std is not None + + def __to_tensor__(self, img): + return F.to_tensor(img) + + def __normalize__(self, img): + return F.normalize(img, self.mean, self.std, self.inplace) + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + return [self.__normalize__(self.__to_tensor__(img)) for img in imgmap] + + +class VideoRandomResizedCrop(torch.nn.Module): + + def __init__(self, size, bottom_area=0.2): + self.p = 1.0 + self.interpolation = Image.BICUBIC + self.size = size + self.bottom_area = bottom_area + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + if random.random() < self.p: # do RandomResizedCrop, consistent=True + top, left, height, width = TF.RandomResizedCrop.get_params( + imgmap[0], + scale=(self.bottom_area, 1.0), + ratio=(3 / 4.0, 4 / 3.0)) + return [ + F.resized_crop( + img=img, + top=top, + left=left, + height=height, + width=width, + size=(self.size, self.size), + ) for img in imgmap + ] + else: + return [ + F.resize(img=img, size=[self.size, self.size]) + for img in imgmap + ] + + +class VideoRandomHFlip(torch.nn.Module): + + def __init__(self, consistent=True, command=None, seq_len=0): + self.consistent = consistent + if seq_len != 0: + self.consistent = False + if command == 'left': + self.threshold = 0 + elif command == 'right': + self.threshold = 1 + else: + self.threshold = 0.5 + self.seq_len = seq_len + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + if self.consistent: + if random.random() < self.threshold: + return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] + else: + return imgmap + else: + result = [] + for idx, i in enumerate(imgmap): + if idx % self.seq_len == 0: + th = random.random() + if th < self.threshold: + result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) + else: + result.append(i) + assert len(result) == len(imgmap) + return result + + +class VideoRandomColorJitter(torch.nn.Module): + """Randomly change the brightness, contrast and saturation of an image. + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__( + self, + brightness=0, + contrast=0, + saturation=0, + hue=0, + consistent=True, + p=1.0, + seq_len=0, + ): + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input( + hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + self.consistent = consistent + self.threshold = p + self.seq_len = seq_len + + def _check_input(self, + value, + name, + center=1, + bound=(0, float('inf')), + clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + 'If {} is a single number, it must be non negative.'. + format(name)) + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError('{} values should be between {}'.format( + name, bound)) + else: + raise TypeError( + '{} should be a single number or a list/tuple with lenght 2.'. + format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + Arguments are same as that of __init__. + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) + transforms.append( + TF.Lambda( + lambda img: F.adjust_brightness(img, brightness_factor))) + + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) + transforms.append( + TF.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) + transforms.append( + TF.Lambda( + lambda img: F.adjust_saturation(img, saturation_factor))) + + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) + transforms.append( + TF.Lambda(lambda img: F.adjust_hue(img, hue_factor))) + + random.shuffle(transforms) + transform = TF.Compose(transforms) + + return transform + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + if random.random() < self.threshold: # do ColorJitter + if self.consistent: + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + + return [transform(i) for i in imgmap] + else: + if self.seq_len == 0: + return [ + self.get_params(self.brightness, self.contrast, + self.saturation, self.hue)(img) + for img in imgmap + ] + else: + result = [] + for idx, img in enumerate(imgmap): + if idx % self.seq_len == 0: + transform = self.get_params( + self.brightness, + self.contrast, + self.saturation, + self.hue, + ) + result.append(transform(img)) + return result + + else: + return imgmap + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += 'brightness={0}'.format(self.brightness) + format_string += ', contrast={0}'.format(self.contrast) + format_string += ', saturation={0}'.format(self.saturation) + format_string += ', hue={0})'.format(self.hue) + return format_string + + +class VideoRandomGaussianBlur(torch.nn.Module): + + def __init__(self, radius_min=0.1, radius_max=2.0, p=0.5): + self.radius_min = radius_min + self.radius_max = radius_max + self.p = p + + def __call__(self, imgmap): + assert isinstance(imgmap, list) + if random.random() < self.p: + result = [] + for _, img in enumerate(imgmap): + _radius = random.uniform(self.radius_min, self.radius_max) + result.append( + img.filter(ImageFilter.GaussianBlur(radius=_radius))) + return result + else: + return imgmap + + +def apply_transform(images, trans): + return torch.stack(trans(images), dim=0) diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py new file mode 100644 index 00000000..3a3ae820 --- /dev/null +++ b/modelscope/preprocessors/multi_modal.py @@ -0,0 +1,386 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from io import BytesIO +from typing import Any, Dict, List, Tuple, Union + +import json +import torch +from PIL import Image +from timm.data import create_transform +from torchvision.transforms import Compose, Normalize, Resize, ToTensor + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Preprocessors +from modelscope.pipelines.base import Input +from modelscope.preprocessors import load_image +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks +from .base import Preprocessor +from .builder import PREPROCESSORS +from .ofa import * # noqa +from .ofa.utils.collate import collate_fn +from .ofa.utils.constant import OFA_TASK_KEY_MAPPING + +__all__ = [ + 'OfaPreprocessor', + 'MPlugPreprocessor', +] + + +@PREPROCESSORS.register_module( + Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor) +class OfaPreprocessor(Preprocessor): + + def __init__(self, + model_dir: str, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + mode: preprocessor mode (model mode) + """ + super().__init__(*args, **kwargs) + preprocess_mapping = { + Tasks.ocr_recognition: OfaOcrRecognitionPreprocessor, + Tasks.image_captioning: OfaImageCaptioningPreprocessor, + Tasks.visual_grounding: OfaVisualGroundingPreprocessor, + Tasks.visual_question_answering: + OfaVisualQuestionAnsweringPreprocessor, + Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, + Tasks.image_classification: OfaImageClassificationPreprocessor, + Tasks.text_classification: OfaTextClassificationPreprocessor, + Tasks.text_summarization: OfaSummarizationPreprocessor, + Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor + } + model_dir = model_dir if osp.exists(model_dir) else snapshot_download( + model_dir) + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + self.preprocess = preprocess_mapping[self.cfg.task]( + cfg=self.cfg, model_dir=model_dir, mode=mode) + self.keys = OFA_TASK_KEY_MAPPING[self.cfg.task] + self.tokenizer = self.preprocess.tokenizer + if kwargs.get('no_collate', None): + self.no_collate = True + else: + self.no_collate = False + + # just for modelscope demo + def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]: + data = dict() + if not isinstance(input, tuple) and not isinstance(input, list): + input = (input, ) + for key, item in zip(self.keys, input): + data[key] = item + return data + + def _ofa_input_compatibility_conversion(self, data): # fake + if 'image' in data and self.cfg.model.get('type', None) == 'ofa': + if isinstance(data['image'], str): + image = load_image(data['image']) + else: + image = data['image'] + if image.mode != 'RGB': + image = image.convert('RGB') + img_buffer = BytesIO() + image.save(img_buffer, format='JPEG') + data['image'] = Image.open(img_buffer) + return data + + def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, + **kwargs) -> Dict[str, Any]: + if isinstance(input, dict): + data = input + else: + data = self._build_dict(input) + sample = self.preprocess(data) + str_data = dict() + for k, v in data.items(): + str_data[k] = str(v) + sample['sample'] = str_data + if self.no_collate: + return sample + else: + return collate_fn([sample], + pad_idx=self.tokenizer.pad_token_id, + eos_idx=self.tokenizer.eos_token_id) + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +@PREPROCESSORS.register_module( + Fields.multi_modal, module_name=Preprocessors.clip_preprocessor) +class CLIPPreprocessor(Preprocessor): + + def __init__(self, + model_dir: str, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + mode: preprocessor mode (model mode) + """ + super().__init__(*args, **kwargs) + model_dir = model_dir if osp.exists(model_dir) else snapshot_download( + model_dir) + self.mode = mode + # text tokenizer + from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer + if 'tokenizer' in kwargs and isinstance(kwargs['tokenizer'], + FullTokenizer): + self.tokenizer = kwargs['tokenizer'] + else: + vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}' + self.tokenizer = FullTokenizer(vocab_file=vocab_file) + # image preprocessor + if 'resolution' in kwargs and isinstance(kwargs['resolution'], int): + self.image_resolution = kwargs['resolution'] + else: + self.image_resolution = json.load( + open('{}/vision_model_config.json'.format( + model_dir)))['image_resolution'] + self.img_preprocess = self._build_image_transform() + # key mapping + # specify the input keys, compatible with training and inference whose key names may be different + self.input_keys = {'img': 'img', 'text': 'text'} + + def _build_image_transform(self): + + if self.mode == ModeKeys.TRAIN: + transform = create_transform( + input_size=self.image_resolution, + scale=(0.9, 1.0), + is_training=True, + color_jitter=None, + auto_augment='original', + interpolation='bicubic', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ) + transform = Compose(transform.transforms[:-3] + [_convert_to_rgb] + + transform.transforms[-3:]) + else: + transform = Compose([ + Resize((self.image_resolution, self.image_resolution), + interpolation=Image.BICUBIC), + _convert_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + return transform + + def tokenize(self, + texts: Union[str, List[str]], + context_length: int = 52) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all baseline models use 24 as the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + all_tokens.append( + [self.tokenizer.vocab['[CLS]']] + + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:context_length - 2] + + [self.tokenizer.vocab['[SEP]']]) + + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + def set_input_img_key(self, new_key: str): + self.input_keys['img'] = new_key + + def set_input_text_key(self, new_key: str): + self.input_keys['text'] = new_key + + def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, + **kwargs) -> Dict[str, Any]: + output = {} + # preprocess the image input + input_img_key = self.input_keys['img'] + if input_img_key in input and input[input_img_key] is not None: + image_input = input[input_img_key] + + # single image input + if isinstance(image_input, Image.Image): + image_tensor = self.img_preprocess(image_input).unsqueeze(0) + # multi images input + elif isinstance(image_input, list): + if all([isinstance(elem, Image.Image) + for elem in image_input]): + image_tensor = torch.stack( + [self.img_preprocess(elem) + for elem in image_input], # noqa + dim=0) # noqa + else: + unsupported_elem_type = [ + type(elem) for elem in image_input + if not isinstance(elem, Image.Image) + ][0] + raise TypeError( + f'img should be PIL.Image or List[PIL.Image], \ + but got a List containing one {unsupported_elem_type}' + ) + # others + else: + raise TypeError( + f'img should be PIL.Image or List[PIL.Image], but got {type(image_input)}' + ) + output['img'] = image_tensor + + # preprocess the text input + input_text_key = self.input_keys['text'] + if input_text_key in input and input[input_text_key] is not None: + text_input = input[input_text_key] + + # single text input + if isinstance(text_input, str): + text_tensor = self.tokenize(text_input) + # multi texts input + elif isinstance(text_input, list): + if all([isinstance(elem, str) for elem in text_input]): + text_tensor = self.tokenize(text_input) + else: + unsupported_elem_type = [ + type(elem) for elem in text_input + if not isinstance(elem, str) + ][0] + raise TypeError( + f'text should be str or List[str], but got a List containing one {unsupported_elem_type}' + ) + # others + else: + raise TypeError( + f'text should be str or List[str], but got {type(text_input)}' + ) + output['text'] = text_tensor + + return output + + +@PREPROCESSORS.register_module( + Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) +class MPlugPreprocessor(Preprocessor): + + def __init__(self, + model_dir: str, + mode: str = ModeKeys.INFERENCE, + tokenizer_max_length: int = 25, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.model_dir = model_dir + self.mode = mode + self.tokenizer_max_length = tokenizer_max_length + + self._tokenizer = None + self._patch_resize_transform = None + self._image_map = {} + + @property + def tokenizer(self): + from transformers import BertTokenizer + + if self._tokenizer is None: + self._tokenizer = BertTokenizer.from_pretrained(self.model_dir) + return self._tokenizer + + @property + def patch_resize_transform(self): + if self._patch_resize_transform is None: + from torchvision import transforms + from modelscope.models.multi_modal.mplug import CONFIG_NAME, MPlugConfig + + config = MPlugConfig.from_yaml_file( + osp.join(self.model_dir, CONFIG_NAME)) + + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + + self._patch_resize_transform = transforms.Compose([ + transforms.Resize((config.image_res, config.image_res), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + return self._patch_resize_transform + + def image_open(self, path: str) -> Tuple[Image.Image, int]: + if path not in self._image_map: + index = len(self._image_map) + self._image_map[path] = (load_image(path), index) + return self._image_map[path] + + def __call__( + self, data: Union[Image.Image, tuple, + Dict[str, Any]]) -> Dict[str, Any]: + self.cfg = Config.from_file( + osp.join(self.model_dir, ModelFile.CONFIGURATION)) + + if isinstance(data, (Image.Image, str)): + image = data + elif isinstance(data, tuple): + image = data[0] + else: + image = data['image'] + index = 0 + if isinstance(image, str): + image, index = self.image_open(image) + image = image.convert('RGB') + image = self.patch_resize_transform(image) + question = '' if self.cfg.task == Tasks.image_captioning \ + else data[1 if isinstance(data, tuple) + else ('text' if 'text' in data else 'question')] + question = self.tokenizer( + question.lower(), + padding='max_length', + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors='pt') + + if self.mode == ModeKeys.INFERENCE: + image = torch.stack([image], dim=0) + return {'image': image, 'question': question} + else: + answer = data['answer'] + answer = self.tokenizer( + answer, + padding='max_length', + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors='pt') + output = { + 'image': image, + 'question_input_ids': question.input_ids.squeeze(), + 'question_attention_mask': question.attention_mask.squeeze(), + 'answer_input_ids': answer.input_ids.squeeze(), + 'answer_attention_mask': answer.attention_mask.squeeze(), + } + if self.cfg.task == Tasks.image_text_retrieval: + output['index'] = index + return output diff --git a/modelscope/preprocessors/nlp/__init__.py b/modelscope/preprocessors/nlp/__init__.py new file mode 100644 index 00000000..7c48fb3c --- /dev/null +++ b/modelscope/preprocessors/nlp/__init__.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .text_error_correction import TextErrorCorrectionPreprocessor + from .nlp_base import (NLPTokenizerPreprocessorBase, NLPBasePreprocessor) + from .text_generation_jieba_preprocessor import TextGenerationJiebaPreprocessor + from .sentence_piece_preprocessor import SentencePiecePreprocessor + from .bert_seq_cls_tokenizer import Tokenize + from .document_segmentation_preprocessor import DocumentSegmentationPreprocessor + from .faq_question_answering_preprocessor import FaqQuestionAnsweringPreprocessor + from .fill_mask_preprocessor import FillMaskPoNetPreprocessor, NLPPreprocessor + from .text_ranking_preprocessor import TextRankingPreprocessor + from .relation_extraction_preprocessor import RelationExtractionPreprocessor + from .sentence_classification_preprocessor import SequenceClassificationPreprocessor + from .sentence_embedding_preprocessor import SentenceEmbeddingPreprocessor + from .text_generation_preprocessor import TextGenerationPreprocessor + from .text2text_generation_preprocessor import Text2TextGenerationPreprocessor + from .token_classification_preprocessor import TokenClassificationPreprocessor, \ + WordSegmentationBlankSetToLabelPreprocessor + from .token_classification_thai_preprocessor import WordSegmentationPreprocessorThai, NERPreprocessorThai + from .token_classification_viet_preprocessor import NERPreprocessorViet + from .zero_shot_classification_reprocessor import ZeroShotClassificationPreprocessor + from .space import (DialogIntentPredictionPreprocessor, + DialogModelingPreprocessor, + DialogStateTrackingPreprocessor, InputFeatures, + MultiWOZBPETextField, IntentBPETextField) + from .space_T_en import ConversationalTextToSqlPreprocessor + from .space_T_cn import TableQuestionAnsweringPreprocessor + from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor +else: + _import_structure = { + 'nlp_base': [ + 'NLPTokenizerPreprocessorBase', + 'NLPBasePreprocessor', + ], + 'text_generation_jieba_preprocessor': + ['TextGenerationJiebaPreprocessor'], + 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], + 'bert_seq_cls_tokenizer': ['Tokenize'], + 'document_segmentation_preprocessor': + ['DocumentSegmentationPreprocessor'], + 'faq_question_answering_preprocessor': + ['FaqQuestionAnsweringPreprocessor'], + 'fill_mask_preprocessor': + ['FillMaskPoNetPreprocessor', 'NLPPreprocessor'], + 'text_ranking_preprocessor': ['TextRankingPreprocessor'], + 'relation_extraction_preprocessor': ['RelationExtractionPreprocessor'], + 'sentence_classification_preprocessor': + ['SequenceClassificationPreprocessor'], + 'sentence_embedding_preprocessor': ['SentenceEmbeddingPreprocessor'], + 'text_generation_preprocessor': ['TextGenerationPreprocessor'], + 'text2text_generation_preprocessor': + ['Text2TextGenerationPreprocessor'], + 'token_classification_preprocessor': [ + 'TokenClassificationPreprocessor', + 'WordSegmentationBlankSetToLabelPreprocessor' + ], + 'zero_shot_classification_reprocessor': + ['ZeroShotClassificationPreprocessor'], + 'text_error_correction': [ + 'TextErrorCorrectionPreprocessor', + ], + 'mglm_summarization_preprocessor': ['MGLMSummarizationPreprocessor'], + 'token_classification_thai_preprocessor': [ + 'NERPreprocessorThai', + 'WordSegmentationPreprocessorThai', + ], + 'token_classification_viet_preprocessor': [ + 'NERPreprocessorViet', + ], + 'space': [ + 'DialogIntentPredictionPreprocessor', + 'DialogModelingPreprocessor', + 'DialogStateTrackingPreprocessor', + 'InputFeatures', + 'MultiWOZBPETextField', + 'IntentBPETextField', + ], + 'space_T_en': ['ConversationalTextToSqlPreprocessor'], + 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/nlp/bert_seq_cls_tokenizer.py b/modelscope/preprocessors/nlp/bert_seq_cls_tokenizer.py new file mode 100644 index 00000000..576687ce --- /dev/null +++ b/modelscope/preprocessors/nlp/bert_seq_cls_tokenizer.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from transformers import AutoTokenizer + +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, InputFields + + +@PREPROCESSORS.register_module(Fields.nlp) +class Tokenize(Preprocessor): + + def __init__(self, tokenizer_name) -> None: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(data, str): + data = {InputFields.text: data} + token_dict = self.tokenizer(data[InputFields.text]) + data.update(token_dict) + return data diff --git a/modelscope/preprocessors/nlp/document_segmentation_preprocessor.py b/modelscope/preprocessors/nlp/document_segmentation_preprocessor.py new file mode 100644 index 00000000..5ab0a0c6 --- /dev/null +++ b/modelscope/preprocessors/nlp/document_segmentation_preprocessor.py @@ -0,0 +1,220 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields +from modelscope.utils.logger import get_logger +from .nlp_base import NLPBasePreprocessor + +logger = get_logger() + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.document_segmentation) +class DocumentSegmentationPreprocessor(NLPBasePreprocessor): + + def __init__(self, model_dir: str, config, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + + super().__init__(model_dir, *args, **kwargs) + from transformers import BertTokenizerFast + self.tokenizer = BertTokenizerFast.from_pretrained( + model_dir, + use_fast=True, + ) + self.question_column_name = 'labels' + self.context_column_name = 'sentences' + self.example_id_column_name = 'example_id' + self.label_to_id = {'B-EOP': 0, 'O': 1} + self.target_specical_ids = set() + self.target_specical_ids.add(self.tokenizer.eos_token_id) + self.max_seq_length = config.max_position_embeddings + self.label_list = ['B-EOP', 'O'] + + def __call__(self, examples) -> Dict[str, Any]: + questions = examples[self.question_column_name] + contexts = examples[self.context_column_name] + example_ids = examples[self.example_id_column_name] + num_examples = len(questions) + + sentences = [] + for sentence_list in contexts: + sentence_list = [_ + '[EOS]' for _ in sentence_list] + sentences.append(sentence_list) + + try: + tokenized_examples = self.tokenizer( + sentences, + is_split_into_words=True, + add_special_tokens=False, + return_token_type_ids=True, + return_attention_mask=True, + ) + except Exception as e: + logger.error(e) + return {} + + segment_ids = [] + token_seq_labels = [] + for example_index in range(num_examples): + example_input_ids = tokenized_examples['input_ids'][example_index] + example_labels = questions[example_index] + example_labels = [ + self.label_to_id[_] if _ in self.label_to_id else -100 + for _ in example_labels + ] + example_token_labels = [] + segment_id = [] + cur_seg_id = 1 + for token_index in range(len(example_input_ids)): + if example_input_ids[token_index] in self.target_specical_ids: + example_token_labels.append(example_labels[cur_seg_id - 1]) + segment_id.append(cur_seg_id) + cur_seg_id += 1 + else: + example_token_labels.append(-100) + segment_id.append(cur_seg_id) + + segment_ids.append(segment_id) + token_seq_labels.append(example_token_labels) + + tokenized_examples['segment_ids'] = segment_ids + tokenized_examples['token_seq_labels'] = token_seq_labels + + new_segment_ids = [] + new_token_seq_labels = [] + new_input_ids = [] + new_token_type_ids = [] + new_attention_mask = [] + new_example_ids = [] + new_sentences = [] + + for example_index in range(num_examples): + example_input_ids = tokenized_examples['input_ids'][example_index] + example_token_type_ids = tokenized_examples['token_type_ids'][ + example_index] + example_attention_mask = tokenized_examples['attention_mask'][ + example_index] + example_segment_ids = tokenized_examples['segment_ids'][ + example_index] + example_token_seq_labels = tokenized_examples['token_seq_labels'][ + example_index] + example_sentences = contexts[example_index] + example_id = example_ids[example_index] + example_total_num_sentences = len(questions[example_index]) + example_total_num_tokens = len( + tokenized_examples['input_ids'][example_index]) + accumulate_length = [ + i for i, x in enumerate(tokenized_examples['input_ids'] + [example_index]) + if x == self.tokenizer.eos_token_id + ] + samples_boundary = [] + left_index = 0 + sent_left_index = 0 + sent_i = 0 + + # for sent_i, length in enumerate(accumulate_length): + while sent_i < len(accumulate_length): + length = accumulate_length[sent_i] + right_index = length + 1 + sent_right_index = sent_i + 1 + if right_index - left_index >= self.max_seq_length - 1 or right_index == example_total_num_tokens: + samples_boundary.append([left_index, right_index]) + + sample_input_ids = [ + self.tokenizer.cls_token_id + ] + example_input_ids[left_index:right_index] + sample_input_ids = sample_input_ids[:self.max_seq_length] + + sample_token_type_ids = [ + 0 + ] + example_token_type_ids[left_index:right_index] + sample_token_type_ids = sample_token_type_ids[:self. + max_seq_length] + + sample_attention_mask = [ + 1 + ] + example_attention_mask[left_index:right_index] + sample_attention_mask = sample_attention_mask[:self. + max_seq_length] + + sample_segment_ids = [ + 0 + ] + example_segment_ids[left_index:right_index] + sample_segment_ids = sample_segment_ids[:self. + max_seq_length] + + sample_token_seq_labels = [ + -100 + ] + example_token_seq_labels[left_index:right_index] + sample_token_seq_labels = sample_token_seq_labels[:self. + max_seq_length] + + if sent_right_index - 1 == sent_left_index: + left_index = right_index + sample_input_ids[-1] = self.tokenizer.eos_token_id + sample_token_seq_labels[-1] = -100 + else: + left_index = accumulate_length[sent_i - 1] + 1 + if sample_token_seq_labels[-1] != -100: + sample_token_seq_labels[-1] = -100 + + if sent_right_index - 1 == sent_left_index or right_index == example_total_num_tokens: + sample_sentences = example_sentences[ + sent_left_index:sent_right_index] + sent_left_index = sent_right_index + sent_i += 1 + else: + sample_sentences = example_sentences[ + sent_left_index:sent_right_index - 1] + sent_left_index = sent_right_index - 1 + + if (len([_ for _ in sample_token_seq_labels if _ != -100 + ])) != len(sample_sentences) - 1 and (len([ + _ + for _ in sample_token_seq_labels if _ != -100 + ])) != len(sample_sentences): + tmp = [] + for w_i, w, l in zip( + sample_input_ids, + self.tokenizer.decode(sample_input_ids).split( + ' '), sample_token_seq_labels): + tmp.append((w_i, w, l)) + while len(sample_input_ids) < self.max_seq_length: + sample_input_ids.append(self.tokenizer.pad_token_id) + sample_token_type_ids.append(0) + sample_attention_mask.append(0) + sample_segment_ids.append(example_total_num_sentences + + 1) + sample_token_seq_labels.append(-100) + + new_input_ids.append(sample_input_ids) + new_token_type_ids.append(sample_token_type_ids) + new_attention_mask.append(sample_attention_mask) + new_segment_ids.append(sample_segment_ids) + new_token_seq_labels.append(sample_token_seq_labels) + new_example_ids.append(example_id) + new_sentences.append(sample_sentences) + else: + sent_i += 1 + continue + + output_samples = {} + + output_samples['input_ids'] = new_input_ids + output_samples['token_type_ids'] = new_token_type_ids + output_samples['attention_mask'] = new_attention_mask + + output_samples['segment_ids'] = new_segment_ids + output_samples['example_id'] = new_example_ids + output_samples['labels'] = new_token_seq_labels + output_samples['sentences'] = new_sentences + + return output_samples diff --git a/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py new file mode 100644 index 00000000..873a8448 --- /dev/null +++ b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py @@ -0,0 +1,98 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.config import Config, ConfigFields +from modelscope.utils.constant import Fields, ModeKeys, ModelFile +from modelscope.utils.type_assert import type_assert +from .nlp_base import NLPBasePreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor) +class FaqQuestionAnsweringPreprocessor(NLPBasePreprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + super(FaqQuestionAnsweringPreprocessor, self).__init__( + model_dir, mode=ModeKeys.INFERENCE, **kwargs) + + from transformers import BertTokenizer + + preprocessor_config = Config.from_file( + os.path.join(model_dir, ModelFile.CONFIGURATION)).get( + ConfigFields.preprocessor, {}) + if preprocessor_config.get('tokenizer', + 'BertTokenizer') == 'XLMRoberta': + from transformers import XLMRobertaTokenizer + self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_dir) + else: + self.tokenizer = BertTokenizer.from_pretrained(model_dir) + + self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) + self.label_dict = None + + def pad(self, samples, max_len): + result = [] + for sample in samples: + pad_len = max_len - len(sample[:max_len]) + result.append(sample[:max_len] + + [self.tokenizer.pad_token_id] * pad_len) + return result + + def set_label_dict(self, label_dict): + self.label_dict = label_dict + + def get_label(self, label_id): + assert self.label_dict is not None and label_id < len(self.label_dict) + return self.label_dict[label_id] + + def encode_plus(self, text): + return [ + self.tokenizer.cls_token_id + ] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text)) + [self.tokenizer.sep_token_id] + + @type_assert(object, Dict) + def __call__(self, data: Dict[str, Any], + **preprocessor_param) -> Dict[str, Any]: + TMP_MAX_LEN = preprocessor_param.get('max_seq_length', self.MAX_LEN) + queryset = data['query_set'] + if not isinstance(queryset, list): + queryset = [queryset] + supportset = data['support_set'] + supportset = sorted(supportset, key=lambda d: d['label']) + + queryset_tokenized = [self.encode_plus(text) for text in queryset] + supportset_tokenized = [ + self.encode_plus(item['text']) for item in supportset + ] + + max_len = max( + [len(seq) for seq in queryset_tokenized + supportset_tokenized]) + max_len = min(TMP_MAX_LEN, max_len) + queryset_padded = self.pad(queryset_tokenized, max_len) + supportset_padded = self.pad(supportset_tokenized, max_len) + + supportset_labels_ori = [item['label'] for item in supportset] + label_dict = [] + for label in supportset_labels_ori: + if label not in label_dict: + label_dict.append(label) + self.set_label_dict(label_dict) + supportset_labels_ids = [ + label_dict.index(label) for label in supportset_labels_ori + ] + return { + 'query': queryset_padded, + 'support': supportset_padded, + 'support_labels': supportset_labels_ids + } + + def batch_encode(self, sentence_list: list, max_length=None): + if not max_length: + max_length = self.MAX_LEN + return self.tokenizer.batch_encode_plus( + sentence_list, padding=True, max_length=max_length) diff --git a/modelscope/preprocessors/nlp/fill_mask_preprocessor.py b/modelscope/preprocessors/nlp/fill_mask_preprocessor.py new file mode 100644 index 00000000..b0638dbc --- /dev/null +++ b/modelscope/preprocessors/nlp/fill_mask_preprocessor.py @@ -0,0 +1,142 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +import re +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModeKeys, ModelFile +from modelscope.utils.nlp import import_external_nltk_data +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.fill_mask) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.feature_extraction) +class NLPPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in MLM task. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', + True) + super().__init__(model_dir, mode=mode, **kwargs) + + @property + def mask_id(self): + return self.tokenizer.mask_token_id + + def decode(self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs): + return self.tokenizer.decode(token_ids, skip_special_tokens, + clean_up_tokenization_spaces, **kwargs) + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.fill_mask_ponet) +class FillMaskPoNetPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in PoNet model's MLM task. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 512) + kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', + True) + super().__init__(model_dir, mode=mode, **kwargs) + + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + self.language = self.cfg.model.get('language', 'en') + if self.language == 'en': + from nltk.tokenize import sent_tokenize + import_external_nltk_data( + osp.join(model_dir, 'nltk_data'), 'tokenizers/punkt') + elif self.language in ['zh', 'cn']: + + def sent_tokenize(para): + para = re.sub(r'([。!!?\?])([^”’])', r'\1\n\2', para) # noqa * + para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa * + para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa * + para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', + para) # noqa * + para = para.rstrip() + return [_ for _ in para.split('\n') if _] + else: + raise NotImplementedError + + self.sent_tokenize = sent_tokenize + self.max_length = kwargs['max_length'] + + def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (tuple): [sentence1, sentence2] + sentence1 (str): a sentence + Example: + 'you are so handsome.' + sentence2 (str): a sentence + Example: + 'you are so beautiful.' + Returns: + Dict[str, Any]: the preprocessed data + """ + + text_a, text_b, labels = self.parse_text_and_label(data) + output = self.tokenizer( + text_a, + text_b, + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, + **self.tokenize_kwargs) + max_seq_length = self.max_length + + if text_b is None: + segment_ids = [] + seg_lens = list( + map( + len, + self.tokenizer( + self.sent_tokenize(text_a), + add_special_tokens=False, + truncation=True)['input_ids'])) + segment_id = [0] + sum( + [[i] * sl for i, sl in enumerate(seg_lens, start=1)], []) + segment_id = segment_id[:max_seq_length - 1] + segment_ids.append(segment_id + [segment_id[-1] + 1] + * (max_seq_length - len(segment_id))) + if self.mode == ModeKeys.INFERENCE: + segment_ids = torch.tensor(segment_ids) + output['segment_ids'] = segment_ids + + output = { + k: np.array(v) if isinstance(v, list) else v + for k, v in output.items() + } + + self.labels_to_id(labels, output) + return output + + @property + def mask_id(self): + return self.tokenizer.mask_token_id + + def decode(self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs): + return self.tokenizer.decode(token_ids, skip_special_tokens, + clean_up_tokenization_spaces, **kwargs) diff --git a/modelscope/preprocessors/nlp/mglm_summarization_preprocessor.py b/modelscope/preprocessors/nlp/mglm_summarization_preprocessor.py new file mode 100644 index 00000000..0a68a9fa --- /dev/null +++ b/modelscope/preprocessors/nlp/mglm_summarization_preprocessor.py @@ -0,0 +1,32 @@ +# Copyright (c) 2022 Zhipu.AI + +import os.path as osp +import re +from typing import Any, Dict, Iterable, Optional, Tuple, Union + +from modelscope.metainfo import Models, Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.config import Config, ConfigFields +from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile +from modelscope.utils.hub import get_model_type, parse_label_mapping +from modelscope.utils.logger import get_logger +from modelscope.utils.nlp import import_external_nltk_data +from modelscope.utils.type_assert import type_assert + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.mglm_summarization) +class MGLMSummarizationPreprocessor(Preprocessor): + + def __init__(self, *args, **kwargs): + """preprocess the data + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + @type_assert(object, (str, tuple, Dict)) + def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]: + return data diff --git a/modelscope/preprocessors/nlp/nlp_base.py b/modelscope/preprocessors/nlp/nlp_base.py new file mode 100644 index 00000000..45efc6e7 --- /dev/null +++ b/modelscope/preprocessors/nlp/nlp_base.py @@ -0,0 +1,289 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from abc import ABC +from collections.abc import Mapping +from typing import Any, Dict, List, Tuple, Union + +import json +import numpy as np +import torch +from transformers import AutoTokenizer + +from modelscope.metainfo import Models +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.base import Preprocessor +from modelscope.utils.constant import ModeKeys +from modelscope.utils.hub import get_model_type, parse_label_mapping +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = [ + 'NLPBasePreprocessor', + 'NLPTokenizerPreprocessorBase', +] + + +class NLPBasePreprocessor(Preprocessor, ABC): + + def __init__(self, + model_dir: str, + first_sequence=None, + second_sequence=None, + label=None, + label2id=None, + mode=ModeKeys.INFERENCE, + use_fast=None, + **kwargs): + """The NLP preprocessor base class. + + Args: + model_dir (str): The local model path + first_sequence: The key for the first sequence + second_sequence: The key for the second sequence + label: The label key + label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping + if this mapping is not supplied. + mode: Run this preprocessor in either 'train'/'eval'/'inference' mode + use_fast: use the fast version of tokenizer + + """ + self.model_dir = model_dir + self.first_sequence = first_sequence + self.second_sequence = second_sequence + self.label = label + + self.use_fast = use_fast + if self.use_fast is None and model_dir is None: + self.use_fast = False + elif self.use_fast is None and os.path.isfile( + os.path.join(model_dir, 'tokenizer_config.json')): + with open(os.path.join(model_dir, 'tokenizer_config.json'), + 'r') as f: + json_config = json.load(f) + self.use_fast = json_config.get('use_fast') + self.use_fast = False if self.use_fast is None else self.use_fast + + self.label2id = label2id + if self.label2id is None and model_dir is not None: + self.label2id = parse_label_mapping(model_dir) + super().__init__(mode, **kwargs) + + @property + def mask_id(self): + """Child preprocessor can override this property to return the id of mask token. + + Returns: + The id of mask token, default None. + """ + return None + + def decode(self, + token_ids: Union[int, List[int], 'np.ndarray', 'torch.Tensor', + 'tf.Tensor'], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + **kwargs): + """Turn the token_ids to real sentence. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to clean up the tokenization spaces. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + Returns: + The real sentence decoded by the preprocessor. + """ + raise NotImplementedError() + + +class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): + + def __init__(self, + model_dir: str, + first_sequence: str = None, + second_sequence: str = None, + label: str = 'label', + label2id: dict = None, + mode: str = ModeKeys.INFERENCE, + use_fast: bool = None, + **kwargs): + """The NLP tokenizer preprocessor base class. + + Any nlp preprocessor which uses the hf tokenizer can inherit from this class. + + Args: + model_dir (str): The local model path + first_sequence: The key for the first sequence + second_sequence: The key for the second sequence + label: The key for the label + label2id: An optional label2id dict. + If label2id is None, the preprocessor will try to parse label-id mapping from: + - configuration.json model.label2id/model.id2label + - config.json label2id/id2label + - label_mapping.json + mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different. + use_fast: use the fast version of tokenizer + kwargs: These kwargs will be directly fed into the tokenizer. + """ + + super().__init__(model_dir, first_sequence, second_sequence, label, + label2id, mode, use_fast, **kwargs) + self.model_dir = model_dir + self.tokenize_kwargs = kwargs + self.tokenizer = self.build_tokenizer(model_dir) + logger.info(f'The key of sentence1: {self.first_sequence}, ' + f'The key of sentence2: {self.second_sequence}, ' + f'The key of label: {self.label}') + if self.first_sequence is None: + logger.warning('[Important] first_sequence attribute is not set, ' + 'this will cause an error if your input is a dict.') + + @property + def id2label(self): + """Return the id2label mapping according to the label2id mapping. + + @return: The id2label mapping if exists. + """ + if self.label2id is not None: + return {id: label for label, id in self.label2id.items()} + return None + + def build_tokenizer(self, model_dir): + """Build a tokenizer by the model type. + + NOTE: This default implementation only returns slow tokenizer, because the fast tokenizers have a + multi-thread problem. + + Args: + model_dir: The local model dir. + + Returns: + The initialized tokenizer. + """ + self.is_transformer_based_model = 'lstm' not in model_dir + # fast version lead to parallel inference failed + model_type = get_model_type(model_dir) + if model_type in (Models.structbert, Models.gpt3, Models.palm, + Models.plug): + from modelscope.models.nlp.structbert import SbertTokenizer, SbertTokenizerFast + tokenizer = SbertTokenizerFast if self.use_fast else SbertTokenizer + return tokenizer.from_pretrained(model_dir) + elif model_type == Models.veco: + from modelscope.models.nlp.veco import VecoTokenizer, VecoTokenizerFast + tokenizer = VecoTokenizerFast if self.use_fast else VecoTokenizer + return tokenizer.from_pretrained(model_dir) + elif model_type == Models.deberta_v2: + from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer, DebertaV2TokenizerFast + tokenizer = DebertaV2TokenizerFast if self.use_fast else DebertaV2Tokenizer + return tokenizer.from_pretrained(model_dir) + elif not self.is_transformer_based_model: + from transformers import BertTokenizer, BertTokenizerFast + tokenizer = BertTokenizerFast if self.use_fast else BertTokenizer + return tokenizer.from_pretrained(model_dir) + else: + return AutoTokenizer.from_pretrained( + model_dir, use_fast=self.use_fast) + + def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (tuple): [sentence1, sentence2] + sentence1 (str): a sentence + Example: + 'you are so handsome.' + sentence2 (str): a sentence + Example: + 'you are so beautiful.' + Returns: + Dict[str, Any]: the preprocessed data + """ + + text_a, text_b, labels = self.parse_text_and_label(data) + output = self.tokenizer( + text_a, + text_b, + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, + **self.tokenize_kwargs) + output = { + k: np.array(v) if isinstance(v, list) else v + for k, v in output.items() + } + self.labels_to_id(labels, output) + return output + + def parse_text_and_label(self, data): + """Parse the input and return the sentences and labels. + + When input type is tuple or list and its size is 2: + If the pair param is False, data will be parsed as the first_sentence and the label, + else it will be parsed as the first_sentence and the second_sentence. + + Args: + data: The input data. + + Returns: + The sentences and labels tuple. + """ + text_a, text_b, labels = None, None, None + if isinstance(data, str): + text_a = data + elif isinstance(data, tuple) or isinstance(data, list): + if len(data) == 3: + text_a, text_b, labels = data + elif len(data) == 2: + if self._mode == ModeKeys.INFERENCE: + text_a, text_b = data + else: + text_a, labels = data + elif isinstance(data, Mapping): + text_a = data.get(self.first_sequence) + text_b = data.get(self.second_sequence) + labels = data.get(self.label) + + return text_a, text_b, labels + + def labels_to_id(self, labels, output): + """Turn the labels to id with the type int or float. + + If the original label's type is str or int, the label2id mapping will try to convert it to the final label. + If the original label's type is float, or the label2id mapping does not exist, + the original label will be returned. + + Args: + labels: The input labels. + output: The label id. + + Returns: + The final labels. + """ + + def label_can_be_mapped(label): + return isinstance(label, str) or isinstance(label, int) + + try: + if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \ + and self.label2id is not None: + output[OutputKeys.LABELS] = [ + self.label2id[label] + if label in self.label2id else self.label2id[str(label)] + for label in labels + ] + elif label_can_be_mapped(labels) and self.label2id is not None: + output[OutputKeys.LABELS] = self.label2id[ + labels] if labels in self.label2id else self.label2id[str( + labels)] + elif labels is not None: + output[OutputKeys.LABELS] = labels + except KeyError as e: + logger.error( + f'Label {labels} cannot be found in the label mapping {self.label2id},' + f'which comes from the user input or the configuration files. ' + f'Please consider matching your labels with this mapping.') + raise e diff --git a/modelscope/preprocessors/nlp/relation_extraction_preprocessor.py b/modelscope/preprocessors/nlp/relation_extraction_preprocessor.py new file mode 100644 index 00000000..9a426ab7 --- /dev/null +++ b/modelscope/preprocessors/nlp/relation_extraction_preprocessor.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +from transformers import AutoTokenizer + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields +from modelscope.utils.type_assert import type_assert +from .nlp_base import NLPBasePreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.re_tokenizer) +class RelationExtractionPreprocessor(NLPBasePreprocessor): + """The relation extraction preprocessor used in normal RE task. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + + super().__init__(model_dir, *args, **kwargs) + + self.model_dir: str = model_dir + self.sequence_length = kwargs.pop('sequence_length', 512) + self.tokenizer = AutoTokenizer.from_pretrained( + model_dir, use_fast=True) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + # preprocess the data for the model input + text = data + output = self.tokenizer([text], return_tensors='pt') + return { + 'text': text, + 'input_ids': output['input_ids'], + 'attention_mask': output['attention_mask'], + 'offsets': output[0].offsets + } diff --git a/modelscope/preprocessors/nlp/sentence_classification_preprocessor.py b/modelscope/preprocessors/nlp/sentence_classification_preprocessor.py new file mode 100644 index 00000000..f1295c50 --- /dev/null +++ b/modelscope/preprocessors/nlp/sentence_classification_preprocessor.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.nli_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sen_sim_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer) +class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in sequence classification. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, mode=mode, **kwargs) diff --git a/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py b/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py new file mode 100644 index 00000000..519de60c --- /dev/null +++ b/modelscope/preprocessors/nlp/sentence_embedding_preprocessor.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sentence_embedding) +class SentenceEmbeddingPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in sentence embedding. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, mode=mode, **kwargs) + + def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]: + """process the raw input data + + Args: + data Dict: + keys: "source_sentence" && "sentences_to_compare" + values: list of sentences + Example: + {"source_sentence": ["how long it take to get a master's degree"], + "sentences_to_compare": ["On average, students take about 18 to 24 months + to complete a master's degree.", + "On the other hand, some students prefer to go at a slower pace + and choose to take several years to complete their studies.", + "It can take anywhere from two semesters"]} + Returns: + Dict[str, Any]: the preprocessed data + """ + source_sentence = data['source_sentence'] + compare_sentences = data['sentences_to_compare'] + sentences = [] + sentences.append(source_sentence[0]) + for sent in compare_sentences: + sentences.append(sent) + + tokenized_inputs = self.tokenizer( + sentences, + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, + padding=True, + truncation=True) + return tokenized_inputs diff --git a/modelscope/preprocessors/nlp/sentence_piece_preprocessor.py b/modelscope/preprocessors/nlp/sentence_piece_preprocessor.py new file mode 100644 index 00000000..1d1ef19d --- /dev/null +++ b/modelscope/preprocessors/nlp/sentence_piece_preprocessor.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +import sentencepiece as spm +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sentence_piece) +class SentencePiecePreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + import os + + super().__init__(*args, **kwargs) + self.tokenizer = None + for file_name in os.listdir(model_dir): + if file_name.endswith('.model'): + m_file = osp.join(model_dir, file_name) + self.tokenizer = spm.SentencePieceProcessor(model_file=m_file) + break + assert self.tokenizer is not None, 'Can not find .model file' + + def __call__(self, data: str) -> Dict[str, Any]: + return torch.tensor(self.tokenizer.encode([data]), dtype=torch.long) diff --git a/modelscope/preprocessors/nlp/space/__init__.py b/modelscope/preprocessors/nlp/space/__init__.py new file mode 100644 index 00000000..b484dabe --- /dev/null +++ b/modelscope/preprocessors/nlp/space/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .data_loader import DataLoader + from .dialog_intent_prediction_preprocessor import \ + DialogIntentPredictionPreprocessor + from .dialog_modeling_preprocessor import DialogModelingPreprocessor + from .dialog_state_tracking_preprocessor import DialogStateTrackingPreprocessor + from .dst_processors import InputFeatures + from .fields import MultiWOZBPETextField, IntentBPETextField + +else: + _import_structure = { + 'data_loader': ['DataLoader'], + 'dialog_intent_prediction_preprocessor': + ['DialogIntentPredictionPreprocessor'], + 'dialog_modeling_preprocessor': ['DialogModelingPreprocessor'], + 'dialog_state_tracking_preprocessor': + ['DialogStateTrackingPreprocessor'], + 'dst_processors': ['InputFeatures'], + 'fields': ['MultiWOZBPETextField', 'IntentBPETextField'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/nlp/space/args.py b/modelscope/preprocessors/nlp/space/args.py new file mode 100644 index 00000000..17c6828b --- /dev/null +++ b/modelscope/preprocessors/nlp/space/args.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse + +import json + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Unsupported value encountered.') + + +class HParams(dict): + """ Hyper-parameters class + + Store hyper-parameters in training / infer / ... scripts. + """ + + def __getattr__(self, name): + if name in self.keys(): + return self[name] + for v in self.values(): + if isinstance(v, HParams): + if name in v: + return v[name] + raise AttributeError(f"'HParams' object has no attribute '{name}'") + + def __setattr__(self, name, value): + self[name] = value + + def save(self, filename): + with open(filename, 'w', encoding='utf-8') as fp: + json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) + + def load(self, filename): + with open(filename, 'r', encoding='utf-8') as fp: + params_dict = json.load(fp) + for k, v in params_dict.items(): + if isinstance(v, dict): + self[k].update(HParams(v)) + else: + self[k] = v + + +def parse_args(parser): + """ Parse hyper-parameters from cmdline. """ + parsed = parser.parse_args() + args = HParams() + optional_args = parser._action_groups[1] + for action in optional_args._group_actions[1:]: + arg_name = action.dest + args[arg_name] = getattr(parsed, arg_name) + for group in parser._action_groups[2:]: + group_args = HParams() + for action in group._group_actions: + arg_name = action.dest + group_args[arg_name] = getattr(parsed, arg_name) + if len(group_args) > 0: + args[group.title] = group_args + return args diff --git a/modelscope/preprocessors/nlp/space/batch.py b/modelscope/preprocessors/nlp/space/batch.py new file mode 100644 index 00000000..d27776f5 --- /dev/null +++ b/modelscope/preprocessors/nlp/space/batch.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + + +def batch(reader, batch_size, drop_last=False): + """ + This operator creates a batched reader which combines the data from the + input reader to batched data. + + Args: + reader(generator): the data reader to read from. + batch_size(int): size of each mini-batch. + drop_last(bool, optional): If set to True, the last batch is dropped when + the size of last batch is not equal to batch_size, if set to False, + it will not. Default: False. + Returns: + The batched reader. + + Return Type: + generator + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + def reader(): + for i in range(10): + yield i + batch_reader = fluid.io.batch(reader, batch_size=2) + + for data in batch_reader(): + print(data) + + # Output is + # [0, 1] + # [2, 3] + # [4, 5] + # [6, 7] + # [8, 9] + """ + + def batch_reader(): + r = reader() + b = [] + for instance in r: + b.append(instance) + if len(b) == batch_size: + yield b + b = [] + if drop_last is False and len(b) != 0: + yield b + + # Batch size check + batch_size = int(batch_size) + if batch_size <= 0: + raise ValueError('batch_size should be a positive integeral value, ' + 'but got batch_size={}'.format(batch_size)) + + return batch_reader diff --git a/modelscope/preprocessors/nlp/space/data_loader.py b/modelscope/preprocessors/nlp/space/data_loader.py new file mode 100644 index 00000000..290b64f3 --- /dev/null +++ b/modelscope/preprocessors/nlp/space/data_loader.py @@ -0,0 +1,110 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os + +import numpy as np + +from modelscope.preprocessors.nlp.space.args import str2bool +from modelscope.preprocessors.nlp.space.batch import batch +from modelscope.preprocessors.nlp.space.lazy_dataset import LazyDataset +from modelscope.preprocessors.nlp.space.sampler import (RandomSampler, + SequentialSampler, + SortedSampler) + + +def get_data_loader(batch_size, reader, hparams, file, collate_fn, is_test): + assert os.path.exists(file), f"{file} doesn't exist" + dataset = LazyDataset(file, reader=reader) + data_loader = DataLoader( + dataset, + batch_size, + hparams.Trainer, + collate_fn=collate_fn, + is_test=is_test) + return data_loader + + +def get_sequential_data_loader(batch_size, reader, hparams, data_paths, + collate_fn, data_type): + data_loaders = [] + for data_path in data_paths: + file = os.path.join( + data_path, + f'{data_type}.{hparams.BPETextField.tokenizer_type}.jsonl') + data_loaders.append( + get_data_loader( + batch_size=batch_size, + reader=reader, + hparams=hparams, + file=file, + collate_fn=collate_fn, + is_test=(data_type != 'train'))) + data_loader = SequentialDataLoaderWrapper(data_loaders) + return data_loader + + +class DataLoader(object): + """ Implement of DataLoader. """ + + @classmethod + def add_cmdline_argument(cls, group): + group.add_argument('--shuffle', type=str2bool, default=True) + group.add_argument('--sort_pool_size', type=int, default=0) + return group + + def __init__(self, + dataset, + batch_size, + hparams, + collate_fn=None, + sampler=None, + is_test=False): + self.dataset = dataset + self.collate_fn = collate_fn + self.gpu = hparams.gpu + self.sort_pool_size = hparams.sort_pool_size + + if sampler is None: + if hparams.shuffle and not is_test: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + + if self.sort_pool_size > 0 and not is_test: + sampler = SortedSampler(sampler, self.sort_pool_size) + + def reader(): + for idx in sampler: + yield idx + + drop_last = False if self.gpu <= 1 or is_test else True + self.reader = batch(reader, batch_size=batch_size, drop_last=drop_last) + self.num_batches = math.floor(len(dataset) / batch_size) if drop_last \ + else math.ceil(len(dataset) / batch_size) + + def __len__(self): + return self.num_batches + + def __iter__(self): + for batch_indices in self.reader(): + samples = [self.dataset[idx] for idx in batch_indices] + yield self.collate_fn(samples) + + +class SequentialDataLoaderWrapper: + + def __init__(self, data_loaders): + self.data_loaders = data_loaders + self.data_file_to_dataset = { + data_loader.dataset.data_file: data_loader.dataset + for data_loader in self.data_loaders + } + + def __iter__(self): + for data_loader in self.data_loaders: + for tmp_batch in data_loader: + yield data_loader.dataset.data_file, tmp_batch + + def __len__(self): + return np.sum([len(data_loader) for data_loader in self.data_loaders]) diff --git a/modelscope/preprocessors/nlp/space/dialog_intent_prediction_preprocessor.py b/modelscope/preprocessors/nlp/space/dialog_intent_prediction_preprocessor.py new file mode 100644 index 00000000..2923157e --- /dev/null +++ b/modelscope/preprocessors/nlp/space/dialog_intent_prediction_preprocessor.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +import json + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.preprocessors.nlp import IntentBPETextField +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile +from modelscope.utils.type_assert import type_assert + +__all__ = ['DialogIntentPredictionPreprocessor'] + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.dialog_intent_preprocessor) +class DialogIntentPredictionPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + self.text_field = IntentBPETextField( + self.model_dir, config=self.config) + + self.categories = None + with open(os.path.join(self.model_dir, 'categories.json'), 'r') as f: + self.categories = json.load(f) + assert len(self.categories) == 77 + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'What do I need to do for the card activation?' + + Returns: + Dict[str, Any]: the preprocessed data + Example: + { + 'src_token': array([[13, 2054, 2079, 1045...]]), + 'src_pos': array([[ 0, 1, 2, 3...]]), + 'src_type': array([[1, 1, 1, 1...]]), + 'src_turn': array([[1, 1, 1, 1...]]), + 'src_mask': array([[1, 1, 1, 1...]]), + 'mlm_token': array([[13, 2054, 2079, 1045...]]), + 'mlm_label': array([[0, 0, 0, 0...]]), + 'mlm_mask': array([[0, 0, 0, 0...]]), + 'tgt_token': array([[29, 30, 31, 32...]]), + 'tgt_mask': array([[1, 1, 1, 1...]]), + 'ids': array([0]), + 'intent_label': array([-1]) + } + """ + samples = self.text_field.preprocessor([data]) + samples, _ = self.text_field.collate_fn_multi_turn(samples) + + return samples diff --git a/modelscope/preprocessors/nlp/space/dialog_modeling_preprocessor.py b/modelscope/preprocessors/nlp/space/dialog_modeling_preprocessor.py new file mode 100644 index 00000000..ae3c214a --- /dev/null +++ b/modelscope/preprocessors/nlp/space/dialog_modeling_preprocessor.py @@ -0,0 +1,79 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.preprocessors.nlp import MultiWOZBPETextField +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile +from modelscope.utils.type_assert import type_assert + +__all__ = ['DialogModelingPreprocessor'] + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.dialog_modeling_preprocessor) +class DialogModelingPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + import torch + self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() + + self.text_field = MultiWOZBPETextField( + config=self.config, model_dir=self.model_dir) + + @type_assert(object, Dict) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (Dict[str, Any]): A sentence and dialogue history info. + Example: + { + 'user_input': 'i want to leave after 17:15 .', + 'history': { + 'labels': [[13, 1045, 2052, 2066...]], + 'resp': [14, 1045, 2064, 2393...], + 'bspn': [15, 43, 7688, 10733...], + 'db': [19, 24, 20], + 'aspn': [16, 43, 48, 2681, 7180, 10], + 'output': ['i', 'can', 'help', 'with'...] + } + } + + Returns: + Dict[str, Any]: the preprocessed data + """ + import torch + first_turn = True if len(data['history']) == 0 else False + user_ids = self.text_field.get_ids(data['user_input']) + inputs, prompt_id = self.text_field.convert_turn_eval( + turn={'user': user_ids}, + pv_turn=data['history'], + first_turn=first_turn) + batch, batch_size = self.text_field.collate_fn_multi_turn( + samples=[inputs]) + + data['first_turn'] = first_turn + data['batch'] = batch + data['batch_size'] = batch_size + data['prompt_id'] = prompt_id + data['labels'] = [ + torch.Tensor(item).int() for item in inputs['labels'] + ] + + return data diff --git a/modelscope/preprocessors/nlp/space/dialog_state_tracking_preprocessor.py b/modelscope/preprocessors/nlp/space/dialog_state_tracking_preprocessor.py new file mode 100644 index 00000000..cff39577 --- /dev/null +++ b/modelscope/preprocessors/nlp/space/dialog_state_tracking_preprocessor.py @@ -0,0 +1,136 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields +from modelscope.utils.type_assert import type_assert +from .dst_processors import convert_examples_to_features, multiwoz22Processor + +__all__ = ['DialogStateTrackingPreprocessor'] + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.dialog_state_tracking_preprocessor) +class DialogStateTrackingPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + from modelscope.models.nlp.space import SpaceConfig, SpaceTokenizer + self.model_dir: str = model_dir + self.config = SpaceConfig.from_pretrained(self.model_dir) + self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir) + self.processor = multiwoz22Processor() + + @type_assert(object, dict) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (Dict[str, Any]): a sentence + Example: + { + 'utter': {'User-1': "Hi, I'm looking for a train that is going" + "to cambridge and arriving there by 20:45, is there anything like that?"}, + 'history_states': [{}] + } + + Returns: + Dict[str, Any]: the preprocessed data + """ + import torch + from torch.utils.data import (DataLoader, RandomSampler, + SequentialSampler) + + utter = data['utter'] + history_states = data['history_states'] + example = self.processor.create_example( + inputs=utter, + history_states=history_states, + set_type='test', + slot_list=self.config.dst_slot_list, + label_maps={}, + append_history=True, + use_history_labels=True, + swap_utterances=True, + label_value_repetitions=True, + delexicalize_sys_utts=True, + unk_token='[UNK]', + analyze=False) + + features = convert_examples_to_features( + examples=[example], + slot_list=self.config.dst_slot_list, + class_types=self.config.dst_class_types, + model_type=self.config.model_type, + tokenizer=self.tokenizer, + max_seq_length=180, # args.max_seq_length + slot_value_dropout=(0.0)) + + all_input_ids = torch.tensor([f.input_ids for f in features], + dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in features], + dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in features], + dtype=torch.long) + all_example_index = torch.arange( + all_input_ids.size(0), dtype=torch.long) + f_start_pos = [f.start_pos for f in features] + f_end_pos = [f.end_pos for f in features] + f_inform_slot_ids = [f.inform_slot for f in features] + f_refer_ids = [f.refer_id for f in features] + f_diag_state = [f.diag_state for f in features] + f_class_label_ids = [f.class_label_id for f in features] + all_start_positions = {} + all_end_positions = {} + all_inform_slot_ids = {} + all_refer_ids = {} + all_diag_state = {} + all_class_label_ids = {} + for s in self.config.dst_slot_list: + all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], + dtype=torch.long) + all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], + dtype=torch.long) + all_inform_slot_ids[s] = torch.tensor( + [f[s] for f in f_inform_slot_ids], dtype=torch.long) + all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], + dtype=torch.long) + all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], + dtype=torch.long) + all_class_label_ids[s] = torch.tensor( + [f[s] for f in f_class_label_ids], dtype=torch.long) + dataset = [ + all_input_ids, all_input_mask, all_segment_ids, + all_start_positions, all_end_positions, all_inform_slot_ids, + all_refer_ids, all_diag_state, all_class_label_ids, + all_example_index + ] + + with torch.no_grad(): + diag_state = { + slot: + torch.tensor([0 for _ in range(self.config.eval_batch_size) + ]).to(self.config.device) + for slot in self.config.dst_slot_list + } + + if len(history_states) > 2: + ds = history_states[-2] + else: + ds = {slot: 'none' for slot in self.config.dst_slot_list} + + return { + 'batch': dataset, + 'features': features, + 'diag_state': diag_state, + 'ds': ds + } diff --git a/modelscope/preprocessors/nlp/space/dst_processors.py b/modelscope/preprocessors/nlp/space/dst_processors.py new file mode 100644 index 00000000..1f9920a9 --- /dev/null +++ b/modelscope/preprocessors/nlp/space/dst_processors.py @@ -0,0 +1,1441 @@ +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +import json +import numpy as np +import six +from tqdm import tqdm + +logger = logging.getLogger(__name__) +USER_NAME = 'User' +SYSTEM_NAME = 'System' +DIALOG_ACT = 'Dialog_Act' + + +class DSTProcessor(object): + ACTS_DICT = { + 'taxi-depart': 'taxi-departure', + 'taxi-dest': 'taxi-destination', + 'taxi-leaveat': 'taxi-leaveAt', + 'taxi-arriveby': 'taxi-arriveBy', + 'train-depart': 'train-departure', + 'train-dest': 'train-destination', + 'train-leaveat': 'train-leaveAt', + 'train-arriveby': 'train-arriveBy', + 'train-bookpeople': 'train-book_people', + 'restaurant-price': 'restaurant-pricerange', + 'restaurant-bookpeople': 'restaurant-book_people', + 'restaurant-bookday': 'restaurant-book_day', + 'restaurant-booktime': 'restaurant-book_time', + 'hotel-price': 'hotel-pricerange', + 'hotel-bookpeople': 'hotel-book_people', + 'hotel-bookday': 'hotel-book_day', + 'hotel-bookstay': 'hotel-book_stay', + 'booking-bookpeople': 'booking-book_people', + 'booking-bookday': 'booking-book_day', + 'booking-bookstay': 'booking-book_stay', + 'booking-booktime': 'booking-book_time', + } + + LABEL_MAPS = {} # Loaded from file + + def __init__(self): + # Required for mapping slot names in dialogue_acts.json file + # to proper designations. + pass + + def _convert_inputs_to_utterances(self, inputs: dict, + history_states: list): + """This method is to generate the utterances with user, sys, dialog_acts and metadata, + while metadata is from the history_states or the output from the inference pipline""" + + utterances = [] + user_inputs = [] + sys_gen_inputs = [] + dialog_acts_inputs = [] + for i, item in enumerate(inputs): + name, turn = item.split('-') + if name == USER_NAME: + user_inputs.insert(int(turn) - 1, inputs[item]) + elif name == SYSTEM_NAME: + sys_gen_inputs.insert(int(turn) - 1, inputs[item]) + else: + dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) + + # user is leading the topic should aways larger than sys and dialog acts + assert len(user_inputs) - 1 == len(sys_gen_inputs) + assert len(user_inputs) - 1 == len(dialog_acts_inputs) + # the history states record both user and sys states + assert len(history_states) == len(user_inputs) + len(sys_gen_inputs) + + # the dialog_act at user turn is useless + for i, item in enumerate(history_states): + utterance = {} + # the dialog_act at user turn is useless + utterance['dialog_act'] = dialog_acts_inputs[ + i // 2] if i % 2 == 1 else {} + utterance['text'] = sys_gen_inputs[ + i // 2] if i % 2 == 1 else user_inputs[i // 2] + utterance['metadata'] = item + utterance['span_info'] = [] + utterances.append(utterance) + + return utterances + + def _load_acts(self, inputs: dict, dialog_id='example.json'): + dialog_acts_inputs = [] + for i, item in enumerate(inputs): + name, turn = item.split('-') + if name == DIALOG_ACT: + dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) + s_dict = {} + + for j, item in enumerate(dialog_acts_inputs): + if isinstance(item, dict): + for a in item: + aa = a.lower().split('-') + if aa[1] == 'inform' or aa[1] == 'recommend' or \ + aa[1] == 'select' or aa[1] == 'book': + for i in item[a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?' or v == 'none': + continue + slot = aa[0] + '-' + s + if slot in self.ACTS_DICT: + slot = self.ACTS_DICT[slot] + key = dialog_id, str(int(j) + 1), slot + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + # s_dict[key] = list([v]) + + return s_dict + + +class multiwoz22Processor(DSTProcessor): + + def __init__(self): + super().__init__() + + def normalize_time(self, text): + text = re.sub(r'(\d{1})(a\.?m\.?|p\.?m\.?)', r'\1 \2', + text) # am/pm without space + text = re.sub(r'(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)', r'\1\2:00 \3', + text) # am/pm short to long form + text = re.sub( + r'(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)', + r'\1\2 \3:\4\5', text) # Missing separator + text = re.sub(r'(^| )(\d{2})[;.,](\d{2})', r'\1\2:\3', + text) # Wrong separator + text = re.sub(r'(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)', + r'\1\2 \3:00\4', text) # normalize simple full hour time + text = re.sub(r'(^| )(\d{1}:\d{2})', r'\g<1>0\2', + text) # Add missing leading 0 + # Map 12 hour times to 24 hour times + text = \ + re.sub( + r'(\d{2})(:\d{2}) ?p\.?m\.?', + lambda x: str(int(x.groups()[0]) + 12 + if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text) + text = re.sub(r'(^| )24:(\d{2})', r'\g<1>00:\2', + text) # Correct times that use 24 as hour + return text + + def normalize_text(self, text): + text = self.normalize_time(text) + text = re.sub("n't", ' not', text) + text = re.sub('(^| )zero(-| )star([s.,? ]|$)', r'\g<1>0 star\3', text) + text = re.sub('(^| )one(-| )star([s.,? ]|$)', r'\g<1>1 star\3', text) + text = re.sub('(^| )two(-| )star([s.,? ]|$)', r'\g<1>2 star\3', text) + text = re.sub('(^| )three(-| )star([s.,? ]|$)', r'\g<1>3 star\3', text) + text = re.sub('(^| )four(-| )star([s.,? ]|$)', r'\g<1>4 star\3', text) + text = re.sub('(^| )five(-| )star([s.,? ]|$)', r'\g<1>5 star\3', text) + text = re.sub('archaelogy', 'archaeology', text) # Systematic typo + text = re.sub('guesthouse', 'guest house', text) # Normalization + text = re.sub('(^| )b ?& ?b([.,? ]|$)', r'\1bed and breakfast\2', + text) # Normalization + text = re.sub('bed & breakfast', 'bed and breakfast', + text) # Normalization + return text + + # Loads the dialogue_acts.json and returns a list + # of slot-value pairs. + def load_acts(self, input_file): + with open(input_file) as f: + acts = json.load(f) + s_dict = {} + for d in acts: + for t in acts[d]: + if int(t) % 2 == 0: + continue + # Only process, if turn has annotation + if isinstance(acts[d][t]['dialog_act'], dict): + for a in acts[d][t]['dialog_act']: + aa = a.lower().split('-') + if aa[1] == 'inform' or aa[1] == 'recommend' \ + or aa[1] == 'select' or aa[1] == 'book': + for i in acts[d][t]['dialog_act'][a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?' or v == 'none': + continue + slot = aa[0] + '-' + s + if slot in self.ACTS_DICT: + slot = self.ACTS_DICT[slot] + key = d, str(int(t) // 2 + 1), slot + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + # s_dict[key] = list([v]) + return s_dict + + # This should only contain label normalizations. All other mappings should + # be defined in LABEL_MAPS. + def normalize_label(self, slot, value_label): + # Normalization of empty slots + if value_label == '' or value_label == 'not mentioned': + return 'none' + + # Normalization of time slots + if 'leaveAt' in slot or 'arriveBy' in slot or slot == 'restaurant-book_time': + return self.normalize_time(value_label) + + # Normalization + if 'type' in slot or 'name' in slot or 'destination' in slot or 'departure' in slot: + value_label = re.sub('guesthouse', 'guest house', value_label) + + # Map to boolean slots + if slot == 'hotel-parking' or slot == 'hotel-internet': + if value_label == 'yes' or value_label == 'free': + return 'true' + if value_label == 'no': + return 'false' + if slot == 'hotel-type': + if value_label == 'hotel': + return 'true' + if value_label == 'guest house': + return 'false' + + return value_label + + def tokenize(self, utt): + utt_lower = convert_to_unicode(utt).lower() + utt_lower = self.normalize_text(utt_lower) + utt_tok = [ + tok for tok in map(str.strip, re.split(r'(\W+)', utt_lower)) + if len(tok) > 0 + ] + return utt_tok + + def delex_utt(self, utt, values, unk_token='[UNK]'): + utt_norm = self.tokenize(utt) + for s, vals in values.items(): + for v in vals: + if v != 'none': + v_norm = self.tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = [unk_token] * v_len + return utt_norm + + def get_token_pos(self, tok_list, value_label): + find_pos = [] + found = False + label_list = [ + item for item in map(str.strip, re.split(r'(\W+)', value_label)) + if len(item) > 0 + ] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + def check_label_existence(self, value_label, usr_utt_tok): + in_usr, usr_pos = self.get_token_pos(usr_utt_tok, value_label) + # If no hit even though there should be one, check for value label variants + if not in_usr and value_label in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[value_label]: + in_usr, usr_pos = self.get_token_pos(usr_utt_tok, + value_label_variant) + if in_usr: + break + return in_usr, usr_pos + + def check_slot_referral(self, value_label, slot, seen_slots): + referred_slot = 'none' + if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking': + return referred_slot + for s in seen_slots: + # Avoid matches for slots that share values with different meaning. + # hotel-internet and -parking are handled separately as Boolean slots. + if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking': + continue + if re.match('(hotel|restaurant)-book_people', + s) and slot == 'hotel-book_stay': + continue + if re.match('(hotel|restaurant)-book_people', + slot) and s == 'hotel-book_stay': + continue + if slot != s and (slot not in seen_slots + or seen_slots[slot] != value_label): + if seen_slots[s] == value_label: + referred_slot = s + break + elif value_label in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[value_label]: + if seen_slots[s] == value_label_variant: + referred_slot = s + break + return referred_slot + + def is_in_list(self, tok, value): + found = False + tok_list = [ + item for item in map(str.strip, re.split(r'(\W+)', tok)) + if len(item) > 0 + ] + value_list = [ + item for item in map(str.strip, re.split(r'(\W+)', value)) + if len(item) > 0 + ] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + # Fuzzy matching to label informed slot values + def check_slot_inform(self, value_label, inform_label): + result = False + informed_value = 'none' + vl = ' '.join(self.tokenize(value_label)) + for il in inform_label: + if vl == il: + result = True + elif self.is_in_list(il, vl): + result = True + elif self.is_in_list(vl, il): + result = True + elif il in self.LABEL_MAPS: + for il_variant in self.LABEL_MAPS[il]: + if vl == il_variant: + result = True + break + elif self.is_in_list(il_variant, vl): + result = True + break + elif self.is_in_list(vl, il_variant): + result = True + break + elif vl in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[vl]: + if value_label_variant == il: + result = True + break + elif self.is_in_list(il, value_label_variant): + result = True + break + elif self.is_in_list(value_label_variant, il): + result = True + break + if result: + informed_value = il + break + return result, informed_value + + def get_turn_label(self, value_label, inform_label, sys_utt_tok, + usr_utt_tok, slot, seen_slots, slot_last_occurrence): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + informed_value = 'none' + referred_slot = 'none' + if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false': + class_type = value_label + else: + in_usr, usr_pos = self.check_label_existence( + value_label, usr_utt_tok) + is_informed, informed_value = self.check_slot_inform( + value_label, inform_label) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + elif is_informed: + class_type = 'inform' + else: + referred_slot = self.check_slot_referral( + value_label, slot, seen_slots) + if referred_slot != 'none': + class_type = 'refer' + else: + class_type = 'unpointable' + return informed_value, referred_slot, usr_utt_tok_label, class_type + + def _create_example(self, + utterances, + sys_inform_dict, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False, + dialog_id='example.json'): + + # Collects all slot changes throughout the dialog + # cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [] + + # Collect all utterances and their metadata + usr_sys_switch = True + turn_itr = 0 + + inform_dict = {slot: 'none' for slot in slot_list} + for utt in utterances: + # Assert that system and user utterances alternate + is_sys_utt = utt['metadata'] != {} + if usr_sys_switch == is_sys_utt: + print( + 'WARN: Wrong order of system and user utterances. Skipping rest of the dialog %s' + % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + if is_sys_utt: + turn_itr += 1 + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + inform_dict = {slot: 'none' for slot in slot_list} + for slot in slot_list: + if (str(dialog_id), str(turn_itr), + slot) in sys_inform_dict: + inform_dict[slot] = sys_inform_dict[(str(dialog_id), + str(turn_itr), + slot)] + utt_tok_list.append( + self.delex_utt(utt['text'], inform_dict, + unk_token)) # normalize utterances + else: + utt_tok_list.append(self.tokenize( + utt['text'])) # normalize utterances + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + hst_utt_tok = [] + hst_utt_tok_label_dict = {slot: [] for slot in slot_list} + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + + ###### + mod_slots_list = [] + ##### + + for i in range(0, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + # inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + + # Collect turn data + if append_history: + if swap_utterances: + hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok + else: + hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok + sys_utt_tok = utt_tok_list[i] + usr_utt_tok = utt_tok_list[i + 1] + turn_slots = mod_slots_list[ + i + 1] if len(mod_slots_list) > 1 else {} + + guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) + + if analyze: + print('%15s %2s %s ||| %s' % + (dialog_id, turn_itr, ' '.join(sys_utt_tok), + ' '.join(usr_utt_tok))) + print('%15s %2s [' % (dialog_id, turn_itr), end='') + + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + inform_slot_dict[slot] = 0 + if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: + inform_label = list([ + self.normalize_label(slot, i) + for i in sys_inform_dict[(str(dialog_id), + str(turn_itr), slot)] + ]) + inform_slot_dict[slot] = 1 + elif (str(dialog_id), str(turn_itr), + 'booking-' + slot.split('-')[1]) in sys_inform_dict: + inform_label = list([ + self.normalize_label(slot, i) + for i in sys_inform_dict[(str(dialog_id), + str(turn_itr), 'booking-' + + slot.split('-')[1])] + ]) + inform_slot_dict[slot] = 1 + + (informed_value, referred_slot, usr_utt_tok_label, + class_type) = self.get_turn_label( + value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True) + + # inform_dict[slot] = informed_value + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list( + diag_seen_slots_value_dict.values()).count( + value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if append_history: + if use_history_labels: + if swap_utterances: + new_hst_utt_tok_label_dict[ + slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[ + slot] + else: + new_hst_utt_tok_label_dict[ + slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[ + slot] + else: + new_hst_utt_tok_label_dict[slot] = [ + 0 for _ in sys_utt_tok_label + usr_utt_tok_label + + new_hst_utt_tok_label_dict[slot] + ] + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[ + slot]: + print('(%s): %s, ' % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] \ + and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print(']') + + if swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + """ + text_a: dialog text + text_b: dialog text + history: dialog text + text_a_label: label,ignore during inference,turns to start/end pos + text_b_label: label,ignore during inference,turns to start/end pos + history_label: label,ignore during inference,turns to start/end pos + values: ignore during inference + inform_label: ignore during inference + inform_slot_label: input, system dialog action + refer_label: label,ignore during inference,turns to start/end pos refer_id + diag_state: input, history dialog state + class_label: label,ignore during inference,turns to start/end pos class_label_id + """ + example = DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + history=hst_utt_tok, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + history_label=hst_utt_tok_label_dict, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + class_label=class_type_dict) + # Update some variables. + hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy() + diag_state = new_diag_state.copy() + + turn_itr += 1 + return example + + def create_example(self, + inputs, + history_states, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False, + dialog_id='0'): + utterances = self._convert_inputs_to_utterances(inputs, history_states) + sys_inform_dict = self._load_acts(inputs) + self.LABEL_MAPS = label_maps + example = self._create_example(utterances, sys_inform_dict, set_type, + slot_list, label_maps, append_history, + use_history_labels, swap_utterances, + label_value_repetitions, + delexicalize_sys_utts, unk_token, + analyze) + + return example + + def create_examples(self, + input_file, + acts_file, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + sys_inform_dict = self.load_acts(acts_file) + + with open(input_file, 'r', encoding='utf-8') as reader: + input_data = json.load(reader) + + self.LABEL_MAPS = label_maps + + examples = [] + for dialog_id in tqdm(input_data): + entry = input_data[dialog_id] + utterances = entry['log'] + + example = self._create_example( + utterances, sys_inform_dict, set_type, slot_list, label_maps, + append_history, use_history_labels, swap_utterances, + label_value_repetitions, delexicalize_sys_utts, unk_token, + analyze) + examples.append(example) + + return examples + + +class DSTExample(object): + """ + A single training/test example for the DST dataset. + """ + + def __init__(self, + guid, + text_a, + text_b, + history, + text_a_label=None, + text_b_label=None, + history_label=None, + values=None, + inform_label=None, + inform_slot_label=None, + refer_label=None, + diag_state=None, + class_label=None): + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.history = history + self.text_a_label = text_a_label + self.text_b_label = text_b_label + self.history_label = history_label + self.values = values + self.inform_label = inform_label + self.inform_slot_label = inform_slot_label + self.refer_label = refer_label + self.diag_state = diag_state + self.class_label = class_label + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s_dict = dict() + s_dict['guid'] = self.guid + s_dict['text_a'] = self.text_a + s_dict['text_b'] = self.text_b + s_dict['history'] = self.history + if self.text_a_label: + s_dict['text_a_label'] = self.text_a_label + if self.text_b_label: + s_dict['text_b_label'] = self.text_b_label + if self.history_label: + s_dict['history_label'] = self.history_label + if self.values: + s_dict['values'] = self.values + if self.inform_label: + s_dict['inform_label'] = self.inform_label + if self.inform_slot_label: + s_dict['inform_slot_label'] = self.inform_slot_label + if self.refer_label: + s_dict['refer_label'] = self.refer_label + if self.diag_state: + s_dict['diag_state'] = self.diag_state + if self.class_label: + s_dict['class_label'] = self.class_label + + s = json.dumps(s_dict) + return s + + +class InputFeatures(object): + """A single set of features of data.""" + + def __init__(self, + input_ids, + input_ids_unmasked, + input_mask, + segment_ids, + start_pos=None, + end_pos=None, + values=None, + inform=None, + inform_slot=None, + refer_id=None, + diag_state=None, + class_label_id=None, + guid='NONE'): + self.guid = guid + self.input_ids = input_ids + self.input_ids_unmasked = input_ids_unmasked + self.input_mask = input_mask + self.segment_ids = segment_ids + self.start_pos = start_pos + self.end_pos = end_pos + self.values = values + self.inform = inform + self.inform_slot = inform_slot + self.refer_id = refer_id + self.diag_state = diag_state + self.class_label_id = class_label_id + + +def convert_examples_to_features(examples, + slot_list, + class_types, + model_type, + tokenizer, + max_seq_length, + slot_value_dropout=0.0): + """Loads a data file into a list of `InputBatch`s.""" + + if model_type == 'bert': + model_specs = { + 'MODEL_TYPE': 'bert', + 'CLS_TOKEN': '[CLS]', + 'UNK_TOKEN': '[UNK]', + 'SEP_TOKEN': '[SEP]', + 'TOKEN_CORRECTION': 4 + } + else: + logger.error('Unknown model type (%s). Aborting.' % (model_type)) + exit(1) + + def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, + model_specs, slot_value_dropout): + joint_text_label = [0 for _ in text_label_dict[slot] + ] # joint all slots' label + for slot_text_label in text_label_dict.values(): + for idx, label in enumerate(slot_text_label): + if label == 1: + joint_text_label[idx] = 1 + + text_label = text_label_dict[slot] + tokens = [] + tokens_unmasked = [] + token_labels = [] + for token, token_label, joint_label in zip(text, text_label, + joint_text_label): + token = convert_to_unicode(token) + sub_tokens = tokenizer.tokenize(token) # Most time intensive step + tokens_unmasked.extend(sub_tokens) + if slot_value_dropout == 0.0 or joint_label == 0: + tokens.extend(sub_tokens) + else: + rn_list = np.random.random_sample((len(sub_tokens), )) + for rn, sub_token in zip(rn_list, sub_tokens): + if rn > slot_value_dropout: + tokens.append(sub_token) + else: + tokens.append(model_specs['UNK_TOKEN']) + token_labels.extend([token_label for _ in sub_tokens]) + assert len(tokens) == len(token_labels) + assert len(tokens_unmasked) == len(token_labels) + return tokens, tokens_unmasked, token_labels + + def _truncate_seq_pair(tokens_a, tokens_b, history, max_length): + """Truncates a sequence pair in place to the maximum length. + Copied from bert/run_classifier.py + """ + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + len(history) + if total_length <= max_length: + break + if len(history) > 0: + history.pop() + elif len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, + model_specs, guid): + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT) + if len(tokens_a) + len(tokens_b) + len( + history) > max_seq_length - model_specs['TOKEN_CORRECTION']: + # logger.info('Truncate Example %s. Total len=%d.' % + # (guid, len(tokens_a) + len(tokens_b) + len(history))) + input_text_too_long = True + else: + input_text_too_long = False + _truncate_seq_pair(tokens_a, tokens_b, history, + max_seq_length - model_specs['TOKEN_CORRECTION']) + return input_text_too_long + + def _get_token_label_ids(token_labels_a, token_labels_b, + token_labels_history, max_seq_length, + model_specs): + token_label_ids = [] + token_label_ids.append(0) # [CLS] + for token_label in token_labels_a: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_b: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_history: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + while len(token_label_ids) < max_seq_length: + token_label_ids.append(0) # padding + assert len(token_label_ids) == max_seq_length + return token_label_ids + + def _get_start_end_pos(class_type, token_label_ids, max_seq_length): + if class_type == 'copy_value' and 1 not in token_label_ids: + class_type = 'none' + start_pos = 0 + end_pos = 0 + if 1 in token_label_ids: + start_pos = token_label_ids.index(1) + # Parsing is supposed to find only first location of wanted value + if 0 not in token_label_ids[start_pos:]: + end_pos = len(token_label_ids[start_pos:]) + start_pos - 1 + else: + end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1 + for i in range(max_seq_length): + if i >= start_pos and i <= end_pos: + assert token_label_ids[i] == 1 + return class_type, start_pos, end_pos + + def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, + tokenizer, model_specs): + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + tokens = [] + segment_ids = [] + tokens.append(model_specs['CLS_TOKEN']) + segment_ids.append(0) + for token in tokens_a: + tokens.append(token) + segment_ids.append(0) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(0) + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + for token in history: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + return tokens, input_ids, input_mask, segment_ids + + total_cnt = 0 + too_long_cnt = 0 + + refer_list = ['none'] + slot_list + + features = [] + # Convert single example + for (example_index, example) in enumerate(examples): + + total_cnt += 1 + + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + refer_id_dict = {} + diag_state_dict = {} + class_label_id_dict = {} + start_pos_dict = {} + end_pos_dict = {} + for slot in slot_list: + tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label( + example.text_a, example.text_a_label, slot, tokenizer, + model_specs, slot_value_dropout) + tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label( + example.text_b, example.text_b_label, slot, tokenizer, + model_specs, slot_value_dropout) + tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label( + example.history, example.history_label, slot, tokenizer, + model_specs, slot_value_dropout) + + input_text_too_long = _truncate_length_and_warn( + tokens_a, tokens_b, tokens_history, max_seq_length, + model_specs, example.guid) + + if input_text_too_long: + + token_labels_a = token_labels_a[:len(tokens_a)] + token_labels_b = token_labels_b[:len(tokens_b)] + token_labels_history = token_labels_history[:len(tokens_history + )] + tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)] + tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)] + tokens_history_unmasked = tokens_history_unmasked[:len( + tokens_history)] + + assert len(token_labels_a) == len(tokens_a) + assert len(token_labels_b) == len(tokens_b) + assert len(token_labels_history) == len(tokens_history) + assert len(token_labels_a) == len(tokens_a_unmasked) + assert len(token_labels_b) == len(tokens_b_unmasked) + assert len(token_labels_history) == len(tokens_history_unmasked) + token_label_ids = _get_token_label_ids(token_labels_a, + token_labels_b, + token_labels_history, + max_seq_length, model_specs) + + value_dict[slot] = example.values[slot] + inform_dict[slot] = example.inform_label[slot] + + class_label_mod, start_pos_dict[slot], end_pos_dict[ + slot] = _get_start_end_pos(example.class_label[slot], + token_label_ids, max_seq_length) + if class_label_mod != example.class_label[slot]: + example.class_label[slot] = class_label_mod + inform_slot_dict[slot] = example.inform_slot_label[slot] + refer_id_dict[slot] = refer_list.index(example.refer_label[slot]) + diag_state_dict[slot] = class_types.index(example.diag_state[slot]) + class_label_id_dict[slot] = class_types.index( + example.class_label[slot]) + + if input_text_too_long: + too_long_cnt += 1 + + tokens, input_ids, input_mask, segment_ids = _get_transformer_input( + tokens_a, tokens_b, tokens_history, max_seq_length, tokenizer, + model_specs) + if slot_value_dropout > 0.0: + _, input_ids_unmasked, _, _ = _get_transformer_input( + tokens_a_unmasked, tokens_b_unmasked, tokens_history_unmasked, + max_seq_length, tokenizer, model_specs) + else: + input_ids_unmasked = input_ids + + assert (len(input_ids) == len(input_ids_unmasked)) + + features.append( + InputFeatures( + guid=example.guid, + input_ids=input_ids, + input_ids_unmasked=input_ids_unmasked, + input_mask=input_mask, + segment_ids=segment_ids, + start_pos=start_pos_dict, + end_pos=end_pos_dict, + values=value_dict, + inform=inform_dict, + inform_slot=inform_slot_dict, + refer_id=refer_id_dict, + diag_state=diag_state_dict, + class_label_id=class_label_id_dict)) + + return features + + +# From bert.tokenization (TF code) +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode('utf-8', 'ignore') + elif isinstance(text, unicode): + return text + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +if __name__ == '__main__': + processor = multiwoz22Processor() + set_type = 'test' + slot_list = [ + 'taxi-leaveAt', 'taxi-destination', 'taxi-departure', 'taxi-arriveBy', + 'restaurant-book_people', 'restaurant-book_day', + 'restaurant-book_time', 'restaurant-food', 'restaurant-pricerange', + 'restaurant-name', 'restaurant-area', 'hotel-book_people', + 'hotel-book_day', 'hotel-book_stay', 'hotel-name', 'hotel-area', + 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-internet', + 'hotel-type', 'attraction-type', 'attraction-name', 'attraction-area', + 'train-book_people', 'train-leaveAt', 'train-destination', 'train-day', + 'train-arriveBy', 'train-departure' + ] + append_history = True + use_history_labels = True + swap_utterances = True + label_value_repetitions = True + delexicalize_sys_utts = True, + unk_token = '[UNK]' + analyze = False + + utter1 = { + 'User-1': + 'am looking for a place to to stay that has cheap price range it should be in a type of hotel' + } + history_states1 = [ + {}, + ] + utter2 = { + 'User-1': + 'am looking for a place to to stay that has cheap price range it should be in a type of hotel', + 'System-1': + 'Okay, do you have a specific area you want to stay in?', + 'Dialog_Act-1': { + 'Hotel-Request': [['Area', '?']] + }, + 'User-2': + 'no, i just need to make sure it\'s cheap. oh, and i need parking', + } + + history_states2 = [{}, { + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'departure': '', + 'arriveBy': '' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'time': '' + }, + 'semi': { + 'food': '', + 'pricerange': '', + 'name': '', + 'area': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'stay': '' + }, + 'semi': { + 'name': 'not mentioned', + 'area': 'not mentioned', + 'parking': 'not mentioned', + 'pricerange': 'cheap', + 'stars': 'not mentioned', + 'internet': 'not mentioned', + 'type': 'hotel' + } + }, + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'type': '', + 'name': '', + 'area': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'day': '', + 'arriveBy': '', + 'departure': '' + } + } + }, {}] + + utter3 = { + 'User-1': + 'am looking for a place to to stay that has cheap price range it should be in a type of hotel', + 'System-1': 'Okay, do you have a specific area you want to stay in?', + 'Dialog_Act-1': { + 'Hotel-Request': [['Area', '?']] + }, + 'User-2': + 'no, i just need to make sure it\'s cheap. oh, and i need parking', + 'System-2': + 'I found 1 cheap hotel for you that includes parking. Do you like me to book it?', + 'Dialog_Act-2': { + 'Booking-Inform': [['none', 'none']], + 'Hotel-Inform': [['Price', 'cheap'], ['Choice', '1'], + ['Parking', 'none']] + }, + 'User-3': 'Yes, please. 6 people 3 nights starting on tuesday.' + } + + history_states3 = [{}, { + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'departure': '', + 'arriveBy': '' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'time': '' + }, + 'semi': { + 'food': '', + 'pricerange': '', + 'name': '', + 'area': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'stay': '' + }, + 'semi': { + 'name': 'not mentioned', + 'area': 'not mentioned', + 'parking': 'not mentioned', + 'pricerange': 'cheap', + 'stars': 'not mentioned', + 'internet': 'not mentioned', + 'type': 'hotel' + } + }, + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'type': '', + 'name': '', + 'area': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'day': '', + 'arriveBy': '', + 'departure': '' + } + } + }, {}, { + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'departure': '', + 'arriveBy': '' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'time': '' + }, + 'semi': { + 'food': '', + 'pricerange': '', + 'name': '', + 'area': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'stay': '' + }, + 'semi': { + 'name': 'not mentioned', + 'area': 'not mentioned', + 'parking': 'yes', + 'pricerange': 'cheap', + 'stars': 'not mentioned', + 'internet': 'not mentioned', + 'type': 'hotel' + } + }, + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'type': '', + 'name': '', + 'area': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'day': '', + 'arriveBy': '', + 'departure': '' + } + } + }, {}] + + example = processor.create_example(utter2, history_states2, set_type, + slot_list, {}, append_history, + use_history_labels, swap_utterances, + label_value_repetitions, + delexicalize_sys_utts, unk_token, + analyze) + print(f'utterances is {example}') diff --git a/modelscope/preprocessors/nlp/space/fields/__init__.py b/modelscope/preprocessors/nlp/space/fields/__init__.py new file mode 100644 index 00000000..475a99dc --- /dev/null +++ b/modelscope/preprocessors/nlp/space/fields/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .gen_field import MultiWOZBPETextField + from .intent_field import IntentBPETextField +else: + _import_structure = { + 'gen_field': ['MultiWOZBPETextField'], + 'intent_field': ['IntentBPETextField'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/nlp/space/fields/gen_field.py b/modelscope/preprocessors/nlp/space/fields/gen_field.py new file mode 100644 index 00000000..1d1879fe --- /dev/null +++ b/modelscope/preprocessors/nlp/space/fields/gen_field.py @@ -0,0 +1,889 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import random +from asyncio import constants +from collections import OrderedDict +from itertools import chain + +import json +import numpy as np + +from modelscope.preprocessors.nlp.space.tokenizer import Tokenizer +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.nlp.space import ontology, utils +from modelscope.utils.nlp.space.db_ops import MultiWozDB +from modelscope.utils.nlp.space.utils import list2np + +logger = get_logger() + + +class BPETextField(object): + + pad_token = '[PAD]' + bos_token = '[BOS]' + eos_token = '[EOS]' + unk_token = '[UNK]' + sos_u_token = '' + eos_u_token = '' + sos_b_token = '' + eos_b_token = '' + sos_d_token = '' + eos_d_token = '' + sos_a_token = '' + eos_a_token = '' + sos_db_token = '' + eos_db_token = '' + sos_r_token = '' + eos_r_token = '' + + @property + def bot_id(self): + return 0 + + @property + def user_id(self): + return 1 + + @property + def vocab_size(self): + return self.tokenizer.vocab_size + + @property + def num_specials(self): + return len(self.tokenizer.special_tokens) + + @property + def pad_id(self): + return self.tokenizer.convert_tokens_to_ids([self.pad_token])[0] + + @property + def bos_id(self): + return self.tokenizer.convert_tokens_to_ids([self.bos_token])[0] + + @property + def eos_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_token])[0] + + @property + def unk_id(self): + return self.tokenizer.convert_tokens_to_ids([self.unk_token])[0] + + @property + def sos_u_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_u_token])[0] + + @property + def eos_u_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_u_token])[0] + + @property + def sos_b_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_b_token])[0] + + @property + def eos_b_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_b_token])[0] + + @property + def sos_db_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_db_token])[0] + + @property + def eos_db_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_db_token])[0] + + @property + def sos_a_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_a_token])[0] + + @property + def eos_a_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_a_token])[0] + + @property + def sos_r_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_r_token])[0] + + @property + def eos_r_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_r_token])[0] + + @property + def sos_d_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_d_token])[0] + + @property + def eos_d_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0] + + def __init__(self, config): + self.train, self.dev, self.test = [], [], [] + self.gpu = config.Trainer.gpu + self.tokenizer = None + self.vocab = None + self.db = None + self.set_stats = {} + + self.prompt_num_for_understand = config.BPETextField.prompt_num_for_understand + self.prompt_num_for_policy = config.BPETextField.prompt_num_for_policy + self.understand_tokens = ontology.get_understand_tokens( + self.prompt_num_for_understand) + self.policy_tokens = ontology.get_policy_tokens( + self.prompt_num_for_policy) + + self.with_query_bow = config.BPETextField.with_query_bow + self.understand = config.BPETextField.understand + self.policy = config.BPETextField.policy + + self.batch_size = config.Trainer.batch_size + self.filtered = config.BPETextField.filtered + self.max_len = config.BPETextField.max_len + self.min_utt_len = config.BPETextField.min_utt_len + self.max_utt_len = config.BPETextField.max_utt_len + self.min_ctx_turn = config.BPETextField.min_ctx_turn + self.max_ctx_turn = config.BPETextField.max_ctx_turn - 1 # subtract reply turn + + self.use_true_prev_bspn = config.Generator.use_true_prev_bspn + self.use_true_prev_aspn = config.Generator.use_true_prev_aspn + self.use_true_db_pointer = config.Generator.use_true_db_pointer + self.use_true_prev_resp = config.Generator.use_true_prev_resp + self.use_true_curr_bspn = config.Generator.use_true_curr_bspn + self.use_true_curr_aspn = config.Generator.use_true_curr_aspn + self.use_all_previous_context = config.Generator.use_all_previous_context + self.use_true_bspn_for_ctr_eval = config.Generator.use_true_bspn_for_ctr_eval + self.use_true_domain_for_ctr_eval = config.Generator.use_true_domain_for_ctr_eval + + def collate_fn_multi_turn(self, samples): + batch_size = len(samples) + batch = {} + + src = [sp['src'][-self.max_ctx_turn:] for sp in samples] + query_token, src_token, src_pos, src_turn, src_role = [], [], [], [], [] + for utts in src: + query_token.append(utts[-1]) + utt_lens = [len(utt) for utt in utts] + + # Token ids + src_token.append(list(chain(*utts))[-self.max_len:]) + + # Position ids + pos = [list(range(utt_len)) for utt_len in utt_lens] + src_pos.append(list(chain(*pos))[-self.max_len:]) + + # Turn ids + turn = [[len(utts) - i] * l for i, l in enumerate(utt_lens)] + src_turn.append(list(chain(*turn))[-self.max_len:]) + + # Role ids + role = [ + [self.bot_id if (len(utts) - i) % 2 == 0 else self.user_id] * l + for i, l in enumerate(utt_lens) + ] + src_role.append(list(chain(*role))[-self.max_len:]) + + # src sequence and tgt sequence should be padded separately,to make sure the first word is aligned + src_token = list2np(src_token, padding=self.pad_id) + src_pos = list2np(src_pos, padding=self.pad_id) + src_turn = list2np(src_turn, padding=self.pad_id) + src_role = list2np(src_role, padding=self.pad_id) + batch['src_token'] = src_token + batch['src_pos'] = src_pos + batch['src_type'] = src_role + batch['src_turn'] = src_turn + batch['src_mask'] = (src_token != self.pad_id).astype('int64') + + if self.with_query_bow: + query_token = list2np(query_token, padding=self.pad_id) + batch['query_token'] = query_token + batch['query_mask'] = (query_token != self.pad_id).astype('int64') + + if self.understand_ids and self.understand: + understand = [self.understand_ids for _ in samples] + understand_token = np.array(understand).astype('int64') + batch['understand_token'] = understand_token + batch['understand_mask'] = \ + (understand_token != self.pad_id).astype('int64') + + if self.policy_ids and self.policy: + policy = [self.policy_ids for _ in samples] + policy_token = np.array(policy).astype('int64') + batch['policy_token'] = policy_token + batch['policy_mask'] = \ + (policy_token != self.pad_id).astype('int64') + + if 'tgt' in samples[0]: + tgt = [sp['tgt'] for sp in samples] + + # Token ids & Label ids + tgt_token = list2np(tgt, padding=self.pad_id) + + # Position ids + tgt_pos = np.zeros_like(tgt_token) + tgt_pos[:] = np.arange(tgt_token.shape[1], dtype=tgt_token.dtype) + + # Turn ids + tgt_turn = np.zeros_like(tgt_token) + + # Role ids + tgt_role = np.full_like(tgt_token, self.bot_id) + + batch['tgt_token'] = tgt_token + batch['tgt_pos'] = tgt_pos + batch['tgt_type'] = tgt_role + batch['tgt_turn'] = tgt_turn + batch['tgt_mask'] = (tgt_token != self.pad_id).astype('int64') + + return batch, batch_size + + def _bucket_by_turn(self, encoded_data): + turn_bucket = {} + for dial in encoded_data: + turn_len = len(dial) + if turn_len not in turn_bucket: + turn_bucket[turn_len] = [] + turn_bucket[turn_len].append(dial) + return OrderedDict(sorted(turn_bucket.items(), key=lambda i: i[0])) + + def _construct_mini_batch(self, data): + all_batches = [] + batch = [] + for dial in data: + batch.append(dial) + if len(batch) == self.batch_size: + all_batches.append(batch) + batch = [] + + # TODO deal with deleted data + if self.gpu <= 1: + if len(batch) > 0.5 * self.batch_size: + all_batches.append(batch) + elif len(all_batches): + all_batches[-1].extend(batch) + else: + all_batches.append(batch) + + return all_batches + + def transpose_batch(self, batch): + dial_batch = [] + turn_num = len(batch[0]) + for turn in range(turn_num): + turn_l = {} + for dial in batch: + this_turn = dial[turn] + for k in this_turn: + if k not in turn_l: + turn_l[k] = [] + turn_l[k].append(this_turn[k]) + dial_batch.append(turn_l) + return dial_batch + + def get_eval_data(self, set_name='dev'): + name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} + dial = name_to_set[set_name] + + if set_name not in self.set_stats: + self.set_stats[set_name] = {} + num_turns = 0 + num_dials = len(dial) + for d in dial: + num_turns += len(d) + + self.set_stats[set_name]['num_turns'] = num_turns + self.set_stats[set_name]['num_dials'] = num_dials + + return dial + + def get_nontranspose_data_iterator(self, all_batches): + for i, batch in enumerate(all_batches): + yield batch + + def get_data_iterator(self, all_batches): + for i, batch in enumerate(all_batches): + yield self.transpose_batch(batch) + + +class MultiWOZBPETextField(BPETextField): + + def __init__(self, config, **kwargs): + super(MultiWOZBPETextField, self).__init__(config) + + import spacy + try: + import en_core_web_sm + except ImportError: + logger.warn('Miss module en_core_web_sm!') + logger.warn('We will download en_core_web_sm automatically.') + try: + spacy.cli.download('en_core_web_sm') + except Exception as e: + logger.error(e) + raise ImportError( + 'Download en_core_web_sm error. ' + 'Please use \'python -m spacy download en_core_web_sm\' to download it by yourself!' + ) + self.nlp = spacy.load('en_core_web_sm') + + if config.do_train: + db_dir = kwargs['data_dir'] + else: + db_dir = kwargs['model_dir'] + self.db = MultiWozDB( + db_dir, { + 'attraction': 'db/attraction_db_processed.json', + 'hospital': 'db/hospital_db_processed.json', + 'hotel': 'db/hotel_db_processed.json', + 'police': 'db/police_db_processed.json', + 'restaurant': 'db/restaurant_db_processed.json', + 'taxi': 'db/taxi_db_processed.json', + 'train': 'db/train_db_processed.json', + }) + self._build_vocab(db_dir) + + special_tokens = [ + self.pad_token, self.bos_token, self.eos_token, self.unk_token + ] + special_tokens.extend(self.add_sepcial_tokens()) + self.tokenizer = Tokenizer( + vocab_path=os.path.join(kwargs['model_dir'], ModelFile.VOCAB_FILE), + special_tokens=special_tokens, + tokenizer_type=config.BPETextField.tokenizer_type) + self.understand_ids = self.tokenizer.convert_tokens_to_ids( + self.understand_tokens) + self.policy_ids = self.tokenizer.convert_tokens_to_ids( + self.policy_tokens) + + if config.do_train: + test_list = [ + line.strip().lower() for line in open( + os.path.join(kwargs['data_dir'], 'testListFile.json'), + 'r').readlines() + ] + dev_list = [ + line.strip().lower() for line in open( + os.path.join(kwargs['data_dir'], 'valListFile.json'), + 'r').readlines() + ] + + self.dev_files, self.test_files = {}, {} + for fn in test_list: + self.test_files[fn.replace('.json', '')] = 1 + for fn in dev_list: + self.dev_files[fn.replace('.json', '')] = 1 + + self._load_data(kwargs['data_dir']) + + return + + def get_ids(self, data: str): + result = [self.sos_u_id] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize( + self._get_convert_str(data))) + [self.eos_u_id] + return result + + def inverse_transpose_turn(self, turn_list): + """ + eval, one dialog at a time + """ + dialogs = {} + turn_num = len(turn_list) + dial_id = turn_list[0]['dial_id'] + dialogs[dial_id] = [] + for turn_idx in range(turn_num): + dial_turn = {} + turn = turn_list[turn_idx] + for key, value in turn.items(): + if key == 'dial_id': + continue + if key == 'pointer' and self.db is not None: + turn_domain = turn['turn_domain'][-1] + value = self.db.pointerBack(value, turn_domain) + dial_turn[key] = value + dialogs[dial_id].append(dial_turn) + return dialogs + + def inverse_transpose_batch(self, turn_batch_list): + """ + :param turn_batch_list: list of transpose dial batch + """ + dialogs = {} + total_turn_num = len(turn_batch_list) + # initialize + for idx_in_batch, dial_id in enumerate(turn_batch_list[0]['dial_id']): + dialogs[dial_id] = [] + for turn_n in range(total_turn_num): + dial_turn = {} + turn_batch = turn_batch_list[turn_n] + for key, v_list in turn_batch.items(): + if key == 'dial_id': + continue + value = v_list[idx_in_batch] + if key == 'pointer' and self.db is not None: + turn_domain = turn_batch['turn_domain'][idx_in_batch][ + -1] + value = self.db.pointerBack(value, turn_domain) + dial_turn[key] = value + dialogs[dial_id].append(dial_turn) + return dialogs + + def get_batches(self, set_name): + """ + compute dataset stats. + """ + global dia_count + log_str = '' + name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} + dial = name_to_set[set_name] + turn_bucket = self._bucket_by_turn(dial) + all_batches = [] + + if set_name not in self.set_stats: + self.set_stats[set_name] = {} + num_training_steps = 0 + num_turns = 0 + num_dials = 0 + + for k in turn_bucket: + if set_name != 'test' and k == 1 or k >= 17: + continue + batches = self._construct_mini_batch(turn_bucket[k]) + try: + log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( + k, len(turn_bucket[k]), len(batches), len(batches[-1])) + except Exception: + log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( + k, len(turn_bucket[k]), len(batches), 0.0) + + num_training_steps += k * len(batches) + num_turns += k * len(turn_bucket[k]) + num_dials += len(turn_bucket[k]) + all_batches += batches + log_str += 'total batch num: %d\n' % len(all_batches) + + self.set_stats[set_name][ + 'num_training_steps_per_epoch'] = num_training_steps # turn-level steps + self.set_stats[set_name]['num_turns'] = num_turns + self.set_stats[set_name]['num_dials'] = num_dials + + if set_name == 'train': + random.shuffle(all_batches) + return all_batches + + def add_sepcial_tokens(self): + """ + add special tokens to gpt tokenizer + serves a similar role of Vocab.construt() + make a dict of special tokens + """ + special_tokens = [] + prompt_tokens = self.understand_tokens + self.policy_tokens + special_tokens.extend( + ontology.get_special_tokens(other_tokens=prompt_tokens)) + + for word in ontology.all_domains + ['general']: + word = '[' + word + ']' + special_tokens.append(word) + for word in ontology.all_acts: + word = '[' + word + ']' + special_tokens.append(word) + for word in self.vocab._word2idx.keys(): + if word.startswith('[value_') and word.endswith(']'): + special_tokens.append(word) + + return special_tokens + + def _build_vocab(self, model_dir: str): + self.vocab = utils.MultiWOZVocab(3000) + vp = os.path.join('{}/vocab'.format(model_dir)) + self.vocab.load_vocab(vp) + return self.vocab.vocab_size + + def _load_data(self, data_dir, save_temp=True): + """ + load processed data and encode, or load already encoded data + """ + + def load_data_from_resource(data_resource): + data = json.loads( + open( + os.path.join(data_dir, data_resource), + 'r', + encoding='utf-8').read().lower()) + train, dev, test = [], [], [] + for fn, dial in data.items(): + if '.json' in fn: + fn = fn.replace('.json', '') + if self.dev_files.get(fn): + dev.append(self._get_encoded_data(fn, dial)) + elif self.test_files.get(fn): + test.append(self._get_encoded_data(fn, dial)) + else: + train.append(self._get_encoded_data(fn, dial)) + return train, dev, test + + data_processed = 'new_db_se_blank_encoded_domain.data.json' + data_resource = 'data_for_damd.json' + if save_temp: # save encoded data + # encoded: no sos, se_encoded: sos and eos + encoded_file = os.path.join(data_dir, data_processed) + + if os.path.exists(encoded_file): + logger.info( + 'Reading encoded data from {}'.format(encoded_file)) + self.data = json.loads( + open( + os.path.join(data_dir, data_resource), + 'r', + encoding='utf-8').read().lower()) + encoded_data = json.loads( + open(encoded_file, 'r', encoding='utf-8').read()) + self.train = encoded_data['train'] + self.dev = encoded_data['dev'] + self.test = encoded_data['test'] + else: + logger.info( + 'Encoding data now and save the encoded data in {}'.format( + encoded_file)) + # not exists, encode data and save + self.train, self.dev, self.test = load_data_from_resource( + data_resource) + # save encoded data + encoded_data = { + 'train': self.train, + 'dev': self.dev, + 'test': self.test + } + json.dump(encoded_data, open(encoded_file, 'w'), indent=2) + else: # directly read processed data and encode + self.train, self.dev, self.test = load_data_from_resource( + data_resource) + + random.seed(10) + random.shuffle(self.train) + logger.info('train size:{}, dev size:{}, test size:{}'.format( + len(self.train), len(self.dev), len(self.test))) + + def _get_convert_str(self, sent): + assert isinstance(sent, str) + return ' '.join([ + self.tokenizer.spec_convert_dict.get(tok, tok) + for tok in sent.split() + ]) + + def _get_encoded_data(self, fn, dial): + encoded_dial = [] + for idx, t in enumerate(dial['log']): # tokenize to list of ids + enc = {} + enc['dial_id'] = fn + + enc_info_list = [ + ('user', self.sos_u_id, 'user', self.eos_u_id), + ('usdx', self.sos_u_id, 'user', self.eos_u_id), + ('resp', self.sos_r_id, 'resp', self.eos_r_id), + ('bspn', self.sos_b_id, 'constraint', self.eos_b_id), + ('bsdx', self.sos_b_id, 'cons_delex', self.eos_b_id), + ('aspn', self.sos_a_id, 'sys_act', self.eos_a_id) + ] + for enc_key, start_token, item_key, end_token in enc_info_list: + enc[enc_key] = [ + start_token + ] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize( + self._get_convert_str(t[item_key]))) + [end_token] + + enc['turn_num'] = t['turn_num'] + + if idx > 0 and t['turn_domain'] == '[general]': + enc['dspn'] = encoded_dial[idx - 1]['dspn'] + enc['pointer'] = encoded_dial[idx - 1]['pointer'][:4] + [ + int(i) for i in t['pointer'].split(',') + ][-2:] + enc['turn_domain'] = encoded_dial[idx - 1]['turn_domain'] + enc['db'] = encoded_dial[idx - 1]['db'] + else: + if t['turn_domain'] == '[general]': + assert not t['constraint'], f'{fn}-{idx}' + enc['dspn'] = [ + self.sos_d_id + ] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize( + self._get_convert_str( + t['turn_domain']))) + [self.eos_d_id] + enc['pointer'] = [int(i) for i in t['pointer'].split(',')] + enc['turn_domain'] = t['turn_domain'].split() + db_pointer = self.bspan_to_DBpointer(t['constraint'], + t['turn_domain'].split()) + enc['db'] = [ + self.sos_db_id + ] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize( + self._get_convert_str(db_pointer))) + [self.eos_db_id] + + encoded_dial.append(enc) + return encoded_dial + + def bspan_to_DBpointer(self, bspan, turn_domain): + constraint_dict = self.bspan_to_constraint_dict(bspan) + matnums = self.db.get_match_num(constraint_dict) + match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] + match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom + match = matnums[match_dom] + + vector = self.db.addDBIndicator(match_dom, match) + return vector + + def bspan_to_constraint_dict(self, bspan, bspn_mode='bspn'): + """ + ['[hotel]', 'pricerange', 'cheap', 'type', 'hotel'] -> {'hotel': {'pricerange': 'cheap', 'type': 'hotel'}} + """ + bspan = bspan.split() if isinstance(bspan, str) else bspan + constraint_dict = {} + domain = None + conslen = len(bspan) + for idx, cons in enumerate(bspan): + cons = self.vocab.decode(cons) if type(cons) is not str else cons + if cons == '': + break + if '[' in cons: + if cons[1:-1] not in ontology.all_domains: + continue + domain = cons[1:-1] + elif cons in ontology.get_slot: + if domain is None: + continue + if cons == 'people': + # handle confusion of value name "people's portraits..." and slot people + try: + ns = bspan[idx + 1] + ns = self.vocab.decode(ns) if type( + ns) is not str else ns + if ns == "'s": + continue + except Exception: + continue + if not constraint_dict.get(domain): + constraint_dict[domain] = {} + if bspn_mode == 'bsdx': + constraint_dict[domain][cons] = 1 + continue + vidx = idx + 1 + if vidx == conslen: + break + vt_collect = [] + vt = bspan[vidx] + vt = self.vocab.decode(vt) if type(vt) is not str else vt + while vidx < conslen and vt != '' and '[' not in vt and vt not in ontology.get_slot: + vt_collect.append(vt) + vidx += 1 + if vidx == conslen: + break + vt = bspan[vidx] + vt = self.vocab.decode(vt) if type(vt) is not str else vt + if vt_collect: + constraint_dict[domain][cons] = ' '.join(vt_collect) + + return constraint_dict + + def convert_batch_turn(self, turn_batch, pv_batch, first_turn=False): + """ + convert the current and the last turn + concat [U_0,R_0,...,U_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t] + firts turn: [U_t, B_t, A_t, R_t] + try: [user, bspn, db, aspn, resp] + + """ + inputs = [] + if first_turn: + batch_zipped = zip(turn_batch['user'], turn_batch['bspn'], + turn_batch['db'], turn_batch['aspn'], + turn_batch['resp']) + for u, b, db, a, r in batch_zipped: + if self.use_true_curr_bspn: + src = [u + b + db] + tgt = a + r + else: + src = [u] + tgt = b + db + a + r + inputs.append({'src': src, 'tgt': tgt}) + pv = [src[-1], tgt] + pv_batch.append(pv) + else: + batch_zipped = zip(pv_batch, turn_batch['user'], + turn_batch['bspn'], turn_batch['db'], + turn_batch['aspn'], turn_batch['resp']) + for i, (pv, u, b, db, a, r) in enumerate(batch_zipped): + if self.use_true_curr_bspn: + src = pv + [u + b + db] + tgt = a + r + else: + src = pv + [u] + tgt = b + db + a + r + inputs.append({'src': src, 'tgt': tgt}) + pv = [src[-1], tgt] + pv_batch[i].extend(pv) + + return inputs, pv_batch + + def wrap_result_lm(self, result_dict, eos_syntax=None): + results = [] + eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax + sos_syntax = ontology.sos_tokens + # ground truth bs, as, ds.. generate response + field = [ + 'dial_id', 'turn_num', 'user', 'bspn_gen', 'bsdx', 'resp_gen', + 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn', 'pointer', + 'qspn_gen', 'qspn' + ] + + for dial_id, turns in result_dict.items(): + entry = {'dial_id': dial_id, 'trun_num': len(turns)} + for f in field[2:]: + entry[f] = '' # TODO ??? + results.append(entry) + for turn_idx, turn in enumerate(turns): + entry = {'dial_id': dial_id} + for key in field: + if key in ['dial_id']: + continue + v = turn.get(key, '') + if key == 'turn_domain': + v = ' '.join(v) + + if key in eos_syntax and v != '': + # remove eos tokens + v = self.tokenizer.decode(v) + v = v.split() + # remove eos/sos in span + if eos_syntax[key] in v: + v.remove(eos_syntax[key]) + if sos_syntax[key] in v: + v.remove(sos_syntax[key]) + v = ' '.join(v) + else: + pass # v = v + entry[key] = v + + results.append(entry) + + return results, field + + def convert_turn_eval(self, turn, pv_turn, first_turn=False): + """ + input: [all previous ubar, U_t, B_t, A_t] predict R_t + firts turn: [U_t, B_t, A_t] predict R_t + + regarding the context, all previous ubar is too slow, try the previous ubar + """ + inputs = {} + + context_list = [] + prompt_id = None + if self.use_true_curr_bspn: + if self.use_true_curr_aspn: # only predict resp + context_list = ['user', 'bspn', 'db', 'aspn'] + prompt_id = self.sos_r_id + else: # predicted aspn + context_list = ['user', 'bspn', 'db'] + prompt_id = self.sos_a_id + else: # predict bspn aspn resp. db are not predicted. this part tbd. + context_list = ['user'] + prompt_id = self.sos_b_id + + if first_turn: + context = [] + for c in context_list: + context += turn[c] + + inputs['src'] = [context] + inputs['labels'] = [context] + else: + context = [] + for c in context_list: + context += turn[c] + + if self.use_true_curr_bspn: + pv_context = pv_turn['labels'] + [ + pv_turn['aspn'] + pv_turn['resp'] + ] + else: + pv_info = pv_turn['bspn'] + pv_turn['db'] + pv_turn[ + 'aspn'] + pv_turn['resp'] + pv_context = pv_turn['labels'] + [pv_info] + + # prompt response, add sos_r + inputs['src'] = pv_context + [context] + + if self.use_all_previous_context: + inputs['labels'] = pv_context + [ + context + ] # use all previous ubar history + else: + inputs['labels'] = [context] # use previous turn + + return inputs, prompt_id + + def restore(self, resp, domain, constraint_dict, mat_ents): + restored = resp + + restored = restored.replace('[value_reference]', '53022') + restored = restored.replace('[value_car]', 'BMW') + + for d in domain: + constraint = constraint_dict.get(d, None) + if constraint: + replace_res_list = [('stay', '[value_stay]'), + ('day', '[value_day]'), + ('people', '[value_people]'), + ('time', '[value_time]'), + ('type', '[value_type]')] + for key, value_key in replace_res_list: + if key in constraint: + restored = restored.replace(value_key, constraint[key]) + + if d in mat_ents and len(mat_ents[d]) == 0: + for s in constraint: + if s == 'pricerange' and d in [ + 'hotel', 'restaurant' + ] and 'price]' in restored: + restored = restored.replace( + '[value_price]', constraint['pricerange']) + if s + ']' in restored: + restored = restored.replace( + '[value_%s]' % s, constraint[s]) + + if '[value_choice' in restored and mat_ents.get(d): + restored = restored.replace('[value_choice]', + str(len(mat_ents[d]))) + if '[value_choice' in restored: + restored = restored.replace('[value_choice]', '3') + + try: + ent = mat_ents.get(domain[-1], []) + if ent: + ent = ent[0] + + for t in restored.split(): + if '[value' in t: + slot = t[7:-1] + if ent.get(slot): + if domain[-1] == 'hotel' and slot == 'price': + slot = 'pricerange' + restored = restored.replace(t, ent[slot]) + elif slot == 'price': + if ent.get('pricerange'): + restored = restored.replace( + t, ent['pricerange']) + else: + logger.info(restored, domain) + except Exception: + logger.error(resp) + logger.error(restored) + quit() + + restored = restored.replace('[value_phone]', '62781111') + restored = restored.replace('[value_postcode]', 'CG9566') + restored = restored.replace('[value_address]', 'Parkside, Cambridge') + + return restored diff --git a/modelscope/preprocessors/nlp/space/fields/intent_field.py b/modelscope/preprocessors/nlp/space/fields/intent_field.py new file mode 100644 index 00000000..29ea915e --- /dev/null +++ b/modelscope/preprocessors/nlp/space/fields/intent_field.py @@ -0,0 +1,1082 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import glob +import multiprocessing +import os +import random +import re +import time +from collections import defaultdict +from itertools import chain + +import json +import numpy as np +from tqdm import tqdm + +from modelscope.preprocessors.nlp.space.tokenizer import Tokenizer +from modelscope.utils.constant import ModelFile +from modelscope.utils.nlp.space import ontology +from modelscope.utils.nlp.space.scores import hierarchical_set_score +from modelscope.utils.nlp.space.utils import list2np + + +class BPETextField(object): + + pad_token = '[PAD]' + bos_token = '[BOS]' + eos_token = '[EOS]' + unk_token = '[UNK]' + mask_token = '[MASK]' + sos_u_token = '' + eos_u_token = '' + sos_b_token = '' + eos_b_token = '' + sos_db_token = '' + eos_db_token = '' + sos_a_token = '' + eos_a_token = '' + sos_r_token = '' + eos_r_token = '' + + def __init__(self, model_dir, config): + self.score_matrixs = {} + self.prompt_num_for_understand = config.BPETextField.prompt_num_for_understand + self.prompt_num_for_policy = config.BPETextField.prompt_num_for_policy + self.understand_tokens = ontology.get_understand_tokens( + self.prompt_num_for_understand) + self.policy_tokens = ontology.get_policy_tokens( + self.prompt_num_for_policy) + special_tokens = [ + self.pad_token, self.bos_token, self.eos_token, self.unk_token + ] + special_tokens.extend(self.add_sepcial_tokens()) + self.tokenizer = Tokenizer( + vocab_path=os.path.join(model_dir, ModelFile.VOCAB_FILE), + special_tokens=special_tokens, + tokenizer_type=config.BPETextField.tokenizer_type) + self.understand_ids = self.numericalize(self.understand_tokens) + self.policy_ids = self.numericalize(self.policy_tokens) + + self.tokenizer_type = config.BPETextField.tokenizer_type + self.filtered = config.BPETextField.filtered + self.max_len = config.BPETextField.max_len + self.min_utt_len = config.BPETextField.min_utt_len + self.max_utt_len = config.BPETextField.max_utt_len + self.min_ctx_turn = config.BPETextField.min_ctx_turn + self.max_ctx_turn = config.BPETextField.max_ctx_turn + self.policy = config.BPETextField.policy + self.generation = config.BPETextField.generation + self.with_mlm = config.Dataset.with_mlm + self.with_query_bow = config.BPETextField.with_query_bow + self.with_contrastive = config.Dataset.with_contrastive + self.num_process = config.Dataset.num_process + self.dynamic_score = config.Dataset.dynamic_score + self.abandon_label = config.Dataset.abandon_label + self.trigger_role = config.Dataset.trigger_role + self.trigger_data = config.Dataset.trigger_data.split( + ',') if config.Dataset.trigger_data else [] + + # data_paths = list(os.path.dirname(c) for c in sorted( + # glob.glob(hparams.data_dir + '/**/' + f'train.{hparams.tokenizer_type}.jsonl', recursive=True))) + # self.data_paths = self.filter_data_path(data_paths=data_paths) + # self.labeled_data_paths = [data_path for data_path in self.data_paths if 'UniDA' in data_path] + # self.unlabeled_data_paths = [data_path for data_path in self.data_paths if 'UnDial' in data_path] + # assert len(self.unlabeled_data_paths) + len(self.labeled_data_paths) == len(self.data_paths) + # assert len(self.labeled_data_paths) or len(self.unlabeled_data_paths), 'No dataset is loaded' + + @property + def vocab_size(self): + return self.tokenizer.vocab_size + + @property + def num_specials(self): + return len(self.tokenizer.special_tokens) + + @property + def pad_id(self): + return self.tokenizer.convert_tokens_to_ids([self.pad_token])[0] + + @property + def bos_id(self): + return self.tokenizer.convert_tokens_to_ids([self.bos_token])[0] + + @property + def eos_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_token])[0] + + @property + def unk_id(self): + return self.tokenizer.convert_tokens_to_ids([self.unk_token])[0] + + @property + def mask_id(self): + return self.tokenizer.convert_tokens_to_ids([self.mask_token])[0] + + @property + def sos_u_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_u_token])[0] + + @property + def eos_u_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_u_token])[0] + + @property + def sos_b_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_b_token])[0] + + @property + def eos_b_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_b_token])[0] + + @property + def sos_db_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_db_token])[0] + + @property + def eos_db_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_db_token])[0] + + @property + def sos_a_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_a_token])[0] + + @property + def eos_a_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_a_token])[0] + + @property + def sos_r_id(self): + return self.tokenizer.convert_tokens_to_ids([self.sos_r_token])[0] + + @property + def eos_r_id(self): + return self.tokenizer.convert_tokens_to_ids([self.eos_r_token])[0] + + @property + def bot_id(self): + return 0 + + @property + def user_id(self): + return 1 + + def add_sepcial_tokens(self): + prompt_tokens = self.understand_tokens + self.policy_tokens + return ontology.get_special_tokens(other_tokens=prompt_tokens) + + def filter_data_path(self, data_paths): + if self.trigger_data: + filtered_data_paths = [] + for data_path in data_paths: + for data_name in self.trigger_data: + if data_path.endswith(f'/{data_name}'): + filtered_data_paths.append(data_path) + break + else: + filtered_data_paths = data_paths + return filtered_data_paths + + def load_score_matrix(self, data_type, data_iter=None): + """ + load score matrix for all labeled datasets + """ + for data_path in self.labeled_data_paths: + file_index = os.path.join( + data_path, f'{data_type}.{self.tokenizer_type}.jsonl') + file = os.path.join(data_path, f'{data_type}.Score.npy') + if self.dynamic_score: + score_matrix = {} + print(f"Created 1 score cache dict for data in '{file_index}'") + else: + # TODO add post score matrix + assert os.path.exists(file), f"{file} isn't exist" + print(f"Loading 1 score matrix from '{file}' ...") + fp = np.memmap(file, dtype='float32', mode='r') + assert len(fp.shape) == 1 + num = int(np.sqrt(fp.shape[0])) + score_matrix = fp.reshape(num, num) + print(f"Loaded 1 score matrix for data in '{file_index}'") + self.score_matrixs[file_index] = score_matrix + + def random_word(self, chars): + output_label = [] + output_chars = [] + + for i, char in enumerate(chars): + # TODO delete this part to learn special tokens + if char in [ + self.sos_u_id, self.eos_u_id, self.sos_r_id, self.eos_r_id + ]: + output_chars.append(char) + output_label.append(self.pad_id) + continue + + prob = random.random() + if prob < 0.15: + prob /= 0.15 + + # 80% randomly change token to mask token + if prob < 0.8: + output_chars.append(self.mask_id) + + # 10% randomly change token to random token + elif prob < 0.9: + tmp = random.randint(1, self.vocab_size - 1) + output_chars.append(tmp) # start from 1, to exclude pad_id + + # 10% randomly change token to current token + else: + output_chars.append(char) + + output_label.append(char) + + else: + output_chars.append(char) + output_label.append(self.pad_id) + + return output_chars, output_label + + def create_masked_lm_predictions(self, sample): + src = sample['src'] + src_span_mask = sample['src_span_mask'] + mlm_inputs = [] + mlm_labels = [] + for chars, chars_span_mask in zip(src, src_span_mask): + if sum(chars_span_mask): + mlm_input, mlm_label = [], [] + for char, char_mask in zip(chars, chars_span_mask): + if char_mask: + mlm_input.append(self.mask_id) + mlm_label.append(char) + else: + mlm_input.append(char) + mlm_label.append(self.pad_id) + else: + mlm_input, mlm_label = self.random_word(chars) + mlm_inputs.append(mlm_input) + mlm_labels.append(mlm_label) + + sample['mlm_inputs'] = mlm_inputs + sample['mlm_labels'] = mlm_labels + return sample + + def create_span_masked_lm_predictions(self, sample): + src = sample['src'] + src_span_mask = sample['src_span_mask'] + mlm_inputs = [] + mlm_labels = [] + for chars, chars_span_mask in zip(src, src_span_mask): + mlm_input, mlm_label = [], [] + for char, char_mask in zip(chars, chars_span_mask): + if char_mask: + mlm_input.append(self.mask_id) + mlm_label.append(char) + else: + mlm_input.append(char) + mlm_label.append(self.pad_id) + mlm_inputs.append(mlm_input) + mlm_labels.append(mlm_label) + + sample['mlm_inputs'] = mlm_inputs + sample['mlm_labels'] = mlm_labels + return sample + + def create_token_masked_lm_predictions(self, sample): + mlm_inputs = sample['mlm_inputs'] + mlm_labels = sample['mlm_labels'] + + for i, span_mlm_label in enumerate(mlm_labels): + if not sum(span_mlm_label): + mlm_input, mlm_label = self.random_word(mlm_inputs[i]) + mlm_inputs[i] = mlm_input + mlm_labels[i] = mlm_label + + return sample + + def numericalize(self, tokens): + """ + here only "convert_tokens_to_ids", + which need be tokenized into tokens(sub-words) by "tokenizer.tokenize" before + """ + assert isinstance(tokens, list) + if len(tokens) == 0: + return [] + element = tokens[0] + if isinstance(element, list): + return [self.numericalize(s) for s in tokens] + else: + return self.tokenizer.convert_tokens_to_ids(tokens) + + def denumericalize(self, numbers): + """ + here first "convert_ids_to_tokens", then combine sub-words into origin words + """ + assert isinstance(numbers, list) + if len(numbers) == 0: + return [] + element = numbers[0] + if isinstance(element, list): + return [self.denumericalize(x) for x in numbers] + else: + return self.tokenizer.decode( + numbers, + ignore_tokens=[self.bos_token, self.eos_token, self.pad_token]) + + def save_examples(self, examples, filename): + start = time.time() + if filename.endswith('npy'): + print(f"Saving 1 object to '{filename}' ...") + assert len( + examples.shape) == 2 and examples.shape[0] == examples.shape[1] + num = examples.shape[0] + fp = np.memmap( + filename, dtype='float32', mode='w+', shape=(num, num)) + fp[:] = examples[:] + fp.flush() + elapsed = time.time() - start + print(f'Saved 1 object (elapsed {elapsed:.2f}s)') + elif filename.endswith('jsonl'): + print(f"Saving examples to '{filename}' ...") + with open(filename, 'w', encoding='utf-8') as fp: + for ex in examples: + fp.write(json.dumps(ex) + '\n') + elapsed = time.time() - start + print(f'Saved {len(examples)} examples (elapsed {elapsed:.2f}s)') + else: + print(f"Saving examples to '{filename}' ...") + raise ValueError(f'Unsport file format: {filename}') + + def load_examples(self, filename): + start = time.time() + if filename.endswith('npy'): + print(f"Loading 1 object from '{filename}' ...") + fp = np.memmap(filename, dtype='float32', mode='r') + assert len(fp.shape) == 1 + num = int(np.sqrt(fp.shape[0])) + examples = fp.reshape(num, num) + elapsed = time.time() - start + print(f'Loaded 1 object (elapsed {elapsed:.2f}s)') + else: + print(f"Loading examples from '{filename}' ...") + with open(filename, 'r', encoding='utf-8') as fp: + examples = list(map(lambda s: json.loads(s.strip()), fp)) + elapsed = time.time() - start + print(f'Loaded {len(examples)} examples (elapsed {elapsed:.2f}s)') + return examples + + def utt_filter_pred(self, utt): + return self.min_utt_len <= len(utt) \ + and (not self.filtered or len(utt) <= self.max_utt_len) + + def utts_filter_pred(self, utts): + return self.min_ctx_turn <= len(utts) \ + and (not self.filtered or len(utts) <= self.max_ctx_turn) + + def get_token_pos(self, tok_list, value_label): + find_pos = [] + found = False + label_list = [ + item + for item in map(str.strip, re.split('(\\W+)', value_label.lower())) + if len(item) > 0 + ] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + def build_score_matrix(self, examples): + """ + build symmetric score matrix + """ + assert self.num_process == 1 + print('Building score matrix from examples ...') + num = len(examples) + score_matrix = np.eye( + num, num, dtype='float32' + ) # in case of empty label of self, resulting in score 0. + + for i in tqdm(range(num)): + for j in range(i): + # TODO change the score method + score = hierarchical_set_score( + frame1=examples[i]['label'], frame2=examples[j]['label']) + score_matrix[i][j] = score + score_matrix[j][i] = score + + print('Built score matrix') + return score_matrix + + def build_score_matrix_on_the_fly(self, + ids, + labels, + data_file, + is_post=False): + """ + build symmetric score matrix on the fly + @is_post: True for resp label of sample i and j, False for query label of sample i and j + """ + num = len(labels) + tag = 'r' if is_post else 'q' + assert len(ids) == len(labels) + score_matrix = np.eye( + num, num, dtype='float32' + ) # in case of empty label of self, resulting in score 0. + + for i in range(num): + for j in range(i): + score = self.score_matrixs[data_file].get( + f'{ids[i]}-{ids[j]}-{tag}', None) + if score is None: + score = self.score_matrixs[data_file].get( + f'{ids[j]}-{ids[i]}-{tag}', None) + if score is None: + # TODO change the score method + score = hierarchical_set_score( + frame1=labels[i], frame2=labels[j]) + self.score_matrixs[data_file][ + f'{ids[i]}-{ids[j]}-{tag}'] = score + score_matrix[i][j] = score + score_matrix[j][i] = score + + return score_matrix + + def build_score_matrix_func(self, examples, start, exclusive_end): + """ + build sub score matrix + """ + num = len(examples) + process_id = os.getpid() + description = f'PID: {process_id} Start: {start} End: {exclusive_end}' + print( + f'PID-{process_id}: Building {start} to {exclusive_end} lines score matrix from examples ...' + ) + score_matrix = np.zeros((exclusive_end - start, num), dtype='float32') + + for abs_i, i in enumerate( + tqdm(range(start, exclusive_end), desc=description)): + for j in range(num): + # TODO change the score method + score = hierarchical_set_score( + frame1=examples[i]['label'], frame2=examples[j]['label']) + score_matrix[abs_i][j] = score + + print( + f'PID-{process_id}: Built {start} to {exclusive_end} lines score matrix' + ) + return {'start': start, 'score_matrix': score_matrix} + + def build_score_matrix_multiprocessing(self, examples): + """ + build score matrix + """ + assert self.num_process >= 2 and multiprocessing.cpu_count() >= 2 + print('Building score matrix from examples ...') + results = [] + num = len(examples) + sub_num, res_num = num // self.num_process, num % self.num_process + patches = [sub_num] * (self.num_process - 1) + [sub_num + res_num] + + start = 0 + pool = multiprocessing.Pool(processes=self.num_process) + for patch in patches: + exclusive_end = start + patch + results.append( + pool.apply_async(self.build_score_matrix_func, + (examples, start, exclusive_end))) + start = exclusive_end + pool.close() + pool.join() + + sub_score_matrixs = [result.get() for result in results] + sub_score_matrixs = sorted( + sub_score_matrixs, key=lambda sub: sub['start']) + sub_score_matrixs = [ + sub_score_matrix['score_matrix'] + for sub_score_matrix in sub_score_matrixs + ] + score_matrix = np.concatenate(sub_score_matrixs, axis=0) + assert score_matrix.shape == (num, num) + np.fill_diagonal( + score_matrix, + 1.) # in case of empty label of self, resulting in score 0. + + print('Built score matrix') + return score_matrix + + def extract_span_texts(self, text, label): + span_texts = [] + for domain, frame in label.items(): + for act, slot_values in frame.items(): + for slot, values in slot_values.items(): + for value in values: + if value['span']: + span_texts.append( + text[value['span'][0]:value['span'][1]]) + elif str(value['value']).strip().lower() in text.strip( + ).lower(): + span_texts.append(str(value['value'])) + return span_texts + + def fix_label(self, label): + for domain, frame in label.items(): + if not frame: + return {} + for act, slot_values in frame.items(): + if act == 'DEFAULT_INTENT' and not slot_values: + return {} + return label + + def build_examples_multi_turn(self, data_file, data_type='train'): + print(f"Reading examples from '{data_file}' ...") + examples = [] + ignored = 0 + + with open(data_file, 'r', encoding='utf-8') as f: + input_data = json.load(f) + for dialog_id in tqdm(input_data): + turns = input_data[dialog_id]['turns'] + history, history_role, history_span_mask, history_label = [], [], [], [] + for t, turn in enumerate(turns): + label = turn['label'] + role = turn['role'] + text = turn['text'] + utterance, span_mask = [], [] + + token_list = [ + tok for tok in map(str.strip, + re.split('(\\W+)', text.lower())) + if len(tok) > 0 + ] + span_list = np.zeros(len(token_list), dtype=np.int32) + span_texts = self.extract_span_texts( + text=text, label=label) + + for span_text in span_texts: + found, find_pos = self.get_token_pos( + tok_list=token_list, value_label=span_text) + if found: + for start, exclusive_end in find_pos: + span_list[start:exclusive_end] = 1 + + token_list = [ + self.tokenizer.tokenize(token) for token in token_list + ] + span_list = [[tag] * len(token_list[i]) + for i, tag in enumerate(span_list)] + for sub_tokens in token_list: + utterance.extend(sub_tokens) + for sub_spans in span_list: + span_mask.extend(sub_spans) + assert len(utterance) == len(span_mask) + + history.append(utterance) + history_role.append(role) + history_span_mask.append(span_mask) + history_label.append(self.fix_label(label)) + + tmp = self.utts_filter_pred(history[:-1]) and all( + map(self.utt_filter_pred, history)) + if ( + tmp or data_type == 'test' + ) and role in self.trigger_role and t: # TODO consider test + src = [ + s[-self.max_utt_len:] + for s in history[:-1][-self.max_ctx_turn:] + ] + src_span_mask = [ + s[-self.max_utt_len:] for s in + history_span_mask[:-1][-self.max_ctx_turn:] + ] + roles = [ + role + for role in history_role[:-1][-self.max_ctx_turn:] + ] + + new_src = [] + for i, s in enumerate(src): + if roles[i] == 'user': + user_or_sys = [self.eos_u_id] + else: + user_or_sys = [self.sos_r_id] + tmp = [self.sos_u_id + ] + self.numericalize(s) + user_or_sys + tmp = tmp + self.numericalize(s) + [self.eos_r_id] + new_src.append(tmp) + + src_span_mask = [[0] + list(map(int, s)) + [0] + for s in src_span_mask] + + tgt = [self.sos_r_id] + self.numericalize( + history[-1]) + [self.eos_r_id] + if data_type != 'test': + tgt = tgt[:self.max_utt_len + 2] + + ex = { + 'dialog_id': dialog_id, + 'turn_id': turn['turn_id'], + 'src': new_src, + 'src_span_mask': src_span_mask, + 'tgt': tgt, + 'query_label': history_label[-2], + 'resp_label': history_label[-1], + 'extra_info': turn.get('extra_info', '') + } + examples.append(ex) + else: + ignored += 1 + + # add span mlm inputs and span mlm labels in advance + if self.with_mlm: + examples = [ + self.create_span_masked_lm_predictions(example) + for example in examples + ] + + # add absolute id of the dataset for indexing scores in its score matrix + for i, example in enumerate(examples): + example['id'] = i + + print( + f'Built {len(examples)} {data_type.upper()} examples ({ignored} filtered)' + ) + return examples + + def preprocessor(self, text_list): + role = 'user' + examples = [] + + for text in text_list: + history, history_role, history_span_mask = [], [], [] + utterance, span_mask = [], [] + token_list = [ + tok for tok in map(str.strip, re.split('(\\W+)', text.lower())) + if len(tok) > 0 + ] + span_list = np.zeros(len(token_list), dtype=np.int32) + token_list = [ + self.tokenizer.tokenize(token) for token in token_list + ] + span_list = [[tag] * len(token_list[i]) + for i, tag in enumerate(span_list)] + + for sub_tokens in token_list: + utterance.extend(sub_tokens) + for sub_spans in span_list: + span_mask.extend(sub_spans) + assert len(utterance) == len(span_mask) + + history.append(utterance) + history_role.append(role) + history_span_mask.append(span_mask) + + src = [s[-self.max_utt_len:] for s in history[-self.max_ctx_turn:]] + src_span_mask = [ + s[-self.max_utt_len:] + for s in history_span_mask[-self.max_ctx_turn:] + ] + roles = [role for role in history_role[-self.max_ctx_turn:]] + + new_src = [] + for i, s in enumerate(src): + if roles[i] == 'user': + user_or_sys = [self.eos_u_id] + else: + user_or_sys = [self.sos_r_id] + tmp = [self.sos_u_id] + self.numericalize(s) + user_or_sys + tmp = tmp + self.numericalize(s) + [self.eos_r_id] + new_src.append(tmp) + + src_span_mask = [[0] + list(map(int, s)) + [0] + for s in src_span_mask] + + ex = { + 'dialog_id': 'inference', + 'turn_id': 0, + 'role': role, + 'src': new_src, + 'src_span_mask': src_span_mask, + 'query_label': { + 'DEFAULT_DOMAIN': { + 'card_arrival': {} + } + }, + 'extra_info': { + 'intent_label': -1 + } + } + examples.append(ex) + # add span mlm inputs and span mlm labels in advance + if self.with_mlm: + examples = [ + self.create_span_masked_lm_predictions(example) + for example in examples + ] + + # add absolute id of the dataset for indexing scores in its score matrix + for i, example in enumerate(examples): + example['id'] = i + + return examples + + def build_examples_single_turn(self, data_file, data_type='train'): + print(f"Reading examples from '{data_file}' ...") + examples = [] + ignored = 0 + + with open(data_file, 'r', encoding='utf-8') as f: + input_data = json.load(f) + for dialog_id in tqdm(input_data): + turns = input_data[dialog_id]['turns'] + history, history_role, history_span_mask = [], [], [] + for turn in turns: + label = turn['label'] + role = turn['role'] + text = turn['text'] + utterance, span_mask = [], [] + + token_list = [ + tok for tok in map(str.strip, + re.split('(\\W+)', text.lower())) + if len(tok) > 0 + ] + span_list = np.zeros(len(token_list), dtype=np.int32) + span_texts = self.extract_span_texts( + text=text, label=label) + + for span_text in span_texts: + found, find_pos = self.get_token_pos( + tok_list=token_list, value_label=span_text) + if found: + for start, exclusive_end in find_pos: + span_list[start:exclusive_end] = 1 + + token_list = [ + self.tokenizer.tokenize(token) for token in token_list + ] + span_list = [[tag] * len(token_list[i]) + for i, tag in enumerate(span_list)] + for sub_tokens in token_list: + utterance.extend(sub_tokens) + for sub_spans in span_list: + span_mask.extend(sub_spans) + assert len(utterance) == len(span_mask) + + history.append(utterance) + history_role.append(role) + history_span_mask.append(span_mask) + + tmp = self.utts_filter_pred(history) and all( + map(self.utt_filter_pred, history)) + tmp = tmp or data_type == 'test' + if tmp and role in self.trigger_role: # TODO consider test + src = [ + s[-self.max_utt_len:] + for s in history[-self.max_ctx_turn:] + ] + src_span_mask = [ + s[-self.max_utt_len:] + for s in history_span_mask[-self.max_ctx_turn:] + ] + roles = [ + role for role in history_role[-self.max_ctx_turn:] + ] + new_src = [] + for i, s in enumerate(src): + if roles[i] == 'user': + user_or_sys = [self.eos_u_id] + else: + user_or_sys = [self.sos_r_id] + tmp = [self.sos_u_id + ] + self.numericalize(s) + user_or_sys + new_src.append(tmp) + + src_span_mask = [[0] + list(map(int, s)) + [0] + for s in src_span_mask] + + ex = { + 'dialog_id': dialog_id, + 'turn_id': turn['turn_id'], + 'role': role, + 'src': new_src, + 'src_span_mask': src_span_mask, + 'query_label': self.fix_label(label), + 'extra_info': turn.get('extra_info', '') + } + examples.append(ex) + else: + ignored += 1 + + # add span mlm inputs and span mlm labels in advance + if self.with_mlm: + examples = [ + self.create_span_masked_lm_predictions(example) + for example in examples + ] + + # add absolute id of the dataset for indexing scores in its score matrix + for i, example in enumerate(examples): + example['id'] = i + + print( + f'Built {len(examples)} {data_type.upper()} examples ({ignored} filtered)' + ) + return examples + + def collate_fn_multi_turn(self, samples): + batch_size = len(samples) + batch = {} + + src = [sp['src'] for sp in samples] + query_token, src_token, src_pos, src_turn, src_role = [], [], [], [], [] + for utts in src: + query_token.append(utts[-1]) + utt_lens = [len(utt) for utt in utts] + + # Token ids + src_token.append(list(chain(*utts))[-self.max_len:]) + + # Position ids + pos = [list(range(utt_len)) for utt_len in utt_lens] + src_pos.append(list(chain(*pos))[-self.max_len:]) + + # Turn ids + turn = [[len(utts) - i] * l for i, l in enumerate(utt_lens)] + src_turn.append(list(chain(*turn))[-self.max_len:]) + + # Role ids + role = [ + [self.bot_id if (len(utts) - i) % 2 == 0 else self.user_id] * l + for i, l in enumerate(utt_lens) + ] + src_role.append(list(chain(*role))[-self.max_len:]) + + src_token = list2np(src_token, padding=self.pad_id) + src_pos = list2np(src_pos, padding=self.pad_id) + src_turn = list2np(src_turn, padding=self.pad_id) + src_role = list2np(src_role, padding=self.pad_id) + batch['src_token'] = src_token + batch['src_pos'] = src_pos + batch['src_type'] = src_role + batch['src_turn'] = src_turn + batch['src_mask'] = (src_token != self.pad_id).astype('int64') + + if self.with_query_bow: + query_token = list2np(query_token, padding=self.pad_id) + batch['query_token'] = query_token + batch['query_mask'] = (query_token != self.pad_id).astype('int64') + + if self.with_mlm: + mlm_token, mlm_label = [], [] + raw_mlm_input = [sp['mlm_inputs'] for sp in samples] + raw_mlm_label = [sp['mlm_labels'] for sp in samples] + for inputs in raw_mlm_input: + mlm_token.append(list(chain(*inputs))[-self.max_len:]) + for labels in raw_mlm_label: + mlm_label.append(list(chain(*labels))[-self.max_len:]) + + mlm_token = list2np(mlm_token, padding=self.pad_id) + mlm_label = list2np(mlm_label, padding=self.pad_id) + batch['mlm_token'] = mlm_token + batch['mlm_label'] = mlm_label + batch['mlm_mask'] = (mlm_label != self.pad_id).astype('int64') + + if self.dynamic_score and self.with_contrastive and not self.abandon_label: + query_labels = [sp['query_label'] for sp in samples] + batch['query_labels'] = query_labels + if self.trigger_role == 'system': + resp_labels = [sp['resp_label'] for sp in samples] + batch['resp_labels'] = resp_labels + batch['label_ids'] = np.arange( + batch_size) # to identify labels for each GPU when multi-gpu + + if self.understand_ids: + understand = [self.understand_ids for _ in samples] + understand_token = np.array(understand).astype('int64') + batch['understand_token'] = understand_token + batch['understand_mask'] = \ + (understand_token != self.pad_id).astype('int64') + + if self.policy_ids and self.policy: + policy = [self.policy_ids for _ in samples] + policy_token = np.array(policy).astype('int64') + batch['policy_token'] = policy_token + batch['policy_mask'] = \ + (policy_token != self.pad_id).astype('int64') + + if 'tgt' in samples[0]: + tgt = [sp['tgt'] for sp in samples] + + # Token ids & Label ids + tgt_token = list2np(tgt, padding=self.pad_id) + + # Position ids + tgt_pos = np.zeros_like(tgt_token) + tgt_pos[:] = np.arange(tgt_token.shape[1], dtype=tgt_token.dtype) + + # Turn ids + tgt_turn = np.zeros_like(tgt_token) + + # Role ids + tgt_role = np.full_like(tgt_token, self.bot_id) + + batch['tgt_token'] = tgt_token + batch['tgt_pos'] = tgt_pos + batch['tgt_type'] = tgt_role + batch['tgt_turn'] = tgt_turn + batch['tgt_mask'] = (tgt_token != self.pad_id).astype('int64') + + if 'id' in samples[0]: + ids = [sp['id'] for sp in samples] + ids = np.array(ids).astype('int64') + batch['ids'] = ids + + return batch, batch_size + + +class IntentBPETextField(BPETextField): + + def __init__(self, model_dir, config): + super(IntentBPETextField, self).__init__(model_dir, config) + + def retrieve_examples(self, + dataset, + labels, + inds, + task, + num=None, + cache=None): + assert task == 'intent', 'Example-driven may only be used with intent prediction' + if num is None and labels is not None: + num = len(labels) * 2 + + # Populate cache + if cache is None: + cache = defaultdict(list) + for i, example in enumerate(dataset): + assert i == example['id'] + cache[example['extra_info']['intent_label']].append(i) + + # One example for each label + example_inds = [] + for lable in set(labels.tolist()): + if lable == -1: + continue + + ind = random.choice(cache[l]) + retries = 0 + while ind in inds.tolist() or type(ind) is not int: + ind = random.choice(cache[l]) + retries += 1 + if retries > len(dataset): + break + + example_inds.append(ind) + + # Sample randomly until we hit batch size + while len(example_inds) < min(len(dataset), num): + ind = random.randint(0, len(dataset) - 1) + if ind not in example_inds and ind not in inds.tolist(): + example_inds.append(ind) + + # Create examples + example_batch = {} + examples = [dataset[i] for i in example_inds] + examples, _ = self.collate_fn_multi_turn(examples) + example_batch['example_src_token'] = examples['src_token'] + example_batch['example_src_pos'] = examples['src_pos'] + example_batch['example_src_type'] = examples['src_type'] + example_batch['example_src_turn'] = examples['src_turn'] + example_batch['example_src_mask'] = examples['src_mask'] + example_batch['example_tgt_token'] = examples['tgt_token'] + example_batch['example_tgt_mask'] = examples['tgt_mask'] + example_batch['example_intent'] = examples['intent_label'] + + return example_batch + + def collate_fn_multi_turn(self, samples): + batch_size = len(samples) + batch = {} + + cur_roles = [sp['role'] for sp in samples] + src = [sp['src'] for sp in samples] + src_token, src_pos, src_turn, src_role = [], [], [], [] + for utts, cur_role in zip(src, cur_roles): + utt_lens = [len(utt) for utt in utts] + + # Token ids + src_token.append(list(chain(*utts))[-self.max_len:]) + + # Position ids + pos = [list(range(utt_len)) for utt_len in utt_lens] + src_pos.append(list(chain(*pos))[-self.max_len:]) + + # Turn ids + turn = [[len(utts) - i] * l for i, l in enumerate(utt_lens)] + src_turn.append(list(chain(*turn))[-self.max_len:]) + + # Role ids + if cur_role == 'user': + role = [[ + self.bot_id if (len(utts) - i) % 2 == 0 else self.user_id + ] * l for i, l in enumerate(utt_lens)] + else: + role = [[ + self.user_id if (len(utts) - i) % 2 == 0 else self.bot_id + ] * l for i, l in enumerate(utt_lens)] + src_role.append(list(chain(*role))[-self.max_len:]) + + src_token = list2np(src_token, padding=self.pad_id) + src_pos = list2np(src_pos, padding=self.pad_id) + src_turn = list2np(src_turn, padding=self.pad_id) + src_role = list2np(src_role, padding=self.pad_id) + batch['src_token'] = src_token + batch['src_pos'] = src_pos + batch['src_type'] = src_role + batch['src_turn'] = src_turn + batch['src_mask'] = (src_token != self.pad_id).astype( + 'int64') # input mask + + if self.with_mlm: + mlm_token, mlm_label = [], [] + raw_mlm_input = [sp['mlm_inputs'] for sp in samples] + raw_mlm_label = [sp['mlm_labels'] for sp in samples] + for inputs in raw_mlm_input: + mlm_token.append(list(chain(*inputs))[-self.max_len:]) + for labels in raw_mlm_label: + mlm_label.append(list(chain(*labels))[-self.max_len:]) + + mlm_token = list2np(mlm_token, padding=self.pad_id) + mlm_label = list2np(mlm_label, padding=self.pad_id) + batch['mlm_token'] = mlm_token + batch['mlm_label'] = mlm_label + batch['mlm_mask'] = (mlm_label != self.pad_id).astype( + 'int64') # label mask + + if self.understand_ids: + tgt = [self.understand_ids for _ in samples] + tgt_token = np.array(tgt).astype('int64') + batch['tgt_token'] = tgt_token + batch['tgt_mask'] = (tgt_token != self.pad_id).astype( + 'int64') # input mask + + if 'id' in samples[0]: + ids = [sp['id'] for sp in samples] + ids = np.array(ids).astype('int64') + batch['ids'] = ids + + if self.dynamic_score and self.with_contrastive: + query_labels = [sp['query_label'] for sp in samples] + batch['query_labels'] = query_labels + batch['label_ids'] = np.arange(batch_size) + + if 'intent_label' in samples[0]['extra_info']: + intent_label = [ + sample['extra_info']['intent_label'] for sample in samples + ] + intent_label = np.array(intent_label).astype('int64') + batch['intent_label'] = intent_label + + return batch, batch_size diff --git a/modelscope/preprocessors/nlp/space/lazy_dataset.py b/modelscope/preprocessors/nlp/space/lazy_dataset.py new file mode 100644 index 00000000..536d9341 --- /dev/null +++ b/modelscope/preprocessors/nlp/space/lazy_dataset.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import json + + +class LazyDataset(object): + """ + Lazy load dataset from disk. + + Each line of data file is a preprocessed example. + """ + + def __init__(self, data_file, reader, transform=lambda s: json.loads(s)): + """ + Initialize lazy dataset. + + By default, loading .jsonl format. + + :param data_file + :type str + + :param transform + :type callable + """ + self.data_file = data_file + self.transform = transform + self.reader = reader + self.offsets = [0] + with open(data_file, 'r', encoding='utf-8') as fp: + while fp.readline() != '': + self.offsets.append(fp.tell()) + self.offsets.pop() + self.fp = open(data_file, 'r', encoding='utf-8') + + def __len__(self): + return len(self.offsets) + + def __getitem__(self, idx): + self.fp.seek(self.offsets[idx], 0) + sample = self.transform(self.fp.readline().strip()) + if self.reader.with_mlm: + sample = self.reader.create_token_masked_lm_predictions(sample) + return sample diff --git a/modelscope/preprocessors/nlp/space/preprocess.py b/modelscope/preprocessors/nlp/space/preprocess.py new file mode 100644 index 00000000..8aab4711 --- /dev/null +++ b/modelscope/preprocessors/nlp/space/preprocess.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import glob +import os + +from modelscope.preprocessors.nlp.space.fields.intent_field import \ + IntentBPETextField + +FILE_NAME = 'train.json' + + +def intent_preprocess(path, cfg): + + bpe = IntentBPETextField(path, cfg) + args = cfg.Dataset + build_examples_fn = bpe.build_examples_multi_turn if args.trigger_role == 'system' \ + else bpe.build_examples_single_turn + build_score_matrix_fn = bpe.build_score_matrix + build_score_matrix_multiprocessing_fn = bpe.build_score_matrix_multiprocessing + data_paths = list( + os.path.dirname(c) for c in sorted( + glob.glob(args.data_dir + '/**/' + FILE_NAME, recursive=True))) + data_paths = bpe.filter_data_path(data_paths=data_paths) + + for mode in ['train', 'valid', 'test']: + for data_path in data_paths: + input_file = os.path.join(data_path, f'{mode}.json') + output_file = os.path.join(data_path, + f'{mode}.{bpe.tokenizer_type}.jsonl') + output_score_file = os.path.join(data_path, f'{mode}.Score.npy') + if os.path.exists(input_file) and not os.path.exists(output_file): + examples = build_examples_fn(input_file, data_type=mode) + if examples: + bpe.save_examples(examples, output_file) + else: + continue + if os.path.exists(output_file) and not os.path.exists(output_score_file) and \ + not args.dynamic_score and 'AnPreDial' in data_path: + examples = bpe.load_examples(output_file) + if args.num_process >= 2: + score_matrix = build_score_matrix_multiprocessing_fn( + examples) + else: + score_matrix = build_score_matrix_fn(examples) + bpe.save_examples(score_matrix, output_score_file) diff --git a/modelscope/preprocessors/nlp/space/sampler.py b/modelscope/preprocessors/nlp/space/sampler.py new file mode 100644 index 00000000..e549c343 --- /dev/null +++ b/modelscope/preprocessors/nlp/space/sampler.py @@ -0,0 +1,73 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np + + +class Sampler(object): + + def __init__(self): + return + + def __len__(self): + raise NotImplementedError + + def __iter__(self): + raise NotImplementedError + + +class SequentialSampler(Sampler): + + def __init__(self, dataset): + self.dataset = dataset + return + + def __len__(self): + return len(self.dataset) + + def __iter__(self): + return iter(range(len(self))) + + +class RandomSampler(Sampler): + + def __init__(self, dataset): + self.dataset = dataset + self.epoch = 0 + return + + def __len__(self): + return len(self.dataset) + + def __iter__(self): + np.random.seed(self.epoch) + self.epoch += 1 + return iter(np.random.permutation(len(self))) + + +class SortedSampler(Sampler): + """ Sorted Sampler. + Sort each block of examples by key. + """ + + def __init__(self, sampler, sort_pool_size, key='src'): + self.sampler = sampler + self.sort_pool_size = sort_pool_size + self.key = lambda idx: len(self.sampler.dataset[idx][key]) + return + + def __len__(self): + return len(self.sampler) + + def __iter__(self): + pool = [] + for idx in self.sampler: + pool.append(idx) + if len(pool) == self.sort_pool_size: + pool = sorted(pool, key=self.key) + for i in pool: + yield i + pool = [] + if len(pool) > 0: + pool = sorted(pool, key=self.key) + for i in pool: + yield i diff --git a/modelscope/preprocessors/nlp/space/tensorlistdataset.py b/modelscope/preprocessors/nlp/space/tensorlistdataset.py new file mode 100644 index 00000000..45243261 --- /dev/null +++ b/modelscope/preprocessors/nlp/space/tensorlistdataset.py @@ -0,0 +1,59 @@ +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.utils.data import Dataset + + +class TensorListDataset(Dataset): + r"""Dataset wrapping tensors, tensor dicts and tensor lists. + + Arguments: + *data (Tensor or dict or list of Tensors): tensors that have the same size + of the first dimension. + """ + + def __init__(self, *data): + if isinstance(data[0], dict): + size = list(data[0].values())[0].size(0) + elif isinstance(data[0], list): + size = data[0][0].size(0) + else: + size = data[0].size(0) + for element in data: + if isinstance(element, dict): + assert all( + size == tensor.size(0) + for name, tensor in element.items()) # dict of tensors + elif isinstance(element, list): + assert all(size == tensor.size(0) + for tensor in element) # list of tensors + else: + assert size == element.size(0) # tensor + self.size = size + self.data = data + + def __getitem__(self, index): + result = [] + for element in self.data: + if isinstance(element, dict): + result.append({k: v[index] for k, v in element.items()}) + elif isinstance(element, list): + result.append(v[index] for v in element) + else: + result.append(element[index]) + return tuple(result) + + def __len__(self): + return self.size diff --git a/modelscope/preprocessors/nlp/space/tokenizer.py b/modelscope/preprocessors/nlp/space/tokenizer.py new file mode 100644 index 00000000..1bd0ce11 --- /dev/null +++ b/modelscope/preprocessors/nlp/space/tokenizer.py @@ -0,0 +1,670 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import collections +import logging +import os +import sys +import unicodedata + +import json +import regex as re + + +def clean_string(string): + replace_mp = { + ' - ': '-', + " ' ": "'", + " n't": "n't", + " 'm": "'m", + ' do not': " don't", + " 's": "'s", + " 've": "'ve", + " 're": "'re" + } + for k, v in replace_mp.items(): + string = string.replace(k, v) + return string + + +class Tokenizer(object): + + def __init__(self, vocab_path, special_tokens=[], tokenizer_type='Bert'): + self.tokenizer_type = tokenizer_type + if tokenizer_type == 'Bert': + self.spec_convert_dict = { + '[BOS]': '[unused0]', + '[EOS]': '[unused1]' + } + for token in special_tokens: + if token not in self.spec_convert_dict and token not in [ + '[PAD]', '[UNK]' + ]: + self.spec_convert_dict[ + token] = f'[unused{len(self.spec_convert_dict)}]' + self.spec_revert_dict = { + v: k + for k, v in self.spec_convert_dict.items() + } + special_tokens = [ + self.spec_convert_dict.get(tok, tok) for tok in special_tokens + ] + self.special_tokens = ('[UNK]', '[SEP]', '[PAD]', '[CLS]', + '[MASK]') + self.special_tokens += tuple(x for x in special_tokens + if x not in self.special_tokens) + + self._tokenizer = BertTokenizer( + vocab_path, never_split=self.special_tokens) + for tok in self.special_tokens: + assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary" + self.vocab_size = len(self._tokenizer.vocab) + elif tokenizer_type == 'GPT2': + self.spec_convert_dict = {'[UNK]': ''} + self.spec_revert_dict = { + v: k + for k, v in self.spec_convert_dict.items() + } + special_tokens = [ + tok for tok in special_tokens + if tok not in self.spec_convert_dict + ] + vocab_file = os.path.join(vocab_path, 'vocab.json') + merges_file = os.path.join(vocab_path, 'merges.txt') + self._tokenizer = GPT2Tokenizer( + vocab_file, merges_file, special_tokens=special_tokens) + self.num_specials = len(special_tokens) + self.vocab_size = len(self._tokenizer) + else: + raise ValueError + + def tokenize(self, text): + return self._tokenizer.tokenize(text) + + def convert_tokens_to_ids(self, tokens): + if self.tokenizer_type == 'Bert': + tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens] + ids = self._tokenizer.convert_tokens_to_ids(tokens) + return ids + else: + tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens] + ids = self._tokenizer.convert_tokens_to_ids(tokens) + ids = [(i + self.num_specials) % self.vocab_size for i in ids] + return ids + + def convert_ids_to_tokens(self, ids): + if self.tokenizer_type == 'Bert': + tokens = self._tokenizer.convert_ids_to_tokens(ids) + tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens] + return tokens + else: + ids = [(i - self.num_specials) % self.vocab_size for i in ids] + tokens = self._tokenizer.convert_ids_to_tokens(ids) + tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens] + return tokens + + def decode(self, ids, ignore_tokens=[]): + tokens = self.convert_ids_to_tokens(ids) + if len(ignore_tokens) > 0: + ignore_tokens = set(ignore_tokens) + tokens = [tok for tok in tokens if tok not in ignore_tokens] + if self.tokenizer_type == 'Bert': + string = ' '.join(tokens).replace(' ##', '') + else: + string = ''.join(tokens) + string = bytearray([ + self._tokenizer.byte_decoder[c] for c in string + ]).decode('utf-8') + string = clean_string(string) + return string + + +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +logger = logging.getLogger(__name__) + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, 'r', encoding='utf-8') as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, + vocab_file, + do_lower_case=True, + max_len=None, + do_basic_tokenize=True, + never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + .format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([ + (ids, tok) for tok, ids in self.vocab.items() + ]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + logger.warning( + 'Token indices sequence length is longer than the specified maximum ' + ' sequence length for this BERT model ({} > {}). Running this' + ' sequence through BERT will result in indexing errors'.format( + len(ids), self.max_len)) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(' '.join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize('NFD', text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == 'Mn': + continue + output.append(char) + return ''.join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return [''.join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(' ') + output.append(char) + output.append(' ') + else: + output.append(char) + return ''.join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + tmp = (cp >= 0x4E00 and cp <= 0x9FFF) + tmp = tmp or (cp >= 0x3400 and cp <= 0x4DBF) + tmp = tmp or (cp >= 0x20000 and cp <= 0x2A6DF) + tmp = tmp or (cp >= 0x2A700 and cp <= 0x2B73F) + tmp = tmp or (cp >= 0x2B740 and cp <= 0x2B81F) + tmp = tmp or (cp >= 0x2B820 and cp <= 0x2CEAF) + tmp = tmp or (cp >= 0xF900 and cp <= 0xFAFF) + tmp = tmp or (cp >= 0x2F800 and cp <= 0x2FA1F) + if tmp: + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(' ') + else: + output.append(char) + return ''.join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = ''.join(chars[start:end]) + if start > 0: + substr = '##' + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == ' ' or char == '\t' or char == '\n' or char == '\r': + return True + cat = unicodedata.category(char) + if cat == 'Zs': + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == '\t' or char == '\n' or char == '\r': + return False + cat = unicodedata.category(char) + if cat.startswith('C'): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + tmp = (cp >= 33 and cp <= 47) + tmp = tmp or (cp >= 58 and cp <= 64) + tmp = tmp or (cp >= 91 and cp <= 96) + tmp = tmp or (cp >= 123 and cp <= 126) + if tmp: + return True + cat = unicodedata.category(char) + if cat.startswith('P'): + return True + return False + + +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI GPT.""" + +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. + def lru_cache(): + return lambda func: func + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + _chr = unichr if sys.version_info[0] == 2 else chr + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [_chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class GPT2Tokenizer(object): + """ + GPT-2 BPE tokenizer. Peculiarities: + - Byte-level BPE + """ + + def __init__(self, + vocab_file, + merges_file, + errors='replace', + special_tokens=None, + max_len=None): + self.max_len = max_len if max_len is not None else int(1e12) + self.encoder = json.load(open(vocab_file)) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_data] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + + self.special_tokens = {} + self.special_tokens_decoder = {} + self.set_special_tokens(special_tokens) + + def __len__(self): + return len(self.encoder) + len(self.special_tokens) + + def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return + self.special_tokens = dict((tok, len(self.encoder) + i) + for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = { + v: k + for k, v in self.special_tokens.items() + } + logger.info('Special tokens {}'.format(self.special_tokens)) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def tokenize(self, text): + """ Tokenize a string. """ + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[ord(b)] for b in token + if ord(b) in self.byte_encoder) + if token == '': + continue + bpe_tokens.extend( + bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + """ Converts a sequence of tokens into ids using the vocab. """ + ids = [] + python_version_3 = isinstance(tokens, str) + python_version_2 = ( + sys.version_info[0] == 2 and isinstance(tokens, unicode)) + if python_version_3 or python_version_2: + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.encoder.get(tokens, 0) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.encoder.get(token, 0)) + if len(ids) > self.max_len: + logger.warning( + 'Token indices sequence length is longer than the specified maximum ' + ' sequence length for this OpenAI GPT model ({} > {}). Running this' + ' sequence through the model will result in indexing errors'. + format(len(ids), self.max_len)) + return ids + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a sequence of ids in BPE tokens using the vocab.""" + tokens = [] + for i in ids: + if i in self.special_tokens_decoder: + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.decoder[i]) + return tokens + + def encode(self, text): + return self.convert_tokens_to_ids(self.tokenize(text)) + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors=self.errors) + return text diff --git a/modelscope/preprocessors/nlp/space_T_cn/__init__.py b/modelscope/preprocessors/nlp/space_T_cn/__init__.py new file mode 100644 index 00000000..9aa562d7 --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_cn/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .table_question_answering_preprocessor import TableQuestionAnsweringPreprocessor + from .fields import MultiWOZBPETextField, IntentBPETextField + +else: + _import_structure = { + 'table_question_answering_preprocessor': + ['TableQuestionAnsweringPreprocessor'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/nlp/space_T_cn/fields/__init__.py b/modelscope/preprocessors/nlp/space_T_cn/fields/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/preprocessors/nlp/space_T_cn/fields/database.py b/modelscope/preprocessors/nlp/space_T_cn/fields/database.py new file mode 100644 index 00000000..2fef8d7e --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_cn/fields/database.py @@ -0,0 +1,123 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import sqlite3 + +import json +import tqdm + +from .struct import Trie + + +class Database: + + def __init__(self, + tokenizer, + table_file_path, + syn_dict_file_path, + is_use_sqlite=True): + self.tokenizer = tokenizer + self.is_use_sqlite = is_use_sqlite + if self.is_use_sqlite: + self.connection_obj = sqlite3.connect( + ':memory:', check_same_thread=False) + self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} + self.tables = self.init_tables(table_file_path=table_file_path) + self.syn_dict = self.init_syn_dict( + syn_dict_file_path=syn_dict_file_path) + + def __del__(self): + if self.is_use_sqlite: + self.connection_obj.close() + + def init_tables(self, table_file_path): + tables = {} + lines = [] + if type(table_file_path) == str: + with open(table_file_path, 'r') as fo: + for line in fo: + lines.append(line) + elif type(table_file_path) == list: + for path in table_file_path: + with open(path, 'r') as fo: + for line in fo: + lines.append(line) + else: + raise ValueError() + + for line in tqdm.tqdm(lines, desc='Load Tables'): + table = json.loads(line.strip()) + + table_header_length = 0 + headers_tokens = [] + for header in table['header_name']: + header_tokens = self.tokenizer.tokenize(header) + table_header_length += len(header_tokens) + headers_tokens.append(header_tokens) + empty_column = self.tokenizer.tokenize('空列') + table_header_length += len(empty_column) + headers_tokens.append(empty_column) + table['tablelen'] = table_header_length + table['header_tok'] = headers_tokens + table['headerid2name'] = {} + for hid, hname in zip(table['header_id'], table['header_name']): + table['headerid2name'][hid] = hname + + table['header_types'].append('null') + table['header_units'] = [ + self.tokenizer.tokenize(unit) for unit in table['header_units'] + ] + [[]] + + trie_set = [Trie() for _ in table['header_name']] + for row in table['rows']: + for ii, cell in enumerate(row): + if 'real' in table['header_types'][ii].lower() or \ + 'number' in table['header_types'][ii].lower() or \ + 'duration' in table['header_types'][ii].lower(): + continue + word = str(cell).strip().lower() + trie_set[ii].insert(word, word) + + table['value_trie'] = trie_set + + # create sqlite + if self.is_use_sqlite: + cursor_obj = self.connection_obj.cursor() + cursor_obj.execute('DROP TABLE IF EXISTS %s' % + (table['table_id'])) + header_string = ', '.join([ + '%s %s' % + (name, self.type_dict[htype]) for name, htype in zip( + table['header_id'], table['header_types']) + ]) + create_table_string = 'CREATE TABLE %s (%s);' % ( + table['table_id'], header_string) + cursor_obj.execute(create_table_string) + for row in table['rows']: + value_string = ', '.join(['"%s"' % (val) for val in row]) + insert_row_string = 'INSERT INTO %s VALUES(%s)' % ( + table['table_id'], value_string) + cursor_obj.execute(insert_row_string) + + tables[table['table_id']] = table + + return tables + + def init_syn_dict(self, syn_dict_file_path): + lines = [] + with open(syn_dict_file_path, encoding='utf-8') as fo: + for line in fo: + lines.append(line) + + syn_dict = {} + for line in tqdm.tqdm(lines, desc='Load Synonym Dict'): + tokens = line.strip().split('\t') + if len(tokens) != 2: + continue + keys = tokens[0].strip().split('|') + values = tokens[1].strip().split('|') + for key in keys: + key = key.lower().strip() + syn_dict.setdefault(key, []) + for value in values: + syn_dict[key].append(value.lower().strip()) + + return syn_dict diff --git a/modelscope/preprocessors/nlp/space_T_cn/fields/schema_link.py b/modelscope/preprocessors/nlp/space_T_cn/fields/schema_link.py new file mode 100644 index 00000000..b62d03e4 --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_cn/fields/schema_link.py @@ -0,0 +1,439 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re + +from .struct import TypeInfo + + +class SchemaLinker: + + def __init__(self): + pass + + def find_in_list(self, comlist, words): + result = False + for com in comlist: + if words in com: + result = True + break + return result + + def get_continue_score(self, pstr, tstr): + comlist = [] + minlen = min(len(pstr), len(tstr)) + for slen in range(minlen, 1, -1): + for ts in range(0, len(tstr), 1): + if ts + slen > len(tstr): + continue + words = tstr[ts:ts + slen] + if words in pstr and not self.find_in_list(comlist, words): + comlist.append(words) + + comlen = 0 + for com in comlist: + comlen += len(com) * len(com) + weight = comlen / (len(tstr) * len(tstr) + 0.001) + if weight > 1.0: + weight = 1.0 + + return weight + + def get_match_score(self, ptokens, ttokens): + pset = set(ptokens) + tset = set(ttokens) + comset = pset & tset + allset = pset | tset + weight2 = len(comset) / (len(allset) + 0.001) + weight3 = self.get_continue_score(''.join(ptokens), ''.join(ttokens)) + return 0.4 * weight2 + 0.6 * weight3 + + def is_number(self, s): + try: + float(s) + return True + except ValueError: + pass + + try: + import unicodedata + unicodedata.numeric(s) + return True + except (TypeError, ValueError): + pass + + return False + + def get_match_phrase(self, query, target): + if target in query: + return target, 1.0 + + qtokens = [] + for i in range(0, len(query), 1): + qtokens.append(query[i:i + 1]) + ttokens = [] + for i in range(0, len(target), 1): + ttokens.append(target[i:i + 1]) + ttok_set = set(ttokens) + + phrase = '' + score = 0.0 + for qidx, qword in enumerate(qtokens): + if qword not in ttok_set: + continue + + eidx = (qidx + 2 * len(ttokens)) if ( + len(qtokens) > qidx + 2 * len(ttokens)) else len(qtokens) + while eidx > qidx: + ptokens = qtokens[qidx:eidx] + weight = self.get_match_score(ptokens, ttokens) + if weight + 0.001 > score: + score = weight + phrase = ''.join(ptokens) + eidx -= 1 + + if self.is_number(target) and phrase != target: + score = 0.0 + if len(phrase) > 1 and phrase in target: + score *= (1.0 + 0.05 * len(phrase)) + + return phrase, score + + def allfindpairidx(self, que_tok, value_tok, weight): + idxs = [] + for i in range(0, len(que_tok) - len(value_tok) + 1, 1): + s = i + e = i + matched = True + for j in range(0, len(value_tok), 1): + if value_tok[j].lower() == que_tok[i + j].lower(): + e = i + j + else: + matched = False + break + if matched: + idxs.append([s, e, weight]) + + return idxs + + def findnear(self, ps1, pe1, ps2, pe2): + if abs(ps1 - pe2) <= 2 or abs(pe1 - ps2) <= 2: + return True + return False + + def get_column_type(self, col_idx, table): + colType = table['header_types'][col_idx] + if 'number' in colType or 'duration' in colType or 'real' in colType: + colType = 'real' + elif 'date' in colType: + colType = 'date' + elif 'bool' in colType: + colType = 'bool' + else: + colType = 'text' + + return colType + + def add_type_all(self, typeinfos, index, idxs, label, linktype, value, + orgvalue): + for idx in idxs: + info = TypeInfo(label, index, linktype, value, orgvalue, idx[0], + idx[1], idx[2]) + flag = True + for i, typeinfo in enumerate(typeinfos): + if info.pstart < typeinfo.pstart: + typeinfos.insert(i, info) + flag = False + break + + if flag: + typeinfos.append(info) + + return typeinfos + + def save_info(self, tinfo, sinfo): + flag = True + if tinfo.pstart > sinfo.pend or tinfo.pend < sinfo.pstart: + pass + elif tinfo.pstart >= sinfo.pstart and \ + tinfo.pend <= sinfo.pend and tinfo.index == -1: + flag = False + elif tinfo.pstart == sinfo.pstart and sinfo.pend == tinfo.pend and \ + abs(tinfo.weight - sinfo.weight) < 0.01: + pass + else: + if sinfo.label == 'col' or sinfo.label == 'val': + if tinfo.label == 'col' or tinfo.label == 'val': + if (sinfo.pend + - sinfo.pstart) > (tinfo.pend - tinfo.pstart) or ( + sinfo.weight > tinfo.weight + and sinfo.index != -1): + flag = False + else: + flag = False + else: + if (tinfo.label == 'op' or tinfo.label == 'agg'): + if (sinfo.pend - sinfo.pstart) > ( + tinfo.pend + - tinfo.pstart) or sinfo.weight > tinfo.weight: + flag = False + + return flag + + def normal_type_infos(self, infos): + typeinfos = [] + for info in infos: + typeinfos = [x for x in typeinfos if self.save_info(x, info)] + flag = True + for i, typeinfo in enumerate(typeinfos): + if not self.save_info(info, typeinfo): + flag = False + break + if info.pstart < typeinfo.pstart: + typeinfos.insert(i, info) + flag = False + break + if flag: + typeinfos.append(info) + return typeinfos + + def findnear_typeinfo(self, info1, info2): + return self.findnear(info1.pstart, info1.pend, info2.pstart, + info2.pend) + + def find_real_column(self, infos, table): + for i, vinfo in enumerate(infos): + if vinfo.index != -1 or vinfo.label != 'val': + continue + eoidx = -1 + for j, oinfo in enumerate(infos): + if oinfo.label != 'op': + continue + if self.findnear_typeinfo(vinfo, oinfo): + eoidx = j + break + for j, cinfo in enumerate(infos): + if cinfo.label != 'col' or table['header_types'][ + cinfo.index] != 'real': + continue + if self.findnear_typeinfo(cinfo, vinfo) or ( + eoidx != -1 + and self.findnear_typeinfo(cinfo, infos[eoidx])): + infos[i].index = cinfo.index + break + + return infos + + def filter_column_infos(self, infos): + delid = [] + for i, info in enumerate(infos): + if info.label != 'col': + continue + for j in range(i + 1, len(infos), 1): + if infos[j].label == 'col' and \ + info.pstart == infos[j].pstart and \ + info.pend == infos[j].pend: + delid.append(i) + delid.append(j) + break + + typeinfos = [] + for idx, info in enumerate(infos): + if idx in set(delid): + continue + typeinfos.append(info) + + return typeinfos + + def filter_type_infos(self, infos, table): + infos = self.filter_column_infos(infos) + infos = self.find_real_column(infos, table) + + colvalMp = {} + for info in infos: + if info.label == 'col': + colvalMp[info.index] = [] + for info in infos: + if info.label == 'val' and info.index in colvalMp: + colvalMp[info.index].append(info) + + delid = [] + for idx, info in enumerate(infos): + if info.label != 'val' or info.index in colvalMp: + continue + for index in colvalMp.keys(): + valinfos = colvalMp[index] + for valinfo in valinfos: + if valinfo.pstart <= info.pstart and \ + valinfo.pend >= info.pend: + delid.append(idx) + break + + typeinfos = [] + for idx, info in enumerate(infos): + if idx in set(delid): + continue + typeinfos.append(info) + + return typeinfos + + def get_table_match_score(self, nlu_t, schema_link): + match_len = 0 + for info in schema_link: + scale = 0.6 + if info['question_len'] > 0 and info['column_index'] != -1: + scale = 1.0 + else: + scale = 0.5 + match_len += scale * info['question_len'] * info['weight'] + + return match_len / (len(nlu_t) + 0.1) + + def get_entity_linking(self, + tokenizer, + nlu, + nlu_t, + tables, + col_syn_dict, + table_id=None, + history_sql=None): + """ + get linking between question and schema column + """ + typeinfos = [] + numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) + + if table_id is not None and table_id in tables: + tables = {table_id: tables[table_id]} + + # search schema link in every table + search_result_list = [] + for tablename in tables: + table = tables[tablename] + trie_set = None + if 'value_trie' in table: + trie_set = table['value_trie'] + + typeinfos = [] + for ii, column in enumerate(table['header_name']): + column = column.lower() + column_new = column + cphrase, cscore = self.get_match_phrase( + nlu.lower(), column_new) + if cscore > 0.3 and cphrase.strip() != '': + phrase_tok = tokenizer.tokenize(cphrase) + cidxs = self.allfindpairidx(nlu_t, phrase_tok, cscore) + typeinfos = self.add_type_all(typeinfos, ii, cidxs, 'col', + 'column', cphrase, column) + if cscore < 0.8 and column_new in col_syn_dict: + columns = list(set(col_syn_dict[column_new])) + for syn_col in columns: + if syn_col not in nlu.lower() or syn_col == '': + continue + phrase_tok = tokenizer.tokenize(syn_col) + cidxs = self.allfindpairidx(nlu_t, phrase_tok, 1.0) + typeinfos = self.add_type_all(typeinfos, ii, cidxs, + 'col', 'column', syn_col, + column) + + for ii, trie in enumerate(trie_set): + ans = trie.match(nlu.lower()) + for cell in ans.keys(): + vphrase = cell + vscore = 1.0 + phrase_tok = tokenizer.tokenize(vphrase) + if len(phrase_tok) == 0 or len(vphrase) < 2: + continue + vidxs = self.allfindpairidx(nlu_t, phrase_tok, vscore) + linktype = self.get_column_type(ii, table) + typeinfos = self.add_type_all(typeinfos, ii, vidxs, 'val', + linktype, vphrase, ans[cell]) + + for number in set(numbers): + number_tok = tokenizer.tokenize(number.lower()) + if len(number_tok) == 0: + continue + nidxs = self.allfindpairidx(nlu_t, number_tok, 1.0) + typeinfos = self.add_type_all(typeinfos, -1, nidxs, 'val', + 'real', number, number) + + newtypeinfos = self.normal_type_infos(typeinfos) + + newtypeinfos = self.filter_type_infos(newtypeinfos, table) + + final_question = [0] * len(nlu_t) + final_header = [0] * len(table['header_name']) + for typeinfo in newtypeinfos: + pstart = typeinfo.pstart + pend = typeinfo.pend + 1 + if typeinfo.label == 'op' or typeinfo.label == 'agg': + score = int(typeinfo.linktype[-1]) + if typeinfo.label == 'op': + score += 6 + else: + score += 11 + for i in range(pstart, pend, 1): + final_question[i] = score + + elif typeinfo.label == 'col': + for i in range(pstart, pend, 1): + final_question[i] = 4 + if final_header[typeinfo.index] % 2 == 0: + final_header[typeinfo.index] += 1 + + elif typeinfo.label == 'val': + if typeinfo.index == -1: + for i in range(pstart, pend, 1): + final_question[i] = 5 + else: + for i in range(pstart, pend, 1): + final_question[i] = 2 + final_question[pstart] = 1 + final_question[pend - 1] = 3 + if final_header[typeinfo.index] < 2: + final_header[typeinfo.index] += 2 + + # collect schema_link + schema_link = [] + for sl in newtypeinfos: + if sl.label in ['val', 'col']: + schema_link.append({ + 'question_len': + max(0, sl.pend - sl.pstart + 1), + 'question_index': [sl.pstart, sl.pend], + 'question_span': + ''.join(nlu_t[sl.pstart:sl.pend + 1]), + 'column_index': + sl.index, + 'column_span': + table['header_name'][sl.index] + if sl.index != -1 else '空列', + 'label': + sl.label, + 'weight': + round(sl.weight, 4) + }) + + # get the match score of each table + match_score = self.get_table_match_score(nlu_t, schema_link) + + # cal table_score + if history_sql is not None and 'from' in history_sql: + table_score = int(table['table_id'] == history_sql['from'][0]) + else: + table_score = 0 + + search_result = { + 'table_id': table['table_id'], + 'question_knowledge': final_question, + 'header_knowledge': final_header, + 'schema_link': schema_link, + 'match_score': match_score, + 'table_score': table_score + } + search_result_list.append(search_result) + + search_result_list = sorted( + search_result_list, + key=lambda x: (x['match_score'], x['table_score']), + reverse=True)[0:1] + + return search_result_list diff --git a/modelscope/preprocessors/nlp/space_T_cn/fields/struct.py b/modelscope/preprocessors/nlp/space_T_cn/fields/struct.py new file mode 100644 index 00000000..917e1aaa --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_cn/fields/struct.py @@ -0,0 +1,203 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +cond_ops = ['>', '<', '==', '!=', 'ASC', 'DESC'] +agg_ops = [ + '', 'AVG', 'MAX', 'MIN', 'COUNT', 'SUM', 'COMPARE', 'GROUP BY', 'SAME' +] +conn_ops = ['', 'AND', 'OR'] + + +class Context: + + def __init__(self): + self.history_sql = None + + def set_history_sql(self, sql): + self.history_sql = sql + + +class SQLQuery: + + def __init__(self, string, query, sql_result): + self.string = string + self.query = query + self.sql_result = sql_result + + +class TrieNode(object): + + def __init__(self): + """ + Initialize your data structure here. + """ + self.data = {} + self.is_word = False + self.term = None + + +class Trie(object): + + def __init__(self): + self.root = TrieNode() + + def insert(self, word, term): + """ + Inserts a word into the trie. + :type word: str + :rtype: void + """ + node = self.root + for letter in word: + child = node.data.get(letter) + if not child: + node.data[letter] = TrieNode() + node = node.data[letter] + node.is_word = True + node.term = term + + def search(self, word): + """ + Returns if the word is in the trie. + :type word: str + :rtype: bool + """ + node = self.root + for letter in word: + node = node.data.get(letter) + if not node: + return None, False + return node.term, True + + def match(self, query): + start = 0 + end = 1 + length = len(query) + ans = {} + while start < length and end < length: + sub = query[start:end] + term, flag = self.search(sub) + if flag: + if term is not None: + ans[sub] = term + end += 1 + else: + start += 1 + end = start + 1 + return ans + + def starts_with(self, prefix): + """ + Returns if there is any word in the trie + that starts with the given prefix. + :type prefix: str + :rtype: bool + """ + node = self.root + for letter in prefix: + node = node.data.get(letter) + if not node: + return False + return True + + def get_start(self, prefix): + """ + Returns words started with prefix + :param prefix: + :return: words (list) + """ + + def _get_key(pre, pre_node): + words_list = [] + if pre_node.is_word: + words_list.append(pre) + for x in pre_node.data.keys(): + words_list.extend(_get_key(pre + str(x), pre_node.data.get(x))) + return words_list + + words = [] + if not self.starts_with(prefix): + return words + if self.search(prefix): + words.append(prefix) + return words + node = self.root + for letter in prefix: + node = node.data.get(letter) + return _get_key(prefix, node) + + +class TypeInfo: + + def __init__(self, label, index, linktype, value, orgvalue, pstart, pend, + weight): + self.label = label + self.index = index + self.linktype = linktype + self.value = value + self.orgvalue = orgvalue + self.pstart = pstart + self.pend = pend + self.weight = weight + + +class Constant: + + def __init__(self): + self.action_ops = [ + 'add_cond', 'change_cond', 'del_cond', 'change_focus_total', + 'change_agg_only', 'del_focus', 'restart', 'switch_table', + 'out_of_scripts', 'repeat', 'firstTurn' + ] + + self.agg_ops = [ + '', 'AVG', 'MAX', 'MIN', 'COUNT', 'SUM', 'COMPARE', 'GROUP BY', + 'SAME' + ] + + self.cond_ops = ['>', '<', '==', '!=', 'ASC', 'DESC'] + + self.cond_conn_ops = ['', 'AND', 'OR'] + + self.col_type_dict = { + 'null': 0, + 'text': 1, + 'number': 2, + 'duration': 3, + 'bool': 4, + 'date': 5 + } + + self.schema_link_dict = { + 'col_start': 1, + 'col_middle': 2, + 'col_end': 3, + 'val_start': 4, + 'val_middle': 5, + 'val_end': 6 + } + + self.max_select_num = 4 + + self.max_where_num = 6 + + self.limit_dict = { + '最': 1, + '1': 1, + '一': 1, + '2': 2, + '二': 2, + '3': 3, + '三': 3, + '4': 4, + '四': 4, + '5': 5, + '五': 5, + '6': 6, + '六': 6, + '7': 7, + '七': 7, + '8': 8, + '八': 8, + '9': 9, + '九': 9 + } diff --git a/modelscope/preprocessors/nlp/space_T_cn/table_question_answering_preprocessor.py b/modelscope/preprocessors/nlp/space_T_cn/table_question_answering_preprocessor.py new file mode 100644 index 00000000..3aabc6a9 --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_cn/table_question_answering_preprocessor.py @@ -0,0 +1,122 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import torch +from transformers import BertTokenizer + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.preprocessors.nlp.space_T_cn.fields.database import Database +from modelscope.preprocessors.nlp.space_T_cn.fields.schema_link import \ + SchemaLinker +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile +from modelscope.utils.type_assert import type_assert + +__all__ = ['TableQuestionAnsweringPreprocessor'] + + +@PREPROCESSORS.register_module( + Fields.nlp, + module_name=Preprocessors.table_question_answering_preprocessor) +class TableQuestionAnsweringPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, db: Database = None, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + db (Database): database instance + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + # read tokenizer + self.tokenizer = BertTokenizer( + os.path.join(self.model_dir, ModelFile.VOCAB_FILE)) + + # read database + if db is None: + self.db = Database( + tokenizer=self.tokenizer, + table_file_path=os.path.join(self.model_dir, 'table.json'), + syn_dict_file_path=os.path.join(self.model_dir, 'synonym.txt')) + else: + self.db = db + + # get schema linker + self.schema_linker = SchemaLinker() + + # set device + self.device = 'cuda' if \ + ('device' not in kwargs or kwargs['device'] == 'gpu') \ + and torch.cuda.is_available() else 'cpu' + + def construct_data(self, search_result_list, nlu, nlu_t, db, history_sql): + datas = [] + for search_result in search_result_list: + data = {} + data['table_id'] = search_result['table_id'] + data['question'] = nlu + data['question_tok'] = nlu_t + data['header_tok'] = db.tables[data['table_id']]['header_tok'] + data['types'] = db.tables[data['table_id']]['header_types'] + data['units'] = db.tables[data['table_id']]['header_units'] + data['action'] = 0 + data['sql'] = None + data['history_sql'] = history_sql + data['wvi_corenlp'] = [] + data['bertindex_knowledge'] = search_result['question_knowledge'] + data['header_knowledge'] = search_result['header_knowledge'] + data['schema_link'] = search_result['schema_link'] + datas.append(data) + + return datas + + @type_assert(object, dict) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (dict): + utterance: a sentence + last_sql: predicted sql of last utterance + Example: + utterance: 'Which of these are hiring?' + last_sql: '' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + # tokenize question + question = data['question'] + table_id = data.get('table_id', None) + history_sql = data.get('history_sql', None) + nlu = question.lower() + nlu_t = self.tokenizer.tokenize(nlu) + + # get linking + search_result_list = self.schema_linker.get_entity_linking( + tokenizer=self.tokenizer, + nlu=nlu, + nlu_t=nlu_t, + tables=self.db.tables, + col_syn_dict=self.db.syn_dict, + table_id=table_id, + history_sql=history_sql) + + # collect data + datas = self.construct_data( + search_result_list=search_result_list[0:1], + nlu=nlu, + nlu_t=nlu_t, + db=self.db, + history_sql=history_sql) + + return {'datas': datas} diff --git a/modelscope/preprocessors/nlp/space_T_en/__init__.py b/modelscope/preprocessors/nlp/space_T_en/__init__.py new file mode 100644 index 00000000..cef8f074 --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_en/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .conversational_text_to_sql_preprocessor import \ + ConversationalTextToSqlPreprocessor + from .fields import (get_label, SubPreprocessor, preprocess_dataset, + process_dataset) + +else: + _import_structure = { + 'conversational_text_to_sql_preprocessor': + ['ConversationalTextToSqlPreprocessor'], + 'fields': [ + 'get_label', 'SubPreprocessor', 'preprocess_dataset', + 'process_dataset' + ] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/nlp/space_T_en/conversational_text_to_sql_preprocessor.py b/modelscope/preprocessors/nlp/space_T_en/conversational_text_to_sql_preprocessor.py new file mode 100644 index 00000000..00c7bcd7 --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_en/conversational_text_to_sql_preprocessor.py @@ -0,0 +1,124 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import json +import torch +from text2sql_lgesql.preprocess.graph_utils import GraphProcessor +from text2sql_lgesql.preprocess.process_graphs import process_dataset_graph +from text2sql_lgesql.utils.batch import Batch +from text2sql_lgesql.utils.example import Example + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.preprocessors.nlp.space_T_en.fields import SubPreprocessor +from modelscope.preprocessors.nlp.space_T_en.fields.preprocess_dataset import \ + preprocess_dataset +from modelscope.preprocessors.nlp.space_T_en.fields.process_dataset import ( + process_dataset, process_tables) +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile +from modelscope.utils.type_assert import type_assert + +__all__ = ['ConversationalTextToSqlPreprocessor'] + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.conversational_text_to_sql) +class ConversationalTextToSqlPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + self.device = 'cuda' if \ + ('device' not in kwargs or kwargs['device'] == 'gpu') \ + and torch.cuda.is_available() else 'cpu' + self.processor = None + self.table_path = os.path.join(self.model_dir, 'tables.json') + self.tables = json.load(open(self.table_path, 'r')) + self.output_tables = None + self.path_cache = [] + self.graph_processor = GraphProcessor() + + Example.configuration( + plm=self.config['model']['plm'], + tables=self.output_tables, + table_path=os.path.join(model_dir, 'tables.json'), + model_dir=self.model_dir, + db_dir=os.path.join(model_dir, 'db')) + + self.device = 'cuda' if \ + ('device' not in kwargs or kwargs['device'] == 'gpu') \ + and torch.cuda.is_available() else 'cpu' + use_device = True if self.device == 'cuda' else False + self.processor = \ + SubPreprocessor(model_dir=model_dir, + db_content=True, + use_gpu=use_device) + self.output_tables = \ + process_tables(self.processor, + self.tables) + + @type_assert(object, dict) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (dict): + utterance: a sentence + last_sql: predicted sql of last utterance + Example: + utterance: 'Which of these are hiring?' + last_sql: '' + + Returns: + Dict[str, Any]: the preprocessed data + """ + # use local database + if data['local_db_path'] is not None and data[ + 'local_db_path'] not in self.path_cache: + self.path_cache.append(data['local_db_path']) + path = os.path.join(data['local_db_path'], 'tables.json') + self.tables = json.load(open(path, 'r')) + self.processor.db_dir = os.path.join(data['local_db_path'], 'db') + self.output_tables = process_tables(self.processor, self.tables) + Example.configuration( + plm=self.config['model']['plm'], + tables=self.output_tables, + table_path=path, + model_dir=self.model_dir, + db_dir=self.processor.db_dir) + + theresult, sql_label = \ + preprocess_dataset( + self.processor, + data, + self.output_tables, + data['database_id'], + self.tables + ) + output_dataset = process_dataset(self.model_dir, self.processor, + theresult, self.output_tables) + output_dataset = \ + process_dataset_graph( + self.graph_processor, + output_dataset, + self.output_tables, + method='lgesql' + ) + dev_ex = Example(output_dataset[0], + self.output_tables[data['database_id']], sql_label) + current_batch = Batch.from_example_list([dev_ex], + self.device, + train=False) + return {'batch': current_batch, 'db': data['database_id']} diff --git a/modelscope/preprocessors/nlp/space_T_en/fields/__init__.py b/modelscope/preprocessors/nlp/space_T_en/fields/__init__.py new file mode 100644 index 00000000..7049c43b --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_en/fields/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .common_utils import SubPreprocessor + from .parse import get_label + from .preprocess_dataset import \ + preprocess_dataset + from .process_dataset import \ + process_dataset, process_tables + +else: + _import_structure = { + 'common_utils': ['SubPreprocessor'], + 'parse': ['get_label'], + 'preprocess_dataset': ['preprocess_dataset'], + 'process_dataset': ['process_dataset', 'process_tables'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/nlp/space_T_en/fields/common_utils.py b/modelscope/preprocessors/nlp/space_T_en/fields/common_utils.py new file mode 100644 index 00000000..431e66b6 --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_en/fields/common_utils.py @@ -0,0 +1,480 @@ +# Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql. + +import os +import sqlite3 +from itertools import combinations, product + +import nltk +import numpy as np +from text2sql_lgesql.utils.constants import MAX_RELATIVE_DIST + +from modelscope.utils.logger import get_logger + +mwtokenizer = nltk.MWETokenizer(separator='') +mwtokenizer.add_mwe(('[', 'CLS', ']')) +logger = get_logger() + + +def is_number(s): + try: + float(s) + return True + except ValueError: + return False + + +def quote_normalization(question): + """ Normalize all usage of quotation marks into a separate \" """ + new_question, quotation_marks = [], [ + "'", '"', '`', '‘', '’', '“', '”', '``', "''", '‘‘', '’’' + ] + for idx, tok in enumerate(question): + if len(tok) > 2 and tok[0] in quotation_marks and tok[ + -1] in quotation_marks: + new_question += ["\"", tok[1:-1], "\""] + elif len(tok) > 2 and tok[0] in quotation_marks: + new_question += ["\"", tok[1:]] + elif len(tok) > 2 and tok[-1] in quotation_marks: + new_question += [tok[:-1], "\""] + elif tok in quotation_marks: + new_question.append("\"") + elif len(tok) == 2 and tok[0] in quotation_marks: + # special case: the length of entity value is 1 + if idx + 1 < len(question) and question[idx + + 1] in quotation_marks: + new_question += ["\"", tok[1]] + else: + new_question.append(tok) + else: + new_question.append(tok) + return new_question + + +class SubPreprocessor(): + + def __init__(self, model_dir, use_gpu=False, db_content=True): + super(SubPreprocessor, self).__init__() + self.model_dir = model_dir + self.db_dir = os.path.join(model_dir, 'db') + self.db_content = db_content + + from nltk import data + from nltk.corpus import stopwords + data.path.append(os.path.join(self.model_dir, 'nltk_data')) + self.stopwords = stopwords.words('english') + + import stanza + from stanza.resources import common + from stanza.pipeline import core + self.nlp = stanza.Pipeline( + 'en', + use_gpu=use_gpu, + dir=self.model_dir, + processors='tokenize,pos,lemma', + tokenize_pretokenized=True, + download_method=core.DownloadMethod.REUSE_RESOURCES) + self.nlp1 = stanza.Pipeline( + 'en', + use_gpu=use_gpu, + dir=self.model_dir, + processors='tokenize,pos,lemma', + download_method=core.DownloadMethod.REUSE_RESOURCES) + + def pipeline(self, entry: dict, db: dict, verbose: bool = False): + """ db should be preprocessed """ + entry = self.preprocess_question(entry, db, verbose=verbose) + entry = self.schema_linking(entry, db, verbose=verbose) + entry = self.extract_subgraph(entry, db, verbose=verbose) + return entry + + def preprocess_database(self, db: dict, verbose: bool = False): + table_toks, table_names = [], [] + for tab in db['table_names']: + doc = self.nlp1(tab) + tab = [w.lemma.lower() for s in doc.sentences for w in s.words] + table_toks.append(tab) + table_names.append(' '.join(tab)) + db['processed_table_toks'], db[ + 'processed_table_names'] = table_toks, table_names + column_toks, column_names = [], [] + for _, c in db['column_names']: + doc = self.nlp1(c) + c = [w.lemma.lower() for s in doc.sentences for w in s.words] + column_toks.append(c) + column_names.append(' '.join(c)) + db['processed_column_toks'], db[ + 'processed_column_names'] = column_toks, column_names + column2table = list(map(lambda x: x[0], db['column_names'])) + table2columns = [[] for _ in range(len(table_names))] + for col_id, col in enumerate(db['column_names']): + if col_id == 0: + continue + table2columns[col[0]].append(col_id) + db['column2table'], db['table2columns'] = column2table, table2columns + + t_num, c_num, dtype = len(db['table_names']), len( + db['column_names']), ' 0: + col1, col2 = list(zip(*db['foreign_keys'])) + col_mat[col1, col2], col_mat[ + col2, col1] = 'column-column-fk', 'column-column-fkr' + col_mat[0, list(range(c_num))] = '*-column-generic' + col_mat[list(range(c_num)), 0] = 'column-*-generic' + col_mat[0, 0] = '*-*-identity' + + # relations between tables and columns, t_num*c_num and c_num*t_num + tab_col_mat = np.array([['table-column-generic'] * c_num + for _ in range(t_num)], + dtype=dtype) + col_tab_mat = np.array([['column-table-generic'] * t_num + for _ in range(c_num)], + dtype=dtype) + cols, tabs = list( + zip(*list(map(lambda x: (x, column2table[x]), range(1, c_num))))) + col_tab_mat[cols, tabs], tab_col_mat[ + tabs, cols] = 'column-table-has', 'table-column-has' + if len(db['primary_keys']) > 0: + cols, tabs = list( + zip(*list( + map(lambda x: (x, column2table[x]), db['primary_keys'])))) + col_tab_mat[cols, tabs], tab_col_mat[ + tabs, cols] = 'column-table-pk', 'table-column-pk' + col_tab_mat[0, list(range(t_num))] = '*-table-generic' + tab_col_mat[list(range(t_num)), 0] = 'table-*-generic' + + relations = \ + np.concatenate([ + np.concatenate([tab_mat, tab_col_mat], axis=1), + np.concatenate([col_tab_mat, col_mat], axis=1) + ], axis=0) + db['relations'] = relations.tolist() + + if verbose: + print('Tables:', ', '.join(db['table_names'])) + print('Lemmatized:', ', '.join(table_names)) + print('Columns:', + ', '.join(list(map(lambda x: x[1], db['column_names'])))) + print('Lemmatized:', ', '.join(column_names), '\n') + return db + + def preprocess_question(self, + entry: dict, + db: dict, + verbose: bool = False): + """ Tokenize, lemmatize, lowercase question""" + # stanza tokenize, lemmatize and POS tag + question = ' '.join(quote_normalization(entry['question_toks'])) + + from nltk import data + data.path.append(os.path.join(self.model_dir, 'nltk_data')) + + zippath = os.path.join(self.model_dir, 'nltk_data/tokenizers/punkt') + if os.path.exists(zippath): + print('punkt has already exist!') + else: + import zipfile + with zipfile.ZipFile(zippath + '.zip') as zf: + zf.extractall( + os.path.join(self.model_dir, 'nltk_data/tokenizers/')) + question = nltk.word_tokenize(question) + question = mwtokenizer.tokenize(question) + + doc = self.nlp([question]) + raw_toks = [w.text.lower() for s in doc.sentences for w in s.words] + toks = [w.lemma.lower() for s in doc.sentences for w in s.words] + pos_tags = [w.xpos for s in doc.sentences for w in s.words] + + entry['raw_question_toks'] = raw_toks + entry['processed_question_toks'] = toks + entry['pos_tags'] = pos_tags + + q_num, dtype = len(toks), ' 0: + orderBy = orderBy[1] + for val_unit in orderBy: + if val_unit[0] == 0: + col_unit = val_unit[1] + used_schema['column'].add(col_unit[1]) + else: + col_unit1, col_unit2 = val_unit[1:] + used_schema['column'].add(col_unit1[1]) + used_schema['column'].add(col_unit2[1]) + # union, intersect and except clause + if sql['intersect']: + used_schema = self.extract_subgraph_from_sql( + sql['intersect'], used_schema) + if sql['union']: + used_schema = self.extract_subgraph_from_sql( + sql['union'], used_schema) + if sql['except']: + used_schema = self.extract_subgraph_from_sql( + sql['except'], used_schema) + return used_schema + + def extract_subgraph_from_conds(self, conds: list, used_schema: dict): + if len(conds) == 0: + return used_schema + for cond in conds: + if cond in ['and', 'or']: + continue + val_unit, val1, val2 = cond[2:] + if val_unit[0] == 0: + col_unit = val_unit[1] + used_schema['column'].add(col_unit[1]) + else: + col_unit1, col_unit2 = val_unit[1:] + used_schema['column'].add(col_unit1[1]) + used_schema['column'].add(col_unit2[1]) + if type(val1) == list: + used_schema['column'].add(val1[1]) + elif type(val1) == dict: + used_schema = self.extract_subgraph_from_sql(val1, used_schema) + if type(val2) == list: + used_schema['column'].add(val1[1]) + elif type(val2) == dict: + used_schema = self.extract_subgraph_from_sql(val2, used_schema) + return used_schema + + def schema_linking(self, entry: dict, db: dict, verbose: bool = False): + raw_question_toks, question_toks = entry['raw_question_toks'], entry[ + 'processed_question_toks'] + table_toks, column_toks = db['processed_table_toks'], db[ + 'processed_column_toks'] + table_names, column_names = db['processed_table_names'], db[ + 'processed_column_names'] + q_num, t_num, c_num, dtype = len(question_toks), len(table_toks), len( + column_toks), ' 1 + and phrase in name): + q_tab_mat[range(i, j), idx] = 'question-table-partialmatch' + tab_q_mat[idx, range(i, j)] = 'table-question-partialmatch' + if verbose: + table_matched_pairs['partial'].append( + str((name, idx, phrase, i, j))) + + # relations between questions and columns + column_matched_pairs = {'partial': [], 'exact': [], 'value': []} + q_col_mat = np.array([['question-column-nomatch'] * c_num + for _ in range(q_num)], + dtype=dtype) + col_q_mat = np.array([['column-question-nomatch'] * q_num + for _ in range(c_num)], + dtype=dtype) + max_len = max([len(c) for c in column_toks]) + index_pairs = list( + filter(lambda x: x[1] - x[0] <= max_len, + combinations(range(q_num + 1), 2))) + index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0]) + for i, j in index_pairs: + phrase = ' '.join(question_toks[i:j]) + if phrase in self.stopwords: + continue + for idx, name in enumerate(column_names): + if phrase == name: + q_col_mat[range(i, j), idx] = 'question-column-exactmatch' + col_q_mat[idx, range(i, j)] = 'column-question-exactmatch' + if verbose: + column_matched_pairs['exact'].append( + str((name, idx, phrase, i, j))) + elif (j - i == 1 + and phrase in name.split()) or (j - i > 1 + and phrase in name): + q_col_mat[range(i, j), + idx] = 'question-column-partialmatch' + col_q_mat[idx, + range(i, j)] = 'column-question-partialmatch' + if verbose: + column_matched_pairs['partial'].append( + str((name, idx, phrase, i, j))) + if self.db_content: + db_file = os.path.join(self.db_dir, db['db_id'], + db['db_id'] + '.sqlite') + if not os.path.exists(db_file): + raise ValueError('[ERROR]: database file %s not found ...' % + (db_file)) + conn = sqlite3.connect(db_file) + conn.text_factory = lambda b: b.decode(errors='ignore') + conn.execute('pragma foreign_keys=ON') + for i, (tab_id, + col_name) in enumerate(db['column_names_original']): + if i == 0 or 'id' in column_toks[ + i]: # ignore * and special token 'id' + continue + tab_name = db['table_names_original'][tab_id] + try: + cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";" + % (col_name, tab_name)) + cell_values = cursor.fetchall() + cell_values = [str(each[0]) for each in cell_values] + cell_values = [[str(float(each))] if is_number(each) else + each.lower().split() + for each in cell_values] + except Exception as e: + print(e) + for j, word in enumerate(raw_question_toks): + word = str(float(word)) if is_number(word) else word + for c in cell_values: + if word in c and 'nomatch' in q_col_mat[ + j, i] and word not in self.stopwords: + q_col_mat[j, i] = 'question-column-valuematch' + col_q_mat[i, j] = 'column-question-valuematch' + if verbose: + column_matched_pairs['value'].append( + str((column_names[i], i, word, j, j + 1))) + break + conn.close() + + q_col_mat[:, 0] = 'question-*-generic' + col_q_mat[0] = '*-question-generic' + q_schema = np.concatenate([q_tab_mat, q_col_mat], axis=1) + schema_q = np.concatenate([tab_q_mat, col_q_mat], axis=0) + entry['schema_linking'] = (q_schema.tolist(), schema_q.tolist()) + + if verbose: + print('Question:', ' '.join(question_toks)) + print('Table matched: (table name, column id, \ + question span, start id, end id)') + print( + 'Exact match:', ', '.join(table_matched_pairs['exact']) + if table_matched_pairs['exact'] else 'empty') + print( + 'Partial match:', ', '.join(table_matched_pairs['partial']) + if table_matched_pairs['partial'] else 'empty') + print('Column matched: (column name, column id, \ + question span, start id, end id)') + print( + 'Exact match:', ', '.join(column_matched_pairs['exact']) + if column_matched_pairs['exact'] else 'empty') + print( + 'Partial match:', ', '.join(column_matched_pairs['partial']) + if column_matched_pairs['partial'] else 'empty') + print( + 'Value match:', ', '.join(column_matched_pairs['value']) + if column_matched_pairs['value'] else 'empty', '\n') + return entry diff --git a/modelscope/preprocessors/nlp/space_T_en/fields/parse.py b/modelscope/preprocessors/nlp/space_T_en/fields/parse.py new file mode 100644 index 00000000..02ae31a0 --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_en/fields/parse.py @@ -0,0 +1,333 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +CLAUSE_KEYWORDS = ('SELECT', 'FROM', 'WHERE', 'GROUP', 'ORDER', 'LIMIT', + 'INTERSECT', 'UNION', 'EXCEPT') +JOIN_KEYWORDS = ('JOIN', 'ON', 'AS') + +WHERE_OPS = ('NOT_IN', 'BETWEEN', '=', '>', '<', '>=', '<=', '!=', 'IN', + 'LIKE', 'IS', 'EXISTS') +UNIT_OPS = ('NONE', '-', '+', '*', '/') +AGG_OPS = ('', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG') +TABLE_TYPE = { + 'sql': 'sql', + 'table_unit': 'table_unit', +} +COND_OPS = ('AND', 'OR') +SQL_OPS = ('INTERSECT', 'UNION', 'EXCEPT') +ORDER_OPS = ('DESC', 'ASC') + + +def get_select_labels(select, slot, cur_nest): + for item in select[1]: + if AGG_OPS[item[0]] != '': + if slot[item[1][1][1]] == '': + slot[item[1][1][1]] += (cur_nest + ' ' + AGG_OPS[item[0]]) + else: + slot[item[1][1][1]] += (' ' + cur_nest + ' ' + + AGG_OPS[item[0]]) + else: + if slot[item[1][1][1]] == '': + slot[item[1][1][1]] += (cur_nest) + else: + slot[item[1][1][1]] += (' ' + cur_nest) + return slot + + +def get_groupby_labels(groupby, slot, cur_nest): + for item in groupby: + if slot[item[1]] == '': + slot[item[1]] += (cur_nest) + else: + slot[item[1]] += (' ' + cur_nest) + return slot + + +def get_orderby_labels(orderby, limit, slot, cur_nest): + if limit is None: + thelimit = '' + else: + thelimit = ' LIMIT' + for item in orderby[1]: + if AGG_OPS[item[1][0]] != '': + agg = ' ' + AGG_OPS[item[1][0]] + ' ' + else: + agg = ' ' + if slot[item[1][1]] == '': + slot[item[1][1]] += ( + cur_nest + agg + orderby[0].upper() + thelimit) + else: + slot[item[1][1]] += (' ' + cur_nest + agg + orderby[0].upper() + + thelimit) + + return slot + + +def get_intersect_labels(intersect, slot, cur_nest): + if isinstance(intersect, dict): + if cur_nest != '': + slot = get_labels(intersect, slot, cur_nest) + else: + slot = get_labels(intersect, slot, 'INTERSECT') + else: + return slot + return slot + + +def get_except_labels(texcept, slot, cur_nest): + if isinstance(texcept, dict): + if cur_nest != '': + slot = get_labels(texcept, slot, cur_nest) + else: + slot = get_labels(texcept, slot, 'EXCEPT') + else: + return slot + return slot + + +def get_union_labels(union, slot, cur_nest): + if isinstance(union, dict): + if cur_nest != '': + slot = get_labels(union, slot, cur_nest) + else: + slot = get_labels(union, slot, 'UNION') + else: + return slot + return slot + + +def get_from_labels(tfrom, slot, cur_nest): + if tfrom['table_units'][0][0] == 'sql': + slot = get_labels(tfrom['table_units'][0][1], slot, 'OP_SEL') + else: + return slot + return slot + + +def get_having_labels(having, slot, cur_nest): + if len(having) == 1: + item = having[0] + if item[0] is True: + neg = ' NOT' + else: + neg = '' + if isinstance(item[3], dict): + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + else: + agg = '' + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + agg + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + ' ' + else: + agg = ' ' + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += (cur_nest + agg + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + agg + + WHERE_OPS[item[1]]) + else: + for index, item in enumerate(having): + if item[0] is True: + neg = ' NOT' + else: + neg = '' + if (index + 1 < len(having) and having[index + 1]) == 'or' or ( + index - 1 >= 0 and having[index - 1] == 'or'): + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + else: + agg = '' + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + agg + neg + + ' ' + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + ' ' + else: + agg = ' ' + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + ' OR' + agg + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + agg + + WHERE_OPS[item[1]]) + elif item == 'and' or item == 'or': + continue + else: + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + ' ' + else: + agg = ' ' + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + agg + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + agg + + WHERE_OPS[item[1]]) + return slot + + +def get_where_labels(where, slot, cur_nest): + if len(where) == 1: + item = where[0] + if item[0] is True: + neg = ' NOT' + else: + neg = '' + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + else: + for index, item in enumerate(where): + if item[0] is True: + neg = ' NOT' + else: + neg = '' + if (index + 1 < len(where) and where[index + 1]) == 'or' or ( + index - 1 >= 0 and where[index - 1] == 'or'): + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + ' OR' + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + neg + + ' ' + WHERE_OPS[item[1]]) + elif item == 'and' or item == 'or': + continue + else: + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + return slot + + +def get_labels(sql_struct, slot, cur_nest): + + if len(sql_struct['select']) > 0: + if cur_nest != '': + slot = get_select_labels(sql_struct['select'], slot, + cur_nest + ' SELECT') + else: + slot = get_select_labels(sql_struct['select'], slot, 'SELECT') + + if sql_struct['from']: + if cur_nest != '': + slot = get_from_labels(sql_struct['from'], slot, 'FROM') + else: + slot = get_from_labels(sql_struct['from'], slot, 'FROM') + + if len(sql_struct['where']) > 0: + if cur_nest != '': + slot = get_where_labels(sql_struct['where'], slot, + cur_nest + ' WHERE') + else: + slot = get_where_labels(sql_struct['where'], slot, 'WHERE') + + if len(sql_struct['groupBy']) > 0: + if cur_nest != '': + slot = get_groupby_labels(sql_struct['groupBy'], slot, + cur_nest + ' GROUP_BY') + else: + slot = get_groupby_labels(sql_struct['groupBy'], slot, 'GROUP_BY') + + if len(sql_struct['having']) > 0: + if cur_nest != '': + slot = get_having_labels(sql_struct['having'], slot, + cur_nest + ' HAVING') + else: + slot = get_having_labels(sql_struct['having'], slot, 'HAVING') + + if len(sql_struct['orderBy']) > 0: + if cur_nest != '': + slot = get_orderby_labels(sql_struct['orderBy'], + sql_struct['limit'], slot, + cur_nest + ' ORDER_BY') + else: + slot = get_orderby_labels(sql_struct['orderBy'], + sql_struct['limit'], slot, 'ORDER_BY') + + if sql_struct['intersect']: + if cur_nest != '': + slot = get_intersect_labels(sql_struct['intersect'], slot, + cur_nest + ' INTERSECT') + else: + slot = get_intersect_labels(sql_struct['intersect'], slot, + 'INTERSECT') + + if sql_struct['except']: + if cur_nest != '': + slot = get_except_labels(sql_struct['except'], slot, + cur_nest + ' EXCEPT') + else: + slot = get_except_labels(sql_struct['except'], slot, 'EXCEPT') + + if sql_struct['union']: + if cur_nest != '': + slot = get_union_labels(sql_struct['union'], slot, + cur_nest + ' UNION') + else: + slot = get_union_labels(sql_struct['union'], slot, 'UNION') + return slot + + +def get_label(sql, column_len): + thelabel = [] + slot = {} + for idx in range(column_len): + slot[idx] = '' + for value in get_labels(sql, slot, '').values(): + thelabel.append(value) + return thelabel diff --git a/modelscope/preprocessors/nlp/space_T_en/fields/preprocess_dataset.py b/modelscope/preprocessors/nlp/space_T_en/fields/preprocess_dataset.py new file mode 100644 index 00000000..a0fd13d1 --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_en/fields/preprocess_dataset.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from text2sql_lgesql.preprocess.parse_raw_json import Schema, get_schemas +from text2sql_lgesql.process_sql import get_sql + +from .parse import get_label + + +def preprocess_dataset(processor, dataset, output_tables, database_id, tables): + + schemas, db_names, thetables = get_schemas(tables) + intables = output_tables[database_id] + schema = schemas[database_id] + table = thetables[database_id] + sql_label = [] + if len(dataset['history']) == 0 or dataset['last_sql'] == '': + sql_label = [''] * len(intables['column_names']) + else: + schema = Schema(schema, table) + try: + sql_label = get_sql(schema, dataset['last_sql']) + except Exception: + sql_label = [''] * len(intables['column_names']) + sql_label = get_label(sql_label, len(table['column_names_original'])) + theone = {'db_id': database_id} + theone['query'] = '' + theone['query_toks_no_value'] = [] + theone['sql'] = {} + if len(dataset['history']) != 0: + theone['question'] = dataset['utterance'] + ' [CLS] ' + ' [CLS] '.join( + dataset['history'][::-1][:4]) + theone['question_toks'] = theone['question'].split() + else: + theone['question'] = dataset['utterance'] + theone['question_toks'] = dataset['utterance'].split() + + return [theone], sql_label diff --git a/modelscope/preprocessors/nlp/space_T_en/fields/process_dataset.py b/modelscope/preprocessors/nlp/space_T_en/fields/process_dataset.py new file mode 100644 index 00000000..88059351 --- /dev/null +++ b/modelscope/preprocessors/nlp/space_T_en/fields/process_dataset.py @@ -0,0 +1,59 @@ +# Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql. + +import os +import pickle +import sys + +from text2sql_lgesql.asdl.asdl import ASDLGrammar +from text2sql_lgesql.asdl.transition_system import TransitionSystem + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + + +def process_example(processor, entry, db, trans, verbose=False): + # preprocess raw tokens, schema linking and subgraph extraction + entry = processor.pipeline(entry, db, verbose=verbose) + # generate target output actions + entry['ast'] = [] + entry['actions'] = [] + return entry + + +def process_tables(processor, tables_list, output_path=None, verbose=False): + tables = {} + for each in tables_list: + if verbose: + print('*************** Processing database %s **************' % + (each['db_id'])) + tables[each['db_id']] = processor.preprocess_database( + each, verbose=verbose) + print('In total, process %d databases .' % (len(tables))) + if output_path is not None: + pickle.dump(tables, open(output_path, 'wb')) + return tables + + +def process_dataset(model_dir, + processor, + dataset, + tables, + output_path=None, + skip_large=False, + verbose=False): + grammar = ASDLGrammar.from_filepath( + os.path.join(model_dir, 'sql_asdl_v2.txt')) + trans = TransitionSystem.get_class_by_lang('sql')(grammar) + processed_dataset = [] + for idx, entry in enumerate(dataset): + if skip_large and len(tables[entry['db_id']]['column_names']) > 100: + continue + if verbose: + print('*************** Processing %d-th sample **************' % + (idx)) + entry = process_example( + processor, entry, tables[entry['db_id']], trans, verbose=verbose) + processed_dataset.append(entry) + if output_path is not None: + # serialize preprocessed dataset + pickle.dump(processed_dataset, open(output_path, 'wb')) + return processed_dataset diff --git a/modelscope/preprocessors/nlp/text2text_generation_preprocessor.py b/modelscope/preprocessors/nlp/text2text_generation_preprocessor.py new file mode 100644 index 00000000..5693d36e --- /dev/null +++ b/modelscope/preprocessors/nlp/text2text_generation_preprocessor.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text2text_gen_preprocessor) +class Text2TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in text generation. + """ + + def __init__(self, + model_dir: str, + tokenizer=None, + mode=ModeKeys.INFERENCE, + **kwargs): + kwargs['truncation'] = kwargs.get('truncation', 'do_not_truncate') + kwargs['padding'] = kwargs.get('padding', False) + kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', + False) + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, mode=mode, **kwargs) + + def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: + text_a, _, _ = self.parse_text_and_label(data) + + inputs = self.tokenizer( + text_a, + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, + **self.tokenize_kwargs) + + # This is produced by tokenizers but is an invalid generate kwargs + if 'token_type_ids' in inputs: + del inputs['token_type_ids'] + return inputs diff --git a/modelscope/preprocessors/nlp/text_error_correction.py b/modelscope/preprocessors/nlp/text_error_correction.py new file mode 100644 index 00000000..4e5ba3bd --- /dev/null +++ b/modelscope/preprocessors/nlp/text_error_correction.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields +from .nlp_base import NLPBasePreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text_error_correction) +class TextErrorCorrectionPreprocessor(NLPBasePreprocessor): + """The preprocessor used in text correction task. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + from fairseq.data import Dictionary + """preprocess the data via the vocab file from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(model_dir, *args, **kwargs) + self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt')) + + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + '随着中国经济突飞猛近,建造工业与日俱增' + Returns: + Dict[str, Any]: the preprocessed data + Example: + {'net_input': + {'src_tokens':tensor([1,2,3,4]), + 'src_lengths': tensor([4])} + } + """ + + text = ' '.join([x for x in data]) + inputs = self.vocab.encode_line( + text, append_eos=True, add_if_not_exist=False) + lengths = inputs.size() + sample = dict() + sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} + return sample diff --git a/modelscope/preprocessors/nlp/text_generation_jieba_preprocessor.py b/modelscope/preprocessors/nlp/text_generation_jieba_preprocessor.py new file mode 100644 index 00000000..1e972d64 --- /dev/null +++ b/modelscope/preprocessors/nlp/text_generation_jieba_preprocessor.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text_gen_jieba_tokenizer) +class TextGenerationJiebaPreprocessor(Preprocessor): + """The jieba tokenizer preprocessor used in text generation. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + from modelscope.models.nlp.gpt3 import JiebaBPETokenizer + super().__init__(*args, **kwargs) + self.tokenizer = JiebaBPETokenizer( + osp.join(model_dir, 'tokenizer.json')) + + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + '深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地' + Returns: + Dict[str, Any]: the preprocessed data + Example: + {'net_input': + {'src_tokens':tensor([1,2,3,4]), + 'src_lengths': tensor([4])} + } + """ + import torch + + return { + 'input_ids': + torch.tensor(self.tokenizer.tokenize(data)).unsqueeze_(0) + } diff --git a/modelscope/preprocessors/nlp/text_generation_preprocessor.py b/modelscope/preprocessors/nlp/text_generation_preprocessor.py new file mode 100644 index 00000000..238e2972 --- /dev/null +++ b/modelscope/preprocessors/nlp/text_generation_preprocessor.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text_gen_tokenizer) +class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in text generation. + """ + + def __init__(self, + model_dir: str, + tokenizer=None, + mode=ModeKeys.INFERENCE, + **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get('padding', 'max_length') + kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', + False) + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, mode=mode, **kwargs) + + @staticmethod + def get_roberta_tokenizer_dir(model_dir: str) -> Optional[str]: + import os + for name in os.listdir(model_dir): + full_name = os.path.join(model_dir, name) + if 'roberta' in name and os.path.isdir(full_name): + return full_name + + def build_tokenizer(self, model_dir: str): + roberta_tokenizer_dir = self.get_roberta_tokenizer_dir(model_dir) + if roberta_tokenizer_dir: + from transformers import RobertaTokenizer + return RobertaTokenizer.from_pretrained( + roberta_tokenizer_dir, do_lower_case=False) + return super().build_tokenizer(model_dir) + + def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: + if self._mode == ModeKeys.INFERENCE: + return super().__call__(data) + src_rst = super().__call__(data['src_txt']) + src_input_ids = src_rst['input_ids'] + src_attention_mask = src_rst['attention_mask'] + if 'tgt_txt' in data: + labels = super().__call__(data['tgt_txt'])['input_ids'] + else: + labels = src_input_ids[1:] + src_input_ids = src_input_ids[:-1] + src_attention_mask = src_attention_mask[:-1] + + return { + 'input_ids': src_input_ids, + 'attention_mask': src_attention_mask, + 'labels': labels, + } diff --git a/modelscope/preprocessors/nlp/text_ranking_preprocessor.py b/modelscope/preprocessors/nlp/text_ranking_preprocessor.py new file mode 100644 index 00000000..2ada6892 --- /dev/null +++ b/modelscope/preprocessors/nlp/text_ranking_preprocessor.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from transformers import AutoTokenizer + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text_ranking) +class TextRankingPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in passage ranking model. + """ + + def __init__(self, + model_dir: str, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + super().__init__(model_dir, mode=mode, *args, **kwargs) + self.model_dir: str = model_dir + self.first_sequence: str = kwargs.pop('first_sequence', + 'source_sentence') + self.second_sequence = kwargs.pop('second_sequence', + 'sentences_to_compare') + self.sequence_length = kwargs.pop('sequence_length', 128) + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, (str, tuple, Dict)) + def __call__(self, data: Union[tuple, Dict]) -> Dict[str, Any]: + if isinstance(data, tuple): + sentence1, sentence2 = data + elif isinstance(data, dict): + sentence1 = data.get(self.first_sequence) + sentence2 = data.get(self.second_sequence) + if isinstance(sentence2, str): + sentence2 = [sentence2] + if isinstance(sentence1, str): + sentence1 = [sentence1] + sentence1 = sentence1 * len(sentence2) + + max_seq_length = self.sequence_length + feature = self.tokenizer( + sentence1, + sentence2, + padding='max_length', + truncation=True, + max_length=max_seq_length, + return_tensors='pt') + if 'labels' in data: + labels = data['labels'] + feature['labels'] = labels + if 'qid' in data: + qid = data['qid'] + feature['qid'] = qid + return feature diff --git a/modelscope/preprocessors/nlp/token_classification_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_preprocessor.py new file mode 100644 index 00000000..a7616736 --- /dev/null +++ b/modelscope/preprocessors/nlp/token_classification_preprocessor.py @@ -0,0 +1,280 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .nlp_base import NLPBasePreprocessor, NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, + module_name=Preprocessors.word_segment_text_to_label_preprocessor) +class WordSegmentationBlankSetToLabelPreprocessor(NLPBasePreprocessor): + """The preprocessor used to turn a single sentence to a labeled token-classification dict. + """ + + def __init__(self, **kwargs): + self.first_sequence: str = kwargs.pop('first_sequence', 'tokens') + self.label = kwargs.pop('label', OutputKeys.LABELS) + + def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]: + data = data.split(' ') + data = list(filter(lambda x: len(x) > 0, data)) + + def produce_train_sample(words): + chars = [] + labels = [] + for word in words: + chars.extend(list(word)) + if len(word) == 1: + labels.append('S-CWS') + else: + labels.extend(['B-CWS'] + ['I-CWS'] * (len(word) - 2) + + ['E-CWS']) + assert len(chars) == len(labels) + return chars, labels + + chars, labels = produce_train_sample(data) + return { + self.first_sequence: chars, + self.label: labels, + } + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.ner_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.token_cls_tokenizer) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer) +class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in normal NER task. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get( + 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + self.sequence_length = kwargs['max_length'] + self.label_all_tokens = kwargs.pop('label_all_tokens', False) + super().__init__(model_dir, mode=mode, **kwargs) + + if 'is_split_into_words' in kwargs: + self.tokenize_kwargs['is_split_into_words'] = kwargs.pop( + 'is_split_into_words') + else: + self.tokenize_kwargs[ + 'is_split_into_words'] = self.tokenizer.init_kwargs.get( + 'is_split_into_words', False) + if 'label2id' in kwargs: + kwargs.pop('label2id') + + @type_assert(object, (str, dict)) + def __call__(self, data: Union[dict, str]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + # preprocess the data for the model input + text = None + labels_list = None + if isinstance(data, str): + # for inference inputs without label + text = data + elif isinstance(data, dict): + # for finetune inputs with label + text = data.get(self.first_sequence) + labels_list = data.get(self.label) + if isinstance(text, list): + self.tokenize_kwargs['is_split_into_words'] = True + + if self._mode == ModeKeys.INFERENCE: + self.tokenize_kwargs['add_special_tokens'] = False + + input_ids = [] + label_mask = [] + offset_mapping = [] + token_type_ids = [] + if self.tokenize_kwargs[ + 'is_split_into_words'] and self._mode == ModeKeys.INFERENCE: + for offset, token in enumerate(list(text)): + subtoken_ids = self.tokenizer.encode(token, + **self.tokenize_kwargs) + if len(subtoken_ids) == 0: + subtoken_ids = [self.tokenizer.unk_token_id] + input_ids.extend(subtoken_ids) + label_mask.extend([1] + [0] * (len(subtoken_ids) - 1)) + offset_mapping.extend([(offset, offset + 1)]) + else: + if self.tokenizer.is_fast: + encodings = self.tokenizer( + text, return_offsets_mapping=True, **self.tokenize_kwargs) + attention_mask = encodings['attention_mask'] + if 'token_type_ids' in encodings: + token_type_ids = encodings['token_type_ids'] + input_ids = encodings['input_ids'] + word_ids = encodings.word_ids() + for i in range(len(word_ids)): + if word_ids[i] is None: + label_mask.append(0) + elif word_ids[i] == word_ids[i - 1]: + label_mask.append(0) + offset_mapping[-1] = ( + offset_mapping[-1][0], + encodings['offset_mapping'][i][1]) + else: + label_mask.append(1) + offset_mapping.append(encodings['offset_mapping'][i]) + else: + encodings = self.tokenizer(text, **self.tokenize_kwargs) + input_ids = encodings['input_ids'] + label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( + text) + + if self._mode == ModeKeys.INFERENCE: + if len(input_ids) >= self.sequence_length - 2: + input_ids = input_ids[:self.sequence_length - 2] + label_mask = label_mask[:self.sequence_length - 2] + input_ids = [self.tokenizer.cls_token_id + ] + input_ids + [self.tokenizer.sep_token_id] + label_mask = [0] + label_mask + [0] + attention_mask = [1] * len(input_ids) + offset_mapping = offset_mapping[:sum(label_mask)] + + if not self.is_transformer_based_model: + input_ids = input_ids[1:-1] + attention_mask = attention_mask[1:-1] + label_mask = label_mask[1:-1] + + input_ids = torch.tensor(input_ids).unsqueeze(0) + attention_mask = torch.tensor(attention_mask).unsqueeze(0) + label_mask = torch.tensor( + label_mask, dtype=torch.bool).unsqueeze(0) + + # the token classification + output = { + 'text': text, + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'label_mask': label_mask, + 'offset_mapping': offset_mapping + } + else: + output = { + 'input_ids': input_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'label_mask': label_mask, + } + + # align the labels with tokenized text + if labels_list is not None: + assert self.label2id is not None + # Map that sends B-Xxx label to its I-Xxx counterpart + b_to_i_label = [] + label_enumerate_values = [ + k for k, v in sorted( + self.label2id.items(), key=lambda item: item[1]) + ] + for idx, label in enumerate(label_enumerate_values): + if label.startswith('B-') and label.replace( + 'B-', 'I-') in label_enumerate_values: + b_to_i_label.append( + label_enumerate_values.index( + label.replace('B-', 'I-'))) + else: + b_to_i_label.append(idx) + + label_row = [self.label2id[lb] for lb in labels_list] + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: + label_ids.append(label_row[word_idx]) + else: + if self.label_all_tokens: + label_ids.append(b_to_i_label[label_row[word_idx]]) + else: + label_ids.append(-100) + previous_word_idx = word_idx + labels = label_ids + output['labels'] = labels + output = { + k: np.array(v) if isinstance(v, list) else v + for k, v in output.items() + } + return output + + def get_tokenizer_class(self): + tokenizer_class = self.tokenizer.__class__.__name__ + if tokenizer_class.endswith( + 'Fast') and tokenizer_class != 'PreTrainedTokenizerFast': + tokenizer_class = tokenizer_class[:-4] + return tokenizer_class + + def get_label_mask_and_offset_mapping(self, text): + label_mask = [] + offset_mapping = [] + tokens = self.tokenizer.tokenize(text) + offset = 0 + if self.get_tokenizer_class() == 'BertTokenizer': + for token in tokens: + is_start = (token[:2] != '##') + if is_start: + label_mask.append(True) + else: + token = token[2:] + label_mask.append(False) + start = offset + text[offset:].index(token) + end = start + len(token) + if is_start: + offset_mapping.append((start, end)) + else: + offset_mapping[-1] = (offset_mapping[-1][0], end) + offset = end + elif self.get_tokenizer_class() == 'XLMRobertaTokenizer': + last_is_blank = False + for token in tokens: + is_start = (token[0] == '▁') + if is_start: + token = token[1:] + label_mask.append(True) + if len(token) == 0: + last_is_blank = True + continue + else: + label_mask.append(False) + start = offset + text[offset:].index(token) + end = start + len(token) + if last_is_blank or is_start: + offset_mapping.append((start, end)) + else: + offset_mapping[-1] = (offset_mapping[-1][0], end) + offset = end + last_is_blank = False + else: + raise NotImplementedError + + return label_mask, offset_mapping diff --git a/modelscope/preprocessors/nlp/token_classification_thai_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_thai_preprocessor.py new file mode 100644 index 00000000..a356cea7 --- /dev/null +++ b/modelscope/preprocessors/nlp/token_classification_thai_preprocessor.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Tuple, Union + +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .token_classification_preprocessor import TokenClassificationPreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.thai_ner_tokenizer) +class NERPreprocessorThai(TokenClassificationPreprocessor): + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + from pythainlp import word_tokenize + + segmented_data = ' '.join([ + w.strip(' ') for w in word_tokenize(text=data, engine='newmm') + if w.strip(' ') != '' + ]) + output = super().__call__(segmented_data) + + return output + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.thai_wseg_tokenizer) +class WordSegmentationPreprocessorThai(TokenClassificationPreprocessor): + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + import regex + data = regex.findall(r'\X', data) + data = ' '.join([char for char in data]) + + output = super().__call__(data) + + return output diff --git a/modelscope/preprocessors/nlp/token_classification_viet_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_viet_preprocessor.py new file mode 100644 index 00000000..f8970d1a --- /dev/null +++ b/modelscope/preprocessors/nlp/token_classification_viet_preprocessor.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Tuple, Union + +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .token_classification_preprocessor import TokenClassificationPreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.viet_ner_tokenizer) +class NERPreprocessorViet(TokenClassificationPreprocessor): + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + from pyvi import ViTokenizer + + seg_words = [ + t.strip(' ') for t in ViTokenizer.tokenize(data).split(' ') + if t.strip(' ') != '' + ] + raw_words = [] + for w in seg_words: + raw_words.extend(w.split('_')) + segmented_data = ' '.join(raw_words) + output = super().__call__(segmented_data) + + return output diff --git a/modelscope/preprocessors/nlp/zero_shot_classification_reprocessor.py b/modelscope/preprocessors/nlp/zero_shot_classification_reprocessor.py new file mode 100644 index 00000000..eb3c4b37 --- /dev/null +++ b/modelscope/preprocessors/nlp/zero_shot_classification_reprocessor.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from .nlp_base import NLPTokenizerPreprocessorBase + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) +class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in zero shot classification. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + self.sequence_length = kwargs.pop('sequence_length', 512) + super().__init__(model_dir, mode=mode, **kwargs) + + def __call__(self, data: Union[str, Dict], hypothesis_template: str, + candidate_labels: list) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str or dict): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + if isinstance(data, dict): + data = data.get(self.first_sequence) + + pairs = [[data, hypothesis_template.format(label)] + for label in candidate_labels] + + features = self.tokenizer( + pairs, + padding=True, + truncation=True, + max_length=self.sequence_length, + truncation_strategy='only_first', + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None) + return features diff --git a/modelscope/preprocessors/ofa/__init__.py b/modelscope/preprocessors/ofa/__init__.py new file mode 100644 index 00000000..59b94b2b --- /dev/null +++ b/modelscope/preprocessors/ofa/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .image_captioning import OfaImageCaptioningPreprocessor +from .image_classification import OfaImageClassificationPreprocessor +from .ocr_recognition import OfaOcrRecognitionPreprocessor +from .summarization import OfaSummarizationPreprocessor +from .text_classification import OfaTextClassificationPreprocessor +from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor +from .visual_entailment import OfaVisualEntailmentPreprocessor +from .visual_grounding import OfaVisualGroundingPreprocessor +from .visual_question_answering import OfaVisualQuestionAnsweringPreprocessor diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py new file mode 100644 index 00000000..55b3895d --- /dev/null +++ b/modelscope/preprocessors/ofa/base.py @@ -0,0 +1,165 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +import string +from os import path as osp + +import json +import numpy as np +import torch +from PIL import Image + +from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH +from modelscope.preprocessors.image import load_image +from modelscope.utils.trie import Trie +from .utils.constant import OFA_TASK_KEY_MAPPING +from .utils.random_help import set_torch_seed + + +class OfaBasePreprocessor: + + def __init__(self, cfg, model_dir, mode, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + self.cfg = cfg + self.mode = mode + self.language = self.cfg.model.get('language', 'en') + if self.language == 'en': + tokenizer = OFATokenizer.from_pretrained(model_dir) + elif self.language in ['zh', 'cn']: + tokenizer = OFATokenizerZH.from_pretrained(model_dir) + else: + raise NotImplementedError + # there is some diff between here and our ofa code, + # there will be no need to use param: use_bpe + tokenizer.add_tokens([''.format(i) for i in range(8192)]) + tokenizer.add_tokens([''.format(i) for i in range(1000)]) + self.tokenizer = tokenizer + self.bos_item = torch.LongTensor([tokenizer.bos_token_id]) + self.pad_item = torch.LongTensor([tokenizer.pad_token_id]) + self.eos_item = torch.LongTensor([tokenizer.eos_token_id]) + self.tgt_dict = self.src_dict = { + value: key + for key, value in tokenizer.get_vocab().items() + } + self.max_src_length = cfg.model.get('max_src_length', 256) + self.max_tgt_length = cfg.model.get('max_tgt_length', 256) + self.max_image_size = cfg.model.get('max_image_size', 512) + self.language = self.cfg.model.get('language', 'en') + self.prompt_type = self.cfg.model.get('prompt_type', 'none') + seed = self.cfg.model.get('seed', 7) + np.random.seed(seed) + set_torch_seed(seed) + imagenet_default_mean_and_std = self.cfg.model.get( + 'imagenet_default_mean_and_std', False) + if imagenet_default_mean_and_std: + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + else: + self.mean = [0.5, 0.5, 0.5] + self.std = [0.5, 0.5, 0.5] + self.patch_image_size = self.cfg.model.get('patch_image_size', 480) + self.column_map = { + key: key + for key in OFA_TASK_KEY_MAPPING[self.cfg.task] + } + if hasattr(self.cfg, + 'dataset') and self.cfg.dataset.column_map is not None: + for k, v in self.cfg.dataset.column_map.items(): + self.column_map[k] = v + self.transtab = str.maketrans( + {key: None + for key in string.punctuation}) + self.constraint_trie = None + if self.cfg.model.get('answer2label', None): + ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) + with open(ans2label_file, 'r') as reader: + ans2label_dict = json.load(reader) + self.ans2label = ans2label_dict + self.label2ans = {v: k for k, v in self.ans2label.items()} + self.constraint_trie = Trie(tokenizer.eos_token_id) + for i, answer in enumerate(ans2label_dict.keys()): + answer_item = self.tokenize_text( + ' ' + answer, add_bos=False, add_eos=False) + self.constraint_trie.insert([tokenizer.bos_token_id] + + answer_item.tolist() + + [tokenizer.eos_token_id]) + + def tokenize_text(self, text, add_bos=True, add_eos=True): + if text is None: + return None + inputs = self.tokenizer( + text, + max_length=self.max_src_length, + add_special_tokens=False, + truncation=True, + return_tensors='pt')['input_ids'].squeeze(0) + if add_bos: + inputs = torch.cat([self.bos_item, inputs]) + if add_eos: + inputs = torch.cat([inputs, self.eos_item]) + return inputs + + @staticmethod + def pre_caption(caption, max_words=None): + caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ') \ + .replace('/', ' ').replace('', 'person') + + caption = re.sub( + r'\s{2,}', + ' ', + caption, + ) + caption = caption.rstrip('\n') + caption = caption.strip(' ') + + # truncate caption + caption_words = caption.split(' ') + if max_words is not None and len(caption_words) > max_words: + caption = ' '.join(caption_words[:max_words]) + + return caption + + @staticmethod + def pre_question(question, max_ques_words): + question = question.lower().lstrip(',.!?*#:;~').replace('-', + ' ').replace( + '/', ' ') + + question = re.sub( + r'\s{2,}', + ' ', + question, + ) + question = question.rstrip('\n') + question = question.strip(' ') + + # truncate question + question_words = question.split(' ') + if len(question_words) > max_ques_words: + question = ' '.join(question_words[:max_ques_words]) + + return question + + def add_constraint_mask(self, sample): + target_itm = sample['target'] + len_label_itm = target_itm.ne(self.pad_item).sum(dim=0).item() + if self.constraint_trie: + constraint_mask = torch.zeros( + (len(target_itm), len(self.tgt_dict))).bool() + start_idx = len(target_itm) - len_label_itm + for i in range(start_idx, len(target_itm)): + constraint_prefix_token = self.bos_item.tolist( + ) + target_itm[start_idx:i].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + sample['constraint_mask'] = constraint_mask + + def get_img_pil(self, path_or_url_or_pil): + image = path_or_url_or_pil if isinstance(path_or_url_or_pil, Image.Image) \ + else load_image(path_or_url_or_pil) + return image diff --git a/modelscope/preprocessors/ofa/image_captioning.py b/modelscope/preprocessors/ofa/image_captioning.py new file mode 100644 index 00000000..5fb83908 --- /dev/null +++ b/modelscope/preprocessors/ofa/image_captioning.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch +from torchvision import transforms + +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor + + +class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaImageCaptioningPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target = sample['label'] + target = target.translate(self.transtab).strip() + target_token_list = target.strip().split() + target = ' '.join(target_token_list[:self.max_tgt_length]) + sample['target'] = self.tokenize_text(target, add_bos=False) + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, sample['target'][:-1]]) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) + patch_image = self.patch_resize_transform(image) + prompt = self.cfg.model.get('prompt', ' what does the image describe?') + inputs = self.tokenize_text(prompt) + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]) + } + if 'text' in self.column_map and self.column_map['text'] in data: + sample['label'] = data[self.column_map['text']] + return sample diff --git a/modelscope/preprocessors/ofa/image_classification.py b/modelscope/preprocessors/ofa/image_classification.py new file mode 100644 index 00000000..038a9e15 --- /dev/null +++ b/modelscope/preprocessors/ofa/image_classification.py @@ -0,0 +1,119 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import functools +from typing import Any, Dict + +import torch +from PIL import Image, ImageFile +from timm.data import create_transform +from torchvision import transforms + +from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor +from .utils.vision_helper import RandomAugment + +ImageFile.LOAD_TRUNCATED_IMAGES = True +ImageFile.MAX_IMAGE_PIXELS = None +Image.MAX_IMAGE_PIXELS = None + + +class OfaImageClassificationPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaImageClassificationPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + # Initialize transform + if self.mode != ModeKeys.TRAIN: + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + else: + self.patch_resize_transform = create_transform( + input_size=self.patch_image_size, + is_training=True, + color_jitter=0.4, + auto_augment='rand-m9-mstd0.5-inc1', + interpolation='bicubic', + re_prob=0.25, + re_mode='pixel', + re_count=1, + mean=self.mean, + std=self.std) + self.patch_resize_transform = transforms.Compose( + functools.reduce(lambda x, y: x + y, [ + [ + lambda image: image.convert('RGB'), + ], + self.patch_resize_transform.transforms[:2], + [self.patch_resize_transform.transforms[2]], + [ + RandomAugment( + 2, + 7, + isPIL=True, + augs=[ + 'Identity', 'AutoContrast', 'Equalize', + 'Brightness', 'Sharpness', 'ShearX', 'ShearY', + 'TranslateX', 'TranslateY', 'Rotate' + ]), + ], + self.patch_resize_transform.transforms[3:], + ])) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target = ' {}'.format(sample['label']) + sample['ref_dict'] = {sample['label']: 1.0} + sample['target'] = self.tokenize_text(target, add_bos=False) + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, sample['target'][:-1]]) + + if self.constraint_trie is not None: + constraint_mask = torch.zeros((len(sample['prev_output_tokens']), + len(self.tgt_dict))).bool() + for i in range(len(sample['prev_output_tokens'])): + constraint_prefix_token = sample[ + 'prev_output_tokens'][:i + 1].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + sample['constraint_mask'] = constraint_mask + + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) + patch_image = self.patch_resize_transform(image) + prompt = self.cfg.model.get('prompt', ' what does the image describe?') + inputs = self.tokenize_text(prompt) + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]) + } + if 'text' in self.column_map and self.column_map['text'] in data: + sample['label'] = data[self.column_map['text']] + return sample diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py new file mode 100644 index 00000000..e15be93f --- /dev/null +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -0,0 +1,114 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch +import unicodedata2 +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms import functional as F +from zhconv import convert + +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor + + +def ocr_resize(img, patch_image_size, is_document=False): + img = img.convert('RGB') + width, height = img.size + + if is_document: + new_height, new_width = 64, 1920 + else: + if width >= height: + new_width = max(64, patch_image_size) + new_height = max(64, int(patch_image_size * (height / width))) + top = (patch_image_size - new_height) // 2 + bottom = patch_image_size - new_height - top + left, right = 0, 0 + else: + new_height = max(64, patch_image_size) + new_width = max(64, int(patch_image_size * (width / height))) + left = (patch_image_size - new_width) // 2 + right = patch_image_size - new_width - left + top, bottom = 0, 0 + + img_new = F.resize( + img, + (new_height, new_width), + interpolation=InterpolationMode.BICUBIC, + ) + + if is_document: + img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1) + img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2)) + new_width, new_height = img_new.size + top = (patch_image_size - new_height) // 2 + bottom = patch_image_size - new_height - top + left, right = 0, 0 + + img_new = F.pad( + img_new, padding=[left, top, right, bottom], padding_mode='edge') + assert img_new.size == (patch_image_size, patch_image_size) + + return img_new + + +class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaOcrRecognitionPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + + self.patch_resize_transform = transforms.Compose([ + lambda image: ocr_resize( + image, + self.patch_image_size, + is_document=self.cfg.model.get('is_document', False)), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target = sample['label'] + target_token_list = target.strip().split() + target = ' '.join(target_token_list[:self.max_tgt_length]) + sample['target'] = self.tokenize_text(target, add_bos=False) + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, sample['target'][:-1]]) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) + patch_image = self.patch_resize_transform(image) + prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') + inputs = self.tokenize_text(prompt) + + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]) + } + if 'text' in self.column_map and self.column_map['text'] in data: + target = data[self.column_map['text']] + sample['label'] = unicodedata2.normalize( + 'NFKC', convert(target, 'zh-hans')) + return sample diff --git a/modelscope/preprocessors/ofa/summarization.py b/modelscope/preprocessors/ofa/summarization.py new file mode 100644 index 00000000..d33e9d25 --- /dev/null +++ b/modelscope/preprocessors/ofa/summarization.py @@ -0,0 +1,77 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch + +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor + + +class OfaSummarizationPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaSummarizationPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target_str = sample['label'].lower() + target = super().pre_caption(target_str, max_words=self.max_tgt_length) + target = target.replace('[unk]', 'unk').replace('', 'unk') + sample['target'] = self.tokenize_text(target, add_bos=False) + noise_target_item = self.add_noise_to_tgt( + sample['target'][:-1].clone()) + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, noise_target_item]) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + source = super().pre_caption( + data[self.column_map['text']], max_words=self.max_src_length) + source = source.replace('[unk]', 'unk').replace('', 'unk') + prompt = self.cfg.model.get( + 'prompt', ' " {} " Summarize the article with a title: ') + text = prompt.format(source) + inputs = self.tokenize_text(text) + if self.prompt_type == 'none': + decoder_prompt = self.bos_item + elif self.prompt_type == 'prev_output': + decoder_prompt = inputs[:-1] + else: + raise NotImplementedError + sample = { + 'source': inputs, + 'decoder_prompt': decoder_prompt, + } + if 'summary' in self.column_map and self.column_map['summary'] in data: + sample['label'] = data[self.column_map['summary']] + return sample + + def add_noise_to_tgt(self, target): + noise_indices = torch.FloatTensor( + target.size(0)).uniform_() < self.cfg.model.get( + 'noise_ratio', 0.0) + target[noise_indices] = torch.randint( + 4, + len(self.src_dict) - self.cfg.model.get('num_codes', 8192) + - self.cfg.model.get('num_bins', 1000), + size=(noise_indices.sum(), )) + return target diff --git a/modelscope/preprocessors/ofa/text_classification.py b/modelscope/preprocessors/ofa/text_classification.py new file mode 100644 index 00000000..24c4f67e --- /dev/null +++ b/modelscope/preprocessors/ofa/text_classification.py @@ -0,0 +1,81 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch + +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor + + +class OfaTextClassificationPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaTextClassificationPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_instruction(self, data): + text1 = ' '.join( + data['text'].lower().strip().split()[:self.max_src_length]) + text2 = ' '.join( + data['text2'].lower().strip().split()[:self.max_src_length]) + prompt = ' can text1 " {} " imply text2 " {} "?' + text = prompt.format(text1, text2) + instruction_itm = self.tokenize_text(text) + return instruction_itm + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + instruction_itm = self._build_instruction(data) + assert 'label' in data, 'there must has `label` column in train phase ' + label = data['label'] + if self.label2ans: + label = self.label2ans[label] # ans + label_itm = self.tokenize_text(f' {label}', add_bos=False) + if self.prompt_type == 'none': + target_itm = label_itm + elif self.prompt_type == 'prev_output': + target_itm = torch.cat([instruction_itm[1:-1], label_itm]) + else: + raise NotImplementedError + prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]]) + target_itm[:-len(label_itm)] = self.pad_item + sample = { + 'source': instruction_itm, + 'target': target_itm, + 'prev_output_tokens': prev_output_itm, + } + self.add_constraint_mask(sample) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + instruction_itm = self._build_instruction(data) + if self.prompt_type == 'none': + prefix_token = [] + elif self.prompt_type == 'prev_output': + prefix_token = instruction_itm[:-1] # remove eos + else: + raise NotImplementedError + sample = { + 'source': instruction_itm, + 'prefix_token': prefix_token, + } + if 'label' in data: + sample['label'] = self.label2ans[data['label']] + return sample diff --git a/modelscope/preprocessors/ofa/text_to_image_synthesis.py b/modelscope/preprocessors/ofa/text_to_image_synthesis.py new file mode 100644 index 00000000..2f6000eb --- /dev/null +++ b/modelscope/preprocessors/ofa/text_to_image_synthesis.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch + +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor + + +class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaTextToImageSynthesisPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + self.max_src_length = 64 + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + source = ' '.join( + data['text'].lower().strip().split()[:self.max_src_length]) + source = 'what is the complete image? caption: {}'.format(source) + inputs = self.tokenize_text(source) + sample = { + 'source': inputs, + 'patch_images': None, + 'patch_masks': torch.tensor([False]), + 'code_masks': torch.tensor([False]) + } + return sample diff --git a/modelscope/preprocessors/ofa/utils/__init__.py b/modelscope/preprocessors/ofa/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/preprocessors/ofa/utils/collate.py b/modelscope/preprocessors/ofa/utils/collate.py new file mode 100644 index 00000000..f7775680 --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/collate.py @@ -0,0 +1,115 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch + + +def collate_fn(samples, pad_idx, eos_idx): + if len(samples) == 0: + return {} + + def merge(key): + return collate_tokens([s[key] for s in samples], + pad_idx, + eos_idx=eos_idx) + + src_tokens = merge('source') + + batch = { + 'nsentences': len(samples), + 'net_input': { + 'input_ids': src_tokens, + }, + } + if samples[0].get('id', None) is not None: + batch['id'] = np.array([s.get['id'] for s in samples]) + if samples[0].get('target', None) is not None: + batch['target'] = merge('target') + tgt_lengths = torch.LongTensor( + [s['target'].ne(pad_idx).long().sum() for s in samples]) + ntokens = tgt_lengths.sum().item() + batch['ntokens'] = ntokens + if samples[0].get('prev_output_tokens', None) is not None: + batch['net_input']['decoder_input_ids'] = merge('prev_output_tokens') + if samples[0].get('patch_image', None) is not None: + batch['net_input']['patch_images'] = torch.stack( + [sample['patch_image'] for sample in samples], dim=0) + if samples[0].get('patch_mask', None) is not None: + batch['net_input']['patch_masks'] = torch.cat( + [sample['patch_mask'] for sample in samples]) + # image generation + if samples[0].get('code_mask', None) is not None: + batch['net_input']['code_masks'] = torch.cat( + [sample['code_mask'] for sample in samples]) + if samples[0].get('code_image', None) is not None: + batch['code_images'] = torch.cat( + [sample['code_image'] for sample in samples]) + # For classification tasks (i.e., VQA, SNLI-VE, GLUE) + if samples[0].get('conf', None) is not None: + batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0) + if samples[0].get('ref_dict', None) is not None: + batch['ref_dict'] = np.array([s['ref_dict'] for s in samples]) + if samples[0].get('label', None) is not None: + batch['labels'] = np.array([s['label'] for s in samples]).tolist() + if samples[0].get('constraint_mask', None) is not None: + batch['constraint_masks'] = merge('constraint_mask') + if samples[0].get('decoder_prompt', None) is not None: + batch['decoder_prompts'] = np.array( + [s['decoder_prompt'].tolist() for s in samples]) + if samples[0].get('prefix_token', None) is not None: + batch['prefix_tokens'] = merge('prefix_token') + # For detection and visual grounding + if samples[0].get('w_resize_ratio', None) is not None: + batch['w_resize_ratios'] = torch.stack( + [s['w_resize_ratio'] for s in samples], dim=0) + if samples[0].get('h_resize_ratio', None) is not None: + batch['h_resize_ratios'] = torch.stack( + [s['h_resize_ratio'] for s in samples], dim=0) + if samples[0].get('region_coord', None) is not None: + batch['region_coords'] = torch.stack( + [s['region_coord'] for s in samples], dim=0) + if samples[0].get('sample', None) is not None: + batch['samples'] = [s['sample'] for s in samples] + return batch + + +def collate_tokens( + values, + pad_idx, + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + pad_to_length=None, + pad_to_multiple=1, + pad_to_bsz=None, +): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if move_eos_to_beginning: + if eos_idx is None: + # if no eos_idx is specified, then use the last token in src + dst[0] = src[-1] + else: + dst[0] = eos_idx + dst[1:] = src[:-1] + else: + dst.copy_(src) + + if values[0].dim() == 1: + res = values[0].new(len(values), size).fill_(pad_idx) + elif values[0].dim() == 2: + assert move_eos_to_beginning is False + res = values[0].new(len(values), size, + values[0].size(1)).fill_(pad_idx) + else: + raise NotImplementedError + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) + return res diff --git a/modelscope/preprocessors/ofa/utils/constant.py b/modelscope/preprocessors/ofa/utils/constant.py new file mode 100644 index 00000000..102d27c0 --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/constant.py @@ -0,0 +1,13 @@ +from modelscope.utils.constant import Tasks + +OFA_TASK_KEY_MAPPING = { + Tasks.ocr_recognition: ['image'], + Tasks.image_captioning: ['image'], + Tasks.image_classification: ['image'], + Tasks.text_summarization: ['text'], + Tasks.text_classification: ['text', 'text2'], + Tasks.visual_grounding: ['image', 'text'], + Tasks.visual_question_answering: ['image', 'text'], + Tasks.visual_entailment: ['image', 'text', 'text2'], + Tasks.text_to_image_synthesis: ['text'] +} diff --git a/modelscope/preprocessors/ofa/utils/random_help.py b/modelscope/preprocessors/ofa/utils/random_help.py new file mode 100644 index 00000000..e0dca54e --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/random_help.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +def get_rng_state(): + state = {'torch_rng_state': torch.get_rng_state()} + if xm is not None: + state['xla_rng_state'] = xm.get_rng_state() + if torch.cuda.is_available(): + state['cuda_rng_state'] = torch.cuda.get_rng_state() + return state + + +def set_rng_state(state): + torch.set_rng_state(state['torch_rng_state']) + if xm is not None: + xm.set_rng_state(state['xla_rng_state']) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state['cuda_rng_state']) + + +class set_torch_seed(object): + + def __init__(self, seed): + assert isinstance(seed, int) + self.rng_state = get_rng_state() + + torch.manual_seed(seed) + if xm is not None: + xm.set_rng_state(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + def __enter__(self): + return self + + def __exit__(self, *exc): + set_rng_state(self.rng_state) diff --git a/modelscope/preprocessors/ofa/utils/transforms.py b/modelscope/preprocessors/ofa/utils/transforms.py new file mode 100644 index 00000000..3fd312c6 --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/transforms.py @@ -0,0 +1,557 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. + +import random + +import numpy as np +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +from PIL import Image + + +def crop(image, target, region, delete=True): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target['size'] = torch.tensor([h, w]) + + fields = ['labels', 'area'] + + if 'boxes' in target: + boxes = target['boxes'] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target['boxes'] = cropped_boxes.reshape(-1, 4) + target['area'] = area + fields.append('boxes') + + if 'polygons' in target: + polygons = target['polygons'] + num_polygons = polygons.shape[0] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + start_coord = torch.cat([ + torch.tensor([j, i], dtype=torch.float32) + for _ in range(polygons.shape[1] // 2)], dim=0) # yapf: disable# + cropped_boxes = polygons - start_coord + cropped_boxes = torch.min( + cropped_boxes.reshape(num_polygons, -1, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + target['polygons'] = cropped_boxes.reshape(num_polygons, -1) + fields.append('polygons') + + if 'masks' in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append('masks') + + # remove elements for which the boxes or masks that have zero area + if delete and ('boxes' in target or 'masks' in target): + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if 'boxes' in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all( + cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep.tolist()] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + w, h = image.size + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target['boxes'] = boxes + + if 'polygons' in target: + polygons = target['polygons'] + num_polygons = polygons.shape[0] + polygons = polygons.reshape(num_polygons, -1, 2) * torch.as_tensor( + [-1, 1]) + torch.as_tensor([w, 0]) + target['polygons'] = polygons + + if 'masks' in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + + if (w <= h and w == size) or (h <= w and h == size): + if max_size is not None: + max_size = int(max_size) + h = min(h, max_size) + w = min(w, max_size) + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + if max_size is not None: + max_size = int(max_size) + oh = min(oh, max_size) + ow = min(ow, max_size) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size, interpolation=Image.BICUBIC) + + if target is None: + return rescaled_image + + ratios = tuple( + float(s) / float(s_orig) + for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if 'boxes' in target: + boxes = target['boxes'] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target['boxes'] = scaled_boxes + + if 'polygons' in target: + polygons = target['polygons'] + scaled_ratio = torch.cat([ + torch.tensor([ratio_width, ratio_height]) + for _ in range(polygons.shape[1] // 2)], dim=0) # yapf: disable + scaled_polygons = polygons * scaled_ratio + target['polygons'] = scaled_polygons + + if 'area' in target: + area = target['area'] + scaled_area = area * (ratio_width * ratio_height) + target['area'] = scaled_area + + h, w = size + target['size'] = torch.tensor([h, w]) + + if 'masks' in target: + assert False + + return rescaled_image, target + + +class CenterCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, + (crop_top, crop_left, crop_height, crop_width)) + + +class ObjectCenterCrop(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + + x0 = float(target['boxes'][0][0]) + y0 = float(target['boxes'][0][1]) + x1 = float(target['boxes'][0][2]) + y1 = float(target['boxes'][0][3]) + + center_x = (x0 + x1) / 2 + center_y = (y0 + y1) / 2 + crop_left = max( + center_x - crop_width / 2 + + min(image_width - center_x - crop_width / 2, 0), 0) + crop_top = max( + center_y - crop_height / 2 + + min(image_height - center_y - crop_height / 2, 0), 0) + + return crop( + img, + target, (crop_top, crop_left, crop_height, crop_width), + delete=False) + + +class RandomHorizontalFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize(object): + + def __init__(self, sizes, max_size=None, equal=False): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + self.equal = equal + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + if self.equal: + return resize(img, target, size, size) + else: + return resize(img, target, size, self.max_size) + + +class ToTensor(object): + + def __call__(self, img, target): + return F.to_tensor(img), target + + +class Normalize(object): + + def __init__(self, mean, std, max_image_size=512): + self.mean = mean + self.std = std + self.max_image_size = max_image_size + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + # h, w = image.shape[-2:] + h, w = target['size'][0], target['size'][1] + if 'boxes' in target: + boxes = target['boxes'] + boxes = boxes / self.max_image_size + target['boxes'] = boxes + if 'polygons' in target: + polygons = target['polygons'] + scale = torch.cat([ + torch.tensor([w, h], dtype=torch.float32) + for _ in range(polygons.shape[1] // 2)], dim=0) # yapf: disable + polygons = polygons / scale + target['polygons'] = polygons + return image, target + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class LargeScaleJitter(object): + """ + implementation of large scale jitter from copy_paste + """ + + def __init__(self, output_size=512, aug_scale_min=0.3, aug_scale_max=2.0): + self.desired_size = torch.tensor([output_size]) + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + + def rescale_target(self, scaled_size, image_size, target): + # compute rescaled targets + image_scale = scaled_size / image_size + ratio_height, ratio_width = image_scale + + target = target.copy() + target['size'] = scaled_size + + if 'boxes' in target: + boxes = target['boxes'] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target['boxes'] = scaled_boxes + + if 'area' in target: + area = target['area'] + scaled_area = area * (ratio_width * ratio_height) + target['area'] = scaled_area + + if 'masks' in target: + assert False + masks = target['masks'] + # masks = interpolate( + # masks[:, None].float(), scaled_size, mode="nearest")[:, 0] > 0.5 + target['masks'] = masks + return target + + def crop_target(self, region, target): + i, j, h, w = region + fields = ['labels', 'area'] + + target = target.copy() + target['size'] = torch.tensor([h, w]) + + if 'boxes' in target: + boxes = target['boxes'] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min( + cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] + - cropped_boxes[:, 0, :]).prod(dim=1) + target['boxes'] = cropped_boxes.reshape(-1, 4) + target['area'] = area + fields.append('boxes') + + if 'masks' in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append('masks') + + # remove elements for which the boxes or masks that have zero area + if 'boxes' in target or 'masks' in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if 'boxes' in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all( + cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep.tolist()] + return target + + def pad_target(self, padding, target): + target = target.copy() + if 'masks' in target: + target['masks'] = torch.nn.functional.pad( + target['masks'], (0, padding[1], 0, padding[0])) + return target + + def __call__(self, image, target=None): + image_size = image.size + image_size = torch.tensor(image_size[::-1]) + + random_scale = torch.rand(1) * ( + self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min + scaled_size = (random_scale * self.desired_size).round() + + scale = torch.maximum(scaled_size / image_size[0], + scaled_size / image_size[1]) + scaled_size = (image_size * scale).round().int() + + scaled_image = F.resize( + image, scaled_size.tolist(), interpolation=Image.BICUBIC) + + if target is not None: + target = self.rescale_target(scaled_size, image_size, target) + + # randomly crop or pad images + if random_scale >= 1: + # Selects non-zero random offset (x, y) if scaled image is larger than desired_size. + max_offset = scaled_size - self.desired_size + offset = (max_offset * torch.rand(2)).floor().int() + region = (offset[0].item(), offset[1].item(), + self.desired_size[0].item(), self.desired_size[0].item()) + output_image = F.crop(scaled_image, *region) + if target is not None: + target = self.crop_target(region, target) + else: + assert False + padding = self.desired_size - scaled_size + output_image = F.pad(scaled_image, + [0, 0, padding[1].item(), padding[0].item()]) + if target is not None: + target = self.pad_target(padding, target) + + return output_image, target + + +class OriginLargeScaleJitter(object): + """ + implementation of large scale jitter from copy_paste + """ + + def __init__(self, output_size=512, aug_scale_min=0.3, aug_scale_max=2.0): + self.desired_size = torch.tensor(output_size) + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + + def rescale_target(self, scaled_size, image_size, target): + # compute rescaled targets + image_scale = scaled_size / image_size + ratio_height, ratio_width = image_scale + + target = target.copy() + target['size'] = scaled_size + + if 'boxes' in target: + boxes = target['boxes'] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height]) + target['boxes'] = scaled_boxes + + if 'area' in target: + area = target['area'] + scaled_area = area * (ratio_width * ratio_height) + target['area'] = scaled_area + + if 'masks' in target: + assert False + masks = target['masks'] + # masks = interpolate( + # masks[:, None].float(), scaled_size, mode="nearest")[:, 0] > 0.5 + target['masks'] = masks + return target + + def crop_target(self, region, target): + i, j, h, w = region + fields = ['labels', 'area'] + + target = target.copy() + target['size'] = torch.tensor([h, w]) + + if 'boxes' in target: + boxes = target['boxes'] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min( + cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] + - cropped_boxes[:, 0, :]).prod(dim=1) + target['boxes'] = cropped_boxes.reshape(-1, 4) + target['area'] = area + fields.append('boxes') + + if 'masks' in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append('masks') + + # remove elements for which the boxes or masks that have zero area + if 'boxes' in target or 'masks' in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if 'boxes' in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all( + cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep.tolist()] + return target + + def pad_target(self, padding, target): + target = target.copy() + if 'masks' in target: + target['masks'] = torch.nn.functional.pad( + target['masks'], (0, padding[1], 0, padding[0])) + return target + + def __call__(self, image, target=None): + image_size = image.size + image_size = torch.tensor(image_size[::-1]) + + out_desired_size = (self.desired_size * image_size + / max(image_size)).round().int() + + random_scale = torch.rand(1) * ( + self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min + scaled_size = (random_scale * self.desired_size).round() + + scale = torch.minimum(scaled_size / image_size[0], + scaled_size / image_size[1]) + scaled_size = (image_size * scale).round().int() + + scaled_image = F.resize(image, scaled_size.tolist()) + + if target is not None: + target = self.rescale_target(scaled_size, image_size, target) + + # randomly crop or pad images + if random_scale > 1: + # Selects non-zero random offset (x, y) if scaled image is larger than desired_size. + max_offset = scaled_size - out_desired_size + offset = (max_offset * torch.rand(2)).floor().int() + region = (offset[0].item(), offset[1].item(), + out_desired_size[0].item(), out_desired_size[1].item()) + output_image = F.crop(scaled_image, *region) + if target is not None: + target = self.crop_target(region, target) + else: + padding = out_desired_size - scaled_size + output_image = F.pad(scaled_image, + [0, 0, padding[1].item(), padding[0].item()]) + if target is not None: + target = self.pad_target(padding, target) + + return output_image, target + + +class RandomDistortion(object): + """ + Distort image w.r.t hue, saturation and exposure. + """ + + def __init__(self, + brightness=0, + contrast=0, + saturation=0, + hue=0, + prob=0.5): + self.prob = prob + self.tfm = T.ColorJitter(brightness, contrast, saturation, hue) + + def __call__(self, img, target=None): + if np.random.random() < self.prob: + return self.tfm(img), target + else: + return img, target diff --git a/modelscope/preprocessors/ofa/utils/vision_helper.py b/modelscope/preprocessors/ofa/utils/vision_helper.py new file mode 100644 index 00000000..518b110a --- /dev/null +++ b/modelscope/preprocessors/ofa/utils/vision_helper.py @@ -0,0 +1,357 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. + +import cv2 +import numpy as np + + +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + ''' + same output as PIL.ImageOps.autocontrast + ''' + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if low.shape[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + ''' + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + ''' + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: + return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + ''' + like PIL, rotate by degree, not radians + ''' + H, W = img.shape[0], img.shape[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + ''' + same output as PIL.ImageOps.posterize + ''' + table = np.array([el if el < thresh else 255 - el for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + # same output as PIL.ImageEnhance.Color + M = ( + np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], + [-0.299, -0.299, 0.701]]) * factor + + np.float32([[0.114], [0.587], [0.299]])) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = np.array([(el - mean) * factor + mean + for el in range(256)]).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def brightness_func(img, factor): + ''' + same output as PIL.ImageEnhance.Contrast + ''' + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype( + np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + ''' + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + ''' + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * ( + out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, + flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + ''' + same output as PIL.Image.transform + ''' + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, + flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + ''' + same output as PIL.Image.transform + ''' + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, + flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def posterize_func(img, bits): + ''' + same output as PIL.ImageOps.posterize + ''' + out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine( + img, M, (W, H), borderValue=fill, + flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = img.shape[0], img.shape[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +# level to args +def enhance_level_to_args(MAX_LEVEL): + + def level_to_args(level): + return ((level / MAX_LEVEL) * 1.8 + 0.1, ) + + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + + def level_to_args(level): + level = (level / MAX_LEVEL) * 0.3 + if np.random.random() > 0.5: + level = -level + return level, replace_value + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + + def level_to_args(level): + level = (level / MAX_LEVEL) * float(translate_const) + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + + def level_to_args(level): + level = int((level / MAX_LEVEL) * cutout_const) + return (level, replace_value) + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + + def level_to_args(level): + level = int((level / MAX_LEVEL) * 256) + return (level, ) + + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + + def level_to_args(level): + level = int((level / MAX_LEVEL) * 4) + return (level, ) + + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + + def level_to_args(level): + level = (level / MAX_LEVEL) * 30 + if np.random.random() < 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +func_dict = { + 'Identity': identity_func, + 'AutoContrast': autocontrast_func, + 'Equalize': equalize_func, + 'Rotate': rotate_func, + 'Solarize': solarize_func, + 'Color': color_func, + 'Contrast': contrast_func, + 'Brightness': brightness_func, + 'Sharpness': sharpness_func, + 'ShearX': shear_x_func, + 'TranslateX': translate_x_func, + 'TranslateY': translate_y_func, + 'Posterize': posterize_func, + 'ShearY': shear_y_func, +} + +translate_const = 10 +MAX_LEVEL = 10 +replace_value = (128, 128, 128) +arg_dict = { + 'Identity': + none_level_to_args, + 'AutoContrast': + none_level_to_args, + 'Equalize': + none_level_to_args, + 'Rotate': + rotate_level_to_args(MAX_LEVEL, replace_value), + 'Solarize': + solarize_level_to_args(MAX_LEVEL), + 'Color': + enhance_level_to_args(MAX_LEVEL), + 'Contrast': + enhance_level_to_args(MAX_LEVEL), + 'Brightness': + enhance_level_to_args(MAX_LEVEL), + 'Sharpness': + enhance_level_to_args(MAX_LEVEL), + 'ShearX': + shear_level_to_args(MAX_LEVEL, replace_value), + 'TranslateX': + translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + 'TranslateY': + translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + 'Posterize': + posterize_level_to_args(MAX_LEVEL), + 'ShearY': + shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img diff --git a/modelscope/preprocessors/ofa/visual_entailment.py b/modelscope/preprocessors/ofa/visual_entailment.py new file mode 100644 index 00000000..fff5bbd3 --- /dev/null +++ b/modelscope/preprocessors/ofa/visual_entailment.py @@ -0,0 +1,120 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor + + +class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaVisualEntailmentPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target = ' {}'.format(sample['label']) + sample['ref_dict'] = {sample['label']: 1.0} + tgt_item = self.tokenize_text(target, add_bos=False, add_eos=False) + + if self.prompt_type == 'none': + prev_output_item = torch.cat([self.bos_item, tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + elif self.prompt_type == 'src': + prev_output_item = torch.cat([sample['source'], tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + elif self.prompt_type == 'prev_output': + prev_output_item = torch.cat([sample['source'][:-1], tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + else: + raise NotImplementedError + + target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id + sample['target'] = target_item + sample['prev_output_tokens'] = prev_output_item + + if self.constraint_trie is not None: + constraint_mask = torch.zeros( + (len(target_item), len(self.tgt_dict))).bool() + start_idx = len(target_item) - len(tgt_item) - 1 + for i in range( + len(target_item) - len(tgt_item) - 1, len(target_item)): + constraint_prefix_token = [ + self.tgt_dict.bos() + ] + target_item[start_idx:i].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + sample['constraint_mask'] = constraint_mask + + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) + patch_image = self.patch_resize_transform(image) + if 'text2' not in data: + hypothesis = self.pre_caption(data[self.column_map['text']], + self.max_src_length) + prompt = self.cfg.model.get('prompt', + ' does the image describe " {} "?') + text = prompt.format(hypothesis) + else: + assert 'text' in data, f'text must be in the input {data.keys()}' + caption = self.pre_caption(data[self.column_map['text2']], + self.max_src_length) + hypothesis = self.pre_caption(data[self.column_map['text']], + self.max_src_length) + prompt = self.cfg.model.get( + 'prompt', ' can image and text1 " {} " imply text2 " {} "?') + text = prompt.format(caption, hypothesis) + inputs = self.tokenize_text(text) + if self.prompt_type == 'none': + decoder_prompt = self.bos_item + elif self.prompt_type == 'src': + decoder_prompt = inputs + elif self.prompt_type == 'prev_output': + decoder_prompt = inputs[:-1] + else: + raise NotImplementedError + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]), + 'decoder_prompt': decoder_prompt, + } + if 'relation' in self.column_map and self.column_map[ + 'relation'] in data: + sample['label'] = data[self.column_map['relation']] + return sample diff --git a/modelscope/preprocessors/ofa/visual_grounding.py b/modelscope/preprocessors/ofa/visual_grounding.py new file mode 100644 index 00000000..2da79670 --- /dev/null +++ b/modelscope/preprocessors/ofa/visual_grounding.py @@ -0,0 +1,141 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor +from .utils import transforms as T + + +class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaVisualGroundingPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + + self.num_bins = self.cfg.model.get('num_bins', 1000) + if self.mode == ModeKeys.TRAIN: + # for positioning + self.positioning_transform = T.Compose([ + T.RandomResize([self.patch_image_size], + max_size=self.patch_image_size), + T.ToTensor(), + T.Normalize( + mean=self.mean, + std=self.std, + max_image_size=self.max_image_size) + ]) + else: + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) + w, h = image.size + boxes_target = { + 'boxes': [], + 'labels': [], + 'area': [], + 'size': torch.tensor([h, w]) + } + x0, y0, x1, y1 = data[self.column_map['region_coord']].strip().split( + ',') + region = torch.tensor([float(x0), float(y0), float(x1), float(y1)]) + boxes_target['boxes'] = torch.tensor( + [[float(x0), float(y0), float(x1), + float(y1)]]) + boxes_target['labels'] = np.array([0]) + area = [(float(x1) - float(x0)) * (float(y1) - float(y0))] + boxes_target['area'] = torch.tensor(area) + + patch_image, patch_boxes = self.positioning_transform( + image, boxes_target) + resize_h, resize_w = patch_boxes['size'][0], patch_boxes['size'][1] + quant_x0 = ''.format( + int((patch_boxes['boxes'][0][0] * (self.num_bins - 1)).round())) + quant_y0 = ''.format( + int((patch_boxes['boxes'][0][1] * (self.num_bins - 1)).round())) + quant_x1 = ''.format( + int((patch_boxes['boxes'][0][2] * (self.num_bins - 1)).round())) + quant_y1 = ''.format( + int((patch_boxes['boxes'][0][3] * (self.num_bins - 1)).round())) + region_coord = '{} {} {} {}'.format(quant_x0, quant_y0, quant_x1, + quant_y1) + src_caption = self.pre_caption(data[self.column_map['text']], + self.max_src_length) + prompt = self.cfg.model.get( + 'prompt', ' which region does the text " {} " describe?') + text = prompt.format(src_caption) + src_item = self.tokenize_text(text) + target_item = self.tokenize_text( + region_coord, add_bos=False) # !!! use_bpe=False + prev_output_item = torch.cat([self.bos_item, target_item[:-1]]) + + sample = { + 'source': src_item, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]), + 'target': target_item, + 'prev_output_tokens': prev_output_item, + 'w_resize_ratio': resize_w / w, + 'h_resize_ratio': resize_h / h, + 'region_coord': region + } + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) + w, h = image.size + patch_image = self.patch_resize_transform(image) + w_resize_ratio = torch.tensor(self.patch_image_size / w) + h_resize_ratio = torch.tensor(self.patch_image_size / h) + src_caption = self.pre_caption(data[self.column_map['text']], + self.max_src_length) + prompt = self.cfg.model.get( + 'prompt', ' which region does the text " {} " describe?') + text = prompt.format(src_caption) + src_item = self.tokenize_text(text) + sample = { + 'source': src_item, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]), + 'w_resize_ratio': w_resize_ratio, + 'h_resize_ratio': h_resize_ratio, + } + + if 'region_coord' in self.column_map and self.column_map[ + 'region_coord'] in data: + x0, y0, x1, y1 = data[ + self.column_map['region_coord']].strip().split(',') + sample['label'] = [float(x0), float(y0), float(x1), float(y1)] + return sample diff --git a/modelscope/preprocessors/ofa/visual_question_answering.py b/modelscope/preprocessors/ofa/visual_question_answering.py new file mode 100644 index 00000000..b83cf935 --- /dev/null +++ b/modelscope/preprocessors/ofa/visual_question_answering.py @@ -0,0 +1,104 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys +from .base import OfaBasePreprocessor + + +class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): + + def __init__(self, + cfg, + model_dir, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + mode: preprocessor mode (model mode) + """ + super(OfaVisualQuestionAnsweringPreprocessor, + self).__init__(cfg, model_dir, mode, *args, **kwargs) + # Initialize transform + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + tgt_item = self.tokenize_text( + ' {}'.format(sample['label']), add_bos=False, add_eos=False) + + if self.prompt_type == 'none': + prev_output_item = torch.cat([self.bos_item, tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + elif self.prompt_type == 'src': + prev_output_item = torch.cat([sample['source'], tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + elif self.prompt_type == 'prev_output': + prev_output_item = torch.cat([sample['source'][:-1], tgt_item]) + target_item = torch.cat([prev_output_item[1:], self.eos_item]) + else: + raise NotImplementedError + target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id + + sample['prev_output_tokens'] = prev_output_item + sample['target'] = target_item + + if self.constraint_trie is not None: + constraint_mask = torch.zeros( + (len(target_item), len(self.tgt_dict))).bool() + start_idx = len(target_item) - len(tgt_item) - 1 + for i in range( + len(target_item) - len(tgt_item) - 1, len(target_item)): + constraint_prefix_token = [ + self.tgt_dict.bos() + ] + target_item[start_idx:i].tolist() + constraint_nodes = self.constraint_trie.get_next_layer( + constraint_prefix_token) + constraint_mask[i][constraint_nodes] = True + sample['constraint_mask'] = constraint_mask + + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) + patch_image = self.patch_resize_transform(image) + text = ' {}'.format(data[self.column_map['text']]) + inputs = self.tokenize_text(text) + if self.prompt_type == 'none': + decoder_prompt = self.bos_item + elif self.prompt_type == 'src': + decoder_prompt = inputs + elif self.prompt_type == 'prev_output': + decoder_prompt = inputs[:-1] + else: + raise NotImplementedError + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]), + 'decoder_prompt': decoder_prompt, + } + if 'answer' in self.column_map and self.column_map['answer'] in data: + sample['label'] = data[self.column_map['answer']] + return sample diff --git a/modelscope/preprocessors/science/__init__.py b/modelscope/preprocessors/science/__init__.py new file mode 100644 index 00000000..54b24887 --- /dev/null +++ b/modelscope/preprocessors/science/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .unifold import (UniFoldPreprocessor) + +else: + _import_structure = {'unifold': ['UniFoldPreprocessor']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/science/uni_fold.py b/modelscope/preprocessors/science/uni_fold.py new file mode 100644 index 00000000..2a44c885 --- /dev/null +++ b/modelscope/preprocessors/science/uni_fold.py @@ -0,0 +1,569 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import gzip +import hashlib +import logging +import os +import pickle +import random +import re +import tarfile +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from unittest import result + +import json +import numpy as np +import requests +import torch +from tqdm import tqdm + +from modelscope.metainfo import Preprocessors +from modelscope.models.science.unifold.data import protein, residue_constants +from modelscope.models.science.unifold.data.protein import PDB_CHAIN_IDS +from modelscope.models.science.unifold.data.utils import compress_features +from modelscope.models.science.unifold.msa import parsers, pipeline, templates +from modelscope.models.science.unifold.msa.tools import hhsearch +from modelscope.models.science.unifold.msa.utils import divide_multi_chains +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields + +__all__ = [ + 'UniFoldPreprocessor', +] + +TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' +DEFAULT_API_SERVER = 'https://api.colabfold.com' + + +def run_mmseqs2( + x, + prefix, + use_env=True, + use_templates=False, + use_pairing=False, + host_url='https://api.colabfold.com') -> Tuple[List[str], List[str]]: + submission_endpoint = 'ticket/pair' if use_pairing else 'ticket/msa' + + def submit(seqs, mode, N=101): + n, query = N, '' + for seq in seqs: + query += f'>{n}\n{seq}\n' + n += 1 + + res = requests.post( + f'{host_url}/{submission_endpoint}', + data={ + 'q': query, + 'mode': mode + }) + try: + out = res.json() + except ValueError: + out = {'status': 'ERROR'} + return out + + def status(ID): + res = requests.get(f'{host_url}/ticket/{ID}') + try: + out = res.json() + except ValueError: + out = {'status': 'ERROR'} + return out + + def download(ID, path): + res = requests.get(f'{host_url}/result/download/{ID}') + with open(path, 'wb') as out: + out.write(res.content) + + # process input x + seqs = [x] if isinstance(x, str) else x + + mode = 'env' + if use_pairing: + mode = '' + use_templates = False + use_env = False + + # define path + path = f'{prefix}' + if not os.path.isdir(path): + os.mkdir(path) + + # call mmseqs2 api + tar_gz_file = f'{path}/out_{mode}.tar.gz' + N, REDO = 101, True + + # deduplicate and keep track of order + seqs_unique = [] + # TODO this might be slow for large sets + [seqs_unique.append(x) for x in seqs if x not in seqs_unique] + Ms = [N + seqs_unique.index(seq) for seq in seqs] + # lets do it! + if not os.path.isfile(tar_gz_file): + TIME_ESTIMATE = 150 * len(seqs_unique) + with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: + while REDO: + pbar.set_description('SUBMIT') + + # Resubmit job until it goes through + out = submit(seqs_unique, mode, N) + while out['status'] in ['UNKNOWN', 'RATELIMIT']: + sleep_time = 5 + random.randint(0, 5) + # logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") + # resubmit + time.sleep(sleep_time) + out = submit(seqs_unique, mode, N) + + if out['status'] == 'ERROR': + error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' + error = error + 'If error persists, please try again an hour later.' + raise Exception(error) + + if out['status'] == 'MAINTENANCE': + raise Exception( + 'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.' + ) + + # wait for job to finish + ID, TIME = out['id'], 0 + pbar.set_description(out['status']) + while out['status'] in ['UNKNOWN', 'RUNNING', 'PENDING']: + t = 5 + random.randint(0, 5) + # logger.error(f"Sleeping for {t}s. Reason: {out['status']}") + time.sleep(t) + out = status(ID) + pbar.set_description(out['status']) + if out['status'] == 'RUNNING': + TIME += t + pbar.update(n=t) + + if out['status'] == 'COMPLETE': + if TIME < TIME_ESTIMATE: + pbar.update(n=(TIME_ESTIMATE - TIME)) + REDO = False + + if out['status'] == 'ERROR': + REDO = False + error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' + error = error + 'If error persists, please try again an hour later.' + raise Exception(error) + + # Download results + download(ID, tar_gz_file) + + # prep list of a3m files + if use_pairing: + a3m_files = [f'{path}/pair.a3m'] + else: + a3m_files = [f'{path}/uniref.a3m'] + if use_env: + a3m_files.append(f'{path}/bfd.mgnify30.metaeuk30.smag30.a3m') + + # extract a3m files + if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): + with tarfile.open(tar_gz_file) as tar_gz: + tar_gz.extractall(path) + + # templates + if use_templates: + templates = {} + + with open(f'{path}/pdb70.m8', 'r') as f: + lines = f.readlines() + for line in lines: + p = line.rstrip().split() + M, pdb, _, _ = p[0], p[1], p[2], p[10] # qid, e_value + M = int(M) + if M not in templates: + templates[M] = [] + templates[M].append(pdb) + + template_paths = {} + for k, TMPL in templates.items(): + TMPL_PATH = f'{prefix}/templates_{k}' + if not os.path.isdir(TMPL_PATH): + os.mkdir(TMPL_PATH) + TMPL_LINE = ','.join(TMPL[:20]) + os.system( + f'curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/' + ) + os.system( + f'cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex' + ) + os.system(f'touch {TMPL_PATH}/pdb70_cs219.ffdata') + template_paths[k] = TMPL_PATH + + # gather a3m lines + a3m_lines = {} + for a3m_file in a3m_files: + update_M, M = True, None + with open(a3m_file, 'r') as f: + lines = f.readlines() + for line in lines: + if len(line) > 0: + if '\x00' in line: + line = line.replace('\x00', '') + update_M = True + if line.startswith('>') and update_M: + M = int(line[1:].rstrip()) + update_M = False + if M not in a3m_lines: + a3m_lines[M] = [] + a3m_lines[M].append(line) + + # return results + + a3m_lines = [''.join(a3m_lines[n]) for n in Ms] + + if use_templates: + template_paths_ = [] + for n in Ms: + if n not in template_paths: + template_paths_.append(None) + # print(f"{n-N}\tno_templates_found") + else: + template_paths_.append(template_paths[n]) + template_paths = template_paths_ + + return (a3m_lines, template_paths) if use_templates else a3m_lines + + +def get_null_template(query_sequence: Union[List[str], str], + num_temp: int = 1) -> Dict[str, Any]: + ln = ( + len(query_sequence) if isinstance(query_sequence, str) else sum( + len(s) for s in query_sequence)) + output_templates_sequence = 'A' * ln + # output_confidence_scores = np.full(ln, 1.0) + + templates_all_atom_positions = np.zeros( + (ln, templates.residue_constants.atom_type_num, 3)) + templates_all_atom_masks = np.zeros( + (ln, templates.residue_constants.atom_type_num)) + templates_aatype = templates.residue_constants.sequence_to_onehot( + output_templates_sequence, + templates.residue_constants.HHBLITS_AA_TO_ID) + template_features = { + 'template_all_atom_positions': + np.tile(templates_all_atom_positions[None], [num_temp, 1, 1, 1]), + 'template_all_atom_masks': + np.tile(templates_all_atom_masks[None], [num_temp, 1, 1]), + 'template_sequence': ['none'.encode()] * num_temp, + 'template_aatype': + np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]), + 'template_domain_names': ['none'.encode()] * num_temp, + 'template_sum_probs': + np.zeros([num_temp], dtype=np.float32), + } + return template_features + + +def get_template(a3m_lines: str, template_path: str, + query_sequence: str) -> Dict[str, Any]: + template_featurizer = templates.HhsearchHitFeaturizer( + mmcif_dir=template_path, + max_template_date='2100-01-01', + max_hits=20, + kalign_binary_path='kalign', + release_dates_path=None, + obsolete_pdbs_path=None, + ) + + hhsearch_pdb70_runner = hhsearch.HHSearch( + binary_path='hhsearch', databases=[f'{template_path}/pdb70']) + + hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines) + hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result) + templates_result = template_featurizer.get_templates( + query_sequence=query_sequence, hits=hhsearch_hits) + return dict(templates_result.features) + + +@PREPROCESSORS.register_module( + Fields.science, module_name=Preprocessors.unifold_preprocessor) +class UniFoldPreprocessor(Preprocessor): + + def __init__(self, **cfg): + self.symmetry_group = cfg['symmetry_group'] # "C1" + if not self.symmetry_group: + self.symmetry_group = None + self.MIN_SINGLE_SEQUENCE_LENGTH = 16 # TODO: change to cfg + self.MAX_SINGLE_SEQUENCE_LENGTH = 1000 + self.MAX_MULTIMER_LENGTH = 1000 + self.jobname = 'unifold' + self.output_dir_base = './unifold-predictions' + os.makedirs(self.output_dir_base, exist_ok=True) + + def clean_and_validate_sequence(self, input_sequence: str, min_length: int, + max_length: int) -> str: + clean_sequence = input_sequence.translate( + str.maketrans('', '', ' \n\t')).upper() + aatypes = set(residue_constants.restypes) # 20 standard aatypes. + if not set(clean_sequence).issubset(aatypes): + raise ValueError( + f'Input sequence contains non-amino acid letters: ' + f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard ' + 'amino acids as inputs.') + if len(clean_sequence) < min_length: + raise ValueError( + f'Input sequence is too short: {len(clean_sequence)} amino acids, ' + f'while the minimum is {min_length}') + if len(clean_sequence) > max_length: + raise ValueError( + f'Input sequence is too long: {len(clean_sequence)} amino acids, while ' + f'the maximum is {max_length}. You may be able to run it with the full ' + f'Uni-Fold system depending on your resources (system memory, ' + f'GPU memory).') + return clean_sequence + + def validate_input(self, input_sequences: Sequence[str], + symmetry_group: str, min_length: int, max_length: int, + max_multimer_length: int) -> Tuple[Sequence[str], bool]: + """Validates and cleans input sequences and determines which model to use.""" + sequences = [] + + for input_sequence in input_sequences: + if input_sequence.strip(): + input_sequence = self.clean_and_validate_sequence( + input_sequence=input_sequence, + min_length=min_length, + max_length=max_length) + sequences.append(input_sequence) + + if symmetry_group is not None and symmetry_group != 'C1': + if symmetry_group.startswith( + 'C') and symmetry_group[1:].isnumeric(): + print( + f'Using UF-Symmetry with group {symmetry_group}. If you do not ' + f'want to use UF-Symmetry, please use `C1` and copy the AU ' + f'sequences to the count in the assembly.') + is_multimer = (len(sequences) > 1) + return sequences, is_multimer, symmetry_group + else: + raise ValueError( + f'UF-Symmetry does not support symmetry group ' + f'{symmetry_group} currently. Cyclic groups (Cx) are ' + f'supported only.') + + elif len(sequences) == 1: + print('Using the single-chain model.') + return sequences, False, None + + elif len(sequences) > 1: + total_multimer_length = sum([len(seq) for seq in sequences]) + if total_multimer_length > max_multimer_length: + raise ValueError( + f'The total length of multimer sequences is too long: ' + f'{total_multimer_length}, while the maximum is ' + f'{max_multimer_length}. Please use the full AlphaFold ' + f'system for long multimers.') + print(f'Using the multimer model with {len(sequences)} sequences.') + return sequences, True, None + + else: + raise ValueError( + 'No input amino acid sequence provided, please provide at ' + 'least one sequence.') + + def add_hash(self, x, y): + return x + '_' + hashlib.sha1(y.encode()).hexdigest()[:5] + + def get_msa_and_templates( + self, + jobname: str, + query_seqs_unique: Union[str, List[str]], + result_dir: Path, + msa_mode: str, + use_templates: bool, + homooligomers_num: int = 1, + host_url: str = DEFAULT_API_SERVER, + ) -> Tuple[Optional[List[str]], Optional[List[str]], List[str], List[int], + List[Dict[str, Any]]]: + + use_env = msa_mode == 'MMseqs2' + + template_features = [] + if use_templates: + a3m_lines_mmseqs2, template_paths = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_templates=True, + host_url=host_url, + ) + if template_paths is None: + for index in range(0, len(query_seqs_unique)): + template_feature = get_null_template( + query_seqs_unique[index]) + template_features.append(template_feature) + else: + for index in range(0, len(query_seqs_unique)): + if template_paths[index] is not None: + template_feature = get_template( + a3m_lines_mmseqs2[index], + template_paths[index], + query_seqs_unique[index], + ) + if len(template_feature['template_domain_names']) == 0: + template_feature = get_null_template( + query_seqs_unique[index]) + else: + template_feature = get_null_template( + query_seqs_unique[index]) + template_features.append(template_feature) + else: + for index in range(0, len(query_seqs_unique)): + template_feature = get_null_template(query_seqs_unique[index]) + template_features.append(template_feature) + + if msa_mode == 'single_sequence': + a3m_lines = [] + num = 101 + for i, seq in enumerate(query_seqs_unique): + a3m_lines.append('>' + str(num + i) + '\n' + seq) + else: + # find normal a3ms + a3m_lines = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_pairing=False, + host_url=host_url, + ) + if len(query_seqs_unique) > 1: + # find paired a3m if not a homooligomers + paired_a3m_lines = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_pairing=True, + host_url=host_url, + ) + else: + num = 101 + paired_a3m_lines = [] + for i in range(0, homooligomers_num): + paired_a3m_lines.append('>' + str(num + i) + '\n' + + query_seqs_unique[0] + '\n') + + return ( + a3m_lines, + paired_a3m_lines, + template_features, + ) + + def __call__(self, data: Union[str, Tuple]): + if isinstance(data, str): + data = [data, '', '', ''] + basejobname = ''.join(data) + basejobname = re.sub(r'\W+', '', basejobname) + target_id = self.add_hash(self.jobname, basejobname) + + sequences, is_multimer, _ = self.validate_input( + input_sequences=data, + symmetry_group=self.symmetry_group, + min_length=self.MIN_SINGLE_SEQUENCE_LENGTH, + max_length=self.MAX_SINGLE_SEQUENCE_LENGTH, + max_multimer_length=self.MAX_MULTIMER_LENGTH) + + descriptions = [ + '> ' + target_id + ' seq' + str(ii) + for ii in range(len(sequences)) + ] + + if is_multimer: + divide_multi_chains(target_id, self.output_dir_base, sequences, + descriptions) + + s = [] + for des, seq in zip(descriptions, sequences): + s += [des, seq] + + unique_sequences = [] + [ + unique_sequences.append(x) for x in sequences + if x not in unique_sequences + ] + + if len(unique_sequences) == 1: + homooligomers_num = len(sequences) + else: + homooligomers_num = 1 + + with open(f'{self.jobname}.fasta', 'w') as f: + f.write('\n'.join(s)) + + result_dir = Path(self.output_dir_base) + output_dir = os.path.join(self.output_dir_base, target_id) + + # msa_mode = 'single_sequence' + msa_mode = 'MMseqs2' + use_templates = True + + unpaired_msa, paired_msa, template_results = self.get_msa_and_templates( + target_id, + unique_sequences, + result_dir=result_dir, + msa_mode=msa_mode, + use_templates=use_templates, + homooligomers_num=homooligomers_num) + + features = [] + pair_features = [] + + for idx, seq in enumerate(unique_sequences): + chain_id = PDB_CHAIN_IDS[idx] + sequence_features = pipeline.make_sequence_features( + sequence=seq, + description=f'> {self.jobname} seq {chain_id}', + num_res=len(seq)) + monomer_msa = parsers.parse_a3m(unpaired_msa[idx]) + msa_features = pipeline.make_msa_features([monomer_msa]) + template_features = template_results[idx] + feature_dict = { + **sequence_features, + **msa_features, + **template_features + } + feature_dict = compress_features(feature_dict) + features_output_path = os.path.join( + output_dir, '{}.feature.pkl.gz'.format(chain_id)) + pickle.dump( + feature_dict, + gzip.GzipFile(features_output_path, 'wb'), + protocol=4) + features.append(feature_dict) + + if is_multimer: + multimer_msa = parsers.parse_a3m(paired_msa[idx]) + pair_features = pipeline.make_msa_features([multimer_msa]) + pair_feature_dict = compress_features(pair_features) + uniprot_output_path = os.path.join( + output_dir, '{}.uniprot.pkl.gz'.format(chain_id)) + pickle.dump( + pair_feature_dict, + gzip.GzipFile(uniprot_output_path, 'wb'), + protocol=4, + ) + pair_features.append(pair_feature_dict) + + # return features, pair_features, target_id + return { + 'features': features, + 'pair_features': pair_features, + 'target_id': target_id, + 'is_multimer': is_multimer, + } + + +if __name__ == '__main__': + proc = UniFoldPreprocessor() + protein_example = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + \ + 'TVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI' + features, pair_features = proc.__call__(protein_example) + import ipdb + ipdb.set_trace() diff --git a/modelscope/preprocessors/video.py b/modelscope/preprocessors/video.py new file mode 100644 index 00000000..794033b5 --- /dev/null +++ b/modelscope/preprocessors/video.py @@ -0,0 +1,354 @@ +import math +import os +import random +import uuid +from os.path import exists +from tempfile import TemporaryDirectory +from urllib.parse import urlparse + +import numpy as np +import torch +import torch.utils.data +import torch.utils.dlpack as dlpack +import torchvision.transforms._transforms_video as transforms +from decord import VideoReader +from torchvision.transforms import Compose + +from modelscope.hub.file_download import http_get_file +from modelscope.metainfo import Preprocessors +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .base import Preprocessor +from .builder import PREPROCESSORS + + +def ReadVideoData(cfg, + video_path, + num_spatial_crops_override=None, + num_temporal_views_override=None): + """ simple interface to load video frames from file + + Args: + cfg (Config): The global config object. + video_path (str): video file path + num_spatial_crops_override (int): the spatial crops per clip + num_temporal_views_override (int): the temporal clips per video + Returns: + data (Tensor): the normalized video clips for model inputs + """ + url_parsed = urlparse(video_path) + if url_parsed.scheme in ('file', '') and exists( + url_parsed.path): # Possibly a local file + data = _decode_video(cfg, video_path, num_temporal_views_override) + else: + with TemporaryDirectory() as temporary_cache_dir: + random_str = uuid.uuid4().hex + http_get_file( + url=video_path, + local_dir=temporary_cache_dir, + file_name=random_str, + cookies=None) + temp_file_path = os.path.join(temporary_cache_dir, random_str) + data = _decode_video(cfg, temp_file_path, + num_temporal_views_override) + + if num_spatial_crops_override is not None: + num_spatial_crops = num_spatial_crops_override + transform = kinetics400_tranform(cfg, num_spatial_crops_override) + else: + num_spatial_crops = cfg.TEST.NUM_SPATIAL_CROPS + transform = kinetics400_tranform(cfg, cfg.TEST.NUM_SPATIAL_CROPS) + data_list = [] + for i in range(data.size(0)): + for j in range(num_spatial_crops): + transform.transforms[1].set_spatial_index(j) + data_list.append(transform(data[i])) + return torch.stack(data_list, dim=0) + + +def kinetics400_tranform(cfg, num_spatial_crops): + """ + Configs the transform for the kinetics-400 dataset. + We apply controlled spatial cropping and normalization. + Args: + cfg (Config): The global config object. + num_spatial_crops (int): the spatial crops per clip + Returns: + transform_function (Compose): the transform function for input clips + """ + resize_video = KineticsResizedCrop( + short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE], + crop_size=cfg.DATA.TEST_CROP_SIZE, + num_spatial_crops=num_spatial_crops) + std_transform_list = [ + transforms.ToTensorVideo(), resize_video, + transforms.NormalizeVideo( + mean=cfg.DATA.MEAN, std=cfg.DATA.STD, inplace=True) + ] + return Compose(std_transform_list) + + +def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, + num_clips, num_frames, interval, minus_interval): + """ + Generates the frame index list using interval based sampling. + Args: + vid_length (int): the length of the whole video (valid selection range). + vid_fps (int): the original video fps + target_fps (int): the normalized video fps + clip_idx (int): -1 for random temporal sampling, and positive values for sampling specific + clip from the video + num_clips (int): the total clips to be sampled from each video. + combined with clip_idx, the sampled video is the "clip_idx-th" video from + "num_clips" videos. + num_frames (int): number of frames in each sampled clips. + interval (int): the interval to sample each frame. + minus_interval (bool): control the end index + Returns: + index (tensor): the sampled frame indexes + """ + if num_frames == 1: + index = [random.randint(0, vid_length - 1)] + else: + # transform FPS + clip_length = num_frames * interval * vid_fps / target_fps + + max_idx = max(vid_length - clip_length, 0) + if num_clips == 1: + start_idx = max_idx / 2 + else: + start_idx = clip_idx * math.floor(max_idx / (num_clips - 1)) + if minus_interval: + end_idx = start_idx + clip_length - interval + else: + end_idx = start_idx + clip_length - 1 + + index = torch.linspace(start_idx, end_idx, num_frames) + index = torch.clamp(index, 0, vid_length - 1).long() + + return index + + +def _decode_video_frames_list(cfg, + frames_list, + vid_fps, + num_temporal_views_override=None): + """ + Decodes the video given the numpy frames. + Args: + cfg (Config): The global config object. + frames_list (list): all frames for a video, the frames should be numpy array. + vid_fps (int): the fps of this video. + num_temporal_views_override (int): the temporal clips per video + Returns: + frames (Tensor): video tensor data + """ + assert isinstance(frames_list, list) + if num_temporal_views_override is not None: + num_clips_per_video = num_temporal_views_override + else: + num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS + + frame_list = [] + for clip_idx in range(num_clips_per_video): + # for each clip in the video, + # a list is generated before decoding the specified frames from the video + list_ = _interval_based_sampling( + len(frames_list), + vid_fps, + cfg.DATA.TARGET_FPS, + clip_idx, + num_clips_per_video, + cfg.DATA.NUM_INPUT_FRAMES, + cfg.DATA.SAMPLING_RATE, + cfg.DATA.MINUS_INTERVAL, + ) + frames = None + frames = torch.from_numpy( + np.stack([frames_list[index] for index in list_.tolist()], axis=0)) + frame_list.append(frames) + frames = torch.stack(frame_list) + del vr + return frames + + +def _decode_video(cfg, path, num_temporal_views_override=None): + """ + Decodes the video given the numpy frames. + Args: + cfg (Config): The global config object. + path (str): video file path. + num_temporal_views_override (int): the temporal clips per video + Returns: + frames (Tensor): video tensor data + """ + vr = VideoReader(path) + if num_temporal_views_override is not None: + num_clips_per_video = num_temporal_views_override + else: + num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS + + frame_list = [] + for clip_idx in range(num_clips_per_video): + # for each clip in the video, + # a list is generated before decoding the specified frames from the video + list_ = _interval_based_sampling( + len(vr), + vr.get_avg_fps(), + cfg.DATA.TARGET_FPS, + clip_idx, + num_clips_per_video, + cfg.DATA.NUM_INPUT_FRAMES, + cfg.DATA.SAMPLING_RATE, + cfg.DATA.MINUS_INTERVAL, + ) + frames = None + if path.endswith('.avi'): + append_list = torch.arange(0, list_[0], 4) + frames = dlpack.from_dlpack( + vr.get_batch(torch.cat([append_list, + list_])).to_dlpack()).clone() + frames = frames[append_list.shape[0]:] + else: + frames = dlpack.from_dlpack( + vr.get_batch(list_).to_dlpack()).clone() + frame_list.append(frames) + frames = torch.stack(frame_list) + del vr + return frames + + +class KineticsResizedCrop(object): + """Perform resize and crop for kinetics-400 dataset + Args: + short_side_range (list): The length of short side range. In inference, this shoudle be [256, 256] + crop_size (int): The cropped size for frames. + num_spatial_crops (int): The number of the cropped spatial regions in each video. + """ + + def __init__( + self, + short_side_range, + crop_size, + num_spatial_crops=1, + ): + self.idx = -1 + self.short_side_range = short_side_range + self.crop_size = int(crop_size) + self.num_spatial_crops = num_spatial_crops + + def _get_controlled_crop(self, clip): + """Perform controlled crop for video tensor. + Args: + clip (Tensor): the video data, the shape is [T, C, H, W] + """ + _, _, clip_height, clip_width = clip.shape + + length = self.short_side_range[0] + + if clip_height < clip_width: + new_clip_height = int(length) + new_clip_width = int(clip_width / clip_height * new_clip_height) + new_clip = torch.nn.functional.interpolate( + clip, size=(new_clip_height, new_clip_width), mode='bilinear') + else: + new_clip_width = int(length) + new_clip_height = int(clip_height / clip_width * new_clip_width) + new_clip = torch.nn.functional.interpolate( + clip, size=(new_clip_height, new_clip_width), mode='bilinear') + x_max = int(new_clip_width - self.crop_size) + y_max = int(new_clip_height - self.crop_size) + if self.num_spatial_crops == 1: + x = x_max // 2 + y = y_max // 2 + elif self.num_spatial_crops == 3: + if self.idx == 0: + if new_clip_width == length: + x = x_max // 2 + y = 0 + elif new_clip_height == length: + x = 0 + y = y_max // 2 + elif self.idx == 1: + x = x_max // 2 + y = y_max // 2 + elif self.idx == 2: + if new_clip_width == length: + x = x_max // 2 + y = y_max + elif new_clip_height == length: + x = x_max + y = y_max // 2 + return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size] + + def _get_random_crop(self, clip): + _, _, clip_height, clip_width = clip.shape + + short_side = min(clip_height, clip_width) + long_side = max(clip_height, clip_width) + new_short_side = int(random.uniform(*self.short_side_range)) + new_long_side = int(long_side / short_side * new_short_side) + if clip_height < clip_width: + new_clip_height = new_short_side + new_clip_width = new_long_side + else: + new_clip_height = new_long_side + new_clip_width = new_short_side + + new_clip = torch.nn.functional.interpolate( + clip, size=(new_clip_height, new_clip_width), mode='bilinear') + + x_max = int(new_clip_width - self.crop_size) + y_max = int(new_clip_height - self.crop_size) + x = int(random.uniform(0, x_max)) + y = int(random.uniform(0, y_max)) + return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size] + + def set_spatial_index(self, idx): + """Set the spatial cropping index for controlled cropping.. + Args: + idx (int): the spatial index. The value should be in [0, 1, 2], means [left, center, right], respectively. + """ + self.idx = idx + + def __call__(self, clip): + return self._get_controlled_crop(clip) + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.movie_scene_segmentation_preprocessor) +class MovieSceneSegmentationPreprocessor(Preprocessor): + + def __init__(self, *args, **kwargs): + """ + movie scene segmentation preprocessor + """ + super().__init__(*args, **kwargs) + + self.is_train = kwargs.pop('is_train', True) + self.preprocessor_train_cfg = kwargs.pop(ModeKeys.TRAIN, None) + self.preprocessor_test_cfg = kwargs.pop(ModeKeys.EVAL, None) + self.num_keyframe = kwargs.pop('num_keyframe', 3) + + from .movie_scene_segmentation import get_transform + self.train_transform = get_transform(self.preprocessor_train_cfg) + self.test_transform = get_transform(self.preprocessor_test_cfg) + + def train(self): + self.is_train = True + return + + def eval(self): + self.is_train = False + return + + @type_assert(object, object) + def __call__(self, results): + if self.is_train: + transforms = self.train_transform + else: + transforms = self.test_transform + + results = torch.stack(transforms(results), dim=0) + results = results.view(-1, self.num_keyframe, 3, 224, 224) + return results diff --git a/modelscope/tools/eval.py b/modelscope/tools/eval.py new file mode 100644 index 00000000..ca39932d --- /dev/null +++ b/modelscope/tools/eval.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import argparse + +from modelscope.trainers import build_trainer + + +def parse_args(): + parser = argparse.ArgumentParser(description='evaluate a model') + parser.add_argument('config', help='config file path', type=str) + parser.add_argument( + '--trainer_name', help='name for trainer', type=str, default=None) + parser.add_argument( + '--checkpoint_path', + help='checkpoint to be evaluated', + type=str, + default=None) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + kwargs = dict(cfg_file=args.config) + trainer = build_trainer(args.trainer_name, kwargs) + trainer.evaluate(args.checkpoint_path) + + +if __name__ == '__main__': + main() diff --git a/modelscope/tools/train.py b/modelscope/tools/train.py new file mode 100644 index 00000000..c6f1ef5f --- /dev/null +++ b/modelscope/tools/train.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import argparse + +from modelscope.trainers import build_trainer + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='config file path', type=str) + parser.add_argument( + 'trainer_name', help='name for trainer', type=str, default=None) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + kwargs = dict(cfg_file=args.config) + trainer = build_trainer(args.trainer_name, kwargs) + trainer.train() + + +if __name__ == '__main__': + main() diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py new file mode 100644 index 00000000..37fdcc12 --- /dev/null +++ b/modelscope/trainers/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .audio.ans_trainer import ANSTrainer + from .base import DummyTrainer + from .builder import build_trainer + from .cv import (ImageInstanceSegmentationTrainer, + ImagePortraitEnhancementTrainer, + MovieSceneSegmentationTrainer, ImageInpaintingTrainer, + ReferringVideoObjectSegmentationTrainer) + from .multi_modal import CLIPTrainer + from .nlp import SequenceClassificationTrainer, TextRankingTrainer + from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments + from .trainer import EpochBasedTrainer + +else: + _import_structure = { + 'audio.ans_trainer': ['ANSTrainer'], + 'base': ['DummyTrainer'], + 'builder': ['build_trainer'], + 'cv': [ + 'ImageInstanceSegmentationTrainer', + 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer', + 'ImageInpaintingTrainer' + ], + 'multi_modal': ['CLIPTrainer'], + 'nlp': ['SequenceClassificationTrainer', 'TextRankingTrainer'], + 'nlp_trainer': + ['NlpEpochBasedTrainer', 'VecoTrainer', 'NlpTrainerArguments'], + 'trainer': ['EpochBasedTrainer'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/audio/__init__.py b/modelscope/trainers/audio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/audio/ans_trainer.py b/modelscope/trainers/audio/ans_trainer.py new file mode 100644 index 00000000..37b201ce --- /dev/null +++ b/modelscope/trainers/audio/ans_trainer.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Trainers +from modelscope.trainers import EpochBasedTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.constant import TrainerStages +from modelscope.utils.data_utils import to_device +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@TRAINERS.register_module(module_name=Trainers.speech_frcrn_ans_cirm_16k) +class ANSTrainer(EpochBasedTrainer): + """ + A trainer is used for acoustic noise suppression. + Override train_loop() to use dataset just one time. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train_loop(self, data_loader): + """ + Update epoch by step number, based on super method. + """ + self.invoke_hook(TrainerStages.before_run) + self._epoch = 0 + kwargs = {} + self.model.train() + enumerated = enumerate(data_loader) + for _ in range(self._epoch, self._max_epochs): + self.invoke_hook(TrainerStages.before_train_epoch) + self._inner_iter = 0 + for i, data_batch in enumerated: + data_batch = to_device(data_batch, self.device) + self.data_batch = data_batch + self._inner_iter += 1 + self.invoke_hook(TrainerStages.before_train_iter) + self.train_step(self.model, data_batch, **kwargs) + self.invoke_hook(TrainerStages.after_train_iter) + del self.data_batch + self._iter += 1 + if self._inner_iter >= self.iters_per_epoch: + break + + self.invoke_hook(TrainerStages.after_train_epoch) + self._epoch += 1 + + self.invoke_hook(TrainerStages.after_run) + + def prediction_step(self, model, inputs): + pass diff --git a/modelscope/trainers/audio/kws_farfield_trainer.py b/modelscope/trainers/audio/kws_farfield_trainer.py new file mode 100644 index 00000000..85c1a496 --- /dev/null +++ b/modelscope/trainers/audio/kws_farfield_trainer.py @@ -0,0 +1,282 @@ +import datetime +import math +import os +from typing import Callable, Dict, Optional + +import numpy as np +import torch +from torch import nn as nn +from torch import optim as optim + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models import Model, TorchModel +from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.audio.audio_utils import update_conf +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.data_utils import to_device +from modelscope.utils.device import create_device +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, + init_dist, is_master) + +logger = get_logger() + +BASETRAIN_CONF_EASY = 'basetrain_easy' +BASETRAIN_CONF_NORMAL = 'basetrain_normal' +BASETRAIN_CONF_HARD = 'basetrain_hard' +FINETUNE_CONF_EASY = 'finetune_easy' +FINETUNE_CONF_NORMAL = 'finetune_normal' +FINETUNE_CONF_HARD = 'finetune_hard' + +EASY_RATIO = 0.1 +NORMAL_RATIO = 0.6 +HARD_RATIO = 0.3 +BASETRAIN_RATIO = 0.5 + + +@TRAINERS.register_module(module_name=Trainers.speech_dfsmn_kws_char_farfield) +class KWSFarfieldTrainer(BaseTrainer): + DEFAULT_WORK_DIR = './work_dir' + conf_keys = (BASETRAIN_CONF_EASY, FINETUNE_CONF_EASY, + BASETRAIN_CONF_NORMAL, FINETUNE_CONF_NORMAL, + BASETRAIN_CONF_HARD, FINETUNE_CONF_HARD) + + def __init__(self, + model: str, + work_dir: str, + cfg_file: Optional[str] = None, + arg_parse_fn: Optional[Callable] = None, + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + custom_conf: Optional[dict] = None, + **kwargs): + + if isinstance(model, str): + if os.path.exists(model): + self.model_dir = model if os.path.isdir( + model) else os.path.dirname(model) + else: + self.model_dir = snapshot_download( + model, revision=model_revision) + if cfg_file is None: + cfg_file = os.path.join(self.model_dir, + ModelFile.CONFIGURATION) + else: + assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' + self.model_dir = os.path.dirname(cfg_file) + + super().__init__(cfg_file, arg_parse_fn) + + # the number of model output dimension + # should update config outside the trainer, if user need more wake word + num_syn = kwargs.get('num_syn', None) + if num_syn: + self.cfg.model.num_syn = num_syn + self._num_classes = self.cfg.model.num_syn + self.model = self.build_model() + self.work_dir = work_dir + + if kwargs.get('launcher', None) is not None: + init_dist(kwargs['launcher']) + + _, world_size = get_dist_info() + self._dist = world_size > 1 + + device_name = kwargs.get('device', 'gpu') + if self._dist: + local_rank = get_local_rank() + device_name = f'cuda:{local_rank}' + + self.device = create_device(device_name) + # model placement + if self.device.type == 'cuda': + self.model.to(self.device) + + if 'max_epochs' not in kwargs: + assert hasattr( + self.cfg.train, 'max_epochs' + ), 'max_epochs is missing from the configuration file' + self._max_epochs = self.cfg.train.max_epochs + else: + self._max_epochs = kwargs['max_epochs'] + self._train_iters = kwargs.get('train_iters_per_epoch', None) + self._val_iters = kwargs.get('val_iters_per_epoch', None) + if self._train_iters is None: + self._train_iters = self.cfg.train.train_iters_per_epoch + if self._val_iters is None: + self._val_iters = self.cfg.evaluation.val_iters_per_epoch + dataloader_config = self.cfg.train.dataloader + self._threads = kwargs.get('workers', None) + if self._threads is None: + self._threads = dataloader_config.workers_per_gpu + self._single_rate = BASETRAIN_RATIO + if 'single_rate' in kwargs: + self._single_rate = kwargs['single_rate'] + self._batch_size = dataloader_config.batch_size_per_gpu + if 'model_bin' in kwargs: + model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) + checkpoint = torch.load(model_bin_file) + self.model.load_state_dict(checkpoint) + # build corresponding optimizer and loss function + lr = self.cfg.train.optimizer.lr + self.optimizer = optim.Adam(self.model.parameters(), lr) + self.loss_fn = nn.CrossEntropyLoss() + self.data_val = None + self.json_log_path = os.path.join(self.work_dir, + '{}.log.json'.format(self.timestamp)) + self.conf_files = [] + for conf_key in self.conf_keys: + template_file = os.path.join(self.model_dir, conf_key) + conf_file = os.path.join(self.model_dir, f'{conf_key}.conf') + update_conf(template_file, conf_file, custom_conf[conf_key]) + self.conf_files.append(conf_file) + self._current_epoch = 0 + self.stages = (math.floor(self._max_epochs * EASY_RATIO), + math.floor(self._max_epochs * NORMAL_RATIO), + math.floor(self._max_epochs * HARD_RATIO)) + + def build_model(self) -> nn.Module: + """ Instantiate a pytorch model and return. + + By default, we will create a model using config from configuration file. You can + override this method in a subclass. + + """ + model = Model.from_pretrained( + self.model_dir, cfg_dict=self.cfg, training=True) + if isinstance(model, TorchModel) and hasattr(model, 'model'): + return model.model + elif isinstance(model, nn.Module): + return model + + def train(self, *args, **kwargs): + if not self.data_val: + self.gen_val() + logger.info('Start training...') + totaltime = datetime.datetime.now() + + for stage, num_epoch in enumerate(self.stages): + self.run_stage(stage, num_epoch) + + # total time spent + totaltime = datetime.datetime.now() - totaltime + logger.info('Total time spent: {:.2f} hours\n'.format( + totaltime.total_seconds() / 3600.0)) + + def run_stage(self, stage, num_epoch): + """ + Run training stages with correspond data + + Args: + stage: id of stage + num_epoch: the number of epoch to run in this stage + """ + if num_epoch <= 0: + logger.warning(f'Invalid epoch number, stage {stage} exit!') + return + logger.info(f'Starting stage {stage}...') + dataset, dataloader = self.create_dataloader( + self.conf_files[stage * 2], self.conf_files[stage * 2 + 1]) + it = iter(dataloader) + for _ in range(num_epoch): + self._current_epoch += 1 + epochtime = datetime.datetime.now() + logger.info('Start epoch %d...', self._current_epoch) + loss_train_epoch = 0.0 + validbatchs = 0 + for bi in range(self._train_iters): + # prepare data + feat, label = next(it) + label = torch.reshape(label, (-1, )) + feat = to_device(feat, self.device) + label = to_device(label, self.device) + # apply model + self.optimizer.zero_grad() + predict = self.model(feat) + # calculate loss + loss = self.loss_fn( + torch.reshape(predict, (-1, self._num_classes)), label) + if not np.isnan(loss.item()): + loss.backward() + self.optimizer.step() + loss_train_epoch += loss.item() + validbatchs += 1 + train_result = 'Epoch: {:04d}/{:04d}, batch: {:04d}/{:04d}, loss: {:.4f}'.format( + self._current_epoch, self._max_epochs, bi + 1, + self._train_iters, loss.item()) + logger.info(train_result) + self._dump_log(train_result) + + # average training loss in one epoch + loss_train_epoch /= validbatchs + loss_val_epoch = self.evaluate('') + val_result = 'Evaluate epoch: {:04d}, loss_train: {:.4f}, loss_val: {:.4f}'.format( + self._current_epoch, loss_train_epoch, loss_val_epoch) + logger.info(val_result) + self._dump_log(val_result) + # check point + ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( + self._current_epoch, loss_train_epoch, loss_val_epoch) + torch.save(self.model, os.path.join(self.work_dir, ckpt_name)) + # time spent per epoch + epochtime = datetime.datetime.now() - epochtime + logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( + self._current_epoch, + epochtime.total_seconds() / 3600.0)) + dataloader.stop() + dataset.release() + logger.info(f'Stage {stage} is finished.') + + def gen_val(self): + """ + generate validation set + """ + logger.info('Start generating validation set...') + dataset, dataloader = self.create_dataloader(self.conf_files[2], + self.conf_files[3]) + it = iter(dataloader) + + self.data_val = [] + for bi in range(self._val_iters): + logger.info('Iterating validation data %d', bi) + feat, label = next(it) + label = torch.reshape(label, (-1, )) + self.data_val.append([feat, label]) + + dataloader.stop() + dataset.release() + logger.info('Finish generating validation set!') + + def create_dataloader(self, base_path, finetune_path): + dataset = KWSDataset(base_path, finetune_path, self._threads, + self._single_rate, self._num_classes) + dataloader = KWSDataLoader( + dataset, batchsize=self._batch_size, numworkers=self._threads) + dataloader.start() + return dataset, dataloader + + def evaluate(self, checkpoint_path: str, *args, + **kwargs) -> Dict[str, float]: + logger.info('Start validation...') + loss_val_epoch = 0.0 + + with torch.no_grad(): + for feat, label in self.data_val: + feat = to_device(feat, self.device) + label = to_device(label, self.device) + # apply model + predict = self.model(feat) + # calculate loss + loss = self.loss_fn( + torch.reshape(predict, (-1, self._num_classes)), label) + loss_val_epoch += loss.item() + logger.info('Finish validation.') + return loss_val_epoch / self._val_iters + + def _dump_log(self, msg): + if is_master(): + with open(self.json_log_path, 'a+') as f: + f.write(msg) + f.write('\n') diff --git a/modelscope/trainers/base.py b/modelscope/trainers/base.py new file mode 100644 index 00000000..c0bf51f3 --- /dev/null +++ b/modelscope/trainers/base.py @@ -0,0 +1,90 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import time +from abc import ABC, abstractmethod +from typing import Callable, Dict, List, Optional, Tuple, Union + +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.config import Config +from .utils.log_buffer import LogBuffer + + +class BaseTrainer(ABC): + """ Base class for trainer which can not be instantiated. + + BaseTrainer defines necessary interface + and provide default implementation for basic initialization + such as parsing config file and parsing commandline args. + """ + + def __init__(self, cfg_file: str, arg_parse_fn: Optional[Callable] = None): + """ Trainer basic init, should be called in derived class + + Args: + cfg_file: Path to configuration file. + arg_parse_fn: Same as ``parse_fn`` in :obj:`Config.to_args`. + """ + self.cfg = Config.from_file(cfg_file) + if arg_parse_fn: + self.args = self.cfg.to_args(arg_parse_fn) + else: + self.args = None + self.log_buffer = LogBuffer() + self.timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + @abstractmethod + def train(self, *args, **kwargs): + """ Train (and evaluate) process + + Train process should be implemented for specific task or + model, releated paramters have been intialized in + ``BaseTrainer.__init__`` and should be used in this function + """ + pass + + @abstractmethod + def evaluate(self, checkpoint_path: str, *args, + **kwargs) -> Dict[str, float]: + """ Evaluation process + + Evaluation process should be implemented for specific task or + model, releated paramters have been intialized in + ``BaseTrainer.__init__`` and should be used in this function + """ + pass + + +@TRAINERS.register_module(module_name='dummy') +class DummyTrainer(BaseTrainer): + + def __init__(self, cfg_file: str, *args, **kwargs): + """ Dummy Trainer. + + Args: + cfg_file: Path to configuration file. + """ + super().__init__(cfg_file) + + def train(self, *args, **kwargs): + """ Train (and evaluate) process + + Train process should be implemented for specific task or + model, releated paramters have been intialized in + ``BaseTrainer.__init__`` and should be used in this function + """ + cfg = self.cfg.train + print(f'train cfg {cfg}') + + def evaluate(self, + checkpoint_path: str = None, + *args, + **kwargs) -> Dict[str, float]: + """ Evaluation process + + Evaluation process should be implemented for specific task or + model, releated paramters have been intialized in + ``BaseTrainer.__init__`` and should be used in this function + """ + cfg = self.cfg.evaluation + print(f'eval cfg {cfg}') + print(f'checkpoint_path {checkpoint_path}') diff --git a/modelscope/trainers/builder.py b/modelscope/trainers/builder.py new file mode 100644 index 00000000..87e99b30 --- /dev/null +++ b/modelscope/trainers/builder.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Trainers +from modelscope.utils.config import ConfigDict +from modelscope.utils.constant import Tasks +from modelscope.utils.registry import Registry, build_from_cfg + +TRAINERS = Registry('trainers') +HOOKS = Registry('hooks') + + +def build_trainer(name: str = Trainers.default, default_args: dict = None): + """ build trainer given a trainer name + + Args: + name (str, optional): Trainer name, if None, default trainer + will be used. + default_args (dict, optional): Default initialization arguments. + """ + cfg = dict(type=name) + return build_from_cfg(cfg, TRAINERS, default_args=default_args) diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py new file mode 100644 index 00000000..32c38de2 --- /dev/null +++ b/modelscope/trainers/cv/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_instance_segmentation_trainer import \ + ImageInstanceSegmentationTrainer + from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer + from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer + from .image_inpainting_trainer import ImageInpaintingTrainer + from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer + +else: + _import_structure = { + 'image_instance_segmentation_trainer': + ['ImageInstanceSegmentationTrainer'], + 'image_portrait_enhancement_trainer': + ['ImagePortraitEnhancementTrainer'], + 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], + 'image_inpainting_trainer': ['ImageInpaintingTrainer'], + 'referring_video_object_segmentation_trainer': + ['ReferringVideoObjectSegmentationTrainer'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/cv/card_detection_scrfd_trainer.py b/modelscope/trainers/cv/card_detection_scrfd_trainer.py new file mode 100644 index 00000000..e1f81bcf --- /dev/null +++ b/modelscope/trainers/cv/card_detection_scrfd_trainer.py @@ -0,0 +1,18 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.cv.face_detection_scrfd_trainer import \ + FaceDetectionScrfdTrainer + + +@TRAINERS.register_module(module_name=Trainers.card_detection_scrfd) +class CardDetectionScrfdTrainer(FaceDetectionScrfdTrainer): + + def __init__(self, cfg_file: str, *args, **kwargs): + """ High-level finetune api for SCRFD. + + Args: + cfg_file: Path to configuration file. + """ + # card/face dataset use different img folder names + super().__init__(cfg_file, imgdir_name='', **kwargs) diff --git a/modelscope/trainers/cv/face_detection_scrfd_trainer.py b/modelscope/trainers/cv/face_detection_scrfd_trainer.py new file mode 100644 index 00000000..9cfae7dd --- /dev/null +++ b/modelscope/trainers/cv/face_detection_scrfd_trainer.py @@ -0,0 +1,154 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import os +import os.path as osp +import time +from typing import Callable, Dict, Optional + +from modelscope.metainfo import Trainers +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS + + +@TRAINERS.register_module(module_name=Trainers.face_detection_scrfd) +class FaceDetectionScrfdTrainer(BaseTrainer): + + def __init__(self, + cfg_file: str, + cfg_modify_fn: Optional[Callable] = None, + *args, + **kwargs): + """ High-level finetune api for SCRFD. + + Args: + cfg_file: Path to configuration file. + cfg_modify_fn: An input fn which is used to modify the cfg read out of the file. + """ + import mmcv + from mmcv.runner import get_dist_info, init_dist + from mmcv.utils import get_git_hash + from mmdet.utils import collect_env, get_root_logger + from mmdet.apis import set_random_seed + from mmdet.models import build_detector + from mmdet.datasets import build_dataset + from mmdet import __version__ + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets import RetinaFaceDataset + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import DefaultFormatBundleV2 + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import LoadAnnotationsV2 + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RotateV2 + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.datasets.pipelines import RandomSquareCrop + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.backbones import ResNetV1e + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.dense_heads import SCRFDHead + from modelscope.models.cv.face_detection.scrfd.mmdet_patch.models.detectors import SCRFD + super().__init__(cfg_file) + cfg = self.cfg + if 'work_dir' in kwargs: + cfg.work_dir = kwargs['work_dir'] + else: + # use config filename as default work_dir if work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(cfg_file))[0]) + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + + if 'resume_from' in kwargs: # pretrain model for finetune + cfg.resume_from = kwargs['resume_from'] + cfg.device = 'cuda' + if 'gpu_ids' in kwargs: + cfg.gpu_ids = kwargs['gpu_ids'] + else: + cfg.gpu_ids = range(1) + labelfile_name = kwargs.pop('labelfile_name', 'labelv2.txt') + imgdir_name = kwargs.pop('imgdir_name', 'images/') + if 'train_root' in kwargs: + cfg.data.train.ann_file = kwargs['train_root'] + labelfile_name + cfg.data.train.img_prefix = kwargs['train_root'] + imgdir_name + if 'val_root' in kwargs: + cfg.data.val.ann_file = kwargs['val_root'] + labelfile_name + cfg.data.val.img_prefix = kwargs['val_root'] + imgdir_name + if 'total_epochs' in kwargs: + cfg.total_epochs = kwargs['total_epochs'] + if cfg_modify_fn is not None: + cfg = cfg_modify_fn(cfg) + if 'launcher' in kwargs: + distributed = True + init_dist(kwargs['launcher'], **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + else: + distributed = False + # no_validate=True will not evaluate checkpoint during training + cfg.no_validate = kwargs.get('no_validate', False) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + meta['config'] = cfg.pretty_text + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + if 'seed' in kwargs: + cfg.seed = kwargs['seed'] + _deterministic = kwargs.get('deterministic', False) + logger.info(f'Set random seed to {kwargs["seed"]}, ' + f'deterministic: {_deterministic}') + set_random_seed(kwargs['seed'], deterministic=_deterministic) + else: + cfg.seed = None + meta['seed'] = cfg.seed + meta['exp_name'] = osp.basename(cfg_file) + + model = build_detector(cfg.model) + model.init_weights() + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + val_dataset.pipeline = cfg.data.train.pipeline + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=__version__ + get_git_hash()[:7], + CLASSES=datasets[0].CLASSES) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + + self.cfg = cfg + self.datasets = datasets + self.model = model + self.distributed = distributed + self.timestamp = timestamp + self.meta = meta + self.logger = logger + + def train(self, *args, **kwargs): + from mmdet.apis import train_detector + train_detector( + self.model, + self.datasets, + self.cfg, + distributed=self.distributed, + validate=(not self.cfg.no_validate), + timestamp=self.timestamp, + meta=self.meta) + + def evaluate(self, + checkpoint_path: str = None, + *args, + **kwargs) -> Dict[str, float]: + cfg = self.cfg.evaluation + logger.info(f'eval cfg {cfg}') + logger.info(f'checkpoint_path {checkpoint_path}') diff --git a/modelscope/trainers/cv/image_inpainting_trainer.py b/modelscope/trainers/cv/image_inpainting_trainer.py new file mode 100644 index 00000000..74d1ed9f --- /dev/null +++ b/modelscope/trainers/cv/image_inpainting_trainer.py @@ -0,0 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import time +from collections.abc import Mapping + +from torch import distributed as dist + +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, + ConfigKeys, Hubs, ModeKeys, ModelFile, + Tasks, TrainerStages) +from modelscope.utils.data_utils import to_device +from modelscope.utils.file_utils import func_receive_dict_inputs + + +@TRAINERS.register_module(module_name=Trainers.image_inpainting) +class ImageInpaintingTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train(self, *args, **kwargs): + super().train(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + metric_values = super().evaluate(*args, **kwargs) + return metric_values + + def prediction_step(self, model, inputs): + pass + + def train_loop(self, data_loader): + """ Training loop used by `EpochBasedTrainer.train()` + """ + self.invoke_hook(TrainerStages.before_run) + self._epoch = 0 + self.model.train() + for _ in range(self._epoch, self._max_epochs): + self.invoke_hook(TrainerStages.before_train_epoch) + for i, data_batch in enumerate(data_loader): + data_batch = to_device(data_batch, self.device) + self.data_batch = data_batch + self._inner_iter = i + for idx in range(2): + self.invoke_hook(TrainerStages.before_train_iter) + self.train_step(self.model, data_batch, idx) + self.invoke_hook(TrainerStages.after_train_iter) + del self.data_batch + self._iter += 1 + self._mode = ModeKeys.TRAIN + + if i + 1 >= self.iters_per_epoch: + break + + self.invoke_hook(TrainerStages.after_train_epoch) + self._epoch += 1 + + self.invoke_hook(TrainerStages.after_run) + + def train_step(self, model, inputs, idx): + """ Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`TorchModel`): The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + # EvaluationHook will do evaluate and change mode to val, return to train mode + # TODO: find more pretty way to change mode + model.train() + self._mode = ModeKeys.TRAIN + # call model forward but not __call__ to skip postprocess + if isinstance(inputs, + Mapping) and not func_receive_dict_inputs(model.forward): + train_outputs = model.model._do_step(**inputs, optimizer_idx=idx) + else: + train_outputs = model.model._do_step(inputs, optimizer_idx=idx) + + if not isinstance(train_outputs, dict): + raise TypeError('"model.forward()" must return a dict') + + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + + self.train_outputs = train_outputs diff --git a/modelscope/trainers/cv/image_instance_segmentation_trainer.py b/modelscope/trainers/cv/image_instance_segmentation_trainer.py new file mode 100644 index 00000000..a777bde1 --- /dev/null +++ b/modelscope/trainers/cv/image_instance_segmentation_trainer.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer + + +@TRAINERS.register_module(module_name=Trainers.image_instance_segmentation) +class ImageInstanceSegmentationTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def collate_fn(self, data): + # we skip this func due to some special data type, e.g., BitmapMasks + return data + + def train(self, *args, **kwargs): + super().train(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + metric_values = super().evaluate(*args, **kwargs) + return metric_values + + def prediction_step(self, model, inputs): + pass diff --git a/modelscope/trainers/cv/image_portrait_enhancement_trainer.py b/modelscope/trainers/cv/image_portrait_enhancement_trainer.py new file mode 100644 index 00000000..0941d1cd --- /dev/null +++ b/modelscope/trainers/cv/image_portrait_enhancement_trainer.py @@ -0,0 +1,148 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections.abc import Mapping + +import torch +from torch import distributed as dist + +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import ModeKeys +from modelscope.utils.logger import get_logger + + +@TRAINERS.register_module(module_name=Trainers.image_portrait_enhancement) +class ImagePortraitEnhancementTrainer(EpochBasedTrainer): + + def train_step(self, model, inputs): + """ Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`TorchModel`): The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + # EvaluationHook will do evaluate and change mode to val, return to train mode + # TODO: find more pretty way to change mode + self.d_reg_every = self.cfg.train.get('d_reg_every', 16) + self.g_reg_every = self.cfg.train.get('g_reg_every', 4) + self.path_regularize = self.cfg.train.get('path_regularize', 2) + self.r1 = self.cfg.train.get('r1', 10) + + train_outputs = dict() + self._mode = ModeKeys.TRAIN + # call model forward but not __call__ to skip postprocess + if isinstance(inputs, Mapping): + d_loss = model._train_forward_d(**inputs) + else: + d_loss = model._train_forward_d(inputs) + train_outputs['d_loss'] = d_loss + + model.discriminator.zero_grad() + d_loss.backward() + self.optimizer_d.step() + + if self._iter % self.d_reg_every == 0: + + if isinstance(inputs, Mapping): + r1_loss = model._train_forward_d_r1(**inputs) + else: + r1_loss = model._train_forward_d_r1(inputs) + train_outputs['r1_loss'] = r1_loss + + model.discriminator.zero_grad() + (self.r1 / 2 * r1_loss * self.d_reg_every).backward() + + self.optimizer_d.step() + + if isinstance(inputs, Mapping): + g_loss = model._train_forward_g(**inputs) + else: + g_loss = model._train_forward_g(inputs) + train_outputs['g_loss'] = g_loss + + model.generator.zero_grad() + g_loss.backward() + self.optimizer.step() + + path_loss = 0 + if self._iter % self.g_reg_every == 0: + if isinstance(inputs, Mapping): + path_loss = model._train_forward_g_path(**inputs) + else: + path_loss = model._train_forward_g_path(inputs) + train_outputs['path_loss'] = path_loss + + model.generator.zero_grad() + weighted_path_loss = self.path_regularize * self.g_reg_every * path_loss + + weighted_path_loss.backward() + + self.optimizer.step() + + model.accumulate() + + if not isinstance(train_outputs, dict): + raise TypeError('"model.forward()" must return a dict') + + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + + self.train_outputs = train_outputs + + def create_optimizer_and_scheduler(self): + """ Create optimizer and lr scheduler + + We provide a default implementation, if you want to customize your own optimizer + and lr scheduler, you can either pass a tuple through trainer init function or + subclass this class and override this method. + + + """ + optimizer, lr_scheduler = self.optimizers + if optimizer is None: + optimizer_cfg = self.cfg.train.get('optimizer', None) + else: + optimizer_cfg = None + optimizer_d_cfg = self.cfg.train.get('optimizer_d', None) + + optim_options = {} + if optimizer_cfg is not None: + optim_options = optimizer_cfg.pop('options', {}) + optimizer = build_optimizer( + self.model.generator, cfg=optimizer_cfg) + if optimizer_d_cfg is not None: + optimizer_d = build_optimizer( + self.model.discriminator, cfg=optimizer_d_cfg) + + lr_options = {} + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.optimizer_d = optimizer_d + return self.optimizer, self.lr_scheduler, optim_options, lr_options diff --git a/modelscope/trainers/cv/movie_scene_segmentation_trainer.py b/modelscope/trainers/cv/movie_scene_segmentation_trainer.py new file mode 100644 index 00000000..7645f9f3 --- /dev/null +++ b/modelscope/trainers/cv/movie_scene_segmentation_trainer.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer + + +@TRAINERS.register_module(module_name=Trainers.movie_scene_segmentation) +class MovieSceneSegmentationTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train(self, *args, **kwargs): + super().train(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + metric_values = super().evaluate(*args, **kwargs) + return metric_values + + def prediction_step(self, model, inputs): + pass diff --git a/modelscope/trainers/cv/referring_video_object_segmentation_trainer.py b/modelscope/trainers/cv/referring_video_object_segmentation_trainer.py new file mode 100644 index 00000000..c15df3a5 --- /dev/null +++ b/modelscope/trainers/cv/referring_video_object_segmentation_trainer.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +import torch + +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import ModeKeys + + +@TRAINERS.register_module( + module_name=Trainers.referring_video_object_segmentation) +class ReferringVideoObjectSegmentationTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model.set_postprocessor(self.cfg.dataset.name) + self.train_data_collator = self.train_dataset.collator + self.eval_data_collator = self.eval_dataset.collator + + device_name = kwargs.get('device', 'gpu') + self.model.set_device(self.device, device_name) + + def train(self, *args, **kwargs): + self.model.criterion.train() + super().train(*args, **kwargs) + + def evaluate(self, checkpoint_path=None): + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + CheckpointHook.load_checkpoint(checkpoint_path, self) + self.model.eval() + self._mode = ModeKeys.EVAL + if self.eval_dataset is None: + self.eval_dataloader = self.get_eval_data_loader() + else: + self.eval_dataloader = self._build_dataloader_with_dataset( + self.eval_dataset, + dist=self._dist, + seed=self._seed, + collate_fn=self.eval_data_collator, + **self.cfg.evaluation.get('dataloader', {})) + self.data_loader = self.eval_dataloader + + from modelscope.metrics import build_metric + ann_file = self.eval_dataset.ann_file + metric_classes = [] + for metric in self.metrics: + metric.update({'ann_file': ann_file}) + metric_classes.append(build_metric(metric)) + + for m in metric_classes: + m.trainer = self + + metric_values = self.evaluation_loop(self.eval_dataloader, + metric_classes) + + self._metric_values = metric_values + return metric_values + + def prediction_step(self, model, inputs): + pass diff --git a/modelscope/trainers/default_config.py b/modelscope/trainers/default_config.py new file mode 100644 index 00000000..a02478b9 --- /dev/null +++ b/modelscope/trainers/default_config.py @@ -0,0 +1,34 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.config import Config + +DEFAULT_CONFIG = { + 'train': { + 'hooks': [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 10 + }, { + 'type': 'IterTimerHook' + }] + } +} + + +def merge_cfg(cfg: Config): + """Merge the default config into the input cfg. + + This function will pop the default CheckpointHook when the BestCkptSaverHook exists in the input cfg. + + Aegs: + cfg: The input cfg to be merged into. + """ + cfg.merge_from_dict(DEFAULT_CONFIG, force=False) + # pop duplicate hook + + if any(['BestCkptSaverHook' == hook['type'] for hook in cfg.train.hooks]): + cfg.train.hooks = list( + filter(lambda hook: hook['type'] != 'CheckpointHook', + cfg.train.hooks)) diff --git a/modelscope/trainers/easycv/__init__.py b/modelscope/trainers/easycv/__init__.py new file mode 100644 index 00000000..b1b8fc15 --- /dev/null +++ b/modelscope/trainers/easycv/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .utils import AddLrLogHook, EasyCVMetric +else: + _import_structure = {'utils': ['AddLrLogHook', 'EasyCVMetric']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/easycv/trainer.py b/modelscope/trainers/easycv/trainer.py new file mode 100644 index 00000000..3c869495 --- /dev/null +++ b/modelscope/trainers/easycv/trainer.py @@ -0,0 +1,167 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.data import Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import TorchModel +from modelscope.msdatasets import MsDataset +from modelscope.preprocessors import Preprocessor +from modelscope.trainers import EpochBasedTrainer +from modelscope.trainers.base import TRAINERS +from modelscope.trainers.easycv.utils import register_util +from modelscope.trainers.hooks import HOOKS +from modelscope.trainers.parallel.builder import build_parallel +from modelscope.trainers.parallel.utils import is_parallel +from modelscope.utils.config import Config +from modelscope.utils.constant import DEFAULT_MODEL_REVISION +from modelscope.utils.import_utils import LazyImportModule +from modelscope.utils.registry import default_group + + +@TRAINERS.register_module(module_name=Trainers.easycv) +class EasyCVEpochBasedTrainer(EpochBasedTrainer): + """Epoch based Trainer for EasyCV. + + Args: + cfg_file(str): The config file of EasyCV. + model (:obj:`torch.nn.Module` or :obj:`TorchModel` or `str`): The model to be run, or a valid model dir + or a model id. If model is None, build_model method will be called. + train_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): + The dataset to use for training. + Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a + distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a + `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will + manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally + sets the seed of the RNGs used. + eval_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): The dataset to use for evaluation. + preprocessor (:obj:`Preprocessor`, *optional*): The optional preprocessor. + NOTE: If the preprocessor has been called before the dataset fed into this trainer by user's custom code, + this parameter should be None, meanwhile remove the 'preprocessor' key from the cfg_file. + Else the preprocessor will be instantiated from the cfg_file or assigned from this parameter and + this preprocessing action will be executed every time the dataset's __getitem__ is called. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple + containing the optimizer and the scheduler to use. + max_epochs: (int, optional): Total training epochs. + """ + + def __init__( + self, + cfg_file: Optional[str] = None, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + arg_parse_fn: Optional[Callable] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Preprocessor] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + **kwargs): + + register_util.register_parallel() + register_util.register_part_mmcv_hooks_to_ms() + + super(EasyCVEpochBasedTrainer, self).__init__( + model=model, + cfg_file=cfg_file, + arg_parse_fn=arg_parse_fn, + preprocessor=preprocessor, + optimizers=optimizers, + model_revision=model_revision, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + **kwargs) + + # reset data_collator + from mmcv.parallel import collate + + self.train_data_collator = partial( + collate, + samples_per_gpu=self.cfg.train.dataloader.batch_size_per_gpu) + self.eval_data_collator = partial( + collate, + samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu) + + # Register easycv hooks dynamicly. If the hook already exists in modelscope, + # the hook in modelscope will be used, otherwise register easycv hook into ms. + # We must manually trigger lazy import to detect whether the hook is in modelscope. + # TODO: use ast index to detect whether the hook is in modelscope + for h_i in self.cfg.train.get('hooks', []): + sig = ('HOOKS', default_group, h_i['type']) + LazyImportModule.import_module(sig) + if h_i['type'] not in HOOKS._modules[default_group]: + if h_i['type'] in [ + 'TensorboardLoggerHookV2', 'WandbLoggerHookV2' + ]: + raise ValueError( + 'Not support hook %s now, we will support it in the future!' + % h_i['type']) + register_util.register_hook_to_ms(h_i['type'], self.logger) + + # reset parallel + if not self._dist: + assert not is_parallel( + self.model + ), 'Not support model wrapped by custom parallel if not in distributed mode!' + dp_cfg = dict( + type='MMDataParallel', + module=self.model, + device_ids=[torch.cuda.current_device()]) + self.model = build_parallel(dp_cfg) + + def create_optimizer_and_scheduler(self): + """ Create optimizer and lr scheduler + """ + optimizer, lr_scheduler = self.optimizers + if optimizer is None: + optimizer_cfg = self.cfg.train.get('optimizer', None) + else: + optimizer_cfg = None + + optim_options = {} + if optimizer_cfg is not None: + optim_options = optimizer_cfg.pop('options', {}) + from easycv.apis.train import build_optimizer + optimizer = build_optimizer(self.model, optimizer_cfg) + + if lr_scheduler is None: + lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) + else: + lr_scheduler_cfg = None + + lr_options = {} + # Adapt to mmcv lr scheduler hook. + # Please refer to: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py + if lr_scheduler_cfg is not None: + assert optimizer is not None + lr_options = lr_scheduler_cfg.pop('options', {}) + assert 'policy' in lr_scheduler_cfg + policy_type = lr_scheduler_cfg.pop('policy') + if policy_type == policy_type.lower(): + policy_type = policy_type.title() + hook_type = policy_type + 'LrUpdaterHook' + lr_scheduler_cfg['type'] = hook_type + + self.cfg.train.lr_scheduler_hook = lr_scheduler_cfg + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + return self.optimizer, self.lr_scheduler, optim_options, lr_options + + def to_parallel(self, model) -> Union[nn.Module, TorchModel]: + if self.cfg.get('parallel', None) is not None: + self.cfg.parallel.update( + dict(module=model, device_ids=[torch.cuda.current_device()])) + return build_parallel(self.cfg.parallel) + + dp_cfg = dict( + type='MMDistributedDataParallel', + module=model, + device_ids=[torch.cuda.current_device()]) + + return build_parallel(dp_cfg) diff --git a/modelscope/trainers/easycv/utils/__init__.py b/modelscope/trainers/easycv/utils/__init__.py new file mode 100644 index 00000000..23cfa36a --- /dev/null +++ b/modelscope/trainers/easycv/utils/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .hooks import AddLrLogHook + from .metric import EasyCVMetric + +else: + _import_structure = {'hooks': ['AddLrLogHook'], 'metric': ['EasyCVMetric']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/easycv/utils/hooks.py b/modelscope/trainers/easycv/utils/hooks.py new file mode 100644 index 00000000..62bc6d1e --- /dev/null +++ b/modelscope/trainers/easycv/utils/hooks.py @@ -0,0 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.trainers.hooks import HOOKS, Priority +from modelscope.trainers.hooks.lr_scheduler_hook import LrSchedulerHook +from modelscope.utils.constant import LogKeys + + +@HOOKS.register_module(module_name='AddLrLogHook') +class AddLrLogHook(LrSchedulerHook): + """For EasyCV to adapt to ModelScope, the lr log of EasyCV is added in the trainer, + but the trainer of ModelScope does not and it is added in the lr scheduler hook. + But The lr scheduler hook used by EasyCV is the hook of mmcv, and there is no lr log. + It will be deleted in the future. + """ + PRIORITY = Priority.NORMAL + + def __init__(self): + pass + + def before_run(self, trainer): + pass + + def before_train_iter(self, trainer): + trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) + + def before_train_epoch(self, trainer): + trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) + + def after_train_epoch(self, trainer): + pass diff --git a/modelscope/trainers/easycv/utils/metric.py b/modelscope/trainers/easycv/utils/metric.py new file mode 100644 index 00000000..53937b67 --- /dev/null +++ b/modelscope/trainers/easycv/utils/metric.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import itertools +from typing import Dict + +import numpy as np +import torch + +from modelscope.metrics.base import Metric +from modelscope.metrics.builder import METRICS + + +@METRICS.register_module(module_name='EasyCVMetric') +class EasyCVMetric(Metric): + """Adapt to ModelScope Metric for EasyCV evaluator. + """ + + def __init__(self, trainer=None, evaluators=None, *args, **kwargs): + from easycv.core.evaluation.builder import build_evaluator + + self.trainer = trainer + self.evaluators = build_evaluator(evaluators) + self.preds = [] + self.grountruths = [] + + def add(self, outputs: Dict, inputs: Dict): + self.preds.append(outputs) + del inputs + + def evaluate(self): + results = {} + for _, batch in enumerate(self.preds): + for k, v in batch.items(): + if k not in results: + results[k] = [] + results[k].append(v) + + for k, v in results.items(): + if len(v) == 0: + raise ValueError(f'empty result for {k}') + + if isinstance(v[0], torch.Tensor): + results[k] = torch.cat(v, 0) + elif isinstance(v[0], (list, np.ndarray)): + results[k] = list(itertools.chain.from_iterable(v)) + else: + raise ValueError( + f'value of batch prediction dict should only be tensor or list, {k} type is {v[0]}' + ) + + metric_values = self.trainer.eval_dataset.evaluate( + results, self.evaluators) + return metric_values diff --git a/modelscope/trainers/easycv/utils/register_util.py b/modelscope/trainers/easycv/utils/register_util.py new file mode 100644 index 00000000..04bf719b --- /dev/null +++ b/modelscope/trainers/easycv/utils/register_util.py @@ -0,0 +1,97 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import inspect +import logging + +from modelscope.trainers.hooks import HOOKS +from modelscope.trainers.parallel.builder import PARALLEL +from modelscope.utils.registry import default_group + + +class _RegisterManager: + + def __init__(self): + self.registries = {} + + def add(self, module, name, group_key=default_group): + if module.name not in self.registries: + self.registries[module.name] = {} + if group_key not in self.registries[module.name]: + self.registries[module.name][group_key] = [] + + self.registries[module.name][group_key].append(name) + + def exists(self, module, name, group_key=default_group): + if self.registries.get(module.name, None) is None: + return False + if self.registries[module.name].get(group_key, None) is None: + return False + if name in self.registries[module.name][group_key]: + return True + + return False + + +_dynamic_register = _RegisterManager() + + +def register_parallel(): + from mmcv.parallel import MMDistributedDataParallel, MMDataParallel + + mmddp = 'MMDistributedDataParallel' + mmdp = 'MMDataParallel' + + if not _dynamic_register.exists(PARALLEL, mmddp): + _dynamic_register.add(PARALLEL, mmddp) + PARALLEL.register_module( + module_name=mmddp, module_cls=MMDistributedDataParallel) + if not _dynamic_register.exists(PARALLEL, mmdp): + _dynamic_register.add(PARALLEL, mmdp) + PARALLEL.register_module(module_name=mmdp, module_cls=MMDataParallel) + + +def register_hook_to_ms(hook_name, logger=None): + """Register EasyCV hook to ModelScope.""" + from easycv.hooks import HOOKS as _EV_HOOKS + + if hook_name not in _EV_HOOKS._module_dict: + raise ValueError( + f'Not found hook "{hook_name}" in EasyCV hook registries!') + + if _dynamic_register.exists(HOOKS, hook_name): + return + _dynamic_register.add(HOOKS, hook_name) + + obj = _EV_HOOKS._module_dict[hook_name] + HOOKS.register_module(module_name=hook_name, module_cls=obj) + + log_str = f'Register hook "{hook_name}" to modelscope hooks.' + logger.info(log_str) if logger is not None else logging.info(log_str) + + +def register_part_mmcv_hooks_to_ms(): + """Register required mmcv hooks to ModelScope. + Currently we only registered all lr scheduler hooks in EasyCV and mmcv. + Please refer to: + EasyCV: https://github.com/alibaba/EasyCV/blob/master/easycv/hooks/lr_update_hook.py + mmcv: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py + """ + from mmcv.runner.hooks import lr_updater + from mmcv.runner.hooks import HOOKS as _MMCV_HOOKS + from easycv.hooks import StepFixCosineAnnealingLrUpdaterHook, YOLOXLrUpdaterHook + + mmcv_hooks_in_easycv = [('StepFixCosineAnnealingLrUpdaterHook', + StepFixCosineAnnealingLrUpdaterHook), + ('YOLOXLrUpdaterHook', YOLOXLrUpdaterHook)] + + members = inspect.getmembers(lr_updater) + members.extend(mmcv_hooks_in_easycv) + + for name, obj in members: + if name in _MMCV_HOOKS._module_dict: + if _dynamic_register.exists(HOOKS, name): + continue + _dynamic_register.add(HOOKS, name) + HOOKS.register_module( + module_name=name, + module_cls=obj, + ) diff --git a/modelscope/trainers/hooks/__init__.py b/modelscope/trainers/hooks/__init__.py new file mode 100644 index 00000000..a2e0cf4b --- /dev/null +++ b/modelscope/trainers/hooks/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .builder import HOOKS, build_hook + from .checkpoint_hook import BestCkptSaverHook, CheckpointHook + from .compression import SparsityHook + from .evaluation_hook import EvaluationHook + from .hook import Hook + from .iter_timer_hook import IterTimerHook + from .logger import TensorboardHook, TextLoggerHook + from .lr_scheduler_hook import LrSchedulerHook + from .optimizer import (ApexAMPOptimizerHook, NoneOptimizerHook, + OptimizerHook, TorchAMPOptimizerHook) + from .priority import Priority, get_priority + +else: + _import_structure = { + 'builder': ['HOOKS', 'build_hook'], + 'checkpoint_hook': ['BestCkptSaverHook', 'CheckpointHook'], + 'compression': ['SparsityHook'], + 'evaluation_hook': ['EvaluationHook'], + 'hook': ['Hook'], + 'iter_timer_hook': ['IterTimerHook'], + 'logger': ['TensorboardHook', 'TextLoggerHook'], + 'lr_scheduler_hook': ['LrSchedulerHook'], + 'optimizer_hook': [ + 'ApexAMPOptimizerHook', 'NoneOptimizerHook', 'OptimizerHook', + 'TorchAMPOptimizerHook' + ], + 'priority': ['Priority', 'get'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/hooks/builder.py b/modelscope/trainers/hooks/builder.py new file mode 100644 index 00000000..1948e481 --- /dev/null +++ b/modelscope/trainers/hooks/builder.py @@ -0,0 +1,9 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.utils.registry import Registry, build_from_cfg, default_group + +HOOKS = Registry('hooks') + + +def build_hook(cfg, default_args=None): + return build_from_cfg( + cfg, HOOKS, group_key=default_group, default_args=default_args) diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py new file mode 100644 index 00000000..89aa39ba --- /dev/null +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -0,0 +1,319 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import random + +import numpy as np +import torch + +from modelscope import __version__ +from modelscope.metainfo import Hooks +from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint +from modelscope.utils.constant import LogKeys, ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import get_dist_info, is_master +from .builder import HOOKS +from .hook import Hook +from .priority import Priority + + +@HOOKS.register_module(module_name=Hooks.CheckpointHook) +class CheckpointHook(Hook): + """Save checkpoints periodically. + + Args: + interval (int): The frequency to save model. If `by_epoch=True`, + it means the number of epochs, else means the number of iterations + by_epoch (bool): Saving checkpoints by epoch or by iteration. + save_optimizer (bool): Whether to save optimizer state dict. Default: True. + save_dir (str): The directory to save checkpoints. If is None, use `trainer.work_dir` + save_last (bool): Whether to save the last checkpoint. Default: True. + checkpoint_file (str): The checkpoint file to be loaded. + """ + + PRIORITY = Priority.LOW + + def __init__(self, + interval=0, + by_epoch=True, + save_optimizer=True, + save_dir=None, + save_last=True, + checkpoint_file=None): + self.interval = interval + self.by_epoch = by_epoch + self.save_optimizer = save_optimizer + self.save_dir = save_dir + self.checkpoint_file = checkpoint_file + self.save_last = save_last + self.rng_state = None + self.need_load_rng_state = False + + def before_run(self, trainer): + if not self.save_dir: + self.save_dir = trainer.work_dir + + if not os.path.exists(self.save_dir) and is_master(): + os.makedirs(self.save_dir) + + if not hasattr(trainer, 'logger'): + self.logger = get_logger(__name__) + else: + self.logger = trainer.logger + + if is_master(): + self.logger.info(f'Checkpoints will be saved to {self.save_dir}') + + if self.checkpoint_file is not None and os.path.isfile( + self.checkpoint_file): + meta = self.load_checkpoint(self.checkpoint_file, trainer) + self.rng_state = meta.get('rng_state') + self.need_load_rng_state = True + + def before_train_iter(self, trainer): + if self.need_load_rng_state: + if self.rng_state is not None: + random.setstate(self.rng_state['random']) + np.random.set_state(self.rng_state['numpy']) + torch.random.set_rng_state(self.rng_state['cpu']) + if torch.cuda.is_available(): + torch.cuda.random.set_rng_state_all(self.rng_state['cuda']) + self.need_load_rng_state = False + else: + self.logger.warn( + 'Random state cannot be found in checkpoint file, ' + 'this may cause a random data order or model initialization.' + ) + + def after_train_epoch(self, trainer): + if not self.by_epoch: + return + + if self._should_save(trainer): + if is_master(): + self.logger.info( + f'Saving checkpoint at {trainer.epoch + 1} epoch') + self._save_checkpoint(trainer) + + @classmethod + def load_checkpoint(cls, filename, trainer): + from modelscope.trainers.parallel.utils import is_parallel + if is_parallel(trainer.model): + model = trainer.model.module + else: + model = trainer.model + meta = load_checkpoint(filename, model, + getattr(trainer, 'optimizer', None), + getattr(trainer, 'lr_scheduler', None)) + trainer._epoch = meta.get('epoch', trainer._epoch) + trainer._iter = meta.get('iter', trainer._iter) + trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) + + for i, hook in enumerate(trainer.hooks): + # hook: Hook + key = f'{hook.__class__}-{i}' + if key in meta and hasattr(hook, 'load_state_dict'): + hook.load_state_dict(meta.get(key, {})) + else: + trainer.logger.warn( + f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' + ) + + version = meta.get('modelscope') + if version != __version__: + trainer.logger.warn( + f'The modelscope version of loaded checkpoint does not match the runtime version. ' + f'The saved version: {version}, runtime version: {__version__}' + ) + trainer.logger.info( + f'Checkpoint {filename} saving time: {meta.get("time")}') + return meta + + def _save_checkpoint(self, trainer): + if self.by_epoch: + cur_save_name = os.path.join( + self.save_dir, f'{LogKeys.EPOCH}_{trainer.epoch + 1}.pth') + else: + cur_save_name = os.path.join( + self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') + + self.rng_state = { + 'random': random.getstate(), + 'numpy': np.random.get_state(), + 'cpu': torch.random.get_rng_state(), + 'cuda': torch.cuda.get_rng_state_all(), + } + meta = { + 'epoch': trainer.epoch, + 'iter': trainer.iter + 1, + 'inner_iter': trainer.inner_iter + 1, + 'rng_state': self.rng_state, + } + for i, hook in enumerate(trainer.hooks): + if hasattr(hook, 'state_dict'): + meta[f'{hook.__class__}-{i}'] = hook.state_dict() + + save_checkpoint( + trainer.model, + cur_save_name, + trainer.optimizer, + trainer.lr_scheduler, + meta=meta) + if (self.is_last_epoch(trainer) + and self.by_epoch) or (self.is_last_iter(trainer) + and not self.by_epoch): + self._save_pretrained(trainer) + + def _save_pretrained(self, trainer): + output_dir = os.path.join(self.save_dir, ModelFile.TRAIN_OUTPUT_DIR) + from modelscope.trainers.parallel.utils import is_parallel + + if is_parallel(trainer.model): + model = trainer.model.module + else: + model = trainer.model + + config = trainer.cfg.to_dict() + # override pipeline by tasks name after finetune done, + # avoid case like fill mask pipeline with a text cls task + config['pipeline'] = {'type': config['task']} + + if hasattr(model, 'save_pretrained'): + model.save_pretrained( + output_dir, + ModelFile.TORCH_MODEL_BIN_FILE, + save_function=save_checkpoint, + config=config, + with_meta=False) + + def after_train_iter(self, trainer): + if self.by_epoch: + return + + if self._should_save(trainer): + if is_master(): + self.logger.info( + f'Saving checkpoint at {trainer.iter + 1} iterations') + self._save_checkpoint(trainer) + + def _should_save(self, trainer): + if self.by_epoch: + check_last = self.is_last_epoch + check_frequency = self.every_n_epochs + else: + check_last = self.is_last_iter + check_frequency = self.every_n_iters + + if check_frequency(trainer, + self.interval) or (self.save_last + and check_last(trainer)): + return True + return False + + +@HOOKS.register_module(module_name=Hooks.BestCkptSaverHook) +class BestCkptSaverHook(CheckpointHook): + """Save best checkpoints hook. + Args: + metric_key (str): Metric key to compare rule for best score. + rule (str): Comparison rule for best score. + Support "max" and "min". If rule is "max", the checkpoint at the maximum `metric_key` + will be saved, If rule is "min", the checkpoint at the minimum `metric_key` will be saved. + by_epoch (bool): Save best checkpoints by epoch or by iteration. + save_optimizer (bool): Whether to save optimizer state dict. Default: True. + save_dir (str): Output directory to save best checkpoint. + restore_best (bool): Whether to restore the best checkpoint after training. + """ + + PRIORITY = Priority.LOW + rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} + + def __init__(self, + metric_key, + rule='max', + by_epoch=True, + save_optimizer=True, + save_dir=None, + save_file_name=None, + restore_best=False, + interval=0): + assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' + super().__init__( + interval=interval, + by_epoch=by_epoch, + save_optimizer=save_optimizer, + save_dir=save_dir, + ) + self.metric_key = metric_key + self.rule = rule + self._best_metric = None + self._best_ckpt_file = None + self.save_file_name = save_file_name + self.restore_best = restore_best + + def _should_save(self, trainer): + return self._is_best_metric(trainer.metric_values) + + def _is_best_metric(self, metric_values): + if metric_values is None: + return False + + if self.metric_key not in metric_values: + raise ValueError( + f'Not find metric_key: {self.metric_key} in {metric_values}') + + if self._best_metric is None: + self._best_metric = metric_values[self.metric_key] + return True + else: + compare_fn = self.rule_map[self.rule] + if compare_fn(metric_values[self.metric_key], self._best_metric): + self._best_metric = metric_values[self.metric_key] + return True + return False + + def _save_checkpoint(self, trainer): + cur_save_name = self.save_file_name + if cur_save_name is None: + if self.by_epoch: + cur_save_name = os.path.join( + self.save_dir, + f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' + ) + else: + cur_save_name = os.path.join( + self.save_dir, + f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' + ) + + meta = { + 'epoch': trainer.epoch, + 'iter': trainer.iter + 1, + 'inner_iter': trainer.inner_iter + 1, + 'rng_state': self.rng_state, + } + for i, hook in enumerate(trainer.hooks): + meta[f'{hook.__class__}-{i}'] = hook.state_dict() + + if os.path.isfile(cur_save_name): + os.remove(cur_save_name) + save_checkpoint(trainer.model, cur_save_name, trainer.optimizer, + trainer.lr_scheduler, meta) + self._best_ckpt_file = cur_save_name + self._save_pretrained(trainer) + + def state_dict(self): + return { + 'best_metric': self._best_metric, + } + + def load_state_dict(self, state_dict): + if state_dict is not None and len(state_dict) > 0: + self._best_metric = state_dict.get('best_metric') + else: + self.logger.warn( + 'The state_dict is not available, the best metric value will be affected.' + ) + + def after_run(self, trainer): + if self.restore_best: + self.load_checkpoint(self._best_ckpt_file, trainer) diff --git a/modelscope/trainers/hooks/clip_clamp_logit_scale_hook.py b/modelscope/trainers/hooks/clip_clamp_logit_scale_hook.py new file mode 100644 index 00000000..ce98e6c9 --- /dev/null +++ b/modelscope/trainers/hooks/clip_clamp_logit_scale_hook.py @@ -0,0 +1,18 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch + +from modelscope.metainfo import Hooks +from modelscope.trainers.multi_modal.clip.clip_trainer import CLIPTrainer +from .builder import HOOKS +from .hook import Hook + + +@HOOKS.register_module(module_name=Hooks.ClipClampLogitScaleHook) +class ClipClampLogitScaleHook(Hook): + """ClipClampLogitScaleHook hook which performs clamp on CLIP logit scale parameter after update""" + + def after_train_iter(self, trainer: CLIPTrainer): + """Called after every training iter to evaluate the results.""" + unwrapped_model = getattr(trainer.model, 'module', trainer.model) + logit_scale = unwrapped_model.clip_model.logit_scale + logit_scale.data = torch.clamp(logit_scale.data, 0, 4.6052) diff --git a/modelscope/trainers/hooks/compression/__init__.py b/modelscope/trainers/hooks/compression/__init__.py new file mode 100644 index 00000000..f755b2ca --- /dev/null +++ b/modelscope/trainers/hooks/compression/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .sparsity_hook import SparsityHook + from .utils import SparseLinear, convert_sparse_network + +else: + _import_structure = { + 'sparsity_hook': ['SparsityHook'], + 'utils': ['convert_sparse_network', 'SparseLinear'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/hooks/compression/sparsity_hook.py b/modelscope/trainers/hooks/compression/sparsity_hook.py new file mode 100644 index 00000000..993488d8 --- /dev/null +++ b/modelscope/trainers/hooks/compression/sparsity_hook.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +from modelscope import __version__ +from modelscope.metainfo import Hooks +from modelscope.trainers.hooks.builder import HOOKS +from modelscope.trainers.hooks.hook import Hook +from modelscope.trainers.hooks.priority import Priority +from modelscope.utils.checkpoint import save_checkpoint +from modelscope.utils.torch_utils import is_master + + +@HOOKS.register_module(module_name=Hooks.SparsityHook) +class SparsityHook(Hook): + + PRIORITY = Priority.HIGHEST + + def __init__(self, pruning_method, config={}, save_dir=None): + self.pruning_method = pruning_method + self.save_dir = save_dir + + self.compress_module = config.get('compress_module', []) + self.weight_rank = config.get('weight_rank', 8) + self.weight_beta = config.get('weight_beta', 1) + self.mask_rank = config.get('mask_rank', 8) + self.mask_alpha1 = config.get('mask_alpha1', 1) + self.mask_alpha2 = config.get('mask_alpha2', 1) + + self.step = 0 + self.total_step = 0 + self.frequency = config.get('frequency', 1) + self.initial_warmup = config.get('initial_warmup', 0.1) + self.final_warmup = config.get('final_warmup', 0.3) + self.initial_sparsity = config.get('initial_sparsity', 0.0) + self.final_sparsity = config.get('final_sparsity', 0.0) + + def before_run(self, trainer): + import torch + + from .utils import SparseLinear, convert_sparse_network + + if self.save_dir is None: + self.save_dir = trainer.work_dir + + if len(self.compress_module) == 0: + convert_sparse_network( + trainer.model, + pruning_method=self.pruning_method, + weight_rank=self.weight_rank, + weight_beta=self.weight_beta, + mask_rank=self.mask_rank, + mask_alpha1=self.mask_alpha1, + mask_alpha2=self.mask_alpha2, + logger=trainer.logger, + ) + else: + for cm in self.compress_module: + for name, module in trainer.model.named_modules(): + if name != cm: + continue + convert_sparse_network( + module, + pruning_method=self.pruning_method, + weight_rank=self.weight_rank, + weight_beta=self.weight_beta, + mask_rank=self.mask_rank, + mask_alpha1=self.mask_alpha1, + mask_alpha2=self.mask_alpha2, + logger=trainer.logger, + ) + + for i in range(len(trainer.optimizer.param_groups)): + new_train_params = [] + for param in trainer.optimizer.param_groups[i]['params']: + is_find = False + for name, module in trainer.model.named_modules(): + if isinstance(module, SparseLinear): + if torch.equal(param.half(), + module.weight.data.half()): + is_find = True + break + + if not is_find: + new_train_params.append(param) + + trainer.optimizer.param_groups[i]['params'] = new_train_params + + new_params = [] + for name, module in trainer.model.named_modules(): + if isinstance(module, SparseLinear): + new_params.extend( + [p for p in module.parameters() if p.requires_grad]) + + trainer.optimizer.add_param_group({'params': new_params}) + + self.total_step = trainer.iters_per_epoch * trainer._max_epochs + + def before_train_iter(self, trainer): + from .utils import schedule_sparsity_ratio, update_network_sparsity + + cur_sparsity = schedule_sparsity_ratio( + self.step, + self.total_step, + self.frequency, + self.initial_warmup, + self.final_warmup, + self.initial_sparsity, + self.final_sparsity, + ) + + update_network_sparsity(trainer.model, cur_sparsity) + + if is_master(): + trainer.logger.info( + f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}' + ) + + self.step += 1 + + def after_run(self, trainer): + from .utils import generate_sparse_model + + generate_sparse_model(trainer.model, logger=trainer.logger) + + self._save_checkpoint(trainer) + + def _save_checkpoint(self, trainer): + if is_master(): + trainer.logger.info('Saving checkpoint at final compress') + cur_save_name = os.path.join(self.save_dir, 'compress_model.pth') + save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) diff --git a/modelscope/trainers/hooks/compression/utils.py b/modelscope/trainers/hooks/compression/utils.py new file mode 100644 index 00000000..59418201 --- /dev/null +++ b/modelscope/trainers/hooks/compression/utils.py @@ -0,0 +1,208 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn + +from modelscope.utils.torch_utils import is_master + + +class SparseBinarizer(torch.autograd.Function): + + @staticmethod + def forward(ctx, mask_scores, sparsity): + num_prune = int(mask_scores.numel() * sparsity) + prune_indices = torch.argsort(mask_scores.reshape(-1))[:num_prune] + mask = mask_scores.clone().fill_(1) + mask.reshape(-1)[prune_indices] = 0.0 + return mask + + @staticmethod + def backward(ctx, gradOutput): + return gradOutput, None + + +class SparseLinear(nn.Module): + """ + Fully Connected layer with on the fly adaptive mask. + """ + + def __init__( + self, + module, + pruning_method='pst', + weight_rank=8, + weight_beta=1.0, + mask_rank=8, + mask_alpha1=1.0, + mask_alpha2=1.0, + ): + super(SparseLinear, self).__init__() + self.module = module + out_features = self.module.weight.shape[0] + in_features = self.module.weight.shape[1] + + self.weight = self.module.weight + self.module.weight = None + self.module._parameters.pop('weight') + + self.pruning_method = pruning_method + + self.cur_sparsity = 0.0 + + if self.pruning_method == 'pst': + self.weight_rank = weight_rank + self.weight_beta = weight_beta + self.mask_rank = mask_rank + self.mask_alpha1 = mask_alpha1 + self.mask_alpha2 = mask_alpha2 + + # create trainable params + self.weight_U = nn.Parameter( + torch.randn(out_features, self.weight_rank).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.weight_V = nn.Parameter( + torch.zeros(self.weight_rank, in_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + + self.mask_scores_A = nn.Parameter( + torch.randn(out_features, self.mask_rank).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.mask_scores_B = nn.Parameter( + torch.zeros(self.mask_rank, in_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.mask_scores_R = nn.Parameter( + torch.zeros(out_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.mask_scores_C = nn.Parameter( + torch.zeros(in_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + + self.weight.requires_grad = False + if self.module.bias is not None: + self.module.bias.requires_grad = False + + def forward(self, *inputs): + if self.pruning_method == 'pst': + weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V + mask_scores = ( + weight.abs() + + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B + + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) + + self.mask_scores_C.unsqueeze(0))) + + mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) + masked_weight = mask * weight + + self.module.weight = masked_weight + return self.module(*inputs) + else: + return self.module(*inputs) + + def convert(self): + if self.pruning_method == 'pst': + weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V + mask_scores = ( + weight.abs() + + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B + + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) + + self.mask_scores_C.unsqueeze(0))) + + mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) + + masked_weight = mask * weight + self.module.weight = nn.Parameter(masked_weight.data) + + +def _setattr(model, name, module): + name_list = name.split('.') + for name in name_list[:-1]: + model = getattr(model, name) + setattr(model, name_list[-1], module) + + +def convert_sparse_network( + model, + pruning_method, + weight_rank, + weight_beta, + mask_rank, + mask_alpha1, + mask_alpha2, + logger=None, +): + compress_module = [nn.Linear] + try: + from megatron import mpu + compress_module.extend( + [mpu.RowParallelLinear, mpu.ColumnParallelLinear]) + except ImportError: + pass + + for name, module in model.named_modules(): + if type(module) in compress_module: + new_module = SparseLinear( + module, + pruning_method, + weight_rank, + weight_beta, + mask_rank, + mask_alpha1, + mask_alpha2, + ) + + # replace original module by new sparse module + _setattr(model, name, new_module) + + if is_master(): + if logger: + logger.info(f'convert {name} to sparse module.') + else: + print(f'convert {name} to sparse module.') + + +def update_network_sparsity(model, sparsity): + for name, module in model.named_modules(): + if isinstance(module, SparseLinear): + module.cur_sparsity = sparsity + + +def schedule_sparsity_ratio( + step, + total_step, + frequency, + initial_warmup, + final_warmup, + initial_sparsity, + final_sparsity, +): + if step <= initial_warmup * total_step: + sparsity = initial_sparsity + elif step > (total_step - final_warmup * total_step): + sparsity = final_sparsity + else: + spars_warmup_steps = initial_warmup * total_step + spars_schedu_steps = (final_warmup + initial_warmup) * total_step + step = (step - spars_warmup_steps) // frequency * frequency + mul_coeff = 1 - step / (total_step - spars_schedu_steps) + sparsity = final_sparsity + (initial_sparsity - final_sparsity) * ( + mul_coeff**3) + return sparsity + + +def generate_sparse_model(model, logger=None): + # generate sparse weight for saving + for name, module in model.named_modules(): + if isinstance(module, SparseLinear): + module.convert() + + _setattr(model, name, module.module) + + if is_master(): + if logger: + logger.info(f'convert {name} weight to sparse weight, \ + sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' + ) + else: + print(f'convert {name} weight to sparse, \ + sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' + ) diff --git a/modelscope/trainers/hooks/evaluation_hook.py b/modelscope/trainers/hooks/evaluation_hook.py new file mode 100644 index 00000000..4479fa23 --- /dev/null +++ b/modelscope/trainers/hooks/evaluation_hook.py @@ -0,0 +1,70 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Hooks +from .builder import HOOKS +from .hook import Hook + + +@HOOKS.register_module(module_name=Hooks.EvaluationHook) +class EvaluationHook(Hook): + """Evaluation hook. + Args: + interval (int): Evaluation interval. + by_epoch (bool): Evaluate by epoch or by iteration. + start_idx (int | None, optional): The epoch/iterations validation begins. + Default: None, validate every interval epochs/iterations from scratch. + """ + + def __init__(self, interval=1, by_epoch=True, start_idx=None): + assert interval > 0, 'interval must be a positive number' + self.interval = interval + self.start_idx = start_idx + self.by_epoch = by_epoch + + def after_train_iter(self, trainer): + """Called after every training iter to evaluate the results.""" + if not self.by_epoch and self._should_evaluate(trainer): + self.do_evaluate(trainer) + + def after_train_epoch(self, trainer): + """Called after every training epoch to evaluate the results.""" + if self.by_epoch and self._should_evaluate(trainer): + self.do_evaluate(trainer) + + def do_evaluate(self, trainer): + """Evaluate the results.""" + eval_res = trainer.evaluate() + for name, val in eval_res.items(): + trainer.log_buffer.output[name] = val + + trainer.log_buffer.ready = True + + def _should_evaluate(self, trainer): + """Judge whether to perform evaluation. + + Here is the rule to judge whether to perform evaluation: + 1. It will not perform evaluation during the epoch/iteration interval, + which is determined by ``self.interval``. + 2. It will not perform evaluation if the ``start_idx`` is larger than + current epochs/iters. + 3. It will not perform evaluation when current epochs/iters is larger than + the ``start_idx`` but during epoch/iteration interval. + + Returns: + bool: The flag indicating whether to perform evaluation. + """ + if self.by_epoch: + current = trainer.epoch + check_time = self.every_n_epochs + else: + current = trainer.iter + check_time = self.every_n_iters + + if self.start_idx is None: + if not check_time(trainer, self.interval): + return False + elif (current + 1) < self.start_idx: + return False + else: + if (current + 1 - self.start_idx) % self.interval: + return False + return True diff --git a/modelscope/trainers/hooks/hook.py b/modelscope/trainers/hooks/hook.py new file mode 100644 index 00000000..d3805be8 --- /dev/null +++ b/modelscope/trainers/hooks/hook.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.utils.constant import TrainerStages +from modelscope.utils.import_utils import is_method_overridden +from .priority import Priority + + +class Hook: + """ + The Hook base class of any modelscope trainer. You can build your own hook inherited from this class. + """ + + stages = (TrainerStages.before_run, TrainerStages.before_train_epoch, + TrainerStages.before_train_iter, TrainerStages.after_train_iter, + TrainerStages.after_train_epoch, TrainerStages.before_val_epoch, + TrainerStages.before_val_iter, TrainerStages.after_val_iter, + TrainerStages.after_val_epoch, TrainerStages.after_run) + + PRIORITY = Priority.NORMAL + + def before_run(self, trainer): + """ + Will be called before any loop begins. + Args: + trainer: The trainer instance. + + Returns: None + + """ + pass + + def after_run(self, trainer): + """ + Will be called after all loops end. + Args: + trainer: The trainer instance. + + Returns: None + + """ + pass + + def before_epoch(self, trainer): + """ + Will be called before every epoch begins. + Args: + trainer: The trainer instance. + + Returns: None + + """ + pass + + def after_epoch(self, trainer): + """ + Will be called after every epoch ends. + Args: + trainer: The trainer instance. + + Returns: None + + """ + pass + + def before_iter(self, trainer): + """ + Will be called before every loop begins. + Args: + trainer: The trainer instance. + + Returns: None + """ + pass + + def after_iter(self, trainer): + """ + Will be called after every loop ends. + Args: + trainer: The trainer instance. + + Returns: None + """ + pass + + def before_train_epoch(self, trainer): + """ + Will be called before every train epoch begins. Default call ``self.before_epoch`` + Args: + trainer: The trainer instance. + + Returns: None + + """ + self.before_epoch(trainer) + + def before_val_epoch(self, trainer): + """ + Will be called before every validation epoch begins. Default call ``self.before_epoch`` + Args: + trainer: The trainer instance. + + Returns: None + + """ + self.before_epoch(trainer) + + def after_train_epoch(self, trainer): + """ + Will be called after every train epoch ends. Default call ``self.after_epoch`` + Args: + trainer: The trainer instance. + + Returns: None + + """ + self.after_epoch(trainer) + + def after_val_epoch(self, trainer): + """ + Will be called after every validation epoch ends. Default call ``self.after_epoch`` + Args: + trainer: The trainer instance. + + Returns: None + + """ + self.after_epoch(trainer) + + def before_train_iter(self, trainer): + """ + Will be called before every train loop begins. Default call ``self.before_iter`` + Args: + trainer: The trainer instance. + + Returns: None + """ + self.before_iter(trainer) + + def before_val_iter(self, trainer): + """ + Will be called before every validation loop begins. Default call ``self.before_iter`` + Args: + trainer: The trainer instance. + + Returns: None + """ + self.before_iter(trainer) + + def after_train_iter(self, trainer): + """ + Will be called after every train loop ends. Default call ``self.after_iter`` + Args: + trainer: The trainer instance. + + Returns: None + """ + self.after_iter(trainer) + + def after_val_iter(self, trainer): + """ + Will be called after every validation loop ends. Default call ``self.after_iter`` + Args: + trainer: The trainer instance. + + Returns: None + """ + self.after_iter(trainer) + + def every_n_epochs(self, trainer, n): + """ + Whether to reach every ``n`` epochs + Returns: bool + """ + return (trainer.epoch + 1) % n == 0 if n > 0 else False + + def every_n_inner_iters(self, runner, n): + """ + Whether to reach every ``n`` iterations at every epoch + Returns: bool + """ + return (runner.inner_iter + 1) % n == 0 if n > 0 else False + + def every_n_iters(self, trainer, n): + """ + Whether to reach every ``n`` iterations + Returns: bool + """ + return (trainer.iter + 1) % n == 0 if n > 0 else False + + def end_of_epoch(self, trainer): + """ + Whether to reach the end of every epoch + Returns: bool + """ + return trainer.inner_iter + 1 == trainer.iters_per_epoch + + def is_last_epoch(self, trainer): + """ + Whether to reach the last epoch + Returns: bool + """ + return trainer.epoch + 1 == trainer.max_epochs + + def is_last_iter(self, trainer): + """ + Whether to reach the last iteration in the entire training process + Returns: bool + """ + return trainer.iter + 1 == trainer.max_iters + + def get_triggered_stages(self): + trigger_stages = set() + for stage in Hook.stages: + if is_method_overridden(stage, Hook, self): + trigger_stages.add(stage) + + return [stage for stage in Hook.stages if stage in trigger_stages] + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass diff --git a/modelscope/trainers/hooks/iter_timer_hook.py b/modelscope/trainers/hooks/iter_timer_hook.py new file mode 100644 index 00000000..6af78235 --- /dev/null +++ b/modelscope/trainers/hooks/iter_timer_hook.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import time + +from modelscope.metainfo import Hooks +from modelscope.utils.constant import LogKeys +from .builder import HOOKS +from .hook import Hook +from .priority import Priority + + +@HOOKS.register_module(module_name=Hooks.IterTimerHook) +class IterTimerHook(Hook): + PRIORITY = Priority.LOW + + def before_epoch(self, trainer): + self.start_time = time.time() + + def before_iter(self, trainer): + trainer.log_buffer.update( + {LogKeys.DATA_LOAD_TIME: time.time() - self.start_time}) + + def after_iter(self, trainer): + trainer.log_buffer.update( + {LogKeys.ITER_TIME: time.time() - self.start_time}) + self.start_time = time.time() diff --git a/modelscope/trainers/hooks/logger/__init__.py b/modelscope/trainers/hooks/logger/__init__.py new file mode 100644 index 00000000..583cd32b --- /dev/null +++ b/modelscope/trainers/hooks/logger/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.trainers.utils.log_buffer import LogBuffer +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .base import LoggerHook + from .tensorboard_hook import TensorboardHook + from .text_logger_hook import TextLoggerHook + +else: + _import_structure = { + 'base': ['LoggerHook'], + 'tensorboard_hook': ['TensorboardHook'], + 'text_logger_hook': ['TextLoggerHook'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/hooks/logger/base.py b/modelscope/trainers/hooks/logger/base.py new file mode 100644 index 00000000..684c4a8c --- /dev/null +++ b/modelscope/trainers/hooks/logger/base.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. +import numbers +from abc import ABCMeta, abstractmethod + +import numpy as np +import torch + +from modelscope.trainers.hooks.hook import Hook +from modelscope.trainers.hooks.priority import Priority +from modelscope.utils.constant import ModeKeys + + +class LoggerHook(Hook): + """Base class for logger hooks. + + Args: + interval (int): Logging interval (every k iterations). It is interval of iterations even by_epoch is true. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + reset_flag (bool): Whether to clear the output buffer after logging. + by_epoch (bool): Whether EpochBasedtrainer is used. + """ + + __metaclass__ = ABCMeta + PRIORITY = Priority.VERY_LOW + + def __init__(self, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True): + self.interval = interval + self.ignore_last = ignore_last + self.reset_flag = reset_flag + self.by_epoch = by_epoch + + @abstractmethod + def log(self, trainer): + pass + + @staticmethod + def is_scalar(val, include_np=True, include_torch=True): + """Tell the input variable is a scalar or not. + + Args: + val: Input variable. + include_np (bool): Whether to treat 0-d np.ndarray as a scalar. + include_torch (bool): Whether to treat 0-d torch.Tensor as a scalar. + + Returns: + bool: True or False. + """ + if isinstance(val, numbers.Number): + return True + elif include_np and isinstance(val, np.ndarray) and val.ndim == 0: + return True + elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1: + return True + else: + return False + + def fetch_tensor(self, trainer, n=0): + """Fetch latest n values or all values, process tensor type, convert to numpy for dump logs.""" + assert n >= 0 + for key in trainer.log_buffer.val_history: + values = trainer.log_buffer.val_history[key][-n:] + + for i, v in enumerate(values): + if isinstance(v, torch.Tensor): + values[i] = v.clone().detach().cpu().numpy() + + trainer.log_buffer.val_history[key][-n:] = values + + def get_epoch(self, trainer): + if trainer.mode in [ModeKeys.TRAIN, ModeKeys.EVAL]: + epoch = trainer.epoch + 1 + else: + raise ValueError( + f'trainer mode should be {ModeKeys.TRAIN} or {ModeKeys.EVAL}, ' + f'but got {trainer.mode}') + return epoch + + def get_iter(self, trainer, inner_iter=False): + """Get the current training iteration step.""" + if self.by_epoch and inner_iter: + current_iter = trainer.inner_iter + 1 + else: + current_iter = trainer.iter + 1 + return current_iter + + def before_run(self, trainer): + for hook in trainer.hooks[::-1]: + if isinstance(hook, LoggerHook): + hook.reset_flag = True + break + + def before_epoch(self, trainer): + trainer.log_buffer.clear() # clear logs of last epoch + + def after_train_iter(self, trainer): + if self.by_epoch and self.every_n_inner_iters(trainer, self.interval): + self.fetch_tensor(trainer, self.interval) + trainer.log_buffer.average(self.interval) + elif not self.by_epoch and self.every_n_iters(trainer, self.interval): + self.fetch_tensor(trainer, self.interval) + trainer.log_buffer.average(self.interval) + elif self.end_of_epoch(trainer) and not self.ignore_last: + # not precise but more stable + self.fetch_tensor(trainer, self.interval) + trainer.log_buffer.average(self.interval) + + if trainer.log_buffer.ready: + self.log(trainer) + if self.reset_flag: + trainer.log_buffer.clear_output() + + def after_train_epoch(self, trainer): + if trainer.log_buffer.ready: + self.log(trainer) + if self.reset_flag: + trainer.log_buffer.clear_output() + + def after_val_epoch(self, trainer): + self.fetch_tensor(trainer) + trainer.log_buffer.average() + self.log(trainer) + if self.reset_flag: + trainer.log_buffer.clear_output() diff --git a/modelscope/trainers/hooks/logger/tensorboard_hook.py b/modelscope/trainers/hooks/logger/tensorboard_hook.py new file mode 100644 index 00000000..a12f7ae7 --- /dev/null +++ b/modelscope/trainers/hooks/logger/tensorboard_hook.py @@ -0,0 +1,69 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +from modelscope.metainfo import Hooks +from modelscope.trainers.hooks.builder import HOOKS +from modelscope.utils.constant import LogKeys +from modelscope.utils.torch_utils import master_only +from .base import LoggerHook + + +@HOOKS.register_module(module_name=Hooks.TensorboardHook) +class TensorboardHook(LoggerHook): + """TensorBoard hook for visualization. + Args: + out_dir: output directory to save tensorboard files + interval (int): Logging interval (every k iterations). + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + reset_flag (bool): Whether to clear the output buffer after logging. + by_epoch (bool): Whether EpochBasedtrainer is used. + skip_keys (list): list of keys which will not add to tensorboard + """ + + def __init__(self, + out_dir=None, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True, + skip_keys=[LogKeys.ITER_TIME, LogKeys.DATA_LOAD_TIME]): + super(TensorboardHook, self).__init__( + interval=interval, + ignore_last=ignore_last, + reset_flag=reset_flag, + by_epoch=by_epoch) + self.out_dir = out_dir + self.skip_keys = skip_keys + + @master_only + def before_run(self, trainer): + super(TensorboardHook, self).before_run(trainer) + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError as e: + raise ImportError( + e.msg + ' ' + 'Please pip install tensorboard by ``pip install future tensorboard`` ' + 'or upgrade version by ``pip install future tensorboard --upgrade``.' + ) + + if self.out_dir is None: + self.out_dir = os.path.join(trainer.work_dir, 'tensorboard_output') + self.writer = SummaryWriter(self.out_dir) + + @master_only + def log(self, trainer): + for key, val in trainer.log_buffer.output.items(): + if key in self.skip_keys: + continue + if isinstance(val, str): + self.writer.add_text(key, val, self.get_iter(trainer)) + elif self.is_scalar(val): + self.writer.add_scalar(key, val, self.get_iter(trainer)) + else: + pass + + @master_only + def after_run(self, trainer): + self.writer.close() diff --git a/modelscope/trainers/hooks/logger/text_logger_hook.py b/modelscope/trainers/hooks/logger/text_logger_hook.py new file mode 100644 index 00000000..95644783 --- /dev/null +++ b/modelscope/trainers/hooks/logger/text_logger_hook.py @@ -0,0 +1,171 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import datetime +import os +import os.path as osp +from collections import OrderedDict + +import json +import torch +from torch import distributed as dist + +from modelscope.metainfo import Hooks +from modelscope.trainers.hooks.builder import HOOKS +from modelscope.trainers.hooks.logger.base import LoggerHook +from modelscope.utils.constant import LogKeys, ModeKeys +from modelscope.utils.json_utils import EnhancedEncoder +from modelscope.utils.torch_utils import get_dist_info, is_master + + +@HOOKS.register_module(module_name=Hooks.TextLoggerHook) +class TextLoggerHook(LoggerHook): + """Logger hook in text, Output log to both console and local json file. + + Args: + by_epoch (bool, optional): Whether EpochBasedtrainer is used. + Default: True. + interval (int, optional): Logging interval (every k iterations). + It is interval of iterations even by_epoch is true. Default: 10. + ignore_last (bool, optional): Ignore the log of last iterations in each + epoch if less than :attr:`interval`. Default: True. + reset_flag (bool, optional): Whether to clear the output buffer after + logging. Default: False. + out_dir (str): The directory to save log. If is None, use `trainer.work_dir` + """ + + def __init__(self, + by_epoch=True, + interval=10, + ignore_last=True, + reset_flag=False, + out_dir=None): + super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, + by_epoch) + self.by_epoch = by_epoch + self.time_sec_tot = 0 + self.out_dir = out_dir + self._logged_keys = [] # store the key has been logged + + def before_run(self, trainer): + super(TextLoggerHook, self).before_run(trainer) + + if self.out_dir is None: + self.out_dir = trainer.work_dir + + if not osp.exists(self.out_dir) and is_master(): + os.makedirs(self.out_dir) + + trainer.logger.info('Text logs will be saved to {}'.format( + self.out_dir)) + + self.start_iter = trainer.iter + self.json_log_path = osp.join(self.out_dir, + '{}.log.json'.format(trainer.timestamp)) + if hasattr(trainer, 'meta') and trainer.meta is not None: + self._dump_log(trainer.meta) + + def _get_max_memory(self, trainer): + device = getattr(trainer.model, 'output_device', None) + mem = torch.cuda.max_memory_allocated(device=device) + mem_mb = torch.tensor([mem / (1024 * 1024)], + dtype=torch.int, + device=device) + _, world_size = get_dist_info() + if world_size > 1: + dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) + return mem_mb.item() + + def _log_info(self, log_dict, trainer): + lr_key = LogKeys.LR + epoch_key = LogKeys.EPOCH + iter_key = LogKeys.ITER + mode_key = LogKeys.MODE + iter_time_key = LogKeys.ITER_TIME + data_load_time_key = LogKeys.DATA_LOAD_TIME + eta_key = LogKeys.ETA + + if log_dict[mode_key] == ModeKeys.TRAIN: + if isinstance(log_dict[lr_key], dict): + lr_str = [] + for k, val in log_dict[lr_key].items(): + lr_str.append(f'{lr_key}_{k}: {val:.3e}') + lr_str = ' '.join(lr_str) + else: + lr_str = f'{lr_key}: {log_dict[lr_key]:.3e}' + + if self.by_epoch: + log_str = f'{epoch_key} [{log_dict[epoch_key]}][{log_dict[iter_key]}/{trainer.iters_per_epoch}]\t' + else: + log_str = f'{iter_key} [{log_dict[iter_key]}/{trainer.max_iters}]\t' + log_str += f'{lr_str}, ' + self._logged_keys.extend([lr_key, mode_key, iter_key, epoch_key]) + + if iter_time_key in log_dict.keys(): + self.time_sec_tot += (log_dict[iter_time_key] * self.interval) + time_sec_avg = self.time_sec_tot / ( + trainer.iter - self.start_iter + 1) + eta_sec = time_sec_avg * (trainer.max_iters - trainer.iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + log_str += f'{eta_key}: {eta_str}, ' + log_str += f'{iter_time_key}: {log_dict[iter_time_key]:.3f}, ' + log_str += f'{data_load_time_key}: {log_dict[data_load_time_key]:.3f}, ' + self._logged_keys.extend([ + iter_time_key, + data_load_time_key, + ]) + else: + # val/test time + # here 1000 is the length of the val dataloader + # by epoch: epoch[val] [4][1000] + # by iter: iter[val] [1000] + if self.by_epoch: + log_str = f'{epoch_key}({log_dict[mode_key]}) [{log_dict[epoch_key]}][{log_dict[iter_key]}]\t' + else: + log_str = f'{iter_key}({log_dict[mode_key]}) [{log_dict[iter_key]}]\t' + self._logged_keys.extend([mode_key, iter_key, epoch_key]) + + log_items = [] + for name, val in log_dict.items(): + if name in self._logged_keys: + continue + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ', '.join(log_items) + + if is_master(): + trainer.logger.info(log_str) + + def _dump_log(self, log_dict): + # dump log in json format + json_log = OrderedDict() + for k, v in log_dict.items(): + json_log[k] = self._round_float(v) + + if is_master(): + with open(self.json_log_path, 'a+') as f: + json.dump(json_log, f, cls=EnhancedEncoder) + f.write('\n') + + def _round_float(self, items, ndigits=5): + if isinstance(items, list): + return [self._round_float(item) for item in items] + elif isinstance(items, float): + return round(items, ndigits) + else: + return items + + def log(self, trainer): + cur_iter = self.get_iter(trainer, inner_iter=True) + + log_dict = OrderedDict( + mode=trainer.mode, epoch=self.get_epoch(trainer), iter=cur_iter) + + # statistic memory + if torch.cuda.is_available(): + log_dict[LogKeys.MEMORY] = self._get_max_memory(trainer) + + log_dict = dict(log_dict, **trainer.log_buffer.output) + + self._log_info(log_dict, trainer) + self._dump_log(log_dict) + return log_dict diff --git a/modelscope/trainers/hooks/lr_scheduler_hook.py b/modelscope/trainers/hooks/lr_scheduler_hook.py new file mode 100644 index 00000000..ed018fef --- /dev/null +++ b/modelscope/trainers/hooks/lr_scheduler_hook.py @@ -0,0 +1,136 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Hooks +from modelscope.trainers.lrscheduler.builder import build_lr_scheduler +from modelscope.utils.constant import LogKeys +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import is_master +from .builder import HOOKS +from .hook import Hook +from .priority import Priority + + +@HOOKS.register_module(module_name=Hooks.LrSchedulerHook) +class LrSchedulerHook(Hook): + """Lr scheduler. + + Args: + by_epoch (bool): Whether lr changes by epoch + warmup (dict): warm up config + """ + PRIORITY = Priority.VERY_HIGH + + def __init__(self, by_epoch=True, warmup=None) -> None: + super().__init__() + self.by_epoch = by_epoch + self.warmup = warmup + self.warmup_lr_scheduler = None + + def before_run(self, trainer): + if self.warmup is not None: + assert isinstance(self.warmup, dict) and 'type' in self.warmup + self.warmup_lr_scheduler = build_lr_scheduler( + cfg=self.warmup, + default_args={'base_scheduler': trainer.lr_scheduler}) + + def get_current_lr(self, trainer): + import torch + + if isinstance(trainer.optimizer, torch.optim.Optimizer): + lr = [group['lr'] for group in trainer.optimizer.param_groups] + elif isinstance(trainer.optimizer, dict): + lr = dict() + for name, optim in trainer.optimizer.items(): + lr[name] = [group['lr'] for group in optim.param_groups] + else: + raise RuntimeError( + 'lr is not applicable because optimizer does not exist.') + return lr + + def before_train_iter(self, trainer): + if not self.by_epoch and trainer.iter >= getattr( + trainer, 'cumulative_iters', 1): + if self.warmup_lr_scheduler is not None: + self.warmup_lr_scheduler.step() + else: + trainer.lr_scheduler.step() + trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) + + def before_train_epoch(self, trainer): + trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) + + def after_train_epoch(self, trainer): + if self.by_epoch: + if self.warmup_lr_scheduler is not None: + self.warmup_lr_scheduler.step() + else: + trainer.lr_scheduler.step() + + def _get_log_lr(self, trainer): + cur_lr = self.get_current_lr(trainer) + # only record lr of the first param group + if isinstance(cur_lr, list): + lr = cur_lr[0] + else: + assert isinstance(cur_lr, dict) + lr = {} + for k, lr_ in cur_lr.items(): + assert isinstance(lr_, list) + lr.update({k: lr_[0]}) + + return lr + + +@HOOKS.register_module(module_name=Hooks.PlateauLrSchedulerHook) +class PlateauLrSchedulerHook(LrSchedulerHook): + """Lr scheduler hook for `ReduceLROnPlateau`. + + Args: + metric_key (str): Metric key returned from `trainer.metric_values`, + get the value of metric key and pass it to `ReduceLROnPlateau.step`. + by_epoch (bool): Whether lr changes by epoch + warmup (dict): warm up config + """ + PRIORITY = Priority.LOW # should be after EvaluationHook + + def __init__(self, metric_key, by_epoch=True, warmup=None) -> None: + super().__init__(by_epoch=by_epoch, warmup=warmup) + self.metric_key = metric_key + + def before_run(self, trainer): + super().before_run(trainer) + if not hasattr(trainer, 'logger'): + self.logger = get_logger(__name__) + else: + self.logger = trainer.logger + + def after_train_epoch(self, trainer): + # adapt to evaluation intervel is greater than 1 + if trainer.metric_values is None: + if is_master(): + self.logger.warning( + f'Current epoch {trainer.epoch} has no evaluation metric values, skip lr_scheduler.step() !' + ) + return + + metrics = trainer.metric_values[self.metric_key] + + if self.by_epoch: + if self.warmup_lr_scheduler is not None: + self.warmup_lr_scheduler.step(metrics=metrics) + else: + trainer.lr_scheduler.step(metrics=metrics) + + +@HOOKS.register_module(module_name=Hooks.NoneLrSchedulerHook) +class NoneLrSchedulerHook(LrSchedulerHook): + + PRIORITY = Priority.LOW # should be after EvaluationHook + + def __init__(self, by_epoch=True, warmup=None) -> None: + super().__init__(by_epoch=by_epoch, warmup=warmup) + + def before_run(self, trainer): + return + + def after_train_epoch(self, trainer): + return diff --git a/modelscope/trainers/hooks/optimizer/__init__.py b/modelscope/trainers/hooks/optimizer/__init__.py new file mode 100644 index 00000000..d7c8c862 --- /dev/null +++ b/modelscope/trainers/hooks/optimizer/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .apex_optimizer_hook import ApexAMPOptimizerHook + from .base import OptimizerHook, NoneOptimizerHook + from .torch_optimizer_hook import TorchAMPOptimizerHook + +else: + _import_structure = { + 'apex_optimizer_hook': ['ApexAMPOptimizerHook'], + 'base': ['OptimizerHook', 'NoneOptimizerHook'], + 'torch_optimizer_hook': ['TorchAMPOptimizerHook'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/hooks/optimizer/apex_optimizer_hook.py b/modelscope/trainers/hooks/optimizer/apex_optimizer_hook.py new file mode 100644 index 00000000..f87ae849 --- /dev/null +++ b/modelscope/trainers/hooks/optimizer/apex_optimizer_hook.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging + +from modelscope.metainfo import Hooks +from modelscope.trainers.hooks.builder import HOOKS +from .base import OptimizerHook + + +@HOOKS.register_module(module_name=Hooks.ApexAMPOptimizerHook) +class ApexAMPOptimizerHook(OptimizerHook): + """Fp16 optimizer, if torch version is less than 1.6.0, + you must install apex (https://www.github.com/nvidia/apex) else use torch.cuda.amp by default + Args: + cumulative_iters (int): interval of gradients accumulation. Default: 1 + grad_clip (dict): Default None. Containing keys: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + More details please refer to `torch.nn.utils.clip_grad.clip_grad_norm_` + loss_keys (str | list): keys list of loss + opt_level (str): "O0" and "O3" are not true mixed precision, + but they are useful for establishing accuracy and speed baselines, respectively. + "O1" and "O2" are different implementations of mixed precision. + Try both, and see what gives the best speedup and accuracy for your model. + """ + + def __init__(self, + cumulative_iters=1, + grad_clip=None, + loss_keys='loss', + opt_level='O1'): + + super(ApexAMPOptimizerHook, self).__init__( + grad_clip=grad_clip, loss_keys=loss_keys) + self.cumulative_iters = cumulative_iters + self.opt_level = opt_level + + try: + from apex import amp + except ImportError: + raise ValueError( + 'apex not installed, please install apex from https://www.github.com/nvidia/apex.' + ) + + def before_run(self, trainer): + from apex import amp + + logging.info('open fp16') + # TODO: fix it should initialze amp with model not wrapper by DDP or DP + if hasattr(trainer.model, 'module'): + trainer.model, trainer.optimizer = amp.initialize( + trainer.model.module, + trainer.optimizer, + opt_level=self.opt_level) + else: + trainer.model, trainer.optimizer = amp.initialize( + trainer.model, trainer.optimizer, opt_level=self.opt_level) + + trainer.optimizer.zero_grad() + + def after_train_iter(self, trainer): + for k in self.loss_keys: + trainer.train_outputs[k] /= self.cumulative_iters + + from apex import amp + for k in self.loss_keys: + with amp.scale_loss(trainer.train_outputs[k], + trainer.optimizer) as scaled_loss: + scaled_loss.backward() + + if self.every_n_iters(trainer, self.cumulative_iters): + if self.grad_clip is not None: + self.clip_grads(trainer.model.parameters(), **self.grad_clip) + + trainer.optimizer.step() + trainer.optimizer.zero_grad() diff --git a/modelscope/trainers/hooks/optimizer/base.py b/modelscope/trainers/hooks/optimizer/base.py new file mode 100644 index 00000000..0f38c67a --- /dev/null +++ b/modelscope/trainers/hooks/optimizer/base.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging + +from torch.nn.utils import clip_grad + +from modelscope.metainfo import Hooks +from modelscope.outputs import OutputKeys +from modelscope.trainers.hooks.builder import HOOKS +from modelscope.trainers.hooks.hook import Hook +from modelscope.trainers.hooks.priority import Priority + + +@HOOKS.register_module(module_name=Hooks.OptimizerHook) +class OptimizerHook(Hook): + """Optimizer hook + + Args: + cumulative_iters (int): interval of gradients accumulation. Default: 1 + grad_clip (dict): Default None. Containing keys: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + More details please refer to `torch.nn.utils.clip_grad.clip_grad_norm_` + loss_keys (str | list): keys list of loss + """ + + PRIORITY = Priority.ABOVE_NORMAL + + def __init__(self, + cumulative_iters=1, + grad_clip=None, + loss_keys=OutputKeys.LOSS) -> None: + if isinstance(loss_keys, str): + loss_keys = [loss_keys] + assert isinstance(loss_keys, (tuple, list)) + self.loss_keys = loss_keys + self.cumulative_iters = cumulative_iters + self.grad_clip = grad_clip + + def clip_grads(self, params, **clip_args): + params = list( + filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + return clip_grad.clip_grad_norm_(params, **clip_args) + + def before_run(self, trainer): + trainer.optimizer.zero_grad() + trainer.cumulative_iters = self.cumulative_iters + + def after_train_iter(self, trainer): + for k in self.loss_keys: + trainer.train_outputs[k] /= self.cumulative_iters + trainer.train_outputs[k].backward() + + if self.every_n_iters(trainer, self.cumulative_iters): + if self.grad_clip is not None: + self.clip_grads(trainer.model.parameters(), **self.grad_clip) + + trainer.optimizer.step() + trainer.optimizer.zero_grad() + + +@HOOKS.register_module(module_name=Hooks.NoneOptimizerHook) +class NoneOptimizerHook(OptimizerHook): + + def __init__(self, cumulative_iters=1, grad_clip=None, loss_keys='loss'): + + super(NoneOptimizerHook, self).__init__( + grad_clip=grad_clip, loss_keys=loss_keys) + self.cumulative_iters = cumulative_iters + + def before_run(self, trainer): + return + + def after_train_iter(self, trainer): + return diff --git a/modelscope/trainers/hooks/optimizer/torch_optimizer_hook.py b/modelscope/trainers/hooks/optimizer/torch_optimizer_hook.py new file mode 100644 index 00000000..30ea88a2 --- /dev/null +++ b/modelscope/trainers/hooks/optimizer/torch_optimizer_hook.py @@ -0,0 +1,83 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging + +from modelscope.metainfo import Hooks +from modelscope.trainers.hooks.builder import HOOKS +from .base import OptimizerHook + + +@HOOKS.register_module(module_name=Hooks.TorchAMPOptimizerHook) +class TorchAMPOptimizerHook(OptimizerHook): + """Fp16 optimizer, if torch version is less than 1.6.0, + you must install apex (https://www.github.com/nvidia/apex) else use torch.cuda.amp by default + Args: + cumulative_iters (int): interval of gradients accumulation. Default: 1 + grad_clip (dict): Default None. Containing keys: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + More details please refer to `torch.nn.utils.clip_grad.clip_grad_norm_` + loss_keys (str | list): keys list of loss + loss_scale (float | dict): grade scale config. If loss_scale is a float, + static loss scaling will be used with the specified scale. + It can also be a dict containing arguments of GradScalar. For Pytorch >= 1.6, + we use official torch.cuda.amp.GradScaler. + please refer to: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler for the parameters. + """ + + def __init__(self, + cumulative_iters=1, + grad_clip=None, + loss_keys='loss', + loss_scale={}): + + super(TorchAMPOptimizerHook, self).__init__( + grad_clip=grad_clip, loss_keys=loss_keys) + self.cumulative_iters = cumulative_iters + self._scale_update_param = None + + from torch.cuda import amp + + if isinstance(loss_scale, float): + self._scale_update_param = loss_scale + self.scaler = amp.GradScaler(init_scale=loss_scale) + elif isinstance(loss_scale, dict): + self.scaler = amp.GradScaler(**loss_scale) + else: + raise ValueError( + '`loss_scale` type must be in [float, dict], but got {loss_scale}' + ) + + def before_run(self, trainer): + logging.info('open fp16') + trainer.optimizer.zero_grad() + + if hasattr(trainer.model, 'module'): + self._ori_model_forward = trainer.model.module.forward + self._model = trainer.model.module + else: + self._ori_model_forward = trainer.model.forward + self._model = trainer.model + + self.ori_model_forward = trainer.model.forward + + def before_train_iter(self, trainer): + from torch.cuda import amp + setattr(self._model, 'forward', amp.autocast()(self._model.forward)) + + def after_train_iter(self, trainer): + for k in self.loss_keys: + trainer.train_outputs[k] /= self.cumulative_iters + + for k in self.loss_keys: + self.scaler.scale(trainer.train_outputs[k]).backward() + + if self.every_n_iters(trainer, self.cumulative_iters): + self.scaler.unscale_(trainer.optimizer) + if self.grad_clip is not None: + self.clip_grads(trainer.model.parameters(), **self.grad_clip) + + self.scaler.step(trainer.optimizer) + self.scaler.update(self._scale_update_param) + trainer.optimizer.zero_grad() + + setattr(self._model, 'forward', self._ori_model_forward) diff --git a/modelscope/trainers/hooks/priority.py b/modelscope/trainers/hooks/priority.py new file mode 100644 index 00000000..db749652 --- /dev/null +++ b/modelscope/trainers/hooks/priority.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. +from enum import Enum +from typing import Union + + +class Priority(Enum): + """Hook priority levels. + + +--------------+------------+ + | Level | Value | + +==============+============+ + | HIGHEST | 0 | + +--------------+------------+ + | VERY_HIGH | 10 | + +--------------+------------+ + | HIGH | 30 | + +--------------+------------+ + | ABOVE_NORMAL | 40 | + +--------------+------------+ + | NORMAL | 50 | + +--------------+------------+ + | BELOW_NORMAL | 60 | + +--------------+------------+ + | LOW | 70 | + +--------------+------------+ + | VERY_LOW | 90 | + +--------------+------------+ + | LOWEST | 100 | + +--------------+------------+ + """ + + HIGHEST = 0 + VERY_HIGH = 10 + HIGH = 30 + ABOVE_NORMAL = 40 + NORMAL = 50 + BELOW_NORMAL = 60 + LOW = 70 + VERY_LOW = 90 + LOWEST = 100 + + +def get_priority(priority: Union[int, str, Priority]) -> int: + """Get priority value. + + Args: + priority (int or str or :obj:`Priority`): Priority. + + Returns: + int: The priority value. + """ + if isinstance(priority, int): + if priority < 0 or priority > 100: + raise ValueError('priority must be between 0 and 100') + return priority + elif isinstance(priority, Priority): + return priority.value + elif isinstance(priority, str): + return Priority[priority.upper()].value + else: + raise TypeError('priority must be an integer or Priority enum value') diff --git a/modelscope/trainers/lrscheduler/__init__.py b/modelscope/trainers/lrscheduler/__init__.py new file mode 100644 index 00000000..54576353 --- /dev/null +++ b/modelscope/trainers/lrscheduler/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .builder import LR_SCHEDULER, build_lr_scheduler + from .warmup import BaseWarmup, ConstantWarmup, ExponentialWarmup, LinearWarmup + +else: + _import_structure = { + 'builder': ['LR_SCHEDULER', 'build_lr_scheduler'], + 'warmup': + ['BaseWarmup', 'ConstantWarmup', 'ExponentialWarmup', 'LinearWarmup'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/lrscheduler/builder.py b/modelscope/trainers/lrscheduler/builder.py new file mode 100644 index 00000000..3a892001 --- /dev/null +++ b/modelscope/trainers/lrscheduler/builder.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import inspect + +from modelscope.utils.config import ConfigDict +from modelscope.utils.registry import Registry, build_from_cfg, default_group + +LR_SCHEDULER = Registry('lr_scheduler') + + +def build_lr_scheduler(cfg: ConfigDict, default_args: dict = None): + """ build lr scheduler from given lr scheduler config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for lr scheduler object. + default_args (dict, optional): Default initialization arguments. + """ + if cfg['type'].lower().endswith('warmup'): + # build warmup lr scheduler + if not hasattr(cfg, 'base_scheduler'): + if default_args is None or ('base_scheduler' not in default_args): + raise ValueError( + 'Must provide ``base_scheduler`` which is an instance of ``torch.optim.lr_scheduler._LRScheduler`` ' + 'for build warmup lr scheduler.') + else: + # build lr scheduler without warmup + if not hasattr(cfg, 'optimizer'): + if default_args is None or ('optimizer' not in default_args): + raise ValueError( + 'Must provide ``optimizer`` which is an instance of ``torch.optim.Optimizer`` ' + 'for build lr scheduler') + + return build_from_cfg( + cfg, LR_SCHEDULER, group_key=default_group, default_args=default_args) + + +def register_torch_lr_scheduler(): + from torch.optim import lr_scheduler + from torch.optim.lr_scheduler import _LRScheduler + + members = inspect.getmembers(lr_scheduler) + + for name, obj in members: + if (inspect.isclass(obj) and issubclass( + obj, _LRScheduler)) or name in ['ReduceLROnPlateau']: + LR_SCHEDULER.register_module(module_name=name, module_cls=obj) + + +register_torch_lr_scheduler() diff --git a/modelscope/trainers/lrscheduler/warmup/__init__.py b/modelscope/trainers/lrscheduler/warmup/__init__.py new file mode 100644 index 00000000..5263f2ff --- /dev/null +++ b/modelscope/trainers/lrscheduler/warmup/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .base import BaseWarmup + from .warmup import ConstantWarmup, ExponentialWarmup, LinearWarmup + +else: + _import_structure = { + 'base': ['BaseWarmup'], + 'warmup': ['ConstantWarmup', 'ExponentialWarmup', 'LinearWarmup'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/lrscheduler/warmup/base.py b/modelscope/trainers/lrscheduler/warmup/base.py new file mode 100644 index 00000000..4b066281 --- /dev/null +++ b/modelscope/trainers/lrscheduler/warmup/base.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from torch.optim.lr_scheduler import _LRScheduler + + +class BaseWarmup(_LRScheduler): + """Base warmup scheduler + + Args: + base_scheduler (torch.optim._LRScheduler): an instance of torch.optim._LRScheduler type + warmup_iters (int | list): Warmup iterations + last_epoch (int): The index of last epoch. + """ + + def __init__(self, + base_scheduler, + warmup_iters, + last_epoch=-1, + verbose=False): + self.base_scheduler = base_scheduler + self.warmup_iters = warmup_iters + optimizer = self.base_scheduler.optimizer + self._is_init_step = True + + super(BaseWarmup, self).__init__( + optimizer, last_epoch=last_epoch, verbose=verbose) + + def get_lr(self): + return self.base_scheduler.get_lr() + + def state_dict(self): + return self.base_scheduler.state_dict() + + def load_state_dict(self, state_dict): + return self.base_scheduler.load_state_dict(state_dict) + + def scale(self): + """Scale the learning rates. + """ + scale_value = self.get_warmup_scale(self.base_scheduler._step_count + - 1) + if isinstance(scale_value, (int, float)): + scale_value = [ + scale_value for _ in range(len(self.optimizer.param_groups)) + ] + else: + assert isinstance( + scale_value, (list, tuple)), 'Only support list or tuple type!' + assert len(scale_value) == len( + self.optimizer.param_groups), ('Size mismatch {} != {}'.format( + len(scale_value), len(self.optimizer.param_groups))) + + for i, group in enumerate(self.optimizer.param_groups): + group['lr'] *= scale_value[i] + + def step(self, *args, **kwargs): + """ + When ``self.base_scheduler._step_count`` is less than ``self.warmup_iters``, multiply lr by scale + """ + if self.base_scheduler._step_count > self.warmup_iters: + return self.base_scheduler.step(*args, **kwargs) + + for group, lr in zip(self.optimizer.param_groups, self.base_lrs): + group['lr'] = lr + + # `base_scheduler` has done step() at init when build + if self._is_init_step: + self._is_init_step = False + else: + self.base_scheduler.step(*args, **kwargs) + + self.scale() + + @classmethod + def get_warmup_scale(self, cur_iter): + pass diff --git a/modelscope/trainers/lrscheduler/warmup/warmup.py b/modelscope/trainers/lrscheduler/warmup/warmup.py new file mode 100644 index 00000000..777796ef --- /dev/null +++ b/modelscope/trainers/lrscheduler/warmup/warmup.py @@ -0,0 +1,80 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import LR_Schedulers +from modelscope.trainers.lrscheduler.builder import LR_SCHEDULER +from .base import BaseWarmup + + +@LR_SCHEDULER.register_module(module_name=LR_Schedulers.ConstantWarmup) +class ConstantWarmup(BaseWarmup): + """Linear warmup scheduler. + + Args: + base_scheduler (torch.optim._LRScheduler): an instance of torch.optim._LRScheduler type + warmup_ratio (float): Lr used at warmup stage equals to warmup_ratio * initial_lr + warmup_iters (int | list): Warmup iterations + last_epoch (int): The index of last epoch. + """ + + def __init__(self, + base_scheduler, + warmup_iters, + warmup_ratio=0.1, + last_epoch=-1): + self.warmup_ratio = warmup_ratio + super(ConstantWarmup, self).__init__( + base_scheduler, warmup_iters=warmup_iters, last_epoch=last_epoch) + + def get_warmup_scale(self, cur_iter): + if cur_iter >= self.warmup_iters: + return 1.0 + return self.warmup_ratio + + +@LR_SCHEDULER.register_module(module_name=LR_Schedulers.LinearWarmup) +class LinearWarmup(BaseWarmup): + """Linear warmup scheduler. + + Args: + base_scheduler (torch.optim._LRScheduler): an instance of torch.optim._LRScheduler type + warmup_iters (int | list): Warmup iterations + warmup_ratio (float): Lr used at the beginning of warmup equals to warmup_ratio * initial_lr + last_epoch (int): The index of last epoch. + """ + + def __init__(self, + base_scheduler, + warmup_iters, + warmup_ratio=0.1, + last_epoch=-1): + self.warmup_ratio = warmup_ratio + super(LinearWarmup, self).__init__( + base_scheduler, warmup_iters=warmup_iters, last_epoch=last_epoch) + + def get_warmup_scale(self, cur_iter): + k = (1 - cur_iter / self.warmup_iters) * (1 - self.warmup_ratio) + return 1 - k + + +@LR_SCHEDULER.register_module(module_name=LR_Schedulers.ExponentialWarmup) +class ExponentialWarmup(BaseWarmup): + """Exponential warmup scheduler. + + Args: + base_scheduler (torch.optim._LRScheduler): an instance of torch.optim._LRScheduler type + warmup_iters (int | list): Warmup iterations + warmup_ratio (float): Lr used at the beginning of warmup equals to warmup_ratio * initial_lr + last_epoch (int): The index of last epoch. + """ + + def __init__(self, + base_scheduler, + warmup_iters, + warmup_ratio=0.1, + last_epoch=-1): + self.warmup_ratio = warmup_ratio + super(ExponentialWarmup, self).__init__( + base_scheduler, warmup_iters=warmup_iters, last_epoch=last_epoch) + + def get_warmup_scale(self, cur_iter): + k = self.warmup_ratio**(1 - cur_iter / self.warmup_iters) + return k diff --git a/modelscope/trainers/multi_modal/__init__.py b/modelscope/trainers/multi_modal/__init__.py new file mode 100644 index 00000000..6840b573 --- /dev/null +++ b/modelscope/trainers/multi_modal/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .clip import CLIPTrainer + from .team import TEAMImgClsTrainer + from .ofa import OFATrainer + from .mplug import MPlugTrainer + +else: + _import_structure = { + 'clip': ['CLIPTrainer'], + 'team': ['TEAMImgClsTrainer'], + 'ofa': ['OFATrainer'], + 'mplug': ['MPlugTrainer'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/multi_modal/clip/__init__.py b/modelscope/trainers/multi_modal/clip/__init__.py new file mode 100644 index 00000000..61a6664b --- /dev/null +++ b/modelscope/trainers/multi_modal/clip/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .clip_trainer import CLIPTrainer diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer.py b/modelscope/trainers/multi_modal/clip/clip_trainer.py new file mode 100644 index 00000000..40c524ac --- /dev/null +++ b/modelscope/trainers/multi_modal/clip/clip_trainer.py @@ -0,0 +1,206 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from torch import distributed as dist +from torch import nn +from torch.utils.data import Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model, TorchModel +from modelscope.models.multi_modal.clip.model import convert_models_to_fp32 +from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.multi_modal import CLIPPreprocessor +from modelscope.trainers import EpochBasedTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.utils.config import Config +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, + ModeKeys) +from .clip_trainer_utils import get_loss, get_optimizer_params, get_schedule + + +def exclude(n): + return 'bn' in n or 'ln' in n or 'bias' in n or 'logit_scale' in n + + +def include(n): + return not exclude(n) + + +@TRAINERS.register_module(module_name=Trainers.clip_multi_modal_embedding) +class CLIPTrainer(EpochBasedTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Union[Callable, Dict[str, + Callable]]] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Union[Preprocessor, + Dict[str, Preprocessor]]] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + seed: int = 42, + **kwargs): + model = Model.from_pretrained(model, revision=model_revision) + # for training & eval, we convert the model from FP16 back to FP32 + # to compatible with modelscope amp training + convert_models_to_fp32(model) + cfg = Config.from_file(cfg_file) + if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: + work_dir = cfg.train.work_dir + else: + work_dir = kwargs['work_dir'] + + # fetch the model name of CLIP model (base, large or large-336) + model_name = cfg.pretrained_model.model_name + + # world size + world_size = int(os.environ.get('WORLD_SIZE', 1)) + + # train step, optimizer and lr_scheduler + epoch_steps = math.ceil( + len(train_dataset) / # noqa + (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa + cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs + + if optimizers[0] is None: + named_parameters = list(model.named_parameters()) + gain_or_bias_params = [ + p for n, p in named_parameters + if exclude(n) and p.requires_grad + ] + rest_params = [ + p for n, p in named_parameters + if include(n) and p.requires_grad + ] + optimizer_hparams = get_optimizer_params( + model_name, cfg) # lr, wd, beta1, beta2, eps + optimizer_args = { + 'params': [ + { + 'params': gain_or_bias_params, + 'weight_decay': 0. + }, + { + 'params': rest_params, + 'weight_decay': optimizer_hparams['weight_decay'] + }, + ], + 'lr': + optimizer_hparams['lr'], + 'betas': + (optimizer_hparams['beta1'], optimizer_hparams['beta2']), + 'eps': + optimizer_hparams['eps'], + } + optimizer = build_optimizer( + model, cfg=cfg.train.optimizer, default_args=optimizer_args) + else: + optimizer = optimizers[0] + + if optimizers[1] is None: + lr_scheduler = get_schedule(optimizer, cfg.train.lr_scheduler) + else: + lr_scheduler = optimizers[1] + optimizers = (optimizer, lr_scheduler) + + # loss module + loss_img = nn.CrossEntropyLoss() + loss_txt = nn.CrossEntropyLoss() + self.loss_img = loss_img.cuda(int(os.environ.get('LOCAL_RANK', 0))) + self.loss_txt = loss_txt.cuda(int(os.environ.get('LOCAL_RANK', 0))) + self.loss_cfg = cfg.train.loss_cfg + + # launcher and use_fp16 + if 'launcher' not in kwargs and cfg.train.get('launcher', None): + kwargs['launcher'] = cfg.train.launcher + if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): + kwargs['use_fp16'] = cfg.train.use_fp16 + + # preprocessor + if preprocessor is None: + preprocessor = { + ConfigKeys.train: + CLIPPreprocessor( + model_dir=work_dir, + mode=ModeKeys.TRAIN, + tokenizer=model.tokenizer, + resolution=model.model_info['image_resolution']), + ConfigKeys.val: + CLIPPreprocessor( + model_dir=work_dir, + mode=ModeKeys.EVAL, + tokenizer=model.tokenizer, + resolution=model.model_info['image_resolution']), + } + + # dataset related + self.dataset_cfg = cfg.dataset + if hasattr(self.dataset_cfg, 'column_map'): + # cases where dataset key names are not "img" and "text" + img_key_name = getattr(self.dataset_cfg.column_map, 'img', 'img') + preprocessor[ConfigKeys.train].set_input_img_key(img_key_name) + preprocessor[ConfigKeys.val].set_input_img_key(img_key_name) + text_key_name = getattr(self.dataset_cfg.column_map, 'text', + 'text') + preprocessor[ConfigKeys.train].set_input_text_key(text_key_name) + preprocessor[ConfigKeys.val].set_input_text_key(text_key_name) + self.global_batch_size = cfg.train.dataloader.batch_size_per_gpu * world_size + + super().__init__( + model=model, + cfg_file=cfg_file, + arg_parse_fn=arg_parse_fn, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + preprocessor=preprocessor, + optimizers=optimizers, + seed=seed, + **kwargs, + ) + + def train_step(self, model, inputs): + model.train() + inputs['mode'] = ModeKeys.TRAIN + model_outputs = model.forward( + inputs + ) # {OutputKeys.IMG_EMBEDDING: Tensor(batch_size, dim), OutputKeys.TEXT_EMBEDDING: Tensor(batch_size, dim)} + loss = get_loss(model_outputs, self.loss_img, self.loss_txt, + self.loss_cfg) + train_outputs = {'loss': loss} + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + unwrapped_model = getattr(model, 'module', model) + log_vars[ + 'logit_scale'] = unwrapped_model.clip_model.logit_scale.data.clone( + ).item() # noqa + log_vars['global_batch_size'] = int(self.global_batch_size) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + self.train_outputs = train_outputs diff --git a/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py new file mode 100644 index 00000000..fed255de --- /dev/null +++ b/modelscope/trainers/multi_modal/clip/clip_trainer_utils.py @@ -0,0 +1,125 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. + +import math +import os +from functools import partial +from inspect import unwrap + +import torch +import torch.distributed as dist +from torch.optim.lr_scheduler import LambdaLR + +from modelscope.outputs import OutputKeys + + +def get_optimizer_params(model_name, cfg): + # get default params + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + # base model + if model_name in ['damo/multi-modal_clip-vit-base-patch16_zh']: + params = { + 'lr': 5.0e-4, + 'beta1': 0.9, + 'beta2': 0.98, + 'eps': 1.0e-6, + 'weight_decay': 0.0 + } + # large models + elif model_name in [ + 'damo/multi-modal_clip-vit-large-patch14_zh', + 'damo/multi-modal_clip-vit-large-patch14_336_zh' + ]: + params = { + 'lr': 4.0e-4, + 'beta1': 0.9, + 'beta2': 0.98, + 'eps': 1.0e-6, + 'weight_decay': 0.0 + } + else: + params = { + 'lr': 5.0e-4, + 'beta1': 0.9, + 'beta2': 0.999, + 'eps': 1.0e-8, + 'weight_decay': 0.0 + } + # override with config params + for key in ['lr', 'beta1', 'beta2', 'eps', 'weight_decay']: + if hasattr(cfg.train, 'optimizer_hparams'): + params[key] = getattr(cfg.train.optimizer_hparams, key, + params[key]) + return params + + +def get_loss(model_outputs, loss_img, loss_txt, loss_cfg): + image_features = model_outputs[OutputKeys.IMG_EMBEDDING] + text_features = model_outputs[OutputKeys.TEXT_EMBEDDING] + logit_scale = model_outputs['logit_scale'] + logit_scale = logit_scale.mean() + if loss_cfg.aggregate and int(os.environ.get('WORLD_SIZE', 1)) > 1: + world_size = dist.get_world_size() + rank = dist.get_rank() + + # We gather tensors from all gpus to get more negatives to contrast with. + gathered_image_features = [ + torch.zeros_like(image_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + + all_image_features = torch.cat([image_features] + + gathered_image_features[:rank] + + gathered_image_features[rank + 1:]) + all_text_features = torch.cat([text_features] + + gathered_text_features[:rank] + + gathered_text_features[rank + 1:]) + + # this is needed to send gradients back everywhere. + logits_per_image = logit_scale * all_image_features @ all_text_features.t( + ) + logits_per_text = logits_per_image.t() + + else: + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + ground_truth = torch.arange(len(logits_per_image)).long() + ground_truth = ground_truth.cuda( + int(os.environ.get('LOCAL_RANK', 0)), non_blocking=True) + + total_loss = (loss_img(logits_per_image, ground_truth) + + loss_txt(logits_per_text, ground_truth)) / 2 + + return total_loss + + +def lr_lambda(num_warmup_steps, num_training_steps, num_cycles, current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps)) + return max( + 0.0, + 0.5 * # noqa + (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) # noqa + + +def get_schedule(optimizer, + scheduler, + num_cycles: float = 0.5, + last_epoch: int = -1): + num_warmup_steps = int(scheduler.warmup_proportion + * scheduler.num_train_steps) + num_training_steps = scheduler.num_train_steps + + return LambdaLR( + optimizer, + partial(lr_lambda, num_warmup_steps, num_training_steps, num_cycles), + last_epoch) diff --git a/modelscope/trainers/multi_modal/mplug/__init__.py b/modelscope/trainers/multi_modal/mplug/__init__.py new file mode 100644 index 00000000..caf7e3f0 --- /dev/null +++ b/modelscope/trainers/multi_modal/mplug/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .mplug_trainer import MPlugTrainer diff --git a/modelscope/trainers/multi_modal/mplug/mplug_trainer.py b/modelscope/trainers/multi_modal/mplug/mplug_trainer.py new file mode 100644 index 00000000..def66220 --- /dev/null +++ b/modelscope/trainers/multi_modal/mplug/mplug_trainer.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from collections.abc import Mapping + +import torch + +from modelscope.metainfo import Trainers +from modelscope.outputs import OutputKeys +from modelscope.trainers import NlpEpochBasedTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.file_utils import func_receive_dict_inputs + + +@TRAINERS.register_module(module_name=Trainers.mplug) +class MPlugTrainer(NlpEpochBasedTrainer): + + def _decode(self, tokens): + tokenizer = self.eval_preprocessor.tokenizer + return tokenizer.decode(tokens, skip_special_tokens=True) + + def evaluation_step(self, data): + model = self.model.module if self._dist else self.model + model.eval() + + with torch.no_grad(): + if isinstance( + data, + Mapping) and not func_receive_dict_inputs(model.forward): + result = model.forward(**data) + else: + result = model.forward(data) + + result[OutputKeys.TEXT] = [ + self._decode(seq) for seq in result['sequences'] + ] + data[OutputKeys.LABELS] = [ + self._decode(seq) for seq in data['answer_input_ids'] + ] + + return result diff --git a/modelscope/trainers/multi_modal/ofa/__init__.py b/modelscope/trainers/multi_modal/ofa/__init__.py new file mode 100644 index 00000000..34e4ec7a --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .ofa_trainer import OFATrainer diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py new file mode 100644 index 00000000..71494768 --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -0,0 +1,161 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +import shutil +from functools import partial +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from torch import distributed as dist +from torch import nn +from torch.utils.data import Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model, TorchModel +from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.multi_modal import OfaPreprocessor +from modelscope.preprocessors.ofa.utils.collate import collate_fn +from modelscope.trainers import EpochBasedTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.utils.config import Config +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, + ModeKeys) +from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, + get_schedule) + + +@TRAINERS.register_module(module_name=Trainers.ofa) +class OFATrainer(EpochBasedTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Union[Callable, Dict[str, + Callable]]] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Union[Preprocessor, + Dict[str, Preprocessor]]] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + seed: int = 42, + **kwargs): + model = Model.from_pretrained(model, revision=model_revision) + model_dir = model.model_dir + self.cfg_modify_fn = cfg_modify_fn + cfg = self.rebuild_config(Config.from_file(cfg_file)) + if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: + work_dir = cfg.train.work_dir + else: + work_dir = kwargs['work_dir'] + tokenizer_files = { + 'zh': [ + 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', + 'config.json', 'ans2label.json' + ], + 'en': [ + 'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json', + 'ans2label.json' + ], + } + for filename in tokenizer_files[cfg.model.get('language', 'en')]: + finetune_file = os.path.join(work_dir, filename) + pretrain_file = os.path.join(model_dir, filename) + if os.path.exists(finetune_file): + continue + if os.path.exists(pretrain_file): + shutil.copy(pretrain_file, finetune_file) + + if preprocessor is None: + preprocessor = { + ConfigKeys.train: + OfaPreprocessor( + model_dir=work_dir, mode=ModeKeys.TRAIN, no_collate=True), + ConfigKeys.val: + OfaPreprocessor( + model_dir=work_dir, mode=ModeKeys.EVAL, no_collate=True), + } + # use torchrun launch + world_size = int(os.environ.get('WORLD_SIZE', 1)) + epoch_steps = math.ceil( + len(train_dataset) / # noqa + (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa + cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs + cfg.train.criterion.tokenizer = model.tokenizer + self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( + cfg.train.criterion) + if optimizers[0] is None: + optimizer = build_optimizer(model, cfg=cfg.train.optimizer) + else: + optimizer = optimizers[0] + if optimizers[1] is None: + scheduler_class, scheduler_args = get_schedule( + cfg.train.lr_scheduler) + if scheduler_class is not None: + lr_scheduler = scheduler_class(**{'optimizer': optimizer}, + **scheduler_args) + else: + lr_scheduler = None + else: + lr_scheduler = optimizers[1] + optimizers = (optimizer, lr_scheduler) + if data_collator is None: + data_collator = partial( + collate_fn, + pad_idx=model.tokenizer.pad_token_id, + eos_idx=model.tokenizer.eos_token_id, + ) + if 'launcher' not in kwargs and cfg.train.get('launcher', None): + kwargs['launcher'] = cfg.train.launcher + if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): + kwargs['use_fp16'] = cfg.train.use_fp16 + kwargs['to_tensor'] = False + super().__init__( + model=model, + cfg_file=cfg_file, + arg_parse_fn=arg_parse_fn, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + preprocessor=preprocessor, + optimizers=optimizers, + seed=seed, + **kwargs, + ) + + def rebuild_config(self, cfg: Config): + if self.cfg_modify_fn is not None: + cfg = self.cfg_modify_fn(cfg) + return cfg + + def train_step(self, model, inputs): + model.train() + loss, sample_size, logging_output = self.criterion(model, inputs) + train_outputs = {'loss': loss} + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + self.train_outputs = train_outputs diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py new file mode 100644 index 00000000..3930febb --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -0,0 +1,247 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. +import math + +import numpy as np +import torch +import torch.nn.functional as F +import transformers +from torch.nn.modules.loss import _Loss + + +def construct_rdrop_sample(x): + if isinstance(x, dict): + for key in x: + x[key] = construct_rdrop_sample(x[key]) + return x + elif isinstance(x, torch.Tensor): + return x.repeat(2, *([1] * (x.dim() - 1))) + elif isinstance(x, int): + return x * 2 + elif isinstance(x, np.ndarray): + return x.repeat(2) + else: + raise NotImplementedError + + +def kl_loss(p, q): + p_loss = F.kl_div(p, torch.exp(q), reduction='sum') + q_loss = F.kl_div(q, torch.exp(p), reduction='sum') + loss = (p_loss + q_loss) / 2 + return loss + + +def label_smoothed_nll_loss(lprobs, + target, + epsilon, + update_num, + reduce=True, + drop_worst_ratio=0.0, + drop_worst_after=0, + use_rdrop=False, + reg_alpha=1.0, + constraint_masks=None, + constraint_start=None, + constraint_end=None): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1) + if constraint_masks is not None: + smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum( + dim=-1, keepdim=True).squeeze(-1) + eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6) + elif constraint_start is not None and constraint_end is not None: + constraint_range = [0, 1, 2, 3] + list( + range(constraint_start, constraint_end)) + smooth_loss = -lprobs[:, constraint_range].sum( + dim=-1, keepdim=True).squeeze(-1) + eps_i = epsilon / (len(constraint_range) - 1 + 1e-6) + else: + smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1) + eps_i = epsilon / (lprobs.size(-1) - 1) + loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss + if drop_worst_ratio > 0 and update_num > drop_worst_after: + if use_rdrop: + true_batch_size = loss.size(0) // 2 + _, indices = torch.topk( + loss[:true_batch_size], + k=int(true_batch_size * (1 - drop_worst_ratio)), + largest=False) + loss = torch.cat([loss[indices], loss[indices + true_batch_size]]) + nll_loss = torch.cat( + [nll_loss[indices], nll_loss[indices + true_batch_size]]) + lprobs = torch.cat( + [lprobs[indices], lprobs[indices + true_batch_size]]) + else: + loss, indices = torch.topk( + loss, + k=int(loss.shape[0] * (1 - drop_worst_ratio)), + largest=False) + nll_loss = nll_loss[indices] + lprobs = lprobs[indices] + + ntokens = loss.numel() + nll_loss = nll_loss.sum() / ntokens # 后面在grads里面处理 + loss = loss.sum() / ntokens # 后面在grads里面处理 + if use_rdrop: + true_batch_size = lprobs.size(0) // 2 + p = lprobs[:true_batch_size] + q = lprobs[true_batch_size:] + if constraint_start is not None and constraint_end is not None: + constraint_range = [0, 1, 2, 3] + list( + range(constraint_start, constraint_end)) + p = p[:, constraint_range] + q = q[:, constraint_range] + loss += kl_loss(p, q) * reg_alpha + + return loss, nll_loss, ntokens + + +class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): + + def __init__(self, args): + super().__init__() + self.sentence_avg = args.get('sentence_avg', False) + self.eps = args.get('label_smoothing', 0.1) + self.ignore_prefix_size = args.get('ignore_prefix_size', 0) + self.ignore_eos = args.get('ignore_eos', False) + self.report_accuracy = args.get('report_accuracy', False) + self.drop_worst_ratio = args.get('drop_worst_ratio', 0.0) + self.drop_worst_after = args.get('drop_worst_after', 0) + self.use_rdrop = args.get('use_rdrop', False) + self.reg_alpha = args.get('reg_alpha', 1.0) + self.sample_patch_num = args.get('sample_patch_num', 196) + + self.constraint_start = None + self.constraint_end = None + if args.get('constraint_range', None): + constraint_start, constraint_end = args.constraint_range.split(',') + self.constraint_start = int(constraint_start) + self.constraint_end = int(constraint_end) + self.padding_idx = args.tokenizer.pad_token_id + self.args = args + + def forward(self, model, sample, update_num=0, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + if 'labels' in sample: + del sample['labels'] + if 'samples' in sample: + del sample['samples'] + + if self.use_rdrop: + construct_rdrop_sample(sample) + output = model.model(**sample['net_input']) + loss, nll_loss, ntokens = self.compute_loss( + output.logits, sample, update_num, reduce=reduce) + sample_size = ( + sample['target'].size(0) if self.sentence_avg else ntokens) + logging_output = { + 'loss': loss.data, + 'nll_loss': nll_loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['nsentences'], + 'sample_size': sample_size, + } + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, logits, sample): + conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ + 'conf'] is not None else 1 + constraint_masks = None + if 'constraint_masks' in sample and sample[ + 'constraint_masks'] is not None: + constraint_masks = sample['constraint_masks'] + logits.masked_fill_(~constraint_masks, -math.inf) + if self.constraint_start is not None and self.constraint_end is not None: + logits[:, :, 4:self.constraint_start] = -math.inf + logits[:, :, self.constraint_end:] = -math.inf + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) * conf + target = sample['target'] + if self.ignore_prefix_size > 0: + lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() + target = target[:, self.ignore_prefix_size:].contiguous() + if constraint_masks is not None: + constraint_masks = constraint_masks[:, self.ignore_prefix_size:, :].contiguous() # yapf: disable + if self.ignore_eos: + bsz, seq_len, embed_dim = lprobs.size() + eos_indices = target.eq(self.task.tgt_dict.eos()) + lprobs = lprobs[~eos_indices].reshape(bsz, seq_len - 1, embed_dim) + target = target[~eos_indices].reshape(bsz, seq_len - 1) + if constraint_masks is not None: + constraint_masks = constraint_masks[~eos_indices].reshape( + bsz, seq_len - 1, embed_dim) + if constraint_masks is not None: + constraint_masks = constraint_masks.view(-1, + constraint_masks.size(-1)) + return lprobs.view(-1, + lprobs.size(-1)), target.view(-1), constraint_masks + + def compute_loss(self, logits, sample, update_num, reduce=True): + lprobs, target, constraint_masks = self.get_lprobs_and_target( + logits, sample) + if constraint_masks is not None: + constraint_masks = constraint_masks[target != self.padding_idx] + lprobs = lprobs[target != self.padding_idx] + target = target[target != self.padding_idx] + loss, nll_loss, ntokens = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + update_num, + reduce=reduce, + drop_worst_ratio=self.drop_worst_ratio, + drop_worst_after=self.drop_worst_after, + use_rdrop=self.use_rdrop, + reg_alpha=self.reg_alpha, + constraint_masks=constraint_masks, + constraint_start=self.constraint_start, + constraint_end=self.constraint_end) + return loss, nll_loss, ntokens + + +def get_schedule(scheduler): + + if scheduler.name == 'const': + scheduler_class = transformers.get_constant_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps) + } + elif scheduler.name == 'linear': + scheduler_class = transformers.get_linear_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps), + 'num_training_steps': + scheduler.num_train_steps + } + elif scheduler.name == 'cosine': + scheduler_class = transformers.get_cosine_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps), + 'num_training_steps': + scheduler.num_train_steps + } + elif scheduler.name == 'polynomial_decay': + scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup + scheduler_args = { + 'num_warmup_steps': + int(scheduler.warmup_proportion * scheduler.num_train_steps), + 'num_training_steps': + scheduler.num_train_steps, + 'lr_end': + scheduler.lr_end + } + else: + raise NotImplementedError + + return scheduler_class, scheduler_args diff --git a/modelscope/trainers/multi_modal/team/__init__.py b/modelscope/trainers/multi_modal/team/__init__.py new file mode 100644 index 00000000..b48fcc7e --- /dev/null +++ b/modelscope/trainers/multi_modal/team/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .team_trainer import TEAMImgClsTrainer diff --git a/modelscope/trainers/multi_modal/team/team_trainer.py b/modelscope/trainers/multi_modal/team/team_trainer.py new file mode 100644 index 00000000..7c557416 --- /dev/null +++ b/modelscope/trainers/multi_modal/team/team_trainer.py @@ -0,0 +1,144 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from collections import OrderedDict +from typing import Callable, Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from sklearn.metrics import confusion_matrix +from torch.optim import AdamW +from torch.utils.data import DataLoader, Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.msdatasets import MsDataset +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.multi_modal.team.team_trainer_utils import ( + get_optimizer, train_mapping, val_mapping) +from modelscope.utils.config import Config +from modelscope.utils.constant import DownloadMode, ModeKeys +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@TRAINERS.register_module(module_name=Trainers.image_classification_team) +class TEAMImgClsTrainer(BaseTrainer): + + def __init__(self, cfg_file: str, model: str, device_id: int, + data_collator: Callable, train_dataset: Dataset, + val_dataset: Dataset, *args, **kwargs): + super().__init__(cfg_file) + + self.cfg = Config.from_file(cfg_file) + team_model = Model.from_pretrained(model) + image_model = team_model.model.image_model.vision_transformer + classification_model = nn.Sequential( + OrderedDict([('encoder', image_model), + ('classifier', + nn.Linear(768, self.cfg.dataset.class_num))])) + self.model = classification_model + + for pname, param in self.model.named_parameters(): + if 'encoder' in pname: + param.requires_grad = False + + self.device_id = device_id + self.total_epoch = self.cfg.train.epoch + self.train_batch_size = self.cfg.train.batch_size + self.val_batch_size = self.cfg.evaluation.batch_size + self.ckpt_dir = self.cfg.train.ckpt_dir + + self.collate_fn = data_collator + self.train_dataset = train_dataset + self.val_dataset = val_dataset + + self.criterion = nn.CrossEntropyLoss().to(self.device_id) + + def train(self, *args, **kwargs): + self.model.train() + self.model.to(self.device_id) + + optimizer = get_optimizer(self.model) + + for epoch in range(self.total_epoch): + train_params = { + 'pin_memory': True, + 'collate_fn': self.collate_fn, + 'batch_size': self.train_batch_size, + 'shuffle': True, + 'drop_last': True, + 'num_workers': 8 + } + + train_loader = DataLoader(self.train_dataset, **train_params) + + for batch_idx, data in enumerate(train_loader): + img_tensor, label_tensor = data['pixel_values'], data['labels'] + img_tensor = img_tensor.to(self.device_id, non_blocking=True) + label_tensor = label_tensor.to( + self.device_id, non_blocking=True) + + pred_logits = self.model(img_tensor) + loss = self.criterion(pred_logits, label_tensor) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch_idx % 10 == 0: + logger.info( + 'epoch: {}, train batch {}/{}, loss={:.5f}'.format( + epoch, batch_idx, len(train_loader), loss.item())) + + os.makedirs(self.ckpt_dir, exist_ok=True) + torch.save(self.model.state_dict(), + '{}/epoch{}.pth'.format(self.ckpt_dir, epoch)) + self.evaluate() + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + if checkpoint_path is not None: + checkpoint_params = torch.load(checkpoint_path, 'cpu') + self.model.load_state_dict(checkpoint_params) + self.model.eval() + self.model.to(self.device_id) + + val_params = { + 'collate_fn': self.collate_fn, + 'batch_size': self.val_batch_size, + 'shuffle': False, + 'drop_last': False, + 'num_workers': 8 + } + val_loader = DataLoader(self.val_dataset, **val_params) + + tp_cnt, processed_cnt = 0, 0 + all_pred_labels, all_gt_labels = [], [] + with torch.no_grad(): + for batch_idx, data in enumerate(val_loader): + img_tensor, label_tensor = data['pixel_values'], data['labels'] + img_tensor = img_tensor.to(self.device_id, non_blocking=True) + label_tensor = label_tensor.to( + self.device_id, non_blocking=True) + + pred_logits = self.model(img_tensor) + pred_labels = torch.max(pred_logits, dim=1)[1] + tp_cnt += torch.sum(pred_labels == label_tensor).item() + processed_cnt += img_tensor.shape[0] + logger.info('Accuracy: {:.3f}'.format(tp_cnt / processed_cnt)) + + all_pred_labels.extend(pred_labels.tolist()) + all_gt_labels.extend(label_tensor.tolist()) + conf_mat = confusion_matrix(all_gt_labels, all_pred_labels) + acc_mean_per_class = np.mean(conf_mat.diagonal() + / conf_mat.sum(axis=1)) + logger.info( + 'Accuracy mean per class: {:.3f}'.format(acc_mean_per_class)) diff --git a/modelscope/trainers/multi_modal/team/team_trainer_utils.py b/modelscope/trainers/multi_modal/team/team_trainer_utils.py new file mode 100644 index 00000000..ff1a4fd6 --- /dev/null +++ b/modelscope/trainers/multi_modal/team/team_trainer_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torchvision.transforms as transforms +from PIL import Image +from torch.optim import AdamW + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +train_transforms = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), +]) +val_transforms = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), +]) + + +def train_mapping(examples): + examples['pixel_values'] = [ + train_transforms(Image.open(image).convert('RGB')) + for image in examples['image:FILE'] + ] + examples['labels'] = [label for label in examples['label:LABEL']] + return examples + + +def val_mapping(examples): + examples['pixel_values'] = [ + val_transforms(Image.open(image).convert('RGB')) + for image in examples['image:FILE'] + ] + examples['labels'] = [label for label in examples['label:LABEL']] + return examples + + +def collate_fn(examples): + images = [] + labels = [] + for example in examples: + images.append((example['pixel_values'])) + labels.append(example['labels']) + + pixel_values = torch.stack(images) + labels = torch.tensor(labels) + return {'pixel_values': pixel_values, 'labels': labels} + + +def get_params_groups(ddp_model, lr): + large_lr_params = [] + small_lr_params = [] + for name, param in ddp_model.named_parameters(): + if not param.requires_grad: + continue + + if 'encoder' in name: + small_lr_params.append(param) + elif 'classifier' in name: + large_lr_params.append(param) + else: + logger.info('skip param: {}'.format(name)) + + params_groups = [{ + 'params': small_lr_params, + 'lr': lr / 10.0 + }, { + 'params': large_lr_params, + 'lr': lr + }] + return params_groups + + +def get_optimizer(ddp_model): + lr_init = 1e-3 + betas = [0.9, 0.999] + weight_decay = 0.02 + params_groups = get_params_groups(ddp_model, lr=lr_init) + return AdamW( + params_groups, lr=lr_init, betas=betas, weight_decay=weight_decay) diff --git a/modelscope/trainers/nlp/__init__.py b/modelscope/trainers/nlp/__init__.py new file mode 100644 index 00000000..e3c39cf2 --- /dev/null +++ b/modelscope/trainers/nlp/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .sequence_classification_trainer import SequenceClassificationTrainer + from .csanmt_translation_trainer import CsanmtTranslationTrainer + from .text_ranking_trainer import TextRankingTrainer + from .text_generation_trainer import TextGenerationTrainer +else: + _import_structure = { + 'sequence_classification_trainer': ['SequenceClassificationTrainer'], + 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], + 'text_ranking_trainer': ['TextRankingTrainer'], + 'text_generation_trainer': ['TextGenerationTrainer'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/nlp/csanmt_translation_trainer.py b/modelscope/trainers/nlp/csanmt_translation_trainer.py new file mode 100644 index 00000000..c93599c7 --- /dev/null +++ b/modelscope/trainers/nlp/csanmt_translation_trainer.py @@ -0,0 +1,326 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Dict, Optional + +import tensorflow as tf + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.nlp import CsanmtForTranslation +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + tf.disable_eager_execution() + +logger = get_logger() + + +@TRAINERS.register_module(module_name=r'csanmt-translation') +class CsanmtTranslationTrainer(BaseTrainer): + + def __init__(self, model: str, cfg_file: str = None, *args, **kwargs): + if not osp.exists(model): + model = snapshot_download(model) + tf.reset_default_graph() + + self.model_dir = model + self.model_path = osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER) + if cfg_file is None: + cfg_file = osp.join(model, ModelFile.CONFIGURATION) + + super().__init__(cfg_file) + + self.params = {} + self._override_params_from_file() + + tf_config = tf.ConfigProto(allow_soft_placement=True) + tf_config.gpu_options.allow_growth = True + self._session = tf.Session(config=tf_config) + + self.source_wids = tf.placeholder( + dtype=tf.int64, shape=[None, None], name='source_wids') + self.target_wids = tf.placeholder( + dtype=tf.int64, shape=[None, None], name='target_wids') + self.output = {} + + self.global_step = tf.train.create_global_step() + + self.model = CsanmtForTranslation(self.model_path, **self.params) + output = self.model(input=self.source_wids, label=self.target_wids) + self.output.update(output) + + self.model_saver = tf.train.Saver( + tf.global_variables(), + max_to_keep=self.params['keep_checkpoint_max']) + with self._session.as_default() as sess: + logger.info(f'loading model from {self.model_path}') + + pretrained_variables_map = get_pretrained_variables_map( + self.model_path) + + tf.train.init_from_checkpoint(self.model_path, + pretrained_variables_map) + sess.run(tf.global_variables_initializer()) + + def _override_params_from_file(self): + + self.params['hidden_size'] = self.cfg['model']['hidden_size'] + self.params['filter_size'] = self.cfg['model']['filter_size'] + self.params['num_heads'] = self.cfg['model']['num_heads'] + self.params['num_encoder_layers'] = self.cfg['model'][ + 'num_encoder_layers'] + self.params['num_decoder_layers'] = self.cfg['model'][ + 'num_decoder_layers'] + self.params['layer_preproc'] = self.cfg['model']['layer_preproc'] + self.params['layer_postproc'] = self.cfg['model']['layer_postproc'] + self.params['shared_embedding_and_softmax_weights'] = self.cfg[ + 'model']['shared_embedding_and_softmax_weights'] + self.params['shared_source_target_embedding'] = self.cfg['model'][ + 'shared_source_target_embedding'] + self.params['initializer_scale'] = self.cfg['model'][ + 'initializer_scale'] + self.params['position_info_type'] = self.cfg['model'][ + 'position_info_type'] + self.params['max_relative_dis'] = self.cfg['model']['max_relative_dis'] + self.params['num_semantic_encoder_layers'] = self.cfg['model'][ + 'num_semantic_encoder_layers'] + self.params['src_vocab_size'] = self.cfg['model']['src_vocab_size'] + self.params['trg_vocab_size'] = self.cfg['model']['trg_vocab_size'] + self.params['attention_dropout'] = 0.0 + self.params['residual_dropout'] = 0.0 + self.params['relu_dropout'] = 0.0 + + self.params['train_src'] = self.cfg['dataset']['train_src'] + self.params['train_trg'] = self.cfg['dataset']['train_trg'] + self.params['vocab_src'] = self.cfg['dataset']['src_vocab']['file'] + self.params['vocab_trg'] = self.cfg['dataset']['trg_vocab']['file'] + + self.params['num_gpus'] = self.cfg['train']['num_gpus'] + self.params['warmup_steps'] = self.cfg['train']['warmup_steps'] + self.params['update_cycle'] = self.cfg['train']['update_cycle'] + self.params['keep_checkpoint_max'] = self.cfg['train'][ + 'keep_checkpoint_max'] + self.params['confidence'] = self.cfg['train']['confidence'] + self.params['optimizer'] = self.cfg['train']['optimizer'] + self.params['adam_beta1'] = self.cfg['train']['adam_beta1'] + self.params['adam_beta2'] = self.cfg['train']['adam_beta2'] + self.params['adam_epsilon'] = self.cfg['train']['adam_epsilon'] + self.params['gradient_clip_norm'] = self.cfg['train'][ + 'gradient_clip_norm'] + self.params['learning_rate_decay'] = self.cfg['train'][ + 'learning_rate_decay'] + self.params['initializer'] = self.cfg['train']['initializer'] + self.params['initializer_scale'] = self.cfg['train'][ + 'initializer_scale'] + self.params['learning_rate'] = self.cfg['train']['learning_rate'] + self.params['train_batch_size_words'] = self.cfg['train'][ + 'train_batch_size_words'] + self.params['scale_l1'] = self.cfg['train']['scale_l1'] + self.params['scale_l2'] = self.cfg['train']['scale_l2'] + self.params['train_max_len'] = self.cfg['train']['train_max_len'] + self.params['max_training_steps'] = self.cfg['train'][ + 'max_training_steps'] + self.params['save_checkpoints_steps'] = self.cfg['train'][ + 'save_checkpoints_steps'] + self.params['num_of_samples'] = self.cfg['train']['num_of_samples'] + self.params['eta'] = self.cfg['train']['eta'] + + self.params['beam_size'] = self.cfg['evaluation']['beam_size'] + self.params['lp_rate'] = self.cfg['evaluation']['lp_rate'] + self.params['max_decoded_trg_len'] = self.cfg['evaluation'][ + 'max_decoded_trg_len'] + + self.params['seed'] = self.cfg['model']['seed'] + + def train(self, *args, **kwargs): + logger.info('Begin csanmt training') + + train_src = osp.join(self.model_dir, self.params['train_src']) + train_trg = osp.join(self.model_dir, self.params['train_trg']) + vocab_src = osp.join(self.model_dir, self.params['vocab_src']) + vocab_trg = osp.join(self.model_dir, self.params['vocab_trg']) + + iteration = 0 + + with self._session.as_default() as tf_session: + while True: + iteration += 1 + if iteration >= self.params['max_training_steps']: + break + + train_input_fn = input_fn( + train_src, + train_trg, + vocab_src, + vocab_trg, + batch_size_words=self.params['train_batch_size_words'], + max_len=self.params['train_max_len'], + num_gpus=self.params['num_gpus'] + if self.params['num_gpus'] > 0 else 1, + is_train=True, + session=tf_session, + iteration=iteration) + + features, labels = train_input_fn + + features_batch, labels_batch = tf_session.run( + [features, labels]) + + feed_dict = { + self.source_wids: features_batch, + self.target_wids: labels_batch + } + sess_outputs = self._session.run( + self.output, feed_dict=feed_dict) + loss_step = sess_outputs['loss'] + logger.info('Iteration: {}, step loss: {:.6f}'.format( + iteration, loss_step)) + + if iteration % self.params['save_checkpoints_steps'] == 0: + tf.logging.info('%s: Saving model on step: %d.' % + (__name__, iteration)) + ck_path = self.model_dir + 'model.ckpt' + self.model_saver.save( + tf_session, + ck_path, + global_step=tf.train.get_global_step()) + + tf.logging.info('%s: NMT training completed at time: %s.') + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + """evaluate a dataset + + evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` + does not exist, read from the config file. + + Args: + checkpoint_path (Optional[str], optional): the model path. Defaults to None. + + Returns: + Dict[str, float]: the results about the evaluation + Example: + {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} + """ + pass + + +def input_fn(src_file, + trg_file, + src_vocab_file, + trg_vocab_file, + num_buckets=20, + max_len=100, + batch_size=200, + batch_size_words=4096, + num_gpus=1, + is_train=True, + session=None, + iteration=None): + src_vocab = tf.lookup.StaticVocabularyTable( + tf.lookup.TextFileInitializer( + src_vocab_file, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER), + num_oov_buckets=1) # NOTE unk-> vocab_size + trg_vocab = tf.lookup.StaticVocabularyTable( + tf.lookup.TextFileInitializer( + trg_vocab_file, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER), + num_oov_buckets=1) # NOTE unk-> vocab_size + src_dataset = tf.data.TextLineDataset(src_file) + trg_dataset = tf.data.TextLineDataset(trg_file) + src_trg_dataset = tf.data.Dataset.zip((src_dataset, trg_dataset)) + src_trg_dataset = src_trg_dataset.map( + lambda src, trg: (tf.string_split([src]), tf.string_split([trg])), + num_parallel_calls=10).prefetch(1000000) + src_trg_dataset = src_trg_dataset.map( + lambda src, trg: (src.values, trg.values), + num_parallel_calls=10).prefetch(1000000) + src_trg_dataset = src_trg_dataset.map( + lambda src, trg: (src_vocab.lookup(src), trg_vocab.lookup(trg)), + num_parallel_calls=10).prefetch(1000000) + + if is_train: + + def key_func(src_data, trg_data): + bucket_width = (max_len + num_buckets - 1) // num_buckets + bucket_id = tf.maximum( + tf.size(input=src_data) // bucket_width, + tf.size(input=trg_data) // bucket_width) + return tf.cast(tf.minimum(num_buckets, bucket_id), dtype=tf.int64) + + def reduce_func(unused_key, windowed_data): + return windowed_data.padded_batch( + batch_size_words, padded_shapes=([None], [None])) + + def window_size_func(key): + bucket_width = (max_len + num_buckets - 1) // num_buckets + key += 1 + size = (num_gpus * batch_size_words // (key * bucket_width)) + return tf.cast(size, dtype=tf.int64) + + src_trg_dataset = src_trg_dataset.filter( + lambda src, trg: tf.logical_and( + tf.size(input=src) <= max_len, + tf.size(input=trg) <= max_len)) + src_trg_dataset = src_trg_dataset.apply( + tf.data.experimental.group_by_window( + key_func=key_func, + reduce_func=reduce_func, + window_size_func=window_size_func)) + + else: + src_trg_dataset = src_trg_dataset.padded_batch( + batch_size * num_gpus, padded_shapes=([None], [None])) + + iterator = tf.data.make_initializable_iterator(src_trg_dataset) + tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) + features, labels = iterator.get_next() + + if is_train: + session.run(iterator.initializer) + if iteration == 1: + session.run(tf.tables_initializer()) + return features, labels + + +def get_pretrained_variables_map(checkpoint_file_path, ignore_scope=None): + reader = tf.train.NewCheckpointReader( + tf.train.latest_checkpoint(checkpoint_file_path)) + saved_shapes = reader.get_variable_to_shape_map() + if ignore_scope is None: + var_names = sorted([(var.name, var.name.split(':')[0]) + for var in tf.global_variables() + if var.name.split(':')[0] in saved_shapes]) + else: + var_names = sorted([(var.name, var.name.split(':')[0]) + for var in tf.global_variables() + if var.name.split(':')[0] in saved_shapes and all( + scope not in var.name + for scope in ignore_scope)]) + restore_vars = [] + name2var = dict( + zip( + map(lambda x: x.name.split(':')[0], tf.global_variables()), + tf.global_variables())) + restore_map = {} + with tf.variable_scope('', reuse=True): + for var_name, saved_var_name in var_names: + curr_var = name2var[saved_var_name] + var_shape = curr_var.get_shape().as_list() + if var_shape == saved_shapes[saved_var_name]: + restore_vars.append(curr_var) + restore_map[saved_var_name] = curr_var + return restore_map diff --git a/modelscope/trainers/nlp/sequence_classification_trainer.py b/modelscope/trainers/nlp/sequence_classification_trainer.py new file mode 100644 index 00000000..ec46e037 --- /dev/null +++ b/modelscope/trainers/nlp/sequence_classification_trainer.py @@ -0,0 +1,226 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import time +from typing import Dict, Optional, Tuple, Union + +import numpy as np + +from modelscope.metainfo import Trainers +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.logger import get_logger + +PATH = None +logger = get_logger(PATH) + + +@TRAINERS.register_module(module_name=Trainers.bert_sentiment_analysis) +class SequenceClassificationTrainer(BaseTrainer): + + def __init__(self, cfg_file: str, *args, **kwargs): + """ A trainer is used for Sequence Classification + + Based on Config file (*.yaml or *.json), the trainer trains or evaluates on a dataset + + Args: + cfg_file (str): the path of config file + Raises: + ValueError: _description_ + """ + super().__init__(cfg_file) + + def train(self, *args, **kwargs): + logger.info('Train') + ... + + def __attr_is_exist(self, attr: str) -> Tuple[Union[str, bool]]: + """get attribute from config, if the attribute does exist, return false + + Example: + >>> self.__attr_is_exist("model path") + out: (model-path, "/workspace/bert-base-sst2") + >>> self.__attr_is_exist("model weights") + out: (model-weights, False) + + Args: + attr (str): attribute str, "model path" -> config["model"][path] + + Returns: + Tuple[Union[str, bool]]:[target attribute name, the target attribute or False] + """ + paths = attr.split(' ') + attr_str: str = '-'.join(paths) + target = self.cfg[paths[0]] if hasattr(self.cfg, paths[0]) else None + + for path_ in paths[1:]: + if not hasattr(target, path_): + return attr_str, False + target = target[path_] + + if target and target != '': + return attr_str, target + return attr_str, False + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + """evaluate a dataset + + evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` + does not exist, read from the config file. + + Args: + checkpoint_path (Optional[str], optional): the model path. Defaults to None. + + Returns: + Dict[str, float]: the results about the evaluation + Example: + {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} + """ + import torch + from easynlp.appzoo import load_dataset + from easynlp.appzoo.dataset import GeneralDataset + from easynlp.appzoo.sequence_classification.model import \ + SequenceClassification + from easynlp.utils import losses + from sklearn.metrics import f1_score + from torch.utils.data import DataLoader + + raise_str = 'Attribute {} is not given in config file!' + + metrics = self.__attr_is_exist('evaluation metrics') + eval_batch_size = self.__attr_is_exist('evaluation batch_size') + test_dataset_path = self.__attr_is_exist('dataset valid file') + + attrs = [metrics, eval_batch_size, test_dataset_path] + for attr_ in attrs: + if not attr_[-1]: + raise AttributeError(raise_str.format(attr_[0])) + + if not checkpoint_path: + checkpoint_path = self.__attr_is_exist('evaluation model_path')[-1] + if not checkpoint_path: + raise ValueError( + 'Argument checkout_path must be passed if the evaluation-model_path is not given in config file!' + ) + + max_sequence_length = kwargs.get( + 'max_sequence_length', + self.__attr_is_exist('evaluation max_sequence_length')[-1]) + if not max_sequence_length: + raise ValueError( + 'Argument max_sequence_length must be passed ' + 'if the evaluation-max_sequence_length does not exist in config file!' + ) + + # get the raw online dataset + raw_dataset = load_dataset(*test_dataset_path[-1].split('/')) + valid_dataset = raw_dataset['validation'] + + # generate a standard dataloader + pre_dataset = GeneralDataset(valid_dataset, checkpoint_path, + max_sequence_length) + valid_dataloader = DataLoader( + pre_dataset, + batch_size=eval_batch_size[-1], + shuffle=False, + collate_fn=pre_dataset.batch_fn) + + # generate a model + model = SequenceClassification.from_pretrained(checkpoint_path) + + # copy from easynlp (start) + model.eval() + total_loss = 0 + total_steps = 0 + total_samples = 0 + hit_num = 0 + total_num = 0 + + logits_list = list() + y_trues = list() + + total_spent_time = 0.0 + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + model.to(device) + for _step, batch in enumerate(valid_dataloader): + try: + batch = { + # key: val.cuda() if isinstance(val, torch.Tensor) else val + # for key, val in batch.items() + key: + val.to(device) if isinstance(val, torch.Tensor) else val + for key, val in batch.items() + } + except RuntimeError: + batch = {key: val for key, val in batch.items()} + + infer_start_time = time.time() + with torch.no_grad(): + label_ids = batch.pop('label_ids') + outputs = model(batch) + infer_end_time = time.time() + total_spent_time += infer_end_time - infer_start_time + + assert 'logits' in outputs + logits = outputs['logits'] + + y_trues.extend(label_ids.tolist()) + logits_list.extend(logits.tolist()) + hit_num += torch.sum( + torch.argmax(logits, dim=-1) == label_ids).item() + total_num += label_ids.shape[0] + + if len(logits.shape) == 1 or logits.shape[-1] == 1: + tmp_loss = losses.mse_loss(logits, label_ids) + elif len(logits.shape) == 2: + tmp_loss = losses.cross_entropy(logits, label_ids) + else: + raise RuntimeError + + total_loss += tmp_loss.mean().item() + total_steps += 1 + total_samples += valid_dataloader.batch_size + if (_step + 1) % 100 == 0: + total_step = len( + valid_dataloader.dataset) // valid_dataloader.batch_size + logger.info('Eval: {}/{} steps finished'.format( + _step + 1, total_step)) + + logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( + total_spent_time, total_spent_time * 1000 / total_samples)) + + eval_loss = total_loss / total_steps + logger.info('Eval loss: {}'.format(eval_loss)) + + logits_list = np.array(logits_list) + eval_outputs = list() + for metric in metrics[-1]: + if metric.endswith('accuracy'): + acc = hit_num / total_num + logger.info('Accuracy: {}'.format(acc)) + eval_outputs.append(('accuracy', acc)) + elif metric == 'f1': + if model.config.num_labels == 2: + f1 = f1_score(y_trues, np.argmax(logits_list, axis=-1)) + logger.info('F1: {}'.format(f1)) + eval_outputs.append(('f1', f1)) + else: + f1 = f1_score( + y_trues, + np.argmax(logits_list, axis=-1), + average='macro') + logger.info('Macro F1: {}'.format(f1)) + eval_outputs.append(('macro-f1', f1)) + f1 = f1_score( + y_trues, + np.argmax(logits_list, axis=-1), + average='micro') + logger.info('Micro F1: {}'.format(f1)) + eval_outputs.append(('micro-f1', f1)) + else: + raise NotImplementedError('Metric %s not implemented' % metric) + # copy from easynlp (end) + + return dict(eval_outputs) diff --git a/modelscope/trainers/nlp/space/__init__.py b/modelscope/trainers/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/nlp/space/dialog_intent_trainer.py b/modelscope/trainers/nlp/space/dialog_intent_trainer.py new file mode 100644 index 00000000..4baaddfe --- /dev/null +++ b/modelscope/trainers/nlp/space/dialog_intent_trainer.py @@ -0,0 +1,151 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Callable, Dict, Optional + +import numpy as np + +from modelscope.metainfo import Trainers +from modelscope.models.nlp.space.model.generator import SpaceGenerator +from modelscope.models.nlp.space.model.model_base import SpaceModelBase +from modelscope.preprocessors.nlp.space.data_loader import \ + get_sequential_data_loader +from modelscope.preprocessors.nlp.space.fields.intent_field import \ + IntentBPETextField +from modelscope.preprocessors.nlp.space.preprocess import intent_preprocess +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.nlp.space.trainer.intent_trainer import IntentTrainer +from modelscope.utils.config import Config, ModelFile +from modelscope.utils.logger import get_logger + +PATH = None +logger = get_logger(PATH) + + +@TRAINERS.register_module(module_name=Trainers.dialog_intent_trainer) +class DialogIntentTrainer(BaseTrainer): + + def __init__(self, + cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, + *args, + **kwargs): + super().__init__(os.path.join(kwargs['model_dir'], kwargs['cfg_name'])) + + def setup_seed(seed): + import random + import torch + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + self.cfg_modify_fn = cfg_modify_fn + self.cfg = self.rebuild_config(self.cfg) + + setup_seed(self.cfg.Trainer.seed) + + # preprocess data + intent_preprocess(self.cfg.Model.init_checkpoint, self.cfg) + # set reader and evaluator + self.bpe = IntentBPETextField(self.cfg.Model.init_checkpoint, self.cfg) + + self.cfg.Model.num_token_embeddings = self.bpe.vocab_size + self.cfg.Model.num_turn_embeddings = self.bpe.max_ctx_turn + 1 + dataset_paths = [ + os.path.join(self.cfg.Dataset.data_dir, + self.cfg.Dataset.trigger_data) + ] + # set data and data status + collate_fn = self.bpe.collate_fn_multi_turn + self.train_label_loader = get_sequential_data_loader( + batch_size=self.cfg.Trainer.batch_size_label, + reader=self.bpe, + hparams=self.cfg, + data_paths=dataset_paths, + collate_fn=collate_fn, + data_type='train') + self.valid_label_loader = get_sequential_data_loader( + batch_size=self.cfg.Trainer.batch_size_label, + reader=self.bpe, + hparams=self.cfg, + data_paths=dataset_paths, + collate_fn=collate_fn, + data_type='valid') + self.test_label_loader = get_sequential_data_loader( + batch_size=self.cfg.Trainer.batch_size_label, + reader=self.bpe, + hparams=self.cfg, + data_paths=dataset_paths, + collate_fn=collate_fn, + data_type='test') + + # set generator + self.generator = SpaceGenerator.create(self.cfg, reader=self.bpe) + self._load_model(**kwargs) + + def _load_model(self, **kwargs): + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda() if self.cfg.use_gpu else array + + # construct model + if 'model' in kwargs: + self.model = kwargs['model'] + else: + self.model = SpaceModelBase.create( + kwargs['model_dir'], + self.cfg, + reader=self.bpe, + generator=self.generator) + + import torch + # multi-gpu + if self.cfg.Trainer.gpu > 1 and torch.cuda.device_count() > 1: + self.model = torch.nn.DataParallel(self.model) + + # construct trainer + self.trainer = IntentTrainer( + self.model, to_tensor, self.cfg, reader=self.bpe) + num_batches = len(self.train_label_loader) + self.trainer.set_optimizers(num_training_steps_per_epoch=num_batches) + # load model, optimizer and lr_scheduler + self.trainer.load() + + def rebuild_config(self, cfg: Config): + if self.cfg_modify_fn is not None: + return self.cfg_modify_fn(cfg) + return cfg + + def train(self, *args, **kwargs): + logger.info('Train') + + self.trainer.train( + train_label_iter=self.train_label_loader, + valid_label_iter=self.valid_label_loader) + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + logger.info('Evaluate') + self.cfg.do_infer = True + + # get best checkpoint path + pos = checkpoint_path.rfind('/') + checkpoint_name = checkpoint_path[pos + 1:] + checkpoint_dir = checkpoint_path[:pos] + + assert checkpoint_name == ModelFile.TORCH_MODEL_BIN_FILE + kwargs['model_dir'] = checkpoint_dir + self._load_model(**kwargs) + self.trainer.infer( + data_iter=self.test_label_loader, + ex_data_iter=self.train_label_loader) diff --git a/modelscope/trainers/nlp/space/dialog_modeling_trainer.py b/modelscope/trainers/nlp/space/dialog_modeling_trainer.py new file mode 100644 index 00000000..aa6bb69d --- /dev/null +++ b/modelscope/trainers/nlp/space/dialog_modeling_trainer.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import time +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np + +from modelscope.metainfo import Trainers +from modelscope.models.nlp.space.model.generator import SpaceGenerator +from modelscope.models.nlp.space.model.model_base import SpaceModelBase +from modelscope.preprocessors.nlp import MultiWOZBPETextField +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.nlp.space.eval import MultiWOZEvaluator +from modelscope.trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer +from modelscope.utils.config import Config, ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def setup_seed(seed: int): + import random + import torch + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +@TRAINERS.register_module(module_name=Trainers.dialog_modeling_trainer) +class DialogModelingTrainer(BaseTrainer): + + def __init__(self, + cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, + *args, + **kwargs): + + super().__init__(os.path.join(kwargs['model_dir'], kwargs['cfg_name'])) + + self.cfg_modify_fn = cfg_modify_fn + self.cfg = self.rebuild_config(self.cfg) + + setup_seed(self.cfg.Trainer.seed) + + # set reader and evaluator + self.bpe = MultiWOZBPETextField(self.cfg, **kwargs) + + self.cfg.Model.num_token_embeddings = self.bpe.vocab_size + self.cfg.Model.num_turn_embeddings = self.bpe.max_ctx_turn + 1 + + if 'work_dir' in kwargs: + self.cfg.Trainer.save_dir = kwargs['work_dir'] + else: + self.cfg.Trainer.save_dir = './default_save_dir' + + # set data and data status + self.train_data = self.bpe.get_batches('train') + self.dev_data = self.bpe.get_batches('dev') + + self.evaluator = MultiWOZEvaluator(reader=self.bpe, **kwargs) + # set generator + self.generator = SpaceGenerator.create(self.cfg, reader=self.bpe) + self._load_model(**kwargs) + + def _load_model(self, **kwargs): + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda( + ) if self.cfg.use_gpu and torch.cuda.is_available() else array + + # construct model + if 'model' in kwargs: + self.model = kwargs['model'] + else: + self.model = SpaceModelBase.create( + kwargs['model_dir'], + self.cfg, + reader=self.bpe, + generator=self.generator) + + import torch + # multi-gpu + if self.cfg.Trainer.gpu > 1 and torch.cuda.device_count() > 1: + self.model = torch.nn.DataParallel(self.model) + + # construct trainer + self.trainer = MultiWOZTrainer( + self.model, + to_tensor, + self.cfg, + reader=self.bpe, + evaluator=self.evaluator) + self.trainer.set_optimizers() + # load model, optimizer and lr_scheduler + self.trainer.load() + + def rebuild_config(self, cfg: Config): + if self.cfg_modify_fn is not None: + return self.cfg_modify_fn(cfg) + return cfg + + def train(self, *args, **kwargs): + logger.info('Train') + + self.trainer.train(train_data=self.train_data, dev_data=self.dev_data) + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + logger.info('Evaluate') + self.cfg.do_infer = True + + # get best checkpoint path + pos = checkpoint_path.rfind('/') + checkpoint_name = checkpoint_path[pos + 1:] + checkpoint_dir = checkpoint_path[:pos] + + assert checkpoint_name == ModelFile.TORCH_MODEL_BIN_FILE + kwargs['model_dir'] = checkpoint_dir + self._load_model(**kwargs) + self.trainer.infer(data_type='test') diff --git a/modelscope/trainers/nlp/space/eval.py b/modelscope/trainers/nlp/space/eval.py new file mode 100644 index 00000000..f315ff07 --- /dev/null +++ b/modelscope/trainers/nlp/space/eval.py @@ -0,0 +1,952 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright from https://github.com/thu-spmi/LABES +# Copyright from https://github.com/TonyNemo/UBAR-MultiWOZ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from collections import Counter + +import json +import numpy as np +from nltk.util import ngrams +from sklearn.metrics import f1_score + +from modelscope.utils.nlp.space import ontology, utils +from modelscope.utils.nlp.space.clean_dataset import clean_slot_values + + +def similar(a, b): + return a == b or a in b or b in a or a.split()[0] == b.split( + )[0] or a.split()[-1] == b.split()[-1] + + +def setsub(a, b): + junks_a = [] + useless_constraint = [ + 'temperature', 'week', 'est ', 'quick', 'reminder', 'near' + ] + for i in a: + flg = False + for j in b: + if similar(i, j): + flg = True + if not flg: + junks_a.append(i) + for junk in junks_a: + flg = False + for item in useless_constraint: + if item in junk: + flg = True + if not flg: + return False + return True + + +def setsim(a, b): + a, b = set(a), set(b) + return setsub(a, b) and setsub(b, a) + + +def DA_evaluate(preds, labels): + preds = np.array(preds) + labels = np.array(labels) + results = {} + + for avg_name in ['micro']: + my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name) + results['f1_{}'.format(avg_name)] = my_f1_score + + return results + + +class BLEUScorer(object): + # BLEU score calculator via GentScorer interface + # it calculates the BLEU-4 by taking the entire corpus in + # Calulate based multiple candidates against multiple references + def __init__(self): + pass + + def score(self, parallel_corpus): + + # containers + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + weights = [0.25, 0.25, 0.25, 0.25] + + # accumulate ngram statistics + for hyps, refs in parallel_corpus: + hyps = [hyp.split() for hyp in hyps] + refs = [ref.split() for ref in refs] + for hyp in hyps: + + for i in range(4): + # accumulate ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max( + max_counts.get(ng, 0), refcnts[ng]) + clipcnt = \ + dict((ng, min(count, max_counts[ng])) for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: + break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r += bestmatch[1] + c += len(hyp) + + # computing bleu score + p0 = 1e-7 + bp = \ + 1 if c > r else math.exp(1 - float(r) / float(c)) + p_ns = \ + [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)] + s = \ + math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n) + bleu = bp * math.exp(s) + return bleu * 100 + + +"""" +For the data preparation and evaluation on MultiWOZ2.0/2.1, +we refer to the code of UBAR (https://github.com/TonyNemo/UBAR-MultiWOZ) +""" + + +class MultiWOZEvaluator(object): + + def __init__(self, reader, **kwargs): + self.reader = reader + self.domains = ontology.all_domains + self.all_data = self.reader.data + self.test_data = self.reader.test + + self.bleu_scorer = BLEUScorer() + + self.all_info_slot = [] + for d, s_list in ontology.informable_slots.items(): + for s in s_list: + self.all_info_slot.append(d + '-' + s) + + # only evaluate these slots for dialog success + self.requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + self.db_dir = kwargs['data_dir'] + + def pack_dial(self, data): + dials = {} + for turn in data: + dial_id = turn['dial_id'] + if dial_id not in dials: + dials[dial_id] = [] + dials[dial_id].append(turn) + return dials + + def validation_metric(self, data, fout=None): + bleu = self.bleu_metric(data) + # accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data) + success, match, req_offer_counts, dial_num = \ + self.context_to_response_eval(data, same_eval_as_cambridge=True, fout=fout) + return bleu, success, match + + def bleu_metric(self, data, eval_dial_list=None): + gen, truth = [], [] + for row in data: + if eval_dial_list and row[ + 'dial_id'] + '.json' not in eval_dial_list: + continue + gen.append(row['resp_gen']) + truth.append(row['resp']) + wrap_generated = [[_] for _ in gen] + wrap_truth = [[_] for _ in truth] + if gen and truth: + try: + sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth)) + except Exception: + sc = 0.0 + else: + sc = 0.0 + return sc + + def context_to_response_eval(self, + data, + eval_dial_list=None, + same_eval_as_cambridge=False, + fout=None): + dials = self.pack_dial(data) + counts = {} + for req in self.requestables: + counts[req + '_total'] = 0 + counts[req + '_offer'] = 0 + + dial_num, successes, matches = 0, 0, 0 + + for dial_id in dials: + if eval_dial_list and dial_id + '.json' not in eval_dial_list: + continue + dial = dials[dial_id] + reqs = {} + goal = {} + if '.json' not in dial_id and '.json' in list( + self.all_data.keys())[0]: + dial_id = dial_id + '.json' + for domain in ontology.all_domains: + if self.all_data[dial_id]['goal'].get(domain): + true_goal = self.all_data[dial_id]['goal'] + goal = self._parseGoal(goal, true_goal, domain) + + for domain in goal.keys(): + reqs[domain] = goal[domain]['requestable'] + + success, match, stats, counts = \ + self._evaluateGeneratedDialogue(dial, goal, reqs, counts, + same_eval_as_cambridge=same_eval_as_cambridge, fout=fout) + + successes += success + matches += match + dial_num += 1 + + succ_rate = successes / (float(dial_num) + 1e-10) * 100 + match_rate = matches / (float(dial_num) + 1e-10) * 100 + return succ_rate, match_rate, counts, dial_num + + def _evaluateGeneratedDialogue(self, + dialog, + goal, + real_requestables, + counts, + soft_acc=False, + same_eval_as_cambridge=False, + fout=None): + """Evaluates the dialogue created by the model. + First we load the user goal of the dialogue, then for each turn + generated by the system we look for key-words. + For the Inform rate we look whether the entity was proposed. + For the Success rate we look for requestables slots""" + # for computing corpus success + requestables = self.requestables + + # CHECK IF MATCH HAPPENED + provided_requestables = {} + venue_offered = {} + domains_in_goal = [] + log = [] + bspans = {} + + for domain in goal.keys(): + venue_offered[domain] = [] + provided_requestables[domain] = [] + domains_in_goal.append(domain) + + for t, turn in enumerate(dialog): + if t == 0: + continue + if fout is not None: + log.append({ + 'turn_num': turn['turn_num'], + 'turn_domain': turn['dspn'], + 'user': turn['user'], + 'aspn': turn['aspn'], + 'aspn_gen': turn['aspn_gen'], + 'resp': turn['resp'], + 'resp_gen': turn['resp_gen'], + 'pointer': turn['pointer'], + }) + + sent_t = turn['resp_gen'] + + for domain in goal.keys(): + # for computing success + if same_eval_as_cambridge: + # [restaurant_name], [hotel_name] instead of [value_name] + if self.reader.use_true_domain_for_ctr_eval: + dom_pred = [d[1:-1] for d in turn['dspn'].split()] + else: + dom_pred = [d[1:-1] for d in turn['dspn_gen'].split()] + + if domain not in dom_pred: # fail + continue + if '[value_name]' in sent_t or '[value_id]' in sent_t: + if domain in [ + 'restaurant', 'hotel', 'attraction', 'train' + ]: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + if not self.reader.use_true_curr_bspn and not self.reader.use_true_bspn_for_ctr_eval: + bspn = turn['bspn_gen'] + else: + bspn = turn['bspn'] + + constraint_dict = self.reader.bspan_to_constraint_dict( + bspn) + if constraint_dict.get(domain): + venues = self.reader.db.queryJsons( + domain, + constraint_dict[domain], + return_name=True) + else: + venues = [] + + if len(venue_offered[domain]) == 0 and venues: + + venue_offered[domain] = venues + bspans[domain] = constraint_dict[domain] + else: + flag = False + for ven in venues: + if ven not in venue_offered[domain]: + flag = True + break + if flag and venues: # sometimes there are no results so sample won't work + venue_offered[domain] = venues + bspans[domain] = constraint_dict[domain] + else: # not limited so we can provide one + venue_offered[domain] = '[value_name]' + + # ATTENTION: assumption here - we didn't provide phone or address twice! etc + for requestable in requestables: + if requestable == 'reference': + if '[value_reference]' in sent_t: + if domain in ['restaurant', 'hotel', 'train']: + if 'booked' in turn['pointer'] or 'ok' in turn[ + 'pointer'] or '[value_reference]' in turn[ + 'resp']: + # if pointer was allowing for that? + provided_requestables[domain].append( + 'reference') + else: + provided_requestables[domain].append( + 'reference') + else: + if '[value_' + requestable + ']' in sent_t: + provided_requestables[domain].append(requestable) + + # if name was given in the task + for domain in goal.keys(): + # if name was provided for the user, the match is being done automatically + if 'name' in goal[domain]['informable']: + venue_offered[domain] = '[value_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[value_name]' + + if domain == 'train': + if not venue_offered[domain] and 'id' not in goal[domain][ + 'requestable']: + venue_offered[domain] = '[value_name]' + """ + Given all inform and requestable slots + we go through each domain from the user goal + and check whether right entity was provided and + all requestable slots were given to the user. + The dialogue is successful if that's the case for all domains. + """ + # HARD EVAL + stats = { + 'restaurant': [0, 0, 0], + 'hotel': [0, 0, 0], + 'attraction': [0, 0, 0], + 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], + 'police': [0, 0, 0] + } + + match = 0 + success = 0 + # MATCH + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.reader.db.queryJsons( + domain, goal[domain]['informable'], return_name=True) + if type(venue_offered[domain] + ) is str and '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + elif len(venue_offered[domain]) > 0 and len( + set(venue_offered[domain]) & set(goal_venues)) > 0: + match += 1 + match_stat = 1 + else: + if '_name]' in venue_offered[domain]: + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + if soft_acc: + match = float(match) / len(goal.keys()) + else: + if match == len(goal.keys()): + match = 1.0 + else: + match = 0.0 + + for domain in domains_in_goal: + for request in real_requestables[domain]: + counts[request + '_total'] += 1 + if request in provided_requestables[domain]: + counts[request + '_offer'] += 1 + + # SUCCESS + if fout is not None: + for domain in domains_in_goal: + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in real_requestables[domain]: + if request in provided_requestables[domain]: + domain_success += 1 + + if domain_success == len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if soft_acc: + success = float(success) / len(real_requestables) + else: + if success >= len(real_requestables): + success = 1 + else: + success = 0 + else: + if match == 1.0: + for domain in domains_in_goal: + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in real_requestables[domain]: + if request in provided_requestables[domain]: + domain_success += 1 + + if domain_success == len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if soft_acc: + success = float(success) / len(real_requestables) + else: + if success >= len(real_requestables): + success = 1 + else: + success = 0 + + if fout is not None and success == 0: + sample = { + dialog[0]['dial_id']: { + 'log': log, + 'real_requestables': real_requestables, + 'provided_requestables': provided_requestables + } + } + line = json.dumps(sample) + fout.write(line) + fout.write('\n') + + return success, match, stats, counts + + def _parseGoal(self, goal, true_goal, domain): + """Parses user goal into dictionary format.""" + goal[domain] = {} + goal[domain] = {'informable': {}, 'requestable': [], 'booking': []} + if 'info' in true_goal[domain]: + if domain == 'train': + # we consider dialogues only where train had to be booked! + if 'book' in true_goal[domain]: + goal[domain]['requestable'].append('reference') + if 'reqt' in true_goal[domain]: + if 'id' in true_goal[domain]['reqt']: + goal[domain]['requestable'].append('id') + else: + if 'reqt' in true_goal[domain]: + for s in true_goal[domain]['reqt']: # addtional requests: + if s in [ + 'phone', 'address', 'postcode', 'reference', + 'id' + ]: + # ones that can be easily delexicalized + goal[domain]['requestable'].append(s) + if 'book' in true_goal[domain]: + goal[domain]['requestable'].append('reference') + + for s, v in true_goal[domain]['info'].items(): + s_, v_ = clean_slot_values(self.db_dir, domain, s, v) + if len(v_.split()) > 1: + v_ = ' '.join( + [token.text for token in self.reader.nlp(v_)]).strip() + goal[domain]['informable'][s_] = v_ + + if 'book' in true_goal[domain]: + goal[domain]['booking'] = true_goal[domain]['book'] + return goal + + +class GenericEvaluator: + + def __init__(self, reader): + self.reader = reader + self.metric_dict = {} + + def pack_dial(self, data): + dials = {} + for turn in data: + dial_id = turn['dial_id'] + if dial_id not in dials: + dials[dial_id] = [] + dials[dial_id].append(turn) + return dials + + def run_metrics(self, results): + raise ValueError('Please specify the evaluator first') + + def bleu_metric(self, data, type='bleu'): + gen, truth = [], [] + for row in data: + gen.append(self.clean(row['resp_gen'])) + # gen.append(self.clean(row['resp'])) + truth.append(self.clean(row['resp'])) + wrap_generated = [[_] for _ in gen] + wrap_truth = [[_] for _ in truth] + sc = BLEUScorer().score(zip(wrap_generated, wrap_truth)) + return sc + + def _normalize_constraint(self, + constraint, + ignore_dontcare=False, + intersection=True): + """ + Normalize belief span, e.g. delete repeated words + :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} + :param intersection: if true, only keeps the words that appear in th ontology + we set intersection=True as in previous works + :returns: normalized constraint dict + e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} + """ + normalized = {} + for s in self.informable_slots: + normalized[s] = '' + for s, v in constraint.items(): + if ignore_dontcare and v == 'dontcare': + continue + if intersection and v != 'dontcare' and v not in self.entities_flat: + continue + + normalized[s] = v + + return normalized + + def _normalize_act(self, aspn, intersection=False): + aspn_list = aspn.split('|') + normalized = {} + for i, v in enumerate(aspn_list): + seq = v.strip() + word_set = set() + for w in seq.split(): + if intersection: + if self.reader.act_order[i] == 'av': + if '[value' in w: + word_set.add(w) + else: + if w in self.requestable_slots: + word_set.add(w) + else: + word_set.add(w) + normalized[self.reader.act_order[i]] = word_set + return normalized + + def tracker_metric(self, data, normalize=True): + # turn level metric + tp, fp, fn, db_correct = 0, 0, 0, 0 + goal_accr, slot_accr, total = 0, {}, 1e-8 + for s in self.informable_slots: + slot_accr[s] = 0 + + for row in data: + if normalize: + gen = self._normalize_constraint(row['bspn_gen']) + truth = self._normalize_constraint(row['bspn']) + else: + gen = self._normalize_constraint( + row['bspn_gen'], intersection=False) + truth = self._normalize_constraint( + row['bspn'], intersection=False) + valid = 'thank' not in row['user'] and 'bye' not in row['user'] + if valid: + for slot, value in gen.items(): + if value in truth[slot]: + tp += 1 + else: + fp += 1 + for slot, value in truth.items(): + if value not in gen[slot]: + fn += 1 + + if truth and valid: + total += 1 + for s in self.informable_slots: + if gen[s] == truth[s]: + slot_accr[s] += 1 + if gen == truth: + goal_accr += 1 + if row.get('db_gen') and row.get('db_match'): + if row['db_gen'] == row['db_match']: + db_correct += 1 + precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + goal_accr /= total + db_correct /= total + for s in slot_accr: + slot_accr[s] /= total + return precision, recall, f1, goal_accr, slot_accr, db_correct + + def request_metric(self, data): + # dialog level metric + dials = self.pack_dial(data) + tp, fp, fn = 0, 0, 0 + for dial_id in dials: + truth_req, gen_req = set(), set() + dial = dials[dial_id] + for turn_num, turn in enumerate(dial): + resp_gen_token = self.clean(turn['resp_gen']).split() + resp_token = self.clean(turn['resp']).split() + for w in resp_gen_token: + if '[value_' in w and w.endswith( + ']') and w != '[value_name]': + gen_req.add(w[1:-1].split('_')[1]) + for w in resp_token: + if '[value_' in w and w.endswith( + ']') and w != '[value_name]': + truth_req.add(w[1:-1].split('_')[1]) + for req in gen_req: + if req in truth_req: + tp += 1 + else: + fp += 1 + for req in truth_req: + if req not in gen_req: + fn += 1 + precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + return f1, precision, recall + + def act_metric(self, data): + # turn level metric + tp, fp, fn = { + 'all_s': 0, + 'all_v': 0 + }, { + 'all_s': 0, + 'all_v': 0 + }, { + 'all_s': 0, + 'all_v': 0 + } + for s in self.requestable_slots: + tp[s], fp[s], fn[s] = 0, 0, 0 + tp['[value_%s]' % s], fp['[value_%s]' % s], fn['[value_%s]' + % s] = 0, 0, 0 + + for row in data: + gen = self._normalize_act(row['aspn_gen']) + truth = self._normalize_act(row['aspn']) + valid = 'thank' not in row['user'] and 'bye' not in row['user'] + if valid: + # how well the act decoder captures user's requests + for value in gen['av']: + if value in truth['av']: + tp['all_v'] += 1 + if tp.get(value): + tp[value] += 1 + else: + fp['all_v'] += 1 + if fp.get(value): + fp[value] += 1 + for value in truth['av']: + if value not in gen['av']: + fn['all_v'] += 1 + if fn.get(value): + fn[value] += 1 + + # how accurately the act decoder predicts system's question + if 'as' not in gen: + continue + for slot in gen['as']: + if slot in truth['as']: + tp['all_s'] += 1 + if tp.get(slot): + tp[slot] += 1 + else: + fp['all_s'] += 1 + if fp.get(slot): + fp[slot] += 1 + for slot in truth['as']: + if slot not in gen['as']: + fn['all_s'] += 1 + if fn.get(slot): + fn[slot] += 1 + + result = {} + for k, v in tp.items(): + precision, recall = tp[k] / (tp[k] + fp[k] + 1e-8), tp[k] / ( + tp[k] + fn[k] + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + result[k] = [f1, precision, recall] + return result + + +""" +For the data preparation and evaluation on In-Car Assistant/CamRest, +we refer to the code of LABES (https://github.com/thu-spmi/LABES) +""" + + +class CamRestEvaluator(GenericEvaluator): + + def __init__(self, reader): + super().__init__(reader) + self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( + self.reader.ontology_path) + self.informable_slots = self.reader.otlg.informable_slots + self.requestable_slots = self.reader.otlg.requestable_slots + + def run_metrics(self, results): + metrics = {} + bleu = self.bleu_metric(results) + p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(results) + match = self.match_metric(results) + req_f1, req_p, req_r = self.request_metric(results) + + metrics['bleu'] = bleu + metrics['match'] = match + metrics['req_f1'] = req_f1 + metrics['joint_goal'] = goal_acc + metrics['slot_accu'] = slot_acc + metrics['slot-p/r/f1'] = (p, r, f1) + metrics['db_acc'] = db_acc + + return metrics + + def get_entities(self, entity_path): + entities_flat = [] + entitiy_to_slot_dict = {} + raw_entities = json.loads(open(entity_path).read().lower()) + for s in raw_entities['informable']: + entities_flat.extend(raw_entities['informable'][s]) + for v in raw_entities['informable'][s]: + entitiy_to_slot_dict[v] = s + return entities_flat, entitiy_to_slot_dict + + def constraint_same(self, truth_cons, gen_cons): + if not truth_cons and not gen_cons: + return True + if not truth_cons or not gen_cons: + return False + return setsim(gen_cons, truth_cons) + + def match_metric(self, data): + dials = self.pack_dial(data) + match, total = 0, 1e-8 + for dial_id in dials: + dial = dials[dial_id] + truth_cons, gen_cons = {'1': '', '2': '', '3': ''}, None + for turn_num, turn in enumerate(dial): + # find the last turn which the system provide an entity + if '[value' in turn['resp_gen']: + gen_cons = self._normalize_constraint( + turn['bspn_gen'], ignore_dontcare=True) + if '[value' in turn['resp']: + truth_cons = self._normalize_constraint( + turn['bspn'], ignore_dontcare=True) + if not gen_cons: + # if no entity is provided, choose the state of the last dialog turn + gen_cons = self._normalize_constraint( + dial[-1]['bspn_gen'], ignore_dontcare=True) + if list(truth_cons.values()) != ['', '', '']: + if gen_cons == truth_cons: + match += 1 + total += 1 + + return match / total + + def clean(self, resp): + # we use the same clean process as in Sequicity, SEDST, FSDM + # to ensure comparable results + resp = resp.replace(f'{self.reader.sos_r_token} ', '') + resp = resp.replace(f' {self.reader.eos_r_token}', '') + resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' + for value, slot in self.entitiy_to_slot_dict.items(): + + resp = utils.clean_replace(resp, value, '[value_%s]' % slot) + return resp + + +class KvretEvaluator(GenericEvaluator): + + def __init__(self, reader): + super().__init__(reader) + self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( + self.reader.ontology_path) + self.informable_slots = self.reader.otlg.informable_slots + self.requestable_slots = self.reader.otlg.requestable_slots + + def run_metrics(self, results): + metrics = {} + bleu = self.bleu_metric(results) + p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric( + results, normalize=True) + match = self.match_metric(results) + req_f1, req_p, req_r = self.request_metric(results) + + metrics['bleu'] = bleu + metrics['match'] = match + metrics['req_f1'] = req_f1 + metrics['joint_goal'] = goal_acc + metrics['slot_accu'] = slot_acc + metrics['slot-p/r/f1'] = (p, r, f1) + metrics['db_acc'] = db_acc + + return metrics + + def _normalize_constraint(self, + constraint, + ignore_dontcare=False, + intersection=True): + """ + Normalize belief span, e.g. delete repeated words + :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} + :param intersection: if true, only keeps the words that appear in th ontology + we set intersection=True as in previous works + :returns: normalized constraint dict + e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} + """ + junk = [ + 'good', 'great', 'quickest', 'shortest', 'route', 'week', + 'fastest', 'nearest', 'next', 'closest', 'way', 'mile', 'activity', + 'restaurant', 'appointment' + ] + normalized = {} + for s in self.informable_slots: + normalized[s] = '' + for s, v in constraint.items(): + for j in junk: + v = ' '.join(v.replace(j, '').split()) + if intersection and v not in self.entities_flat: + continue + + if s in self.informable_slots: + normalized[s] = v + else: + # TODO only use slot (not domain) in s for matching !!! + pass + + return normalized + + def get_entities(self, entity_path): + entities_flat = [] + entitiy_to_slot_dict = {} + + entitiy_to_slot_dict = self.reader.entity_dict + for s in entitiy_to_slot_dict: + if s not in entities_flat: + entities_flat.append(s) + return entities_flat, entitiy_to_slot_dict + + def constraint_same(self, truth_cons, gen_cons): + if not truth_cons and not gen_cons: + return True + if not truth_cons or not gen_cons: + return False + return setsim(gen_cons, truth_cons) + + def match_metric(self, data): + dials = self.pack_dial(data) + match, total = 0, 1e-8 + for dial_id in dials: + dial = dials[dial_id] + truth_cons, gen_cons = { + '1': '', + '2': '', + '3': '', + '4': '', + '5': '', + '6': '', + '7': '', + '8': '', + '9': '', + '10': '', + '11': '' + }, None + for turn_num, turn in enumerate(dial): + # find the last turn which the system provide an entity + if '[value' in turn['resp_gen']: + gen_cons = self._normalize_constraint( + turn['bspn_gen'], ignore_dontcare=True) + if '[value' in turn['resp']: + truth_cons = self._normalize_constraint( + turn['bspn'], ignore_dontcare=True) + + if not gen_cons: + # if no entity is provided, choose the state of the last dialog turn + gen_cons = self._normalize_constraint( + dial[-1]['bspn_gen'], ignore_dontcare=True) + + if list(truth_cons.values()) != [''] * 11: + gen_cons = [x for x in gen_cons.values() if x] + truth_cons = [x for x in truth_cons.values() if x] + if self.constraint_same(gen_cons, truth_cons): + match += 1 + total += 1 + + return match / total + + def clean(self, resp): + # we use the same clean process as in Sequicity, SEDST, FSDM + # to ensure comparable results + resp = resp.replace(f'{self.reader.sos_r_token} ', '') + resp = resp.replace(f' {self.reader.eos_r_token}', '') + resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' + for value, slot in self.entitiy_to_slot_dict.items(): + resp = utils.clean_replace(resp, value, '[value_%s]' % slot) + return resp diff --git a/modelscope/trainers/nlp/space/metrics/__init__.py b/modelscope/trainers/nlp/space/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/nlp/space/metrics/metrics_tracker.py b/modelscope/trainers/nlp/space/metrics/metrics_tracker.py new file mode 100644 index 00000000..340077a6 --- /dev/null +++ b/modelscope/trainers/nlp/space/metrics/metrics_tracker.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +from collections import defaultdict + + +class MetricsTracker(object): + """ Tracking metrics. """ + + def __init__(self): + self.metrics_val = defaultdict(float) # for one batch + self.metrics_avg = defaultdict(float) # avg batches + self.num_samples = 0 + + def update(self, metrics, num_samples): + for key, val in metrics.items(): + if val is not None: + val = float(val) # [val] -> val + self.metrics_val[key] = val + avg_val = \ + (self.metrics_avg.get(key, 0) * self.num_samples + val * num_samples) / \ + (self.num_samples + num_samples) + self.metrics_avg[key] = avg_val + self.num_samples += num_samples + + def clear(self): + self.metrics_val = defaultdict(float) + self.metrics_avg = defaultdict(float) + self.num_samples = 0 + + def items(self): + return self.metrics_avg.items() + + def get(self, name): + if self.num_samples == 0: + raise ValueError('There is no data in Metrics.') + return self.metrics_avg.get(name) + + def state_dict(self): + return { + 'metrics_val': self.metrics_val, + 'metrics_avg': self.metrics_avg, + 'num_samples': self.num_samples, + } + + def load_state_dict(self, state_dict): + self.metrics_val = state_dict['metrics_val'] + self.metrics_avg = state_dict['metrics_avg'] + self.num_samples = state_dict['num_samples'] + + def value(self): + metric_strs = [] + for key, val in self.metrics_val.items(): + metric_str = f'{key.upper()}-{val:.3f}' + metric_strs.append(metric_str) + if 'token_nll' in self.metrics_val: + metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}" + metric_strs.append(metric_str) + metric_strs = ' '.join(metric_strs) + return metric_strs + + def summary(self): + metric_strs = [] + for key, val in self.metrics_avg.items(): + metric_str = f'{key.upper()}-{val:.3f}' + metric_strs.append(metric_str) + if 'token_nll' in self.metrics_avg: + metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}" + metric_strs.append(metric_str) + metric_strs = ' '.join(metric_strs) + return metric_strs diff --git a/modelscope/trainers/nlp/space/trainer/__init__.py b/modelscope/trainers/nlp/space/trainer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/nlp/space/trainer/gen_trainer.py b/modelscope/trainers/nlp/space/trainer/gen_trainer.py new file mode 100644 index 00000000..05efa138 --- /dev/null +++ b/modelscope/trainers/nlp/space/trainer/gen_trainer.py @@ -0,0 +1,734 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import time +from collections import OrderedDict + +import json +import numpy as np +import torch +from tqdm import tqdm +from transformers.optimization import AdamW, get_linear_schedule_with_warmup + +from modelscope.trainers.nlp.space.metrics.metrics_tracker import \ + MetricsTracker +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.nlp.space import ontology + + +class Trainer(object): + + def __init__(self, + model, + to_tensor, + config, + logger=None, + lr_scheduler=None, + optimizer=None, + reader=None, + evaluator=None): + self.to_tensor = to_tensor + + self.do_train = config.do_train + self.do_infer = config.do_infer + if self.do_train: + self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ + 0] == '-' + self.valid_metric_name = config.Trainer.valid_metric_name[1:] + self.num_epochs = config.Trainer.num_epochs + self.save_dir = config.Trainer.save_dir + self.log_steps = config.Trainer.log_steps + self.valid_steps = config.Trainer.valid_steps + self.save_checkpoint = config.Trainer.save_checkpoint + self.save_summary = config.Trainer.save_summary + self.lr = config.Model.lr + self.weight_decay = config.Model.weight_decay + self.batch_size = config.Trainer.batch_size + self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps + self.warmup_steps = config.Model.warmup_steps + self.gpu = config.Trainer.gpu + + self.lr_scheduler = lr_scheduler + self.optimizer = optimizer + + self.model = model + self.func_model = self.model.module if self.gpu > 1 and config.use_gpu else self.model + self.reader = reader + self.evaluator = evaluator + self.tokenizer = reader.tokenizer + + self.logger = logger or get_logger() + + self.batch_metrics_tracker = MetricsTracker() + self.token_metrics_tracker = MetricsTracker() + + if self.do_train: + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + self.best_valid_metric = float( + 'inf' if self.is_decreased_valid_metric else '-inf') + self.epoch = 0 + + def decode_generated_bspn_resp(self, generated): + """ + decode generated + return decoded ('bspn', 'resp') + """ + decoded = {} + eos_r_id = self.reader.eos_r_id + eos_b_id = self.reader.eos_b_id + + # eos_r may not exists if gpt2 generated repetitive words. + if eos_r_id in generated: + eos_r_idx = generated.index(eos_r_id) + else: + eos_r_idx = len(generated) - 1 + # self.logger.info('eos_r not in generated: ' + self.tokenizer.decode(generated)) + + # predicted bspn, resp + eos_b_idx = generated.index(eos_b_id) + decoded['bspn'] = generated[:eos_b_idx + 1] + decoded['resp'] = generated[eos_b_idx + 1:eos_r_idx + 1] + return decoded + + def decode_generated_act_resp(self, generated): + """ + decode generated + return decoded['resp'] ('bspn', 'aspn') + """ + decoded = {} + eos_a_id = self.reader.eos_a_id + eos_r_id = self.reader.eos_r_id + # eos_b_id = self.reader.eos_b_id + + # eos_r may not exists if gpt2 generated repetitive words. + if eos_r_id in generated: + eos_r_idx = generated.index(eos_r_id) + else: + eos_r_idx = len(generated) - 1 + msg = 'eos_r not in generated: ' + self.tokenizer.decode(generated) + self.logger.info(msg) + + if self.reader.use_true_curr_aspn: # only predict resp + decoded['resp'] = generated[:eos_r_idx + 1] + else: # predicted aspn, resp + eos_a_idx = generated.index(eos_a_id) + decoded['aspn'] = generated[:eos_a_idx + 1] + decoded['resp'] = generated[eos_a_idx + 1:eos_r_idx + 1] + return decoded + + def decode_generated_bspn(self, generated): + eos_b_id = self.reader.eos_b_id + if eos_b_id in generated: + eos_b_idx = generated.index(eos_b_id) + else: + eos_b_idx = len(generated) - 1 + return generated[:eos_b_idx + 1] + + def set_optimizers(self): + """ + Setup the optimizer and the learning rate scheduler. + + from transformers.Trainer + + parameters from cfg: lr (1e-3); warmup_steps + """ + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'norm.weight'] + optimizer_grouped_parameters = [ + { + 'params': [ + p for n, p in self.model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + 'weight_decay': + self.weight_decay, + }, + { + 'params': [ + p for n, p in self.model.named_parameters() + if any(nd in n for nd in no_decay) + ], + 'weight_decay': + 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) + + num_training_steps = \ + self.reader.set_stats['train']['num_training_steps_per_epoch'] \ + * self.num_epochs \ + // self.gradient_accumulation_steps + num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( + num_training_steps * 0.1) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def train(self, train_data, dev_data): + # log info + set_stats = self.reader.set_stats['train'] + self.logger.info('***** Running training *****') + self.logger.info( + ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', + set_stats['num_training_steps_per_epoch']) + self.logger.info(' Num Turns = %d', set_stats['num_turns']) + self.logger.info(' Num Dialogs = %d', set_stats['num_dials']) + self.logger.info(' Num Epochs = %d', self.num_epochs) + self.logger.info(' Batch size = %d', self.batch_size) + self.logger.info(' Gradient Accumulation steps = %d', + self.gradient_accumulation_steps) + steps = set_stats[ + 'num_training_steps_per_epoch'] * self.num_epochs // self.gradient_accumulation_steps + msg = ' Total optimization steps = %d' % steps + self.logger.info(msg) + + # begin training + num_epochs = self.num_epochs - self.epoch + for epoch in range(num_epochs): + self.train_epoch(train_data=train_data, dev_data=dev_data) + + def train_epoch(self, train_data, dev_data): + """ + Train an epoch. + """ + raise NotImplementedError + + def infer(self, data_type): + """ + Inference interface. + """ + raise NotImplementedError + + def save(self, is_best=False): + """ save """ + train_state = { + 'epoch': self.epoch, + 'best_valid_metric': self.best_valid_metric, + 'optimizer': self.optimizer.state_dict() + } + if self.lr_scheduler is not None: + train_state['lr_scheduler'] = self.lr_scheduler.state_dict() + + # Save checkpoint + if self.save_checkpoint: + model_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.model') + torch.save(self.model.state_dict(), model_file) + self.logger.info(f"Saved model state to '{model_file}'") + + train_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.train') + torch.save(train_state, train_file) + self.logger.info(f"Saved train state to '{train_file}'") + + # Save current best model + if is_best: + best_model_file = os.path.join(self.save_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + torch.save(self.model.state_dict(), best_model_file) + best_train_file = os.path.join( + self.save_dir, + '{}.train'.format(ModelFile.TORCH_MODEL_BIN_FILE)) + torch.save(train_state, best_train_file) + self.logger.info( + f"Saved best model state to '{best_model_file}' with new best valid metric " + f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' + ) + + def load(self): + """ load """ + + def _load_model_state(): + model_state_dict = torch.load( + f'{self.func_model.init_checkpoint}', + map_location=lambda storage, loc: storage) + + if 'module.' in list(model_state_dict.keys())[0]: + new_model_state_dict = OrderedDict() + for k, v in model_state_dict.items(): + assert k[:7] == 'module.' + new_model_state_dict[k[7:]] = v + model_state_dict = new_model_state_dict + + new_model_state_dict = OrderedDict() + parameters = { + name: param + for name, param in self.func_model.named_parameters() + } + for name, param in model_state_dict.items(): + if name in parameters: + if param.shape != parameters[name].shape: + assert hasattr(param, 'numpy') + arr = param.numpy() + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + if name == 'embedder.token_embedding.weight': + z[-param.shape[0]:] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + else: + if z.shape[0] < param.shape[0]: + z = arr[:z.shape[0]] + print(f'part of parameter({name}) are dropped') + else: + z[:param.shape[0]] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + dtype, device = param.dtype, param.device + z = torch.tensor(z, dtype=dtype, device=device) + new_model_state_dict[name] = z + else: + new_model_state_dict[name] = param + else: + print(f'parameter({name}) are dropped') + model_state_dict = new_model_state_dict + + for name in parameters: + if name not in model_state_dict: + if parameters[name].requires_grad: + print(f'parameter({name}) random normlize initialize') + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + dtype, device = parameters[name].dtype, parameters[ + name].device + model_state_dict[name] = torch.tensor( + z, dtype=dtype, device=device) + else: + model_state_dict[name] = parameters[name] + + self.func_model.load_state_dict(model_state_dict) + self.logger.info( + f"Loaded model state from '{self.func_model.init_checkpoint}'") + + def _load_train_state(): + train_file = f'{self.func_model.init_checkpoint}.train' + if os.path.exists(train_file): + train_state_dict = torch.load( + train_file, map_location=lambda storage, loc: storage) + self.epoch = train_state_dict['epoch'] + self.best_valid_metric = train_state_dict['best_valid_metric'] + if self.optimizer is not None and 'optimizer' in train_state_dict: + self.optimizer.load_state_dict( + train_state_dict['optimizer']) + if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: + self.lr_scheduler.load_state_dict( + train_state_dict['lr_scheduler']) + self.logger.info( + f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " + f'best_valid_metric={self.best_valid_metric:.3f})') + else: + self.logger.info('Loaded no train state') + + if self.func_model.init_checkpoint is None: + self.logger.info('Loaded no model !!!') + return + + if self.do_train: + _load_model_state() + return + + if self.do_infer: + _load_model_state() + _load_train_state() + + +class MultiWOZTrainer(Trainer): + + def __init__(self, + model, + to_tensor, + config, + logger=None, + lr_scheduler=None, + optimizer=None, + reader=None, + evaluator=None): + super(MultiWOZTrainer, + self).__init__(model, to_tensor, config, logger, lr_scheduler, + optimizer, reader, evaluator) + + def train_epoch(self, train_data, dev_data): + """ + Train an epoch. + """ + times = [] + epoch_step = 0 + global_step = 0 + tr_batch_loss = 0.0 + tr_token_loss = 0.0 + self.epoch += 1 + self.batch_metrics_tracker.clear() + self.token_metrics_tracker.clear() + num_training_steps = \ + self.reader.set_stats['train']['num_training_steps_per_epoch'] // \ + self.gradient_accumulation_steps # similar to the original num_batches + + self.model.zero_grad() + data_iterator = self.reader.get_data_iterator(all_batches=train_data) + + for batch_idx, dial_batch in enumerate(data_iterator): + pv_batch = [] + for turn_num, turn_batch in enumerate(dial_batch): + first_turn = (turn_num == 0) + samples, pv_batch = self.reader.convert_batch_turn( + turn_batch, pv_batch, first_turn) + batch, batch_size = self.reader.collate_fn_multi_turn( + samples=samples) + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + + # Do a training iteration + start_time = time.time() + metrics = self.model(batch, is_training=True) + if self.gpu > 1: + for metric in metrics: + if metric is not None: + assert len(metric) == self.gpu + nll, token_nll, token_num = metrics + metrics = {} + + token_num = torch.sum(token_num) + token_nll = \ + torch.sum(nll) * (batch_size / self.gpu) / \ + token_num + nll = torch.mean(nll) + metrics['token_num'] = token_num + metrics['token_nll'] = token_nll + metrics['nll'] = nll + loss = token_nll if self.func_model.token_loss else nll + + metrics['loss'] = loss + else: + loss = metrics['loss'] + self.func_model._optimize( + loss, do_update=False, optimizer=self.optimizer) + metrics = { + k: v.cpu().detach().numpy() + if isinstance(v, torch.Tensor) else v + for k, v in metrics.items() + } + token_num = metrics.pop('token_num', None) + # bow_num = metrics.pop("bow_num", None) + elapsed = time.time() - start_time + times.append(elapsed) + epoch_step += 1 + + tr_batch_loss += metrics['nll'] + tr_token_loss += metrics['token_nll'] + batch_metrics = { + k: v + for k, v in metrics.items() if 'token' not in k + } + token_metrics = { + k: v + for k, v in metrics.items() if 'token' in k + } + self.batch_metrics_tracker.update(batch_metrics, batch_size) + self.token_metrics_tracker.update(token_metrics, token_num) + + if (epoch_step % self.gradient_accumulation_steps == 0) or \ + (epoch_step == self.reader.set_stats['train']['num_training_steps_per_epoch']): + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + global_step += 1 + + if self.log_steps > 0 and global_step % self.log_steps == 0: + batch_metrics_message = self.batch_metrics_tracker.value( + ) + token_metrics_message = self.token_metrics_tracker.value( + ) + message_prefix = f'[Train][{self.epoch}][{global_step}/{num_training_steps}]' + avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' + message = ' '.join([ + message_prefix, batch_metrics_message, + token_metrics_message, avg_time + ]) + self.logger.info(message) + + self.logger.info('-' * 150) + avg_batch_loss = tr_batch_loss / epoch_step + avg_token_loss = tr_token_loss / epoch_step + batch_metrics_message = self.batch_metrics_tracker.summary() + token_metrics_message = self.token_metrics_tracker.summary() + message_prefix = f'[Valid][{self.epoch}]' + message = ' '.join([ + message_prefix, batch_metrics_message, token_metrics_message, + str(avg_batch_loss), + str(avg_token_loss) + ]) + self.logger.info(message) + + cur_valid_metric = self.batch_metrics_tracker.get( + self.valid_metric_name) + if self.is_decreased_valid_metric: + is_best = cur_valid_metric < self.best_valid_metric + else: + is_best = cur_valid_metric > self.best_valid_metric + if is_best: + self.best_valid_metric = cur_valid_metric + self.save(is_best) + self.logger.info('-' * 150) + + return + + def infer(self, data_type='test'): + """ + Inference interface. + """ + self.logger.info('Generation starts ...') + infer_save_file = os.path.join(self.save_dir, + f'infer_{self.epoch}.result.json') + infer_samples_save_file = os.path.join( + self.save_dir, f'infer_samples_{self.epoch}.result.json') + + # Inference + result_collection = {} + begin_time = time.time() + + eval_data = self.reader.get_eval_data(data_type) + set_stats = self.reader.set_stats[data_type] + self.logger.info('***** Running Evaluation *****') + self.logger.info(' Num Turns = %d', set_stats['num_turns']) + + with torch.no_grad(): + pbar = tqdm(eval_data) + for dial_idx, dialog in enumerate(pbar): + pv_turn = {} + for turn_idx, turn in enumerate(dialog): + first_turn = (turn_idx == 0) + inputs, prompt_id = self.reader.convert_turn_eval( + turn, pv_turn, first_turn) + batch, batch_size = self.reader.collate_fn_multi_turn( + samples=[inputs]) + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + if self.reader.use_true_curr_bspn: # generate act, response + max_len = 60 + if not self.reader.use_true_curr_aspn: + max_len = 80 + outputs = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_r_id, + max_gen_len=max_len) + # resp_gen, need to trim previous context + generated = outputs[0].cpu().numpy().tolist() + try: + decoded = self.decode_generated_act_resp(generated) + except ValueError as exception: + self.logger.info(str(exception)) + self.logger.info(self.tokenizer.decode(generated)) + decoded = {'resp': [], 'bspn': [], 'aspn': []} + else: # predict bspn, access db, then generate act and resp + outputs = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_b_id, + max_gen_len=60) + generated_bs = outputs[0].cpu().numpy().tolist() + bspn_gen = self.decode_generated_bspn(generated_bs) + # check DB result + if self.reader.use_true_db_pointer: + db = turn['db'] + else: + db_result = self.reader.bspan_to_DBpointer( + self.tokenizer.decode(bspn_gen), + turn['turn_domain']) + assert len(turn['db']) == 3 + assert isinstance(db_result, str) + db = \ + [self.reader.sos_db_id] + \ + self.tokenizer.convert_tokens_to_ids([db_result]) + \ + [self.reader.eos_db_id] + prompt_id = self.reader.sos_a_id + + prev_input = torch.tensor(bspn_gen + db) + if self.func_model.use_gpu: + prev_input = prev_input.cuda() + outputs_db = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_r_id, + max_gen_len=80, + prev_input=prev_input) + generated_ar = outputs_db[0].cpu().numpy().tolist() + try: + decoded = self.decode_generated_act_resp( + generated_ar) + decoded['bspn'] = bspn_gen + except ValueError as exception: + self.logger.info(str(exception)) + self.logger.info( + self.tokenizer.decode(generated_ar)) + decoded = {'resp': [], 'bspn': [], 'aspn': []} + + turn['resp_gen'] = decoded['resp'] + turn['bspn_gen'] = turn[ + 'bspn'] if self.reader.use_true_curr_bspn else decoded[ + 'bspn'] + turn['aspn_gen'] = turn[ + 'aspn'] if self.reader.use_true_curr_aspn else decoded[ + 'aspn'] + turn['dspn_gen'] = turn['dspn'] + + pv_turn['labels'] = inputs[ + 'labels'] # all true previous context + pv_turn['resp'] = turn[ + 'resp'] if self.reader.use_true_prev_resp else decoded[ + 'resp'] + if not self.reader.use_true_curr_bspn: + pv_turn['bspn'] = turn[ + 'bspn'] if self.reader.use_true_prev_bspn else decoded[ + 'bspn'] + pv_turn['db'] = turn[ + 'db'] if self.reader.use_true_prev_bspn else db + pv_turn['aspn'] = turn[ + 'aspn'] if self.reader.use_true_prev_aspn else decoded[ + 'aspn'] + + tmp_dialog_result = self.reader.inverse_transpose_turn(dialog) + result_collection.update(tmp_dialog_result) + + # compute tmp scores + results, _ = self.reader.wrap_result_lm(tmp_dialog_result) + bleu, success, match = self.evaluator.validation_metric( + results) + score = 0.5 * (success + match) + bleu + pbar.set_description( + 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % + (match, success, bleu, score)) + + # compute scores + results, _ = self.reader.wrap_result_lm(result_collection) + bleu, success, match = self.evaluator.validation_metric(results) + score = 0.5 * (success + match) + bleu + + # log results + metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % \ + (match, success, bleu, score) + message_prefix = f'[Infer][{self.epoch}]' + time_cost = f'TIME-{time.time() - begin_time:.3f}' + message = ' '.join([message_prefix, metrics_message, time_cost]) + self.logger.info(message) + + # save results + eval_results = { + 'bleu': bleu, + 'success': success, + 'match': match, + 'score': score, + 'result': message + } + with open(infer_save_file, 'w') as fp: + json.dump(eval_results, fp, indent=2) + self.logger.info(f'Saved inference results to {infer_save_file}') + with open(infer_samples_save_file, 'w') as fp: + for sample in results: + line = json.dumps(sample) + fp.write(line) + fp.write('\n') + self.logger.info( + f'Saved inference samples to {infer_samples_save_file}') + + return + + def _get_turn_domain(self, old_pv_turn, bspn_gen_ids, first_turn): + + def _get_slots(constraint): + domain_name = '' + slots = {} + for item in constraint: + if item in ontology.placeholder_tokens: + continue + if item in ontology.all_domains_with_bracket: + domain_name = item + slots[domain_name] = set() + else: + assert domain_name in ontology.all_domains_with_bracket + slots[domain_name].add(item) + return slots + + turn_domain = [] + if first_turn and len(bspn_gen_ids) == 0: + turn_domain = ['[general]'] + return turn_domain + + bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids) + turn_slots = _get_slots(bspn_token) + if first_turn: + return list(turn_slots.keys()) + + assert 'bspn' in old_pv_turn + pv_bspn_token = self.tokenizer.convert_ids_to_tokens( + old_pv_turn['bspn'].cpu().numpy().tolist()) + pv_turn_slots = _get_slots(pv_bspn_token) + for domain, value in turn_slots.items(): + pv_value = pv_turn_slots[ + domain] if domain in pv_turn_slots else set() + if len(value - pv_value) > 0 or len(pv_value - value): + turn_domain.append(domain) + if len(turn_domain) == 0: + turn_domain = list(turn_slots.keys()) + + return turn_domain + + def forward(self, first_turn, batch, prompt_id, labels, old_pv_turn): + with torch.no_grad(): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) + pv_turn = {} + + outputs = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_b_id, + max_gen_len=60) + generated_bs = outputs[0].cpu().numpy().tolist() + bspn_gen = self.decode_generated_bspn(generated_bs) + + turn_domain = self._get_turn_domain(old_pv_turn, bspn_gen, + first_turn) + + db_result = self.reader.bspan_to_DBpointer( + self.tokenizer.decode(bspn_gen), turn_domain) + assert isinstance(db_result, str) + db = \ + [self.reader.sos_db_id] + \ + self.tokenizer.convert_tokens_to_ids([db_result]) + \ + [self.reader.eos_db_id] + prompt_id = self.reader.sos_a_id + prev_input = torch.tensor(bspn_gen + db) + if self.func_model.use_gpu: + prev_input = prev_input.cuda() + outputs_db = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_r_id, + max_gen_len=80, + prev_input=prev_input) + generated_ar = outputs_db[0].cpu().numpy().tolist() + decoded = self.decode_generated_act_resp(generated_ar) + decoded['bspn'] = bspn_gen + + pv_turn['labels'] = [ + label.cpu().numpy().tolist() for label in labels + ] + pv_turn['resp'] = decoded['resp'] + pv_turn['bspn'] = decoded['bspn'] + pv_turn['db'] = db + pv_turn['aspn'] = decoded['aspn'] + + return pv_turn diff --git a/modelscope/trainers/nlp/space/trainer/intent_trainer.py b/modelscope/trainers/nlp/space/trainer/intent_trainer.py new file mode 100644 index 00000000..dc6b317b --- /dev/null +++ b/modelscope/trainers/nlp/space/trainer/intent_trainer.py @@ -0,0 +1,705 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import time +from collections import OrderedDict + +import json +import numpy as np +import torch +from tqdm import tqdm +from transformers.optimization import AdamW, get_linear_schedule_with_warmup + +from modelscope.trainers.nlp.space.metrics.metrics_tracker import \ + MetricsTracker +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + + +class Trainer(object): + + def __init__(self, + model, + to_tensor, + config, + reader=None, + logger=None, + lr_scheduler=None, + optimizer=None): + self.model = model + self.to_tensor = to_tensor + self.do_train = config.do_train + self.do_infer = config.do_infer + + self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ + 0] == '-' + self.valid_metric_name = config.Trainer.valid_metric_name[1:] + self.num_epochs = config.Trainer.num_epochs + self.save_dir = config.Trainer.save_dir + self.log_steps = config.Trainer.log_steps + self.valid_steps = config.Trainer.valid_steps + self.save_checkpoint = config.Trainer.save_checkpoint + self.save_summary = config.Trainer.save_summary + self.learning_method = config.Dataset.learning_method + self.weight_decay = config.Model.weight_decay + self.warmup_steps = config.Model.warmup_steps + self.batch_size_label = config.Trainer.batch_size_label + self.batch_size_nolabel = config.Trainer.batch_size_nolabel + self.gpu = config.Trainer.gpu + self.lr = config.Model.lr + + self.model = model + self.func_model = self.model.module if self.gpu > 1 else self.model + self.reader = reader + self.tokenizer = reader.tokenizer + + self.lr_scheduler = lr_scheduler + self.optimizer = optimizer + + self.logger = logger or get_logger() + + self.batch_metrics_tracker_label = MetricsTracker() + self.token_metrics_tracker_label = MetricsTracker() + self.batch_metrics_tracker_nolabel = MetricsTracker() + self.token_metrics_tracker_nolabel = MetricsTracker() + + self.best_valid_metric = float( + 'inf' if self.is_decreased_valid_metric else '-inf') + self.epoch = 0 + self.batch_num = 0 + + def set_optimizers(self, num_training_steps_per_epoch): + """ + Setup the optimizer and the learning rate scheduler. + + from transformers.Trainer + + parameters from cfg: lr (1e-3); warmup_steps + """ + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'norm.weight'] + optimizer_grouped_parameters = [ + { + 'params': [ + p for n, p in self.model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + 'weight_decay': + self.weight_decay, + }, + { + 'params': [ + p for n, p in self.model.named_parameters() + if any(nd in n for nd in no_decay) + ], + 'weight_decay': + 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) + + num_training_steps = num_training_steps_per_epoch * self.num_epochs + num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( + num_training_steps * 0.1) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps) + + # reset optimizer and lr_scheduler + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # log info + self.logger.info( + f'***** Running training: {self.learning_method} *****') + self.logger.info(' Num Epochs = %d', self.num_epochs) + self.logger.info( + ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', + num_training_steps_per_epoch) + self.logger.info(' Batch size for labeled data = %d', + self.batch_size_label) + self.logger.info(' Batch size for unlabeled data = %d', + self.batch_size_nolabel) + self.logger.info(' Total optimization steps = %d', num_training_steps) + self.logger.info(' Total warmup steps = %d', num_warmup_steps) + self.logger.info('************************************') + + def train(self, + train_label_iter, + train_nolabel_iter=None, + valid_label_iter=None, + valid_nolabel_iter=None): + # begin training + num_epochs = self.num_epochs - self.epoch + for epoch in range(num_epochs): + self.train_epoch( + train_label_iter=train_label_iter, + train_nolabel_iter=train_nolabel_iter, + valid_label_iter=valid_label_iter, + valid_nolabel_iter=valid_nolabel_iter) + + def train_epoch(self, train_label_iter, train_nolabel_iter, + valid_label_iter, valid_nolabel_iter): + """ + Train an epoch. + """ + raise NotImplementedError + + def evaluate(self, data_label_iter, data_nolabel_iter, need_save=True): + raise NotImplementedError + + def infer(self, data_iter, num_batches=None): + raise NotImplementedError + + def save(self, is_best=False): + """ save """ + train_state = { + 'epoch': self.epoch, + 'batch_num': self.batch_num, + 'best_valid_metric': self.best_valid_metric, + 'optimizer': self.optimizer.state_dict() + } + if self.lr_scheduler is not None: + train_state['lr_scheduler'] = self.lr_scheduler.state_dict() + + # Save checkpoint + if self.save_checkpoint: + model_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.model') + torch.save(self.model.state_dict(), model_file) + self.logger.info(f"Saved model state to '{model_file}'") + + train_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.train') + torch.save(train_state, train_file) + self.logger.info(f"Saved train state to '{train_file}'") + + # Save current best model + if is_best: + best_model_file = os.path.join(self.save_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + torch.save(self.model.state_dict(), best_model_file) + best_train_file = os.path.join( + self.save_dir, + '{}.train'.format(ModelFile.TORCH_MODEL_BIN_FILE)) + torch.save(train_state, best_train_file) + self.logger.info( + f"Saved best model state to '{best_model_file}' with new best valid metric " + f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' + ) + + def load(self): + """ load """ + + def _load_model_state(): + model_state_dict = torch.load( + f'{self.func_model.init_checkpoint}', + map_location=lambda storage, loc: storage) + + if 'module.' in list(model_state_dict.keys())[0]: + new_model_state_dict = OrderedDict() + for k, v in model_state_dict.items(): + assert k[:7] == 'module.' + new_model_state_dict[k[7:]] = v + model_state_dict = new_model_state_dict + + new_model_state_dict = OrderedDict() + parameters = { + name: param + for name, param in self.func_model.named_parameters() + } + for name, param in model_state_dict.items(): + if name in parameters: + if param.shape != parameters[name].shape: + assert hasattr(param, 'numpy') + arr = param.numpy() + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + if name == 'embedder.token_embedding.weight': + z[-param.shape[0]:] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + else: + if z.shape[0] < param.shape[0]: + z = arr[:z.shape[0]] + print(f'part of parameter({name}) are dropped') + else: + z[:param.shape[0]] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + dtype, device = param.dtype, param.device + z = torch.tensor(z, dtype=dtype, device=device) + new_model_state_dict[name] = z + else: + new_model_state_dict[name] = param + else: + print(f'parameter({name}) are dropped') + model_state_dict = new_model_state_dict + + for name in parameters: + if name not in model_state_dict: + if parameters[name].requires_grad: + print(f'parameter({name}) random normlize initialize') + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + dtype, device = parameters[name].dtype, parameters[ + name].device + model_state_dict[name] = torch.tensor( + z, dtype=dtype, device=device) + else: + model_state_dict[name] = parameters[name] + + self.func_model.load_state_dict(model_state_dict) + self.logger.info( + f"Loaded model state from '{self.func_model.init_checkpoint}.model'" + ) + + def _load_train_state(): + train_file = f'{self.func_model.init_checkpoint}.train' + if os.path.exists(train_file): + train_state_dict = torch.load( + train_file, map_location=lambda storage, loc: storage) + self.epoch = train_state_dict['epoch'] + self.best_valid_metric = train_state_dict['best_valid_metric'] + if self.optimizer is not None and 'optimizer' in train_state_dict: + self.optimizer.load_state_dict( + train_state_dict['optimizer']) + if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: + self.lr_scheduler.load_state_dict( + train_state_dict['lr_scheduler']) + self.logger.info( + f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " + f'best_valid_metric={self.best_valid_metric:.3f})') + else: + self.logger.info('Loaded no train state') + + if self.func_model.init_checkpoint is None: + self.logger.info('Loaded no model !!!') + return + + if self.do_train: + _load_model_state() + return + + if self.do_infer: + _load_model_state() + _load_train_state() + + +class IntentTrainer(Trainer): + + def __init__(self, model, to_tensor, config, reader=None): + super(IntentTrainer, self).__init__(model, to_tensor, config, reader) + self.example = config.Model.example + self.can_norm = config.Trainer.can_norm + + def can_normalization(self, y_pred, y_true, ex_data_iter): + # compute ACC + acc_original = np.mean([y_pred.argmax(1) == y_true]) + message = 'original acc: %s' % acc_original + + # compute uncertainty + k = 3 + y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] + y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) + y_pred_uncertainty =\ + -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k) + + # choose threshold + # print(np.sort(y_pred_uncertainty)[-100:].tolist()) + threshold = 0.7 + y_pred_confident = y_pred[y_pred_uncertainty < threshold] + y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold] + y_true_confident = y_true[y_pred_uncertainty < threshold] + y_true_unconfident = y_true[y_pred_uncertainty >= threshold] + + # compute ACC again for high and low confidence sets + acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \ + if len(y_true_confident) else 0. + acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \ + if len(y_true_unconfident) else 0. + message += ' (%s) confident acc: %s' % (len(y_true_confident), + acc_confident) + message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident), + acc_unconfident) + + # get prior distribution from training set + prior = np.zeros(self.func_model.num_intent) + for _, (batch, batch_size) in ex_data_iter: + for intent_label in batch['intent_label']: + prior[intent_label] += 1. + + prior /= prior.sum() + + # revise each sample from the low confidence set, and compute new ACC + right, alpha, iters = 0, 1, 1 + for i, y in enumerate(y_pred_unconfident): + Y = np.concatenate([y_pred_confident, y[None]], axis=0) + for j in range(iters): + Y = Y**alpha + Y /= Y.mean(axis=0, keepdims=True) + Y *= prior[None] + Y /= Y.sum(axis=1, keepdims=True) + y = Y[-1] + if y.argmax() == y_true_unconfident[i]: + right += 1 + + # get final ACC + acc_final = \ + (acc_confident * len(y_pred_confident) + right) / \ + len(y_pred) + if len(y_pred_unconfident): + message += ' new unconfident acc: %s' % ( + right / len(y_pred_unconfident)) + else: + message += ' no unconfident predictions' + message += ' final acc: %s' % acc_final + return acc_original, acc_final, message + + def train_epoch(self, train_label_iter, train_nolabel_iter, + valid_label_iter, valid_nolabel_iter): + """ + Train an epoch. + """ + times = [] + self.epoch += 1 + self.batch_metrics_tracker_label.clear() + self.token_metrics_tracker_label.clear() + self.batch_metrics_tracker_nolabel.clear() + self.token_metrics_tracker_nolabel.clear() + + num_label_batches = len(train_label_iter) + num_nolabel_batches = len( + train_nolabel_iter) if train_nolabel_iter is not None else 0 + num_batches = max(num_label_batches, num_nolabel_batches) + + train_label_iter_loop = iter(train_label_iter) + train_nolabel_iter_loop = iter( + train_nolabel_iter) if train_nolabel_iter is not None else None + report_for_unlabeled_data = True if train_nolabel_iter is not None else False + + for batch_id in range(1, num_batches + 1): + # Do a training iteration + start_time = time.time() + batch_list, batch_size_list, with_label_list, loss_list, metrics_list = [], [], [], [], [] + data_file_list = [] + + # collect batch for labeled data + try: + data_file_label, ( + batch_label, + batch_size_label) = next(train_label_iter_loop) + except StopIteration: + train_label_iter_loop = iter(train_label_iter) + data_file_label, ( + batch_label, + batch_size_label) = next(train_label_iter_loop) + batch_list.append(batch_label) + batch_size_list.append(batch_size_label) + with_label_list.append(True) + data_file_list.append(data_file_label) + + # collect batch for unlabeled data + if train_nolabel_iter is not None: + try: + data_file_nolabel, ( + batch_nolabel, + batch_size_nolabel) = next(train_nolabel_iter_loop) + except StopIteration: + train_nolabel_iter_loop = iter(train_nolabel_iter) + data_file_nolabel, ( + batch_nolabel, + batch_size_nolabel) = next(train_nolabel_iter_loop) + batch_list.append(batch_nolabel) + batch_size_list.append(batch_size_nolabel) + with_label_list.append(False) + data_file_list.append(data_file_nolabel) + + # forward labeled batch and unlabeled batch and collect outputs, respectively + for (batch, batch_size, with_label, data_file) in \ + zip(batch_list, batch_size_list, with_label_list, data_file_list): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + if self.example and with_label: + current_dataset = train_label_iter.data_file_to_dataset[ + data_file] + example_batch = self.reader.retrieve_examples( + dataset=current_dataset, + labels=batch['intent_label'], + inds=batch['ids'], + task='intent') + example_batch = type(example_batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + example_batch.items())) + for k, v in example_batch.items(): + batch[k] = v + batch['epoch'] = self.epoch + batch['num_steps'] = self.batch_num + metrics = self.model( + batch, + is_training=True, + with_label=with_label, + data_file=data_file) + loss, metrics = self.balance_metrics( + metrics=metrics, batch_size=batch_size) + loss_list.append(loss) + metrics_list.append(metrics) + + # combine loss for labeled data and unlabeled data + # TODO change the computation of combined loss of labeled batch and unlabeled batch + loss = loss_list[0] if len( + loss_list) == 1 else loss_list[0] + loss_list[1] + + # optimization procedure + self.func_model._optimize( + loss, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler) + elapsed = time.time() - start_time + times.append(elapsed) + self.batch_num += 1 + + # track metrics and log temporary message + for (batch_size, metrics, + with_label) in zip(batch_size_list, metrics_list, + with_label_list): + self.track_and_log_message( + metrics=metrics, + batch_id=batch_id, + batch_size=batch_size, + num_batches=num_batches, + times=times, + with_label=with_label) + + # evaluate + if self.valid_steps > 0 and valid_label_iter is not None and valid_nolabel_iter is not None \ + and batch_id % self.valid_steps == 0: + self.evaluate( + data_label_iter=valid_label_iter, + data_nolabel_iter=valid_nolabel_iter) + + # compute accuracy for valid dataset + accuracy = self.infer( + data_iter=valid_label_iter, ex_data_iter=train_label_iter) + + # report summary message and save checkpoints + self.save_and_log_message( + report_for_unlabeled_data, cur_valid_metric=-accuracy) + + def forward(self, batch): + pred = [] + + with torch.no_grad(): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) + result = self.model.infer(inputs=batch) + result = { + name: result[name].cpu().detach().numpy() + for name in result + } + intent_probs = result['intent_probs'] + if self.can_norm: + pred += [intent_probs] + else: + pred += np.argmax(intent_probs, axis=1).tolist() + + return pred + + def infer(self, data_iter, num_batches=None, ex_data_iter=None): + """ + Inference interface. + """ + self.logger.info('Generation starts ...') + infer_save_file = os.path.join(self.save_dir, + f'infer_{self.epoch}.result.json') + + # Inference + batch_cnt = 0 + pred, true = [], [] + outputs, labels = [], [] + begin_time = time.time() + + with torch.no_grad(): + if self.example: + for _, (batch, batch_size) in tqdm( + ex_data_iter, desc='Building train memory.'): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + result = self.model.infer(inputs=batch) + result = { + name: result[name].cpu().detach().numpy() + for name in result + } + outputs.append(torch.from_numpy(result['features'])) + labels += batch['intent_label'].tolist() + + mem = torch.cat(outputs, dim=0) + mem = mem.cuda() if self.func_model.use_gpu else mem + labels = torch.LongTensor(labels).unsqueeze(0) + labels = labels.cuda() if self.func_model.use_gpu else labels + self.logger.info(f'Memory size: {mem.size()}') + + for _, (batch, batch_size) in tqdm(data_iter, total=num_batches): + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + result = self.model.infer(inputs=batch) + result = { + name: result[name].cpu().detach().numpy() + for name in result + } + + if self.example: + features = torch.from_numpy(result['features']) + features = features.cuda( + ) if self.func_model.use_gpu else features + probs = torch.softmax(features.mm(mem.t()), dim=-1) + intent_probs = torch.zeros( + probs.size(0), self.func_model.num_intent) + intent_probs = intent_probs.cuda( + ) if self.func_model.use_gpu else intent_probs + intent_probs = intent_probs.scatter_add( + -1, labels.repeat(probs.size(0), 1), probs) + intent_probs = intent_probs.cpu().detach().numpy() + else: + intent_probs = result['intent_probs'] + + if self.can_norm: + pred += [intent_probs] + true += batch['intent_label'].cpu().detach().tolist() + else: + pred += np.argmax(intent_probs, axis=1).tolist() + true += batch['intent_label'].cpu().detach().tolist() + + batch_cnt += 1 + if batch_cnt == num_batches: + break + + if self.can_norm: + true = np.array(true) + pred = np.concatenate(pred, axis=0) + acc_original, acc_final, message = self.can_normalization( + y_pred=pred, y_true=true, ex_data_iter=ex_data_iter) + accuracy = max(acc_original, acc_final) + infer_results = { + 'accuracy': accuracy, + 'pred_labels': pred.tolist(), + 'message': message + } + metrics_message = f'Accuracy: {accuracy} {message}' + else: + accuracy = sum(p == t for p, t in zip(pred, true)) / len(pred) + infer_results = {'accuracy': accuracy, 'pred_labels': pred} + metrics_message = f'Accuracy: {accuracy}' + + self.logger.info(f'Saved inference results to {infer_save_file}') + with open(infer_save_file, 'w') as fp: + json.dump(infer_results, fp, indent=2) + message_prefix = f'[Infer][{self.epoch}]' + time_cost = f'TIME-{time.time() - begin_time:.3f}' + message = ' '.join([message_prefix, metrics_message, time_cost]) + self.logger.info(message) + return accuracy + + def track_and_log_message(self, metrics, batch_id, batch_size, num_batches, + times, with_label): + # track metrics + batch_metrics_tracker = self.batch_metrics_tracker_label if with_label else self.batch_metrics_tracker_nolabel + token_metrics_tracker = self.token_metrics_tracker_label if with_label else self.token_metrics_tracker_nolabel + + metrics = { + k: v.cpu().detach().numpy() if isinstance(v, torch.Tensor) else v + for k, v in metrics.items() + } + mlm_num = metrics.pop('mlm_num', 0) + + batch_metrics = {k: v for k, v in metrics.items() if 'token' not in k} + token_metrics = {k: v for k, v in metrics.items() if 'token' in k} + batch_metrics_tracker.update(batch_metrics, batch_size) + token_metrics_tracker.update(token_metrics, mlm_num) + + # log message + if self.log_steps > 0 and batch_id % self.log_steps == 0: + batch_metrics_message = batch_metrics_tracker.value() + token_metrics_message = token_metrics_tracker.value() + label_prefix = 'Labeled' if with_label else 'Unlabeled' + message_prefix = f'[Train][{self.epoch}][{batch_id}/{num_batches}][{label_prefix}]' + avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' + message = ' '.join([ + message_prefix, batch_metrics_message, token_metrics_message, + avg_time + ]) + self.logger.info(message) + + def save_and_log_message(self, + report_for_unlabeled_data, + cur_valid_metric=None): + # report message + batch_metrics_message = self.batch_metrics_tracker_label.summary() + token_metrics_message = self.token_metrics_tracker_label.summary() + message_prefix = f'[Valid][{self.epoch}][Labeled]' + message = ' '.join( + [message_prefix, batch_metrics_message, token_metrics_message]) + self.logger.info(message) + if report_for_unlabeled_data: + batch_metrics_message = self.batch_metrics_tracker_nolabel.summary( + ) + token_metrics_message = self.token_metrics_tracker_nolabel.summary( + ) + message_prefix = f'[Valid][{self.epoch}][Unlabeled]' + message = ' '.join( + [message_prefix, batch_metrics_message, token_metrics_message]) + self.logger.info(message) + + # save checkpoints + assert cur_valid_metric is not None + if self.is_decreased_valid_metric: + is_best = cur_valid_metric < self.best_valid_metric + else: + is_best = cur_valid_metric > self.best_valid_metric + if is_best: + self.best_valid_metric = cur_valid_metric + self.save(is_best) + + def balance_metrics(self, metrics, batch_size): + if self.gpu > 1: + for metric in metrics: + if metric is not None: + assert len(metric) == self.gpu + + intent_loss, mlm, token_mlm, mlm_num, kl, con = metrics + metrics = {} + + intent_loss = torch.mean(intent_loss) + metrics['intent_loss'] = intent_loss + loss = intent_loss + + if mlm is not None: + mlm_num = torch.sum(mlm_num) + token_mlm = torch.sum(mlm) * (batch_size / self.gpu) / mlm_num + mlm = torch.mean(mlm) + metrics['mlm_num'] = mlm_num + metrics['token_mlm'] = token_mlm + metrics['mlm'] = mlm + loss = loss + (token_mlm if self.func_model.token_loss else + mlm) * self.func_model.mlm_ratio + + if kl is not None: + kl = torch.mean(kl) + metrics['kl'] = kl + loss = loss + kl * self.func_model.kl_ratio + + if con is not None: + con = torch.mean(con) + metrics['con'] = con + loss = loss + con + + metrics['loss'] = loss + + assert 'loss' in metrics + return metrics['loss'], metrics diff --git a/modelscope/trainers/nlp/text_generation_trainer.py b/modelscope/trainers/nlp/text_generation_trainer.py new file mode 100644 index 00000000..f02faf71 --- /dev/null +++ b/modelscope/trainers/nlp/text_generation_trainer.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from collections.abc import Mapping + +import torch + +from modelscope.metainfo import Trainers +from modelscope.trainers import NlpEpochBasedTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.file_utils import func_receive_dict_inputs + + +@TRAINERS.register_module(module_name=Trainers.text_generation_trainer) +class TextGenerationTrainer(NlpEpochBasedTrainer): + + def _decode(self, tokens): + tokenizer = self.eval_preprocessor.tokenizer + return tokenizer.decode(tokens.tolist(), skip_special_tokens=True) + + def evaluation_step(self, data): + model = self.model.module if self._dist else self.model + model.eval() + + with torch.no_grad(): + if isinstance( + data, + Mapping) and not func_receive_dict_inputs(model.generate): + result = model.generate(**data) + else: + result = model.generate(data) + + result['preds'] = [self._decode(seq) for seq in result['sequences']] + data['tgts'] = [self._decode(seq) for seq in data['labels']] + assert len(result['preds']) == len(data['tgts']) + + return result diff --git a/modelscope/trainers/nlp/text_ranking_trainer.py b/modelscope/trainers/nlp/text_ranking_trainer.py new file mode 100644 index 00000000..610c36b5 --- /dev/null +++ b/modelscope/trainers/nlp/text_ranking_trainer.py @@ -0,0 +1,202 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model, TorchModel +from modelscope.models.nlp import BertForTextRanking +from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.preprocessors.base import Preprocessor +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer +from modelscope.utils.constant import DEFAULT_MODEL_REVISION +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@dataclass +class GroupCollator(): + """ + Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] + and pass batch separately to the actual collator. + Abstract out data detail for the model. + """ + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(features[0], list): + features = sum(features, []) + keys = features[0].keys() + batch = {k: list() for k in keys} + for ele in features: + for k, v in ele.items(): + batch[k].append(v) + batch = {k: torch.cat(v, dim=0) for k, v in batch.items()} + return batch + + +@TRAINERS.register_module(module_name=Trainers.nlp_text_ranking_trainer) +class TextRankingTrainer(NlpEpochBasedTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Callable] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Preprocessor] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + **kwargs): + + if data_collator is None: + data_collator = GroupCollator() + + super().__init__( + model=model, + cfg_file=cfg_file, + cfg_modify_fn=cfg_modify_fn, + arg_parse_fn=arg_parse_fn, + data_collator=data_collator, + preprocessor=preprocessor, + optimizers=optimizers, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + model_revision=model_revision, + **kwargs) + + def compute_mrr(self, result, k=10): + mrr = 0 + for res in result.values(): + sorted_res = sorted(res, key=lambda x: x[0], reverse=True) + ar = 0 + for index, ele in enumerate(sorted_res[:k]): + if str(ele[1]) == '1': + ar = 1.0 / (index + 1) + break + mrr += ar + return mrr / len(result) + + def compute_ndcg(self, result, k=10): + ndcg = 0 + from sklearn import ndcg_score + for res in result.values(): + sorted_res = sorted(res, key=lambda x: [0], reverse=True) + labels = np.array([[ele[1] for ele in sorted_res]]) + scores = np.array([[ele[0] for ele in sorted_res]]) + ndcg += float(ndcg_score(labels, scores, k=k)) + ndcg = ndcg / len(result) + return ndcg + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + """evaluate a dataset + + evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` + does not exist, read from the config file. + + Args: + checkpoint_path (Optional[str], optional): the model path. Defaults to None. + + Returns: + Dict[str, float]: the results about the evaluation + Example: + {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} + """ + # get the raw online dataset + self.eval_dataloader = self._build_dataloader_with_dataset( + self.eval_dataset, + **self.cfg.evaluation.get('dataloader', {}), + collate_fn=self.eval_data_collator) + # generate a standard dataloader + # generate a model + if checkpoint_path is not None: + model = BertForTextRanking.from_pretrained(checkpoint_path) + else: + model = self.model + + # copy from easynlp (start) + model.eval() + total_samples = 0 + + logits_list = list() + label_list = list() + qid_list = list() + + total_spent_time = 0.0 + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + model.to(device) + for _step, batch in enumerate(tqdm(self.eval_dataloader)): + try: + batch = { + key: + val.to(device) if isinstance(val, torch.Tensor) else val + for key, val in batch.items() + } + except RuntimeError: + batch = {key: val for key, val in batch.items()} + + infer_start_time = time.time() + with torch.no_grad(): + label_ids = batch.pop('labels').detach().cpu().numpy() + qids = batch.pop('qid').detach().cpu().numpy() + outputs = model(**batch) + infer_end_time = time.time() + total_spent_time += infer_end_time - infer_start_time + total_samples += self.eval_dataloader.batch_size + + def sigmoid(logits): + return np.exp(logits) / (1 + np.exp(logits)) + + logits = outputs['logits'].squeeze(-1).detach().cpu().numpy() + logits = sigmoid(logits).tolist() + + label_list.extend(label_ids) + logits_list.extend(logits) + qid_list.extend(qids) + + logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( + total_spent_time, total_spent_time * 1000 / total_samples)) + + rank_result = {} + for qid, score, label in zip(qid_list, logits_list, label_list): + if qid not in rank_result: + rank_result[qid] = [] + rank_result[qid].append((score, label)) + + for qid in rank_result: + rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0]) + + eval_outputs = list() + for metric in self.metrics: + if metric.startswith('mrr'): + k = metric.split('@')[-1] + k = int(k) + mrr = self.compute_mrr(rank_result, k=k) + logger.info('{}: {}'.format(metric, mrr)) + eval_outputs.append((metric, mrr)) + elif metric.startswith('ndcg'): + k = metric.split('@')[-1] + k = int(k) + ndcg = self.compute_ndcg(rank_result, k=k) + logger.info('{}: {}'.format(metric, ndcg)) + eval_outputs.append(('ndcg', ndcg)) + else: + raise NotImplementedError('Metric %s not implemented' % metric) + + return dict(eval_outputs) diff --git a/modelscope/trainers/nlp_trainer.py b/modelscope/trainers/nlp_trainer.py new file mode 100644 index 00000000..5ff6f62f --- /dev/null +++ b/modelscope/trainers/nlp_trainer.py @@ -0,0 +1,656 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.metrics.builder import build_metric +from modelscope.models.base import Model, TorchModel +from modelscope.msdatasets import MsDataset +from modelscope.preprocessors import Preprocessor +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModeKeys, + ModelFile) +from modelscope.utils.hub import parse_label_mapping +from .base import TRAINERS +from .trainer import EpochBasedTrainer + + +@dataclass +class NlpTrainerArguments: + """The arguments for the nlp trainer. + + All the arguments listed here have None default values, which means follow the default value in the input + cfg dict. + """ + + work_dir: Optional[str] = field( + default=None, metadata={'help': 'The work dir(key: train.work_dir)'}) + + task: Optional[str] = field( + default=None, metadata={'help': 'The task type(key: task)'}) + + preprocessor_type: Optional[str] = field( + default=None, + metadata={'help': 'The preprocessor type(key: preprocessor.type)'}) + + train_first_sequence: str = field( + default=None, + metadata={ + 'help': + 'The key of first sentence for the training dataset(key:preprocessor.train.' + 'first_sequence/dataset.train.first_sequence)' + }) + + train_second_sequence: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The key of second sentence for the training dataset(key:preprocessor.train.' + 'second_sequence/dataset.train.second_sequence)' + }) + + train_label: str = field( + default=None, + metadata={ + 'help': + 'The key of label for the training dataset(key:preprocessor.train.' + 'second_sequence/dataset.train.second_sequence)' + }) + + eval_first_sequence: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The key of first sentence for the eval dataset(key:preprocessor.val.' + 'first_sequence/dataset.val.first_sequence), ' + 'if not provided, the trainer will use the train_first_sequence for evaluation' + }) + + eval_second_sequence: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The key of second sentence for the eval dataset(key:preprocessor.val.' + 'second_sequence/dataset.val.second_sequence),' + 'if not provided, the trainer will use the train_second_sequence for evaluation' + }) + + eval_label: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The key of label for the eval dataset(key:preprocessor.val.' + 'second_sequence/dataset.val.second_sequence),' + 'if not provided, the trainer will use the train_label for evaluation' + }) + + labels: Optional[List] = field( + default=None, + metadata={ + 'help': + 'The labels list of the dataset(key:dataset.train.labels),' + 'This parameter has the same effect with "label2id"' + }) + + max_epochs: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The max_epochs of the training loop(key: train.max_epochs)' + }) + + train_batch_size_per_gpu: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The train batch size per gpu(key: train.dataloader.batch_size_per_gpu)' + }) + + train_workers_per_gpu: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The number of workers per gpu(key: train.dataloader.workers_per_gpu)' + }) + + train_shuffle: Optional[bool] = field( + default=None, + metadata={ + 'help': + 'Shuffle the train dataset or not(key: train.dataloader.shuffle)' + }) + + eval_batch_size_per_gpu: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The eval batch size per gpu(key: evaluation.dataloader.batch_size_per_gpu)' + }) + + eval_workers_per_gpu: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The number of workers per gpu(key: evaluation.dataloader.workers_per_gpu)' + }) + + eval_shuffle: Optional[bool] = field( + default=None, + metadata={ + 'help': + 'Shuffle the eval dataset or not(key: evaluation.dataloader.shuffle)' + }) + + optimizer_args: Optional[Dict] = field( + default=None, + metadata={'help': 'The optimizer config dict(key: train.optimizer)'}) + + lr_scheduler_args: Optional[Dict] = field( + default=None, + metadata={ + 'help': 'The lr_scheduler config dict(key: train.lr_scheduler)' + }) + + checkpoint_saving_type: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The checkpoint saving type(key: The ckpt hook dict in train.hooks), ' + 'valid options: "BestCkptSaverHook", "CheckpointHook"' + }) + + checkpoint_by_epoch: Optional[bool] = field( + default=None, + metadata={ + 'help': + 'Saving checkpoint by epoch or not(key: The by_epoch key in ' + 'ckpt hook dict in train.hooks)' + }) + + checkpoint_interval: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The checkpoint saving interval(key: The interval key in ' + 'ckpt hook dict in train.hooks)' + }) + + metric_key: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The metric key for the BestCkptSaverHook(key: The metric_key key in ' + 'ckpt hook dict in train.hooks), if the checkpoint_saving_type is "CheckpointHook" or ' + '"None", the metric_key key has no effects' + }) + + evaluation_type: Optional[str] = field( + default=None, + metadata={ + 'help': + 'The evaluation type(key: The evaluation hook dict in train.hooks), ' + 'valid options: "EvaluationHook", "None"' + }) + + evaluation_by_epoch: Optional[bool] = field( + default=None, + metadata={ + 'help': + 'Evaluating by epoch or not(key: The by_epoch key in ' + 'evaluation hook dict in train.hooks)' + }) + + evaluation_interval: Optional[int] = field( + default=None, + metadata={ + 'help': + 'The evaluating interval(key: The interval key in ' + 'evaluation hook dict in train.hooks)' + }) + + metrics: Optional[List[str]] = field( + default=None, + metadata={'help': 'The metrics class keys(key: evaluation.metrics)'}) + + default_train_config = ConfigDict({ + 'work_dir': + '/tmp', + 'max_epochs': + 5, + 'dataloader': { + 'batch_size_per_gpu': 32, + 'workers_per_gpu': 0 + }, + 'optimizer': { + 'type': 'AdamW', + 'lr': 2e-5, + 'options': {} + }, + 'lr_scheduler': { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'total_iters': 10000, + 'options': { + 'by_epoch': False + } + }, + 'hooks': [{ + 'type': 'CheckpointHook', + 'by_epoch': False, + 'interval': 100 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 100 + }] + }) + + def __call__(self, cfg): + """ + + Args: + cfg(`Config`): The cfg to be modified. + + Returns: + The cfg after modification. + """ + + if self.task is not None: + cfg.task = self.task + + if self.preprocessor_type is not None: + if not hasattr(cfg, 'preprocessor'): + cfg.preprocessor = ConfigDict() + cfg.preprocessor.type = self.preprocessor_type + + if self.train_first_sequence is not None or self.train_second_sequence \ + is not None or self.train_label is not None or self.labels is not None: + if not hasattr(cfg, 'dataset'): + cfg.dataset = ConfigDict() + if not hasattr(cfg.dataset, 'train'): + cfg.dataset.train = ConfigDict() + if self.train_first_sequence is not None: + cfg.dataset.train.first_sequence = self.train_first_sequence + if self.train_second_sequence is not None: + cfg.dataset.train.second_sequence = self.train_second_sequence + if self.train_label is not None: + cfg.dataset.train.label = self.train_label + if self.labels is not None: + cfg.dataset.train.labels = self.labels + + if self.eval_first_sequence is not None or self.eval_second_sequence \ + is not None or self.eval_label is not None: + if not hasattr(cfg, 'dataset'): + cfg.dataset = ConfigDict() + if not hasattr(cfg.dataset, 'val'): + cfg.dataset.val = ConfigDict() + if self.eval_first_sequence is not None: + cfg.dataset.val.first_sequence = self.eval_first_sequence + if self.eval_second_sequence is not None: + cfg.dataset.val.second_sequence = self.eval_second_sequence + if self.eval_label is not None: + cfg.dataset.val.label = self.eval_label + + if self.max_epochs is not None or self.train_batch_size_per_gpu is not None \ + or self.train_shuffle is not None or self.optimizer_args is not None \ + or self.work_dir is not None or self.lr_scheduler_args is not None\ + or self.train_workers_per_gpu is not None: + if not hasattr(cfg, 'train'): + cfg.train = deepcopy(self.default_train_config) + if not hasattr(cfg.train, 'dataloader'): + cfg.train.dataloader = deepcopy( + self.default_train_config.dataloader) + if not hasattr(cfg.train, 'optimizer'): + cfg.train.optimizer = deepcopy( + self.default_train_config.optimizer) + if not hasattr(cfg.train, 'lr_scheduler'): + cfg.train.lr_scheduler = deepcopy( + self.default_train_config.lr_scheduler) + if self.work_dir is not None: + cfg.train.work_dir = self.work_dir + if self.max_epochs is not None: + cfg.train.max_epochs = self.max_epochs + if self.train_batch_size_per_gpu is not None: + cfg.train.dataloader.batch_size_per_gpu = self.train_batch_size_per_gpu + if self.train_workers_per_gpu is not None: + cfg.train.dataloader.workers_per_gpu = self.train_workers_per_gpu + if self.train_shuffle is not None: + cfg.train.dataloader.shuffle = self.train_shuffle + if self.optimizer_args is not None: + if cfg.train.optimizer.type != self.optimizer_args.get( + 'type', cfg.train.optimizer.type): + cfg.train.optimizer = ConfigDict( + deepcopy(self.optimizer_args)) + else: + cfg.train.optimizer = Config._merge_a_into_b( + self.optimizer_args, cfg.train.optimizer, force=True) + if self.lr_scheduler_args is not None: + if cfg.train.lr_scheduler.type != self.lr_scheduler_args.get( + 'type', cfg.train.lr_scheduler.type): + cfg.train.lr_scheduler = ConfigDict( + deepcopy(self.lr_scheduler_args)) + else: + cfg.train.lr_scheduler = Config._merge_a_into_b( + self.lr_scheduler_args, + cfg.train.lr_scheduler, + force=True) + + if self.checkpoint_saving_type is not None or self.checkpoint_by_epoch is not None \ + or self.checkpoint_interval is not None or self.metric_key is not None: + if not any([ + self.checkpoint_saving_type == hook['type'] + for hook in cfg.train.hooks + ]): + cfg.train.hooks = list( + filter( + lambda hook: hook['type'] not in + ['CheckpointHook', 'BestCkptSaverHook'], + cfg.train.hooks)) + cfg.train.hooks.append( + deepcopy(self.default_train_config.hooks[0])) + cfg.train.hooks[-1].type = self.checkpoint_saving_type + checkpoint_hook = list( + filter( + lambda hook: hook[ + 'type'] in ['CheckpointHook', 'BestCkptSaverHook'], + cfg.train.hooks))[0] + if self.checkpoint_by_epoch is not None: + checkpoint_hook['by_epoch'] = self.checkpoint_by_epoch + if self.checkpoint_interval is not None: + checkpoint_hook['interval'] = self.checkpoint_interval + if checkpoint_hook['type'] == 'BestCkptSaverHook': + assert self.metric_key is not None, 'The metric_key must be provided ' \ + 'if the ckpt saving hook is "BestCkptSaverHook"' + checkpoint_hook['metric_key'] = self.metric_key + + if self.evaluation_type is not None or self.evaluation_by_epoch is not None \ + or self.evaluation_interval is not None or self.eval_batch_size_per_gpu is not None or \ + self.eval_shuffle is not None or self.metrics is not None: + if self.evaluation_type is not None and not any([ + self.evaluation_type == hook['type'] + for hook in cfg.train.hooks + ]): + cfg.train.hooks = list( + filter(lambda hook: hook['type'] not in ['EvaluationHook'], + cfg.train.hooks)) + if self.evaluation_type != 'None': + cfg.train.hooks.append( + deepcopy(self.default_train_config.hooks[3])) + cfg.train.hooks[-1].type = self.evaluation_type + + evaluation_hook = list( + filter(lambda hook: hook['type'] in ['EvaluationHook'], + cfg.train.hooks)) + evaluation_hook = evaluation_hook[0] if len( + evaluation_hook) > 0 else None + + if evaluation_hook is not None and self.evaluation_by_epoch is not None: + evaluation_hook['by_epoch'] = self.evaluation_by_epoch + if evaluation_hook is not None and self.evaluation_interval is not None: + evaluation_hook['interval'] = self.evaluation_interval + + if not hasattr(cfg, 'evaluation'): + cfg.evaluation = ConfigDict({ + 'dataloader': { + 'batch_size_per_gpu': 32, + 'workers_per_gpu': 0, + 'shuffle': False + } + }) + + if self.metrics is not None: + cfg.evaluation.metrics = self.metrics + if self.eval_batch_size_per_gpu is not None: + cfg.evaluation.dataloader.batch_size_per_gpu = self.eval_batch_size_per_gpu + if self.eval_workers_per_gpu is not None: + cfg.evaluation.dataloader.workers_per_gpu = self.eval_workers_per_gpu + if self.eval_shuffle is not None: + cfg.evaluation.dataloader.shuffle = self.eval_shuffle + + return cfg + + +@TRAINERS.register_module(module_name=Trainers.nlp_base_trainer) +class NlpEpochBasedTrainer(EpochBasedTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Callable] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Preprocessor] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + **kwargs): + """Add code to adapt with nlp models. + + This trainer will accept the information of labels&text keys in the cfg, and then initialize + the nlp models/preprocessors with this information. + + Labels&text key information may be carried in the cfg like this: + + >>> cfg = { + >>> ... + >>> "dataset": { + >>> "train": { + >>> "first_sequence": "text1", + >>> "second_sequence": "text2", + >>> "label": "label", + >>> "labels": [1, 2, 3, 4] + >>> } + >>> } + >>> } + + + Args: + cfg_modify_fn: An input fn which is used to modify the cfg read out of the file. + + Example: + >>> def cfg_modify_fn(cfg): + >>> cfg.preprocessor.first_sequence= 'text1' + >>> cfg.preprocessor.second_sequence='text2' + >>> return cfg + + To view some actual finetune examples, please check the test files listed below: + tests/trainers/test_finetune_sequence_classification.py + tests/trainers/test_finetune_token_classification.py + """ + + if isinstance(model, str): + if os.path.exists(model): + model_dir = model if os.path.isdir(model) else os.path.dirname( + model) + else: + model_dir = snapshot_download(model, revision=model_revision) + if cfg_file is None: + cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) + else: + assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' + model_dir = os.path.dirname(cfg_file) + + self.label2id = None + self.id2label = None + self.num_labels = None + self.cfg_modify_fn = cfg_modify_fn + self.cfg = self.rebuild_config(Config.from_file(cfg_file)) + + try: + labels = self.cfg.dataset.train.labels + self.label2id = {label: idx for idx, label in enumerate(labels)} + self.id2label = {idx: label for idx, label in enumerate(labels)} + self.num_labels = len(labels) + except AttributeError: + label2id = parse_label_mapping(model_dir) + if label2id is not None: + self.label2id = label2id + self.id2label = {id: label for label, id in label2id.items()} + self.num_labels = len(label2id) + + def build_dataset_keys(cfg): + if cfg is not None: + input_keys = { + 'first_sequence': getattr(cfg, 'first_sequence', None), + 'second_sequence': getattr(cfg, 'second_sequence', None), + 'label': getattr(cfg, 'label', None), + } + else: + input_keys = {} + + return {k: v for k, v in input_keys.items() if v is not None} + + self.train_keys = build_dataset_keys( + self.cfg.dataset.train if hasattr(self.cfg, 'dataset') + and hasattr(self.cfg.dataset, 'train') else None) + self.eval_keys = build_dataset_keys( + self.cfg.dataset.val if hasattr(self.cfg, 'dataset') + and hasattr(self.cfg.dataset, 'val') else None) + if len(self.eval_keys) == 0: + self.eval_keys = self.train_keys + + super().__init__( + model=model_dir, + cfg_file=cfg_file, + arg_parse_fn=arg_parse_fn, + data_collator=data_collator, + preprocessor=preprocessor, + optimizers=optimizers, + model_revision=model_revision, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + **kwargs) + + def rebuild_config(self, cfg: Config): + if self.cfg_modify_fn is not None: + cfg = self.cfg_modify_fn(cfg) + if not hasattr(cfg.model, 'label2id') and not hasattr( + cfg.model, 'id2label'): + if self.id2label is not None: + cfg.model['id2label'] = self.id2label + if self.label2id is not None: + cfg.model['label2id'] = self.label2id + return cfg + + def build_model(self) -> Union[nn.Module, TorchModel]: + """ Instantiate a pytorch model and return. + + By default, we will create a model using config from configuration file. You can + override this method in a subclass. + + """ + model_args = {} if self.num_labels is None else { + 'num_labels': self.num_labels + } + model = Model.from_pretrained( + self.model_dir, cfg_dict=self.cfg, **model_args) + if not isinstance(model, nn.Module) and hasattr(model, 'model'): + return model.model + elif isinstance(model, nn.Module): + return model + + def build_preprocessor(self) -> Tuple[Preprocessor, Preprocessor]: + """Build the preprocessor. + + User can override this method to implement custom logits. + + Returns: The preprocessor instance. + + """ + model_args = {} if self.label2id is None else { + 'label2id': self.label2id + } + + train_preprocessor = Preprocessor.from_pretrained( + self.model_dir, + cfg_dict=self.cfg, + preprocessor_mode=ModeKeys.TRAIN, + **model_args, + **self.train_keys, + mode=ModeKeys.TRAIN, + use_fast=True) + eval_preprocessor = Preprocessor.from_pretrained( + self.model_dir, + cfg_dict=self.cfg, + preprocessor_mode=ModeKeys.EVAL, + **model_args, + **self.eval_keys, + mode=ModeKeys.EVAL, + use_fast=True) + return train_preprocessor, eval_preprocessor + + +@TRAINERS.register_module(module_name=Trainers.nlp_veco_trainer) +class VecoTrainer(NlpEpochBasedTrainer): + + def evaluate(self, checkpoint_path=None): + """Veco evaluates the datasets one by one. + + """ + from modelscope.msdatasets.task_datasets import VecoDataset + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + CheckpointHook.load_checkpoint(checkpoint_path, self) + self.model.eval() + self._mode = ModeKeys.EVAL + metric_values = {} + + if self.eval_dataset is None: + val_data = self.cfg.dataset.val + self.eval_dataset = self.build_dataset( + val_data, mode=ModeKeys.EVAL) + + idx = 0 + dataset_cnt = 1 + if isinstance(self.eval_dataset, VecoDataset): + self.eval_dataset.switch_dataset(idx) + dataset_cnt = len(self.eval_dataset.datasets) + + while True: + self.eval_dataloader = self._build_dataloader_with_dataset( + self.eval_dataset, **self.cfg.evaluation.get('dataloader', {})) + self.data_loader = self.eval_dataloader + + metric_classes = [build_metric(metric) for metric in self.metrics] + for m in metric_classes: + m.trainer = self + self.evaluation_loop(self.eval_dataloader, metric_classes) + + for m_idx, metric_cls in enumerate(metric_classes): + if f'eval_dataset[{idx}]' not in metric_values: + metric_values[f'eval_dataset[{idx}]'] = {} + metric_values[f'eval_dataset[{idx}]'][ + self.metrics[m_idx]] = metric_cls.evaluate() + + idx += 1 + if idx < dataset_cnt: + self.eval_dataset.switch_dataset(idx) + else: + break + + for metric_name in self.metrics: + all_metrics = [m[metric_name] for m in metric_values.values()] + for key in all_metrics[0].keys(): + metric_values[key] = np.average( + [metric[key] for metric in all_metrics]) + + return metric_values diff --git a/modelscope/trainers/optimizer/__init__.py b/modelscope/trainers/optimizer/__init__.py new file mode 100644 index 00000000..9962c2c2 --- /dev/null +++ b/modelscope/trainers/optimizer/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .builder import OPTIMIZERS, build_optimizer +from .child_tuning_adamw_optimizer import ChildTuningAdamW + +__all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW'] diff --git a/modelscope/trainers/optimizer/builder.py b/modelscope/trainers/optimizer/builder.py new file mode 100644 index 00000000..f43768d6 --- /dev/null +++ b/modelscope/trainers/optimizer/builder.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import inspect + +import torch + +from modelscope.utils.config import ConfigDict +from modelscope.utils.registry import Registry, build_from_cfg, default_group + +OPTIMIZERS = Registry('optimizer') + + +def build_optimizer(model: torch.nn.Module, + cfg: ConfigDict, + default_args: dict = None): + """ build optimizer from optimizer config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for optimizer object. + default_args (dict, optional): Default initialization arguments. + """ + if hasattr(model, 'module'): + model = model.module + + if default_args is None: + default_args = {} + default_args['params'] = model.parameters() + + return build_from_cfg( + cfg, OPTIMIZERS, group_key=default_group, default_args=default_args) + + +def register_torch_optimizers(): + for name, module in inspect.getmembers(torch.optim): + if name.startswith('__'): + continue + if inspect.isclass(module) and issubclass(module, + torch.optim.Optimizer): + OPTIMIZERS.register_module( + default_group, module_name=name, module_cls=module) + + +register_torch_optimizers() diff --git a/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py b/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py new file mode 100644 index 00000000..d004071f --- /dev/null +++ b/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py @@ -0,0 +1,188 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import types +from typing import Callable, Iterable, Tuple + +import numpy as np +import torch +from torch.distributions.bernoulli import Bernoulli +from torch.optim import Optimizer + +from modelscope.utils.logger import get_logger +from .builder import OPTIMIZERS, default_group + +logger = get_logger(__name__) + +__all__ = ['calculate_fisher', 'ChildTuningAdamW'] + + +def calculate_fisher(model: torch.nn.Module, + data_loader, + forward_step, + reserve_p, + grad_clip=None): + + gradient_mask = dict() + model.train() + for name, params in model.named_parameters(): + if 'layer' in name: + gradient_mask[params] = params.new_zeros(params.size()) + + iters = len(data_loader) + for inputs in data_loader: + loss = forward_step(model, inputs) + loss.backward() + for name, params in model.named_parameters(): + if 'layer' in name: + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_(params, **grad_clip) + gradient_mask[params] += (params.grad**2) / iters + model.zero_grad() + + logger.info('Calculate Fisher Information...') + + # Numpy + r = None + for k, v in gradient_mask.items(): + v = v.view(-1).cpu().numpy() + if r is None: + r = v + else: + r = np.append(r, v) + polar = np.percentile(r, (1 - reserve_p) * 100) + for k in gradient_mask: + gradient_mask[k] = gradient_mask[k] >= polar + print('Polar => {}'.format(polar)) + + # TODO: pytorch: torch.kthvalue + + return gradient_mask + + +@OPTIMIZERS.register_module( + group_key=default_group, module_name='ChildTuningAdamW') +class ChildTuningAdamW(Optimizer): + + def __init__(self, + params: Iterable[torch.nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + reserve_p=1.0, + mode=None): + if lr < 0.0: + raise ValueError( + 'Invalid learning rate: {} - should be >= 0.0'.format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + 'Invalid beta parameter: {} - should be in [0.0, 1.0['.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + 'Invalid beta parameter: {} - should be in [0.0, 1.0['.format( + betas[1])) + if not 0.0 <= eps: + raise ValueError( + 'Invalid epsilon value: {} - should be >= 0.0'.format(eps)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias) + super().__init__(params, defaults) + + self.gradient_mask = None + self.reserve_p = reserve_p + self.mode = mode + + def set_gradient_mask(self, gradient_mask): + self.gradient_mask = gradient_mask + + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + Arguments: + closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + 'Adam does not support sparse gradients, please consider SparseAdam instead' + ) + + # ChildTuning code + if self.mode is not None: + if self.mode == 'ChildTuning-D': + if p in self.gradient_mask: + grad *= self.gradient_mask[p] + else: + # ChildTuning-F + grad_mask = Bernoulli( + grad.new_full( + size=grad.size(), fill_value=self.reserve_p)) + grad *= grad_mask.sample() / self.reserve_p + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group['eps']) + + step_size = group['lr'] + if group['correct_bias']: # No bias correction for Bert + bias_correction1 = 1.0 - beta1**state['step'] + bias_correction2 = 1.0 - beta2**state['step'] + step_size = step_size * math.sqrt( + bias_correction2) / bias_correction1 + + p.data.addcdiv_(exp_avg, denom, value=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay']) + + return loss diff --git a/modelscope/trainers/parallel/__init__.py b/modelscope/trainers/parallel/__init__.py new file mode 100644 index 00000000..3d71a75b --- /dev/null +++ b/modelscope/trainers/parallel/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .builder import PARALLEL diff --git a/modelscope/trainers/parallel/builder.py b/modelscope/trainers/parallel/builder.py new file mode 100644 index 00000000..56e05a2b --- /dev/null +++ b/modelscope/trainers/parallel/builder.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from torch.nn.parallel.distributed import DistributedDataParallel + +from modelscope.utils.config import ConfigDict +from modelscope.utils.registry import Registry, build_from_cfg + +PARALLEL = Registry('parallel') +PARALLEL.register_module( + module_name='DistributedDataParallel', module_cls=DistributedDataParallel) + + +def build_parallel(cfg: ConfigDict, default_args: dict = None): + """ build parallel + + Args: + cfg (:obj:`ConfigDict`): config dict for parallel object. + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg(cfg, PARALLEL, default_args=default_args) diff --git a/modelscope/trainers/parallel/utils.py b/modelscope/trainers/parallel/utils.py new file mode 100644 index 00000000..a80b43b7 --- /dev/null +++ b/modelscope/trainers/parallel/utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .builder import PARALLEL + + +def is_parallel(module): + """Check if a module is wrapped by parallel object. + + The following modules are regarded as parallel object: + - torch.nn.parallel.DataParallel + - torch.nn.parallel.distributed.DistributedDataParallel + You may add you own parallel object by registering it to `modelscope.parallel.PARALLEL`. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: True if the is wrapped by parallel object. + """ + module_wrappers = [] + for group, module_dict in PARALLEL.modules.items(): + module_wrappers.extend(list(module_dict.values())) + + return isinstance(module, tuple(module_wrappers)) diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py new file mode 100644 index 00000000..12c25f30 --- /dev/null +++ b/modelscope/trainers/trainer.py @@ -0,0 +1,1007 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import time +from collections.abc import Mapping +from distutils.version import LooseVersion +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union + +import json +import torch +from torch import distributed as dist +from torch import nn +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.dataloader import default_collate +from torch.utils.data.distributed import DistributedSampler + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.hub.utils.utils import create_library_statistics +from modelscope.metainfo import Trainers +from modelscope.metrics import build_metric, task_default_metrics +from modelscope.models.base import Model, TorchModel +from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.msdatasets.task_datasets.builder import build_task_dataset +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.outputs import ModelOutputBase +from modelscope.preprocessors.base import Preprocessor +from modelscope.trainers.hooks.builder import HOOKS +from modelscope.trainers.hooks.priority import Priority, get_priority +from modelscope.trainers.lrscheduler.builder import build_lr_scheduler +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, + ConfigKeys, ModeKeys, ModelFile, + TrainerStages) +from modelscope.utils.data_utils import to_device +from modelscope.utils.device import create_device +from modelscope.utils.file_utils import func_receive_dict_inputs +from modelscope.utils.logger import get_logger +from modelscope.utils.registry import build_from_cfg +from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, + init_dist, set_random_seed) +from .base import BaseTrainer +from .builder import TRAINERS +from .default_config import merge_cfg +from .hooks.hook import Hook +from .parallel.builder import build_parallel +from .parallel.utils import is_parallel + + +@TRAINERS.register_module(module_name=Trainers.default) +class EpochBasedTrainer(BaseTrainer): + """Epoch based Trainer, a training helper for PyTorch. + + Args: + cfg_file(str): The local config file. + model (:obj:`torch.nn.Module` or :obj:`TorchModel` or `str`): The model to be run, or a valid model dir + or a model id. If model is None, build_model method will be called. + data_collator (`Callable`, *optional*): + The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. + train_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): + The dataset to use for training. + + Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a + distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a + `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will + manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally + sets the seed of the RNGs used. + eval_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): The dataset to use for evaluation. + preprocessor (:obj:`Preprocessor`, *optional*): The optional preprocessor. + NOTE: If the preprocessor has been called before the dataset fed into this trainer by user's custom code, + this parameter should be None, meanwhile remove the 'preprocessor' key from the cfg_file. + Else the preprocessor will be instantiated from the cfg_file or assigned from this parameter and + this preprocessing action will be executed every time the dataset's __getitem__ is called. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple + containing the optimizer and the scheduler to use. + seed (int): The optional random seed for torch, cuda, numpy and random. + max_epochs: (int, optional): Total training epochs. + """ + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Union[Callable, Dict[str, + Callable]]] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Union[Preprocessor, + Dict[str, Preprocessor]]] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + seed: int = 42, + **kwargs): + + self._seed = seed + set_random_seed(self._seed) + if isinstance(model, str): + if os.path.exists(model): + self.model_dir = model if os.path.isdir( + model) else os.path.dirname(model) + else: + self.model_dir = snapshot_download( + model, revision=model_revision) + if cfg_file is None: + cfg_file = os.path.join(self.model_dir, + ModelFile.CONFIGURATION) + else: + assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' + self.model_dir = os.path.dirname(cfg_file) + + super().__init__(cfg_file, arg_parse_fn) + + # add default config + merge_cfg(self.cfg) + self.cfg = self.rebuild_config(self.cfg) + + if 'cfg_options' in kwargs: + self.cfg.merge_from_dict(kwargs['cfg_options']) + + if isinstance(model, (TorchModel, nn.Module)): + self.model = model + else: + self.model = self.build_model() + + if 'work_dir' in kwargs: + self.work_dir = kwargs['work_dir'] + else: + self.work_dir = self.cfg.train.get('work_dir', './work_dir') + + self.train_preprocessor, self.eval_preprocessor = None, None + if isinstance(preprocessor, Preprocessor): + self.train_preprocessor = preprocessor + self.eval_preprocessor = preprocessor + elif isinstance(preprocessor, Mapping): + if not (ConfigKeys.train in preprocessor + or ConfigKeys.val in preprocessor): + raise ValueError( + f'Preprocessor must split with `{ConfigKeys.train}` and `{ConfigKeys.val}` keys!' + ) + if ConfigKeys.train in preprocessor: + assert isinstance(preprocessor[ConfigKeys.train], Preprocessor) + self.train_preprocessor = preprocessor[ConfigKeys.train] + if ConfigKeys.val in preprocessor: + assert isinstance(preprocessor[ConfigKeys.val], Preprocessor) + self.eval_preprocessor = preprocessor[ConfigKeys.val] + elif hasattr(self.cfg, ConfigFields.preprocessor + ) and self.cfg.preprocessor is not None: + self.train_preprocessor, self.eval_preprocessor = self.build_preprocessor( + ) + + if self.train_preprocessor is not None: + self.train_preprocessor.mode = ModeKeys.TRAIN + if self.eval_preprocessor is not None: + self.eval_preprocessor.mode = ModeKeys.EVAL + + if kwargs.get('launcher', None) is not None: + init_dist(kwargs['launcher']) + + _, world_size = get_dist_info() + self._dist = world_size > 1 + + device_name = kwargs.get('device', 'gpu') + if self._dist: + local_rank = get_local_rank() + device_name = f'cuda:{local_rank}' + + self.device = create_device(device_name) + self.train_dataset = self.to_task_dataset( + train_dataset, + mode=ModeKeys.TRAIN, + task_data_config=self.cfg.dataset.get('train', None) if hasattr( + self.cfg, 'dataset') else None, + preprocessor=self.train_preprocessor, + **kwargs) + self.eval_dataset = self.to_task_dataset( + eval_dataset, + mode=ModeKeys.EVAL, + task_data_config=self.cfg.dataset.get('val', None) if hasattr( + self.cfg, 'dataset') else None, + preprocessor=self.eval_preprocessor, + **kwargs) + + self.train_data_collator, self.eval_data_collator = None, None + if isinstance(data_collator, Mapping): + if not (ConfigKeys.train in data_collator + or ConfigKeys.val in data_collator): + raise ValueError( + f'data_collator must split with `{ConfigKeys.train}` and `{ConfigKeys.val}` keys!' + ) + if ConfigKeys.train in data_collator: + assert isinstance(data_collator[ConfigKeys.train], Callable) + self.train_data_collator = data_collator[ConfigKeys.train] + if ConfigKeys.val in data_collator: + assert isinstance(data_collator[ConfigKeys.val], Callable) + self.eval_data_collator = data_collator[ConfigKeys.val] + else: + collate_fn = default_collate if data_collator is None else data_collator + self.train_data_collator = collate_fn + self.eval_data_collator = collate_fn + + self.metrics = self.get_metrics() + self._metric_values = None + self.optimizers = optimizers + self.logger = get_logger(log_level=self.cfg.get('log_level', 'INFO')) + self._mode = ModeKeys.TRAIN + self._hooks: List[Hook] = [] + self._epoch = 0 + self._iter = 0 + self._inner_iter = 0 + if 'max_epochs' not in kwargs: + assert hasattr( + self.cfg.train, + 'max_epochs'), 'max_epochs is missing in configuration file' + self._max_epochs = self.cfg.train.max_epochs + else: + self._max_epochs = kwargs['max_epochs'] + self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) + self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) + if self._train_iters_per_epoch is None and hasattr( + self.cfg.train, 'train_iters_per_epoch'): + self._train_iters_per_epoch = self.cfg.train.train_iters_per_epoch + if self._eval_iters_per_epoch is None and hasattr( + self.cfg, 'evaluation') and hasattr(self.cfg.evaluation, + 'val_iters_per_epoch'): + self._eval_iters_per_epoch = self.cfg.evaluation.val_iters_per_epoch + + self.use_fp16 = kwargs.get('use_fp16', False) + + # model placement + if self.device.type == 'cuda': + self.model.to(self.device) + if not is_parallel(self.model) and self._dist: + self.model = self.to_parallel(self.model) + + def rebuild_config(self, cfg: Config): + """A method used to rebuild the config, any subclass can override this method. + + Returns: The rebuilt config + + """ + return cfg + + @property + def mode(self): + return self._mode + + @property + def hooks(self) -> List[Hook]: + """list[:obj:`Hook`]: A list of registered hooks.""" + return self._hooks + + @property + def epoch(self) -> int: + """int: Current epoch.""" + return self._epoch + + @property + def iter(self) -> int: + """int: Current iteration.""" + return self._iter + + @property + def inner_iter(self) -> int: + """int: Iteration in an epoch.""" + return self._inner_iter + + @property + def max_epochs(self): + """int: Maximum training epochs.""" + return self._max_epochs + + @property + def max_iters(self): + """int: Maximum training iterations.""" + return self._max_epochs * self.iters_per_epoch + + @property + def iters_per_epoch(self): + """int: Total iterations of one epoch""" + + def _get_data_len(data_loader): + try: + return len(data_loader) + except Exception as e: + self.logger.error(e) + raise ValueError( + 'Please implement ``__len__`` method for your dataset, ' + 'or add `train_iters_per_epoch` and `train_iters_per_epoch` ' + 'to your configuration file or kwargs') + + if self.mode == ModeKeys.TRAIN: + if self._train_iters_per_epoch is not None: + return self._train_iters_per_epoch + else: + return _get_data_len(self.train_dataloader) + elif self.mode == ModeKeys.EVAL: + if self._eval_iters_per_epoch is not None: + return self._eval_iters_per_epoch + else: + return _get_data_len(self.eval_dataloader) + + def to_task_dataset(self, + datasets: Union[Dataset, List[Dataset]], + mode: str, + task_data_config: Config = None, + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Build the task specific dataset processor for this trainer. + + Returns: The task dataset processor for the task. If no result for the very model-type and task, + the default TaskDataset will be returned. + """ + try: + to_tensor = kwargs.get('to_tensor', True) + if not datasets: + return datasets + if isinstance(datasets, TorchTaskDataset): + return datasets + elif isinstance(datasets, MsDataset): + if task_data_config is None: + # adapt to some special models + task_data_config = ConfigDict( + type=self.cfg.model.type) if hasattr( + self.cfg, ConfigFields.model) else ConfigDict( + type=None) + task_data_config.update(dict(mode=mode)) + return datasets.to_torch_dataset( + task_data_config=task_data_config, + task_name=self.cfg.task, + preprocessors=preprocessor, + to_tensor=to_tensor) + elif isinstance(datasets, List) and isinstance( + datasets[0], MsDataset): + if task_data_config is None: + # adapt to some special models + task_data_config = ConfigDict( + type=self.cfg.model.type) if hasattr( + self.cfg, ConfigFields.model) else ConfigDict( + type=None) + task_data_config.update(dict(mode=mode)) + datasets = [ + d.to_torch_dataset( + task_data_config=task_data_config, + task_name=self.cfg.task, + preprocessors=preprocessor, + to_tensor=to_tensor) for d in datasets + ] + cfg = ConfigDict( + type=self.cfg.model.type, mode=mode, datasets=datasets) + task_dataset = build_task_dataset(cfg, self.cfg.task) + task_dataset.trainer = self + return task_dataset + else: + if task_data_config is None: + # adapt to some special models + task_data_config = {} + # avoid add no str value datasets, preprocessors in cfg + task_data_build_config = ConfigDict( + type=self.cfg.model.type, + mode=mode, + datasets=datasets, + preprocessor=preprocessor) + task_data_build_config.update(task_data_config) + task_dataset = build_task_dataset(task_data_build_config, + self.cfg.task) + task_dataset.trainer = self + return task_dataset + except Exception: + if isinstance(datasets, (List, Tuple)) or preprocessor is not None: + task_dataset = TorchTaskDataset( + datasets, + mode=mode, + preprocessor=preprocessor, + **(dict(type=self.cfg.model.type) if hasattr( + self.cfg, 'model') else {})) + task_dataset.trainer = self + return task_dataset + else: + return datasets + + def build_preprocessor(self) -> Tuple[Preprocessor, Preprocessor]: + """Build train and eval preprocessor. + + User can override this method to implement custom logits. + + Returns: The train preprocessor and eval preprocessor instance. + + """ + train_preprocessor = Preprocessor.from_pretrained( + self.model_dir, + cfg_dict=self.cfg, + preprocessor_mode=ModeKeys.TRAIN) + eval_preprocessor = Preprocessor.from_pretrained( + self.model_dir, cfg_dict=self.cfg, preprocessor_mode=ModeKeys.EVAL) + return train_preprocessor, eval_preprocessor + + def get_metrics(self) -> List[Union[str, Dict]]: + """Get the metric class types. + + The first choice will be the metrics configured in the config file, if not found, the default metrics will be + used. + If no metrics is found and the eval dataset exists, the method will raise an error. + + Returns: The metric types. + + """ + metrics = self.cfg.evaluation.metrics if hasattr( + self.cfg, 'evaluation') and hasattr(self.cfg.evaluation, + 'metrics') else None + metrics = metrics if metrics is not None else task_default_metrics.get( + self.cfg.task) + if metrics is None and self.eval_dataset is not None: + raise ValueError( + f'Metrics are needed in evaluation, please try to either ' + f'add metrics in configuration.json or add the default metric for {self.cfg.task}.' + ) + if isinstance(metrics, (str, Mapping)): + metrics = [metrics] + return metrics + + def set_checkpoint_file_to_hook(self, checkpoint_path): + if checkpoint_path is not None: + if os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + checkpoint_hooks = list( + filter(lambda hook: isinstance(hook, CheckpointHook), + self.hooks)) + for hook in checkpoint_hooks: + hook.checkpoint_file = checkpoint_path + else: + self.logger.error( + f'No {checkpoint_path} found in local file system.') + + def train(self, checkpoint_path=None, *args, **kwargs): + self._mode = ModeKeys.TRAIN + if hasattr(self.model, 'name'): + create_library_statistics('train', self.model.name, None) + + if self.train_dataset is None: + self.train_dataloader = self.get_train_dataloader() + else: + self.train_dataloader = self._build_dataloader_with_dataset( + self.train_dataset, + dist=self._dist, + seed=self._seed, + collate_fn=self.train_data_collator, + **self.cfg.train.get('dataloader', {})) + self.data_loader = self.train_dataloader + + self.register_optimizers_hook() + self.register_hook_from_cfg(self.cfg.train.hooks) + self.set_checkpoint_file_to_hook(checkpoint_path) + self.model.train() + + self.train_loop(self.train_dataloader) + + def evaluate(self, checkpoint_path=None): + if hasattr(self.model, 'name'): + create_library_statistics('evaluate', self.model.name, None) + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + CheckpointHook.load_checkpoint(checkpoint_path, self) + self.model.eval() + self._mode = ModeKeys.EVAL + if self.eval_dataset is None: + self.eval_dataloader = self.get_eval_data_loader() + else: + self.eval_dataloader = self._build_dataloader_with_dataset( + self.eval_dataset, + dist=self._dist, + seed=self._seed, + collate_fn=self.eval_data_collator, + **self.cfg.evaluation.get('dataloader', {})) + self.data_loader = self.eval_dataloader + metric_classes = [build_metric(metric) for metric in self.metrics] + for m in metric_classes: + m.trainer = self + + metric_values = self.evaluation_loop(self.eval_dataloader, + metric_classes) + + self._metric_values = metric_values + return metric_values + + @property + def metric_values(self): + return self._metric_values + + def build_model(self) -> Union[nn.Module, TorchModel]: + """ Instantiate a pytorch model and return. + + By default, we will create a model using config from configuration file. You can + override this method in a subclass. + + """ + model = Model.from_pretrained(self.model_dir, cfg_dict=self.cfg) + if not isinstance(model, nn.Module) and hasattr(model, 'model'): + return model.model + elif isinstance(model, nn.Module): + return model + + def to_parallel(self, model) -> Union[nn.Module, TorchModel]: + # config format to reserve custom ddp + if self.cfg.get('parallel', None) is not None: + self.cfg.parallel.update( + dict(module=model, device_ids=[torch.cuda.current_device()])) + return build_parallel(self.cfg.parallel) + + dp_cfg = dict( + type='DistributedDataParallel', + module=model, + find_unused_parameters=True, + device_ids=[torch.cuda.current_device()]) + + return build_parallel(dp_cfg) + + def train_step(self, model, inputs): + """ Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`TorchModel`): The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + # EvaluationHook will do evaluate and change mode to val, return to train mode + # TODO: find more pretty way to change mode + model.train() + self._mode = ModeKeys.TRAIN + # call model forward but not __call__ to skip postprocess + + if is_parallel(model): + receive_dict_inputs = func_receive_dict_inputs( + model.module.forward) + else: + receive_dict_inputs = func_receive_dict_inputs(model.forward) + + if isinstance(inputs, Mapping) and not receive_dict_inputs: + train_outputs = model.forward(**inputs) + else: + train_outputs = model.forward(inputs) + + if isinstance(train_outputs, ModelOutputBase): + train_outputs = train_outputs.to_dict() + if not isinstance(train_outputs, dict): + raise TypeError('"model.forward()" must return a dict') + + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone().to('cuda') + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + + self.train_outputs = train_outputs + + def prediction_step(self, model, inputs): + """ Perform forward step by `model` using `inputs`. + + Args: + model (`TorchModel`): The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`Lst[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + raise NotImplementedError + + def get_train_dataloader(self): + """ Builder torch dataloader for training. + + We provide a reasonable default that works well. If you want to use something else, you can change + the config for data.train in configuration file, or subclass and override this method + (or `get_train_dataloader` in a subclass. + """ + if self.train_dataset is None: + train_data = self.cfg.dataset.train + self.train_dataset = self.build_dataset( + train_data, + mode=ModeKeys.TRAIN, + preprocessor=self.train_preprocessor) + + data_loader = self._build_dataloader_with_dataset( + self.train_dataset, + dist=self._dist, + seed=self._seed, + collate_fn=self.train_data_collator, + **self.cfg.train.get('dataloader', {})) + return data_loader + + def get_eval_data_loader(self): + """ Builder torch dataloader for evaluation. + + We provide a reasonable default that works well. If you want to use something else, you can change + the config for dataset.eval in configuration file, or subclass and override this method in a subclass. + pass + """ + if self.eval_dataset is None: + val_data = self.cfg.dataset.val + self.eval_dataset = self.build_dataset( + val_data, + mode=ModeKeys.EVAL, + preprocessor=self.eval_preprocessor) + + batch_size = self.cfg.evaluation.dataloader.batch_size_per_gpu + workers = self.cfg.evaluation.dataloader.workers_per_gpu + shuffle = self.cfg.evaluation.dataloader.get('shuffle', False) + data_loader = self._build_dataloader_with_dataset( + self.eval_dataset, + batch_size_per_gpu=batch_size, + workers_per_gpu=workers, + shuffle=shuffle, + dist=self._dist, + seed=self._seed, + persistent_workers=True, + collate_fn=self.eval_data_collator, + ) + return data_loader + + def build_dataset(self, data_cfg, mode, preprocessor=None): + """ Build torch dataset object using data config + """ + # TODO: support MsDataset load for cv + if hasattr(data_cfg, 'name'): + dataset_name = data_cfg.pop('name') + dataset = MsDataset.load( + dataset_name=dataset_name, + **data_cfg, + ) + cfg = ConfigDict(type=self.cfg.model.type, mode=mode) + torch_dataset = dataset.to_torch_dataset( + task_data_config=cfg, + task_name=self.cfg.task, + preprocessors=preprocessor) + else: + torch_dataset = build_task_dataset(data_cfg, self.cfg.task) + dataset = self.to_task_dataset(torch_dataset, mode) + return dataset + + def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): + try: + return build_optimizer( + self.model, cfg=cfg, default_args=default_args) + except KeyError as e: + self.logger.error( + f'Build optimizer error, the optimizer {cfg} is a torch native component, ' + f'please check if your torch with version: {torch.__version__} matches the config.' + ) + raise e + + def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): + try: + return build_lr_scheduler(cfg=cfg, default_args=default_args) + except KeyError as e: + self.logger.error( + f'Build lr_scheduler error, the lr_scheduler {cfg} is a torch native component, ' + f'please check if your torch with version: {torch.__version__} matches the config.' + ) + raise e + + def create_optimizer_and_scheduler(self): + """ Create optimizer and lr scheduler + + We provide a default implementation, if you want to customize your own optimizer + and lr scheduler, you can either pass a tuple through trainer init function or + subclass this class and override this method. + """ + optimizer, lr_scheduler = self.optimizers + if optimizer is None: + optimizer_cfg = self.cfg.train.get('optimizer', None) + else: + optimizer_cfg = None + + optim_options = {} + if optimizer_cfg is not None: + optim_options = optimizer_cfg.pop('options', {}) + optimizer = self.build_optimizer(cfg=optimizer_cfg) + + if lr_scheduler is None: + lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) + else: + lr_scheduler_cfg = None + + lr_options = {} + if lr_scheduler_cfg is not None: + assert optimizer is not None + lr_options = lr_scheduler_cfg.pop('options', {}) + lr_scheduler = self.build_lr_scheduler( + cfg=lr_scheduler_cfg, default_args={'optimizer': optimizer}) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + return self.optimizer, self.lr_scheduler, optim_options, lr_options + + def register_optimizers_hook(self): + """ Register optimizer hook and lr scheduler hook. + """ + _, lr_scheduler, optim_options, lr_options = self.create_optimizer_and_scheduler( + ) + + optim_hook = self.cfg.train.get('optimizer_hook', None) + lr_hook = self.cfg.train.get('lr_scheduler_hook', None) + + # adapt to `ReduceLROnPlateau` + from torch.optim.lr_scheduler import ReduceLROnPlateau + if isinstance(lr_scheduler, ReduceLROnPlateau) and lr_hook is None: + plateau_cfg = { + 'train': { + 'lr_scheduler_hook': { + 'type': 'PlateauLrSchedulerHook', + 'metric_key': + 'Metric Key used for PlateauLrSchedulerHook' + } + } + } + plateau_cfg = json.dumps( + plateau_cfg, sort_keys=False, indent=4, separators=(',', ':')) + raise ValueError( + 'Must add `lr_scheduler_hook` to configuration for `ReduceLROnPlateau` lr scheduler as follows:' + + '\n' + plateau_cfg) + + if lr_hook is None: + lr_hook = dict(type='LrSchedulerHook', **lr_options) + if optim_hook is None: + if self.use_fp16: + optim_hook = dict( + type='TorchAMPOptimizerHook', **optim_options) + else: + optim_hook = dict(type='OptimizerHook', **optim_options) + + self.register_hook_from_cfg([lr_hook, optim_hook]) + + def _build_dataloader_with_dataset(self, + dataset: Dataset, + batch_size_per_gpu: int, + workers_per_gpu: int, + dist: bool = False, + shuffle: bool = True, + seed: int = 0, + persistent_workers=False, + **kwargs) -> DataLoader: + """Build dataloader using input dataset and cfg. Used by `EpochBasedTrainer.train()` + and `EpochBasedTrainer.evaluate()`. + + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + + Args: + dataset (Dataset): A PyTorch dataset. + batch_size_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + seed (int, Optional): Seed to be used. Default: 0. + runner_type (str): Type of runner. Default: `EpochBasedRunner` + persistent_workers (bool): If True, the data loader will not shutdown + the worker processes after a dataset has been consumed once. + This allows to maintain the workers `Dataset` instances alive. + This argument is only valid when PyTorch>=1.7.0. Default: False. + kwargs: any keyword argument to be used to initialize DataLoader + + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + + if dist: + # When model is :obj:`DistributedDataParallel`, + # `batch_size` of :obj:`dataloader` is the + # number of training samples on each GPU. + batch_size = batch_size_per_gpu + num_workers = workers_per_gpu + else: + batch_size = batch_size_per_gpu + num_workers = workers_per_gpu + + if dist and not isinstance(dataset, torch.utils.data.IterableDataset): + sampler = DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) + else: + sampler = None + + batch_sampler = None + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): + kwargs['persistent_workers'] = persistent_workers + elif persistent_workers is True: + self.logger.warning( + 'persistent_workers is invalid because your pytorch ' + 'version is lower than 1.7.0') + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + batch_sampler=batch_sampler, + pin_memory=kwargs.pop('pin_memory', False), + worker_init_fn=init_fn, + **kwargs) + + return data_loader + + def train_loop(self, data_loader): + """ Training loop used by `EpochBasedTrainer.train()` + """ + self.invoke_hook(TrainerStages.before_run) + kwargs = {} + self.model.train() + for _ in range(self._epoch, self._max_epochs): + self.invoke_hook(TrainerStages.before_train_epoch) + for i, data_batch in enumerate(data_loader): + if i < self.inner_iter: + # inner_iter may be read out from the checkpoint file, so skip the trained iters in the epoch. + continue + data_batch = to_device(data_batch, self.device) + self.data_batch = data_batch + self._inner_iter = i + self.invoke_hook(TrainerStages.before_train_iter) + self.train_step(self.model, data_batch, **kwargs) + self.invoke_hook(TrainerStages.after_train_iter) + # Value changed after the hooks are invoked, do not move them above the invoke_hook code. + del self.data_batch + self._iter += 1 + self._mode = ModeKeys.TRAIN + + if i + 1 >= self.iters_per_epoch: + break + + self.invoke_hook(TrainerStages.after_train_epoch) + # Value changed after the hooks are invoked, do not move them above the invoke_hook code. + self._inner_iter = 0 + self._epoch += 1 + + self.invoke_hook(TrainerStages.after_run) + + def evaluation_step(self, data): + """Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + """ + model = self.model.module if self._dist else self.model + model.eval() + + if is_parallel(model): + receive_dict_inputs = func_receive_dict_inputs( + model.module.forward) + else: + receive_dict_inputs = func_receive_dict_inputs(model.forward) + + with torch.no_grad(): + if isinstance(data, Mapping) and not receive_dict_inputs: + result = model.forward(**data) + else: + result = model.forward(data) + return result + + def evaluation_loop(self, data_loader, metric_classes): + """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. + + """ + if self._dist: + from modelscope.trainers.utils.inference import multi_gpu_test + metric_values = multi_gpu_test( + self, + data_loader, + device=self.device, + tmpdir=None, + gpu_collect=False, + metric_classes=metric_classes, + data_loader_iters_per_gpu=self._eval_iters_per_epoch) + else: + from modelscope.trainers.utils.inference import single_gpu_test + metric_values = single_gpu_test( + self, + data_loader, + device=self.device, + metric_classes=metric_classes, + data_loader_iters=self._eval_iters_per_epoch) + + self._inner_iter = self.iters_per_epoch - 1 # start from index 0 + + return metric_values + + def register_hook(self, hook: Hook) -> None: + """Register a hook into the hook list. + + The hook will be inserted into a priority queue, with the specified + priority (See :class:`Priority` for details of priorities). + For hooks with the same priority, they will be triggered in the same + order as they are registered. + + Args: + hook (:obj:`Hook`): The hook to be registered. + """ + # insert the hook to a sorted list + inserted = False + for i in range(len(self._hooks) - 1, -1, -1): + p = hook.PRIORITY if hasattr(hook, 'PRIORITY') else Priority.NORMAL + p_i = self._hooks[i].PRIORITY if hasattr( + self._hooks[i], 'PRIORITY') else Priority.NORMAL + + if get_priority(p) > get_priority(p_i): + self._hooks.insert(i + 1, hook) + inserted = True + break + if not inserted: + self._hooks.insert(0, hook) + + def register_hook_from_cfg(self, hook_cfg: Dict) -> None: + """Register a hook from its cfg. + + Args: + hook_cfg (dict): Hook config. It should have at least keys 'type' + and 'priority' indicating its type and priority. + + Note: + The specific hook class to register should not use 'type' and + 'priority' arguments during initialization. + """ + hook_cfg = hook_cfg.copy() + assert isinstance(hook_cfg, list) + for cfg_i in hook_cfg: + hook = build_from_cfg(cfg_i, HOOKS) + self.register_hook(hook) + + def invoke_hook(self, fn_name: str) -> None: + """Call all hooks. + + Args: + fn_name (str): The function name in each hook to be called, such as + "before_train_epoch". + """ + for hook in self._hooks: + getattr(hook, fn_name)(self) + + def get_hook_info(self) -> str: + # Get hooks info in each stage + stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages} + for hook in self.hooks: + try: + priority = Priority(hook.priority).name # type: ignore + except ValueError: + priority = hook.priority # type: ignore + classname = hook.__class__.__name__ + hook_info = f'({priority:<12}) {classname:<35}' + for trigger_stage in hook.get_triggered_stages(): + stage_hook_map[trigger_stage].append(hook_info) + + stage_hook_infos = [] + for stage in Hook.stages: + hook_infos = stage_hook_map[stage] + if len(hook_infos) > 0: + info = f'{stage}:\n' + info += '\n'.join(hook_infos) + info += '\n -------------------- ' + stage_hook_infos.append(info) + return '\n'.join(stage_hook_infos) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + set_random_seed(worker_seed) diff --git a/modelscope/trainers/utils/__init__.py b/modelscope/trainers/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/utils/inference.py b/modelscope/trainers/utils/inference.py new file mode 100644 index 00000000..6e4e7a19 --- /dev/null +++ b/modelscope/trainers/utils/inference.py @@ -0,0 +1,301 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging +import os +import pickle +import shutil + +import torch +from torch import distributed as dist +from tqdm import tqdm + +from modelscope.utils.data_utils import to_device +from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, + make_tmp_dir) + + +def single_gpu_test(trainer, + data_loader, + device, + metric_classes=None, + data_loader_iters=None): + """Test model in EpochBasedTrainer with a single gpu. + + Args: + trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + device (str | torch.device): The target device for the data. + metric_classes (List): List of Metric class that uses to collect metrics + data_loader_iters (int): Used when dataset has no attribute __len__ or only load part of dataset. + + Returns: + list: The prediction results. + """ + dataset = data_loader.dataset + progress_with_iters = False + if data_loader_iters is None: + try: + data_len = len(dataset) + except Exception as e: + logging.error(e) + raise ValueError( + 'Please implement ``__len__`` method for your dataset, or provide ``data_loader_iters``' + ) + desc = 'Total test samples' + else: + progress_with_iters = True + data_len = data_loader_iters + desc = 'Test iterations' + + with tqdm(total=data_len, desc=desc) as pbar: + for i, data in enumerate(data_loader): + data = to_device(data, device) + result = trainer.evaluation_step(data) + if metric_classes is not None: + for metric_cls in metric_classes: + metric_cls.add(result, data) + + if progress_with_iters: + batch_size = 1 # iteration count + else: + if isinstance(data, dict): + if 'nsentences' in data: + batch_size = data['nsentences'] + else: + try: + batch_size = len(next(iter(data.values()))) + except Exception: + batch_size = data_loader.batch_size + else: + batch_size = len(data) + for _ in range(batch_size): + pbar.update() + + if progress_with_iters and (i + 1) >= data_len: + break + + metric_values = {} + for metric_cls in metric_classes: + metric_values.update(metric_cls.evaluate()) + + return metric_values + + +def multi_gpu_test(trainer, + data_loader, + device, + tmpdir=None, + gpu_collect=False, + metric_classes=None, + data_loader_iters_per_gpu=None): + """Test model in EpochBasedTrainer with multiple gpus. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting + ``gpu_collect=True``, it encodes results to gpu tensors and use gpu + communication for results collection. On cpu mode it saves the results on + different gpus to ``tmpdir`` and collects them by the rank 0 worker. + + Args: + trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + device: (str | torch.device): The target device for the data. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + metric_classes(List): List of Metric class that uses to collect metrics + data_loader_iters_per_gpu (int): Used when dataset has no attribute __len__ or only load part of dataset. + Returns: + list: The prediction results. + """ + results = [] + data_list = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + + progress_with_iters = False + if data_loader_iters_per_gpu is None: + try: + data_len = len(dataset) + total_samples = data_len + except Exception as e: + logging.error(e) + raise ValueError( + 'Please implement ``__len__`` method for your dataset, or provide ``data_loader_iters_per_gpu``' + ) + desc = 'Total test samples with multi gpus' + else: + total_samples = 0 + progress_with_iters = True + data_len = data_loader_iters_per_gpu * world_size + desc = 'Total test iterations with multi gpus' + + count = 0 + with tqdm(total=data_len, desc=desc) as pbar: + for i, data in enumerate(data_loader): + data = to_device(data, device) + data_list.append(data) + result = trainer.evaluation_step(data) + results.append(result) + + if isinstance(data, dict): + if 'nsentences' in data: + batch_size = data['nsentences'] + else: + batch_size = len(next(iter(data.values()))) + else: + batch_size = len(data) + if i >= (data_len // world_size) - 1: + total_samples = torch.LongTensor([batch_size]).to(model.device) + dist.all_reduce(total_samples, op=dist.reduce_op.SUM) + total_samples = total_samples.item() + else: + total_samples = batch_size * world_size + if progress_with_iters: + iter_cnt_all = world_size + else: + iter_cnt_all = total_samples + count += iter_cnt_all + + if rank == 0: + if count > data_len: + iter_cnt_all = data_len - (count - iter_cnt_all) + for _ in range(iter_cnt_all): + pbar.update() + + if progress_with_iters and (i + 1) >= data_len: + break + + # TODO: allgather data list may cost a lot of memory and needs to be redesigned + # collect results and data from all ranks + if gpu_collect: + results = collect_results_gpu(results, total_samples) + data_list = collect_results_gpu(data_list, total_samples) + else: + if tmpdir is None: + tmpdir = make_tmp_dir() + results = collect_results_cpu(results, total_samples, + os.path.join(tmpdir, 'predict')) + data_list = collect_results_cpu(data_list, total_samples, + os.path.join(tmpdir, 'groundtruth')) + + if is_master(): + assert len(data_list) == len( + results), f'size mismatch {len(data_list)} and {len(results)}' + if metric_classes is not None: + for i in range(len(data_list)): + for metric_cls in metric_classes: + metric_cls.add(results[i], data_list[i]) + + metric_values = {} + if rank == 0: + for metric_cls in metric_classes: + metric_values.update(metric_cls.evaluate()) + if world_size > 1: + metric_values = broadcast(metric_values, 0) + + return metric_values + + +def collect_results_cpu(result_part, size, tmpdir=None): + """Collect results under cpu mode. + + On cpu mode, this function will save the results on different gpus to + ``tmpdir`` and collect them by the rank 0 worker. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + tmpdir (str | None): temporal directory for collected results to + store. If set to None, it will create a random temporal directory + for it. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + if tmpdir is None: + tmpdir = make_tmp_dir() + if not os.path.exists(tmpdir) and is_master(): + os.makedirs(tmpdir) + dist.barrier() + + # dump the part result to the dir + with open(os.path.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: + pickle.dump(result_part, f) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = os.path.join(tmpdir, f'part_{i}.pkl') + with open(part_file, 'rb') as f: + part_result = pickle.load(f) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + """Collect results under gpu mode. + + On gpu mode, this function will encode results to gpu tensors and use gpu + communication for results collection. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results diff --git a/modelscope/trainers/utils/log_buffer.py b/modelscope/trainers/utils/log_buffer.py new file mode 100644 index 00000000..edcc273e --- /dev/null +++ b/modelscope/trainers/utils/log_buffer.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections import OrderedDict + +import numpy as np + + +class LogBuffer: + + def __init__(self): + self.val_history = OrderedDict() + self.n_history = OrderedDict() + self.output = OrderedDict() + self.ready = False + + def clear(self) -> None: + self.val_history.clear() + self.n_history.clear() + self.clear_output() + + def clear_output(self) -> None: + self.output.clear() + self.ready = False + + def update(self, vars: dict, count: int = 1) -> None: + assert isinstance(vars, dict) + for key, var in vars.items(): + if key not in self.val_history: + self.val_history[key] = [] + self.n_history[key] = [] + self.val_history[key].append(var) + self.n_history[key].append(count) + + def average(self, n: int = 0) -> None: + """Average latest n values or all values.""" + assert n >= 0 + for key in self.val_history: + values = np.array(self.val_history[key][-n:]) + nums = np.array(self.n_history[key][-n:]) + avg = np.sum(values * nums) / np.sum(nums) + self.output[key] = avg + self.ready = True diff --git a/modelscope/utils/__init__.py b/modelscope/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/utils/ast_utils.py b/modelscope/utils/ast_utils.py new file mode 100644 index 00000000..f59100cb --- /dev/null +++ b/modelscope/utils/ast_utils.py @@ -0,0 +1,681 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import ast +import contextlib +import hashlib +import importlib +import os +import os.path as osp +import time +import traceback +from functools import reduce +from pathlib import Path +from typing import Generator, Union + +import gast +import json + +from modelscope import __version__ +from modelscope.fileio.file import LocalStorage +from modelscope.metainfo import (Datasets, Heads, Hooks, LR_Schedulers, + Metrics, Models, Optimizers, Pipelines, + Preprocessors, TaskModels, Trainers) +from modelscope.utils.constant import Fields, Tasks +from modelscope.utils.file_utils import get_default_cache_dir +from modelscope.utils.logger import get_logger +from modelscope.utils.registry import default_group + +logger = get_logger() +storage = LocalStorage() +p = Path(__file__) + +# get the path of package 'modelscope' +MODELSCOPE_PATH = p.resolve().parents[1] +REGISTER_MODULE = 'register_module' +IGNORED_PACKAGES = ['modelscope', '.'] +SCAN_SUB_FOLDERS = [ + 'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets' +] +INDEXER_FILE = 'ast_indexer' +DECORATOR_KEY = 'decorators' +EXPRESS_KEY = 'express' +FROM_IMPORT_KEY = 'from_imports' +IMPORT_KEY = 'imports' +FILE_NAME_KEY = 'filepath' +VERSION_KEY = 'version' +MD5_KEY = 'md5' +INDEX_KEY = 'index' +REQUIREMENT_KEY = 'requirements' +MODULE_KEY = 'module' +CLASS_NAME = 'class_name' +GROUP_KEY = 'group_key' +MODULE_NAME = 'module_name' +MODULE_CLS = 'module_cls' + + +class AstScaning(object): + + def __init__(self) -> None: + self.result_import = dict() + self.result_from_import = dict() + self.result_decorator = [] + self.express = [] + + def _is_sub_node(self, node: object) -> bool: + return isinstance(node, + ast.AST) and not isinstance(node, ast.expr_context) + + def _is_leaf(self, node: ast.AST) -> bool: + for field in node._fields: + attr = getattr(node, field) + if self._is_sub_node(attr): + return False + elif isinstance(attr, (list, tuple)): + for val in attr: + if self._is_sub_node(val): + return False + else: + return True + + def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple: + if show_offsets: + return n._attributes + n._fields + else: + return n._fields + + def _leaf(self, node: ast.AST, show_offsets: bool = True) -> str: + output = dict() + local_print = list() + if isinstance(node, ast.AST): + local_dict = dict() + for field in self._fields(node, show_offsets=show_offsets): + field_output, field_prints = self._leaf( + getattr(node, field), show_offsets=show_offsets) + local_dict[field] = field_output + local_print.append('{}={}'.format(field, field_prints)) + + prints = '{}({})'.format( + type(node).__name__, + ', '.join(local_print), + ) + output[type(node).__name__] = local_dict + return output, prints + elif isinstance(node, list): + if '_fields' not in node: + return node, repr(node) + for item in node: + item_output, item_prints = self._leaf( + getattr(node, item), show_offsets=show_offsets) + local_print.append(item_prints) + return node, '[{}]'.format(', '.join(local_print), ) + else: + return node, repr(node) + + def _refresh(self): + self.result_import = dict() + self.result_from_import = dict() + self.result_decorator = [] + self.result_express = [] + + def scan_ast(self, node: Union[ast.AST, None, str]): + self._setup_global() + self.scan_import(node, indent=' ', show_offsets=False) + + def scan_import( + self, + node: Union[ast.AST, None, str], + indent: Union[str, int] = ' ', + show_offsets: bool = True, + _indent: int = 0, + parent_node_name: str = '', + ) -> tuple: + if node is None: + return node, repr(node) + elif self._is_leaf(node): + return self._leaf(node, show_offsets=show_offsets) + else: + if isinstance(indent, int): + indent_s = indent * ' ' + else: + indent_s = indent + + class state: + indent = _indent + + @contextlib.contextmanager + def indented() -> Generator[None, None, None]: + state.indent += 1 + yield + state.indent -= 1 + + def indentstr() -> str: + return state.indent * indent_s + + def _scan_import(el: Union[ast.AST, None, str], + _indent: int = 0, + parent_node_name: str = '') -> str: + return self.scan_import( + el, + indent=indent, + show_offsets=show_offsets, + _indent=_indent, + parent_node_name=parent_node_name) + + out = type(node).__name__ + '(\n' + outputs = dict() + # add relative path expression + if type(node).__name__ == 'ImportFrom': + level = getattr(node, 'level') + if level >= 1: + path_level = ''.join(['.'] * level) + setattr(node, 'level', 0) + module_name = getattr(node, 'module') + if module_name is None: + setattr(node, 'module', path_level) + else: + setattr(node, 'module', path_level + module_name) + with indented(): + for field in self._fields(node, show_offsets=show_offsets): + attr = getattr(node, field) + if attr == []: + representation = '[]' + outputs[field] = [] + elif (isinstance(attr, list) and len(attr) == 1 + and isinstance(attr[0], ast.AST) + and self._is_leaf(attr[0])): + local_out, local_print = _scan_import(attr[0]) + representation = f'[{local_print}]' + outputs[field] = local_out + + elif isinstance(attr, list): + representation = '[\n' + el_dict = dict() + with indented(): + for el in attr: + local_out, local_print = _scan_import( + el, state.indent, + type(el).__name__) + representation += '{}{},\n'.format( + indentstr(), + local_print, + ) + name = type(el).__name__ + if (name == 'Import' or name == 'ImportFrom' + or parent_node_name == 'ImportFrom' + or parent_node_name == 'Import'): + if name not in el_dict: + el_dict[name] = [] + el_dict[name].append(local_out) + representation += indentstr() + ']' + outputs[field] = el_dict + elif isinstance(attr, ast.AST): + output, representation = _scan_import( + attr, state.indent) + outputs[field] = output + else: + representation = repr(attr) + outputs[field] = attr + + if (type(node).__name__ == 'Import' + or type(node).__name__ == 'ImportFrom'): + if type(node).__name__ == 'ImportFrom': + if field == 'module': + self.result_from_import[ + outputs[field]] = dict() + if field == 'names': + if isinstance(outputs[field]['alias'], list): + item_name = [] + for item in outputs[field]['alias']: + local_name = item['alias']['name'] + item_name.append(local_name) + self.result_from_import[ + outputs['module']] = item_name + else: + local_name = outputs[field]['alias'][ + 'name'] + self.result_from_import[ + outputs['module']] = [local_name] + + if type(node).__name__ == 'Import': + final_dict = outputs[field]['alias'] + if isinstance(final_dict, list): + for item in final_dict: + self.result_import[ + item['alias']['name']] = item['alias'] + else: + self.result_import[outputs[field]['alias'] + ['name']] = final_dict + + if 'decorator_list' == field and attr != []: + for item in attr: + setattr(item, CLASS_NAME, node.name) + self.result_decorator.extend(attr) + + if attr != [] and type( + attr + ).__name__ == 'Call' and parent_node_name == 'Expr': + self.result_express.append(attr) + + out += f'{indentstr()}{field}={representation},\n' + + out += indentstr() + ')' + return { + IMPORT_KEY: self.result_import, + FROM_IMPORT_KEY: self.result_from_import, + DECORATOR_KEY: self.result_decorator, + EXPRESS_KEY: self.result_express + }, out + + def _parse_decorator(self, node: ast.AST) -> tuple: + + def _get_attribute_item(node: ast.AST) -> tuple: + value, id, attr = None, None, None + if type(node).__name__ == 'Attribute': + value = getattr(node, 'value') + id = getattr(value, 'id') + attr = getattr(node, 'attr') + if type(node).__name__ == 'Name': + id = getattr(node, 'id') + return id, attr + + def _get_args_name(nodes: list) -> list: + result = [] + for node in nodes: + if type(node).__name__ == 'Str': + result.append((node.s, None)) + else: + result.append(_get_attribute_item(node)) + return result + + def _get_keyword_name(nodes: ast.AST) -> list: + result = [] + for node in nodes: + if type(node).__name__ == 'keyword': + attribute_node = getattr(node, 'value') + if type(attribute_node).__name__ == 'Str': + result.append((getattr(node, + 'arg'), attribute_node.s, None)) + elif type(attribute_node).__name__ == 'Constant': + result.append( + (getattr(node, 'arg'), attribute_node.value, None)) + else: + result.append((getattr(node, 'arg'), ) + + _get_attribute_item(attribute_node)) + return result + + functions = _get_attribute_item(node.func) + args_list = _get_args_name(node.args) + keyword_list = _get_keyword_name(node.keywords) + return functions, args_list, keyword_list + + def _get_registry_value(self, key_item): + if key_item is None: + return None + if key_item == 'default_group': + return default_group + split_list = key_item.split('.') + # in the case, the key_item is raw data, not registred + if len(split_list) == 1: + return key_item + else: + return getattr(eval(split_list[0]), split_list[1]) + + def _registry_indexer(self, parsed_input: tuple, class_name: str) -> tuple: + """format registry information to a tuple indexer + + Return: + tuple: (MODELS, Tasks.text-classification, Models.structbert) + """ + functions, args_list, keyword_list = parsed_input + + # ignore decocators other than register_module + if REGISTER_MODULE != functions[1]: + return None + output = [functions[0]] + + if len(args_list) == 0 and len(keyword_list) == 0: + args_list.append(default_group) + if len(keyword_list) == 0 and len(args_list) == 1: + args_list.append(class_name) + + if len(keyword_list) > 0 and len(args_list) == 0: + remove_group_item = None + for item in keyword_list: + key, name, attr = item + if key == GROUP_KEY: + args_list.append((name, attr)) + remove_group_item = item + if remove_group_item is not None: + keyword_list.remove(remove_group_item) + + if len(args_list) == 0: + args_list.append(default_group) + + for item in keyword_list: + key, name, attr = item + if key == MODULE_CLS: + class_name = name + else: + args_list.append((name, attr)) + + for item in args_list: + # the case empty input + if item is None: + output.append(None) + # the case (default_group) + elif item[1] is None: + output.append(item[0]) + elif isinstance(item, str): + output.append(item) + else: + output.append('.'.join(item)) + return (output[0], self._get_registry_value(output[1]), + self._get_registry_value(output[2])) + + def parse_decorators(self, nodes: list) -> list: + """parse the AST nodes of decorators object to registry indexer + + Args: + nodes (list): list of AST decorator nodes + + Returns: + list: list of registry indexer + """ + results = [] + for node in nodes: + if type(node).__name__ != 'Call': + continue + class_name = getattr(node, CLASS_NAME, None) + func = getattr(node, 'func') + + if getattr(func, 'attr', None) != REGISTER_MODULE: + continue + + parse_output = self._parse_decorator(node) + index = self._registry_indexer(parse_output, class_name) + if None is not index: + results.append(index) + return results + + def generate_ast(self, file): + self._refresh() + with open(file, 'r', encoding='utf8') as code: + data = code.readlines() + data = ''.join(data) + + node = gast.parse(data) + output, _ = self.scan_import(node, indent=' ', show_offsets=False) + output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY]) + output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY]) + output[DECORATOR_KEY].extend(output[EXPRESS_KEY]) + return output + + +class FilesAstScaning(object): + + def __init__(self) -> None: + self.astScaner = AstScaning() + self.file_dirs = [] + + def _parse_import_path(self, + import_package: str, + current_path: str = None) -> str: + """ + Args: + import_package (str): relative import or abs import + current_path (str): path/to/current/file + """ + if import_package.startswith(IGNORED_PACKAGES[0]): + return MODELSCOPE_PATH + '/' + '/'.join( + import_package.split('.')[1:]) + '.py' + elif import_package.startswith(IGNORED_PACKAGES[1]): + current_path_list = current_path.split('/') + import_package_list = import_package.split('.') + level = 0 + for index, item in enumerate(import_package_list): + if item != '': + level = index + break + + abs_path_list = current_path_list[0:-level] + abs_path_list.extend(import_package_list[index:]) + return '/' + '/'.join(abs_path_list) + '.py' + else: + return current_path + + def _traversal_import( + self, + import_abs_path, + ): + pass + + def parse_import(self, scan_result: dict) -> list: + """parse import and from import dicts to a third party package list + + Args: + scan_result (dict): including the import and from import result + + Returns: + list: a list of package ignored 'modelscope' and relative path import + """ + output = [] + output.extend(list(scan_result[IMPORT_KEY].keys())) + output.extend(list(scan_result[FROM_IMPORT_KEY].keys())) + + # get the package name + for index, item in enumerate(output): + if '' == item.split('.')[0]: + output[index] = '.' + else: + output[index] = item.split('.')[0] + + ignored = set() + for item in output: + for ignored_package in IGNORED_PACKAGES: + if item.startswith(ignored_package): + ignored.add(item) + return list(set(output) - set(ignored)) + + def traversal_files(self, path, check_sub_dir): + self.file_dirs = [] + if check_sub_dir is None or len(check_sub_dir) == 0: + self._traversal_files(path) + + for item in check_sub_dir: + sub_dir = os.path.join(path, item) + if os.path.isdir(sub_dir): + self._traversal_files(sub_dir) + + def _traversal_files(self, path): + dir_list = os.scandir(path) + for item in dir_list: + if item.name.startswith('__'): + continue + if item.is_dir(): + self._traversal_files(item.path) + elif item.is_file() and item.name.endswith('.py'): + self.file_dirs.append(item.path) + + def _get_single_file_scan_result(self, file): + try: + output = self.astScaner.generate_ast(file) + except Exception as e: + detail = traceback.extract_tb(e.__traceback__) + raise Exception( + f'During ast indexing, error is in the file {detail[-1].filename}' + f' line: {detail[-1].lineno}: "{detail[-1].line}" with error msg: ' + f'"{type(e).__name__}: {e}"') + + import_list = self.parse_import(output) + return output[DECORATOR_KEY], import_list + + def _inverted_index(self, forward_index): + inverted_index = dict() + for index in forward_index: + for item in forward_index[index][DECORATOR_KEY]: + inverted_index[item] = { + FILE_NAME_KEY: index, + IMPORT_KEY: forward_index[index][IMPORT_KEY], + MODULE_KEY: forward_index[index][MODULE_KEY], + } + return inverted_index + + def _module_import(self, forward_index): + module_import = dict() + for index, value_dict in forward_index.items(): + module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY] + return module_import + + def _ignore_useless_keys(self, inverted_index): + if ('OPTIMIZERS', 'default', 'name') in inverted_index: + del inverted_index[('OPTIMIZERS', 'default', 'name')] + if ('LR_SCHEDULER', 'default', 'name') in inverted_index: + del inverted_index[('LR_SCHEDULER', 'default', 'name')] + return inverted_index + + def get_files_scan_results(self, + target_dir=MODELSCOPE_PATH, + target_folders=SCAN_SUB_FOLDERS): + """the entry method of the ast scan method + + Args: + target_dir (str, optional): the absolute path of the target directory to be scaned. Defaults to None. + target_folder (list, optional): the list of + sub-folders to be scaned in the target folder. + Defaults to SCAN_SUB_FOLDERS. + + Returns: + dict: indexer of registry + """ + + self.traversal_files(target_dir, target_folders) + start = time.time() + logger.info( + f'AST-Scaning the path "{target_dir}" with the following sub folders {target_folders}' + ) + + result = dict() + for file in self.file_dirs: + filepath = file[file.rfind('modelscope'):] + module_name = filepath.replace(osp.sep, '.').replace('.py', '') + decorator_list, import_list = self._get_single_file_scan_result( + file) + result[file] = { + DECORATOR_KEY: decorator_list, + IMPORT_KEY: import_list, + MODULE_KEY: module_name + } + inverted_index_with_results = self._inverted_index(result) + inverted_index_with_results = self._ignore_useless_keys( + inverted_index_with_results) + module_import = self._module_import(result) + index = { + INDEX_KEY: inverted_index_with_results, + REQUIREMENT_KEY: module_import + } + logger.info( + f'Scaning done! A number of {len(inverted_index_with_results)}' + f' files indexed! Time consumed {time.time()-start}s') + return index + + def files_mtime_md5(self, + target_path=MODELSCOPE_PATH, + target_subfolder=SCAN_SUB_FOLDERS): + self.file_dirs = [] + self.traversal_files(target_path, target_subfolder) + files_mtime = [] + for item in self.file_dirs: + files_mtime.append(os.path.getmtime(item)) + result_str = reduce(lambda x, y: str(x) + str(y), files_mtime, '') + md5 = hashlib.md5(result_str.encode()) + return md5.hexdigest() + + +file_scanner = FilesAstScaning() + + +def _save_index(index, file_path): + # convert tuple key to str key + index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()} + index[VERSION_KEY] = __version__ + index[MD5_KEY] = file_scanner.files_mtime_md5() + json_index = json.dumps(index) + storage.write(json_index.encode(), file_path) + index[INDEX_KEY] = { + ast.literal_eval(k): v + for k, v in index[INDEX_KEY].items() + } + + +def _load_index(file_path): + bytes_index = storage.read(file_path) + wrapped_index = json.loads(bytes_index) + # convert str key to tuple key + wrapped_index[INDEX_KEY] = { + ast.literal_eval(k): v + for k, v in wrapped_index[INDEX_KEY].items() + } + return wrapped_index + + +def load_index(force_rebuild=False): + """get the index from scan results or cache + + Args: + force_rebuild: If set true, rebuild and load index + Returns: + dict: the index information for all registred modules, including key: + index, requirments, version and md5, the detail is shown below example: + { + 'index': { + ('MODELS', 'nlp', 'bert'):{ + 'filepath' : 'path/to/the/registered/model', 'imports': + ['os', 'torch', 'typeing'] 'module': + 'modelscope.models.nlp.bert' + }, + ... + }, 'requirments': { + 'modelscope.models.nlp.bert': ['os', 'torch', 'typeing'], + 'modelscope.models.nlp.structbert': ['os', 'torch', 'typeing'], + ... + }, 'version': '0.2.3', 'md5': '8616924970fe6bc119d1562832625612', + } + """ + cache_dir = os.getenv('MODELSCOPE_CACHE', get_default_cache_dir()) + file_path = os.path.join(cache_dir, INDEXER_FILE) + logger.info(f'Loading ast index from {file_path}') + index = None + if not force_rebuild and os.path.exists(file_path): + wrapped_index = _load_index(file_path) + md5 = file_scanner.files_mtime_md5() + if (wrapped_index[VERSION_KEY] == __version__ + and wrapped_index[MD5_KEY] == md5): + index = wrapped_index + + if index is None: + if force_rebuild: + logger.info('Force rebuilding ast index') + else: + logger.info( + f'No valid ast index found from {file_path}, rebuilding ast index!' + ) + index = file_scanner.get_files_scan_results() + _save_index(index, file_path) + logger.info( + f'Loading done! Current index file version is {index[VERSION_KEY]}, ' + f'with md5 {index[MD5_KEY]}') + return index + + +def check_import_module_avaliable(module_dicts: dict) -> list: + missed_module = [] + for module in module_dicts.keys(): + loader = importlib.find_loader(module) + if loader is None: + missed_module.append(module) + return missed_module + + +if __name__ == '__main__': + index = load_index() + print(index) diff --git a/modelscope/utils/audio/__init__.py b/modelscope/utils/audio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py new file mode 100644 index 00000000..32e2fa54 --- /dev/null +++ b/modelscope/utils/audio/audio_utils.py @@ -0,0 +1,103 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +import struct +from typing import Union +from urllib.parse import urlparse + +from modelscope.fileio.file import HTTPStorage + +SEGMENT_LENGTH_TRAIN = 16000 + + +def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN): + """ + Dataset mapping function to split one audio into segments. + It only works in batch mode. + """ + noisy_arrays = [] + clean_arrays = [] + for x, y in zip(batch['noisy'], batch['clean']): + length = min(len(x['array']), len(y['array'])) + noisy = x['array'] + clean = y['array'] + for offset in range(segment_length, length + 1, segment_length): + noisy_arrays.append(noisy[offset - segment_length:offset]) + clean_arrays.append(clean[offset - segment_length:offset]) + return {'noisy': noisy_arrays, 'clean': clean_arrays} + + +def audio_norm(x): + rms = (x**2).mean()**0.5 + scalar = 10**(-25 / 20) / rms + x = x * scalar + pow_x = x**2 + avg_pow_x = pow_x.mean() + rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5 + scalarx = 10**(-25 / 20) / rmsx + x = x * scalarx + return x + + +def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): + + def repl(matched): + key = matched.group(1) + if key in conf_item: + return conf_item[key] + else: + return None + + with open(origin_config_file) as f: + lines = f.readlines() + with open(new_config_file, 'w') as f: + for line in lines: + line = re.sub(r'\$\{(.*)\}', repl, line) + f.write(line) + + +def extract_pcm_from_wav(wav: bytes) -> bytes: + data = wav + sample_rate = None + if len(data) > 44: + frame_len = 44 + file_len = len(data) + try: + header_fields = {} + header_fields['ChunkID'] = str(data[0:4], 'UTF-8') + header_fields['Format'] = str(data[8:12], 'UTF-8') + header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') + if header_fields['ChunkID'] == 'RIFF' and header_fields[ + 'Format'] == 'WAVE' and header_fields[ + 'Subchunk1ID'] == 'fmt ': + header_fields['SubChunk1Size'] = struct.unpack( + ' Union[bytes, str]: + sample_rate = None + result = urlparse(url) + if result.scheme is not None and len(result.scheme) > 0: + storage = HTTPStorage() + data = storage.read(url) + data, sample_rate = extract_pcm_from_wav(data) + else: + data = url + + return data, sample_rate diff --git a/modelscope/utils/audio/tts_exceptions.py b/modelscope/utils/audio/tts_exceptions.py new file mode 100644 index 00000000..43ec994b --- /dev/null +++ b/modelscope/utils/audio/tts_exceptions.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Define TTS exceptions +""" + + +class TtsException(Exception): + """ + TTS exception class. + """ + pass + + +class TtsModelConfigurationException(TtsException): + """ + TTS model configuration exceptions. + """ + pass + + +class TtsVoiceNotExistsException(TtsException): + """ + TTS voice not exists exception. + """ + pass + + +class TtsFrontendException(TtsException): + """ + TTS frontend module level exceptions. + """ + pass + + +class TtsFrontendInitializeFailedException(TtsFrontendException): + """ + If tts frontend resource is invalid or not exist, this exception will be raised. + """ + pass + + +class TtsFrontendLanguageTypeInvalidException(TtsFrontendException): + """ + If language type is invalid, this exception will be raised. + """ + + +class TtsVocoderException(TtsException): + """ + Vocoder exception + """ + + +class TtsVocoderMelspecShapeMismatchException(TtsVocoderException): + """ + If vocoder's input melspec shape mismatch, this exception will be raised. + """ diff --git a/modelscope/utils/checkpoint.py b/modelscope/utils/checkpoint.py new file mode 100644 index 00000000..5acaa411 --- /dev/null +++ b/modelscope/utils/checkpoint.py @@ -0,0 +1,210 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +import os +import time +from collections import OrderedDict +from shutil import copytree, ignore_patterns, rmtree +from typing import Callable, List, Optional, Union + +import json +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +from modelscope import __version__ +from modelscope.fileio import File, LocalStorage +from modelscope.utils.config import JSONIteratorEncoder +from modelscope.utils.constant import ConfigFields, ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + +storage = LocalStorage() + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + # Keep metadata in state_dict + state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict()) + return state_dict_cpu + + +def save_checkpoint(model: torch.nn.Module, + filename: str, + optimizer: Optional[Optimizer] = None, + lr_scheduler: Optional[_LRScheduler] = None, + meta: Optional[dict] = None, + with_meta: bool = True) -> None: + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default, ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + with_meta (bool, optional): + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(modelscope=__version__, time=time.asctime()) + + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + if with_meta: + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(model.state_dict()) + } + + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + # save lr_scheduler state dict in the checkpoint + if lr_scheduler is not None and hasattr(lr_scheduler, 'state_dict'): + checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + else: + checkpoint = weights_to_cpu(model.state_dict()) + + with io.BytesIO() as f: + torch.save(checkpoint, f) + File.write(f.getvalue(), filename) + + +def load_checkpoint(filename, + model, + optimizer: Optimizer = None, + lr_scheduler: _LRScheduler = None): + if not os.path.exists(filename): + raise ValueError(f'Checkpoint file {filename} does not exist!') + checkpoint = torch.load(filename, map_location='cpu') + + if optimizer is not None: + if 'optimizer' in checkpoint: + if isinstance(optimizer, Optimizer): + optimizer.load_state_dict(checkpoint['optimizer']) + elif isinstance(optimizer, dict): + optimizer_dict = checkpoint['optimizer'] + for key, optimizer_ins in optimizer.items(): + if key in optimizer_dict: + optimizer_ins.load_state_dict(optimizer_dict[key]) + else: + logger.warn( + f'The state dict of optimizer {key} cannot be found in checkpoint file: {filename}' + ) + else: + logger.warn( + f'The state dict of optimizer cannot be found in checkpoint file: {filename}' + ) + + if lr_scheduler is not None: + if 'lr_scheduler' in checkpoint: + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + else: + logger.warn( + f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}' + ) + + state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ + 'state_dict'] + model.load_state_dict(state_dict) + return checkpoint.get('meta', {}) + + +def save_pretrained(model, + target_folder: Union[str, os.PathLike], + save_checkpoint_name: str = None, + save_function: Callable = None, + config: Optional[dict] = None, + **kwargs): + """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded + + Args: + model (Model): Model whose params are to be saved. + + target_folder (Union[str, os.PathLike]): + Directory to which to save. Will be created if it doesn't exist. + + save_checkpoint_name (str): + The checkpoint name to be saved in the target_folder + + save_function (Callable, optional): + The function to use to save the state dictionary. + + config (Optional[dict], optional): + The config for the configuration.json, might not be identical with model.config + """ + + if save_function is None or not isinstance(save_function, Callable): + raise Exception('A valid save function must be passed in') + + if target_folder is None or os.path.isfile(target_folder): + raise ValueError( + f'Provided path ({target_folder}) should be a directory, not a file' + ) + + if save_checkpoint_name is None: + raise Exception( + 'At least pass in one checkpoint name for saving method') + + if config is None: + raise ValueError('Configuration is not valid') + + # Clean the folder from a previous save + if os.path.exists(target_folder): + rmtree(target_folder) + + # Single ckpt path, sharded ckpt logic will be added later + output_ckpt_path = os.path.join(target_folder, save_checkpoint_name) + + # Save the files to be copied to the save directory, ignore the original ckpts and configuration + origin_file_to_be_ignored = [save_checkpoint_name] + ignore_file_set = set(origin_file_to_be_ignored) + ignore_file_set.add(ModelFile.CONFIGURATION) + ignore_file_set.add('.*') + if hasattr(model, 'model_dir') and model.model_dir is not None: + copytree( + model.model_dir, + target_folder, + ignore=ignore_patterns(*ignore_file_set)) + + # Save the ckpt to the save directory + try: + save_function(model, output_ckpt_path, **kwargs) + except Exception as e: + raise Exception( + f'During saving checkpoints, the error of "{type(e).__name__} ' + f'with msg {e} throwed') + + # Dump the config to the configuration.json + if ConfigFields.pipeline not in config: + config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]} + cfg_str = json.dumps(config, indent=4, cls=JSONIteratorEncoder) + config_file = os.path.join(target_folder, ModelFile.CONFIGURATION) + storage.write(cfg_str.encode(), config_file) diff --git a/modelscope/utils/chinese_utils.py b/modelscope/utils/chinese_utils.py new file mode 100644 index 00000000..e5fe7aa8 --- /dev/null +++ b/modelscope/utils/chinese_utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + + +def is_chinese_char(word: str): + chinese_punctuations = { + ',', '。', ';', ':' + '!', '?', '《', '》', '‘', '’', '“', '”', '(', ')', '【', '】' + } + return len(word) == 1 \ + and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations) + + +def remove_space_between_chinese_chars(decoded_str: str): + old_word_list = decoded_str.split(' ') + new_word_list = [] + start = -1 + for i, word in enumerate(old_word_list): + if is_chinese_char(word): + if start == -1: + start = i + else: + if start != -1: + new_word_list.append(''.join(old_word_list[start:i])) + start = -1 + new_word_list.append(word) + if start != -1: + new_word_list.append(''.join(old_word_list[start:])) + return ' '.join(new_word_list).strip() + + +# add space for each chinese char +def rebuild_chinese_str(string: str): + return ' '.join(''.join([ + f' {char} ' if is_chinese_char(char) else char for char in string + ]).split()) diff --git a/modelscope/utils/config.py b/modelscope/utils/config.py new file mode 100644 index 00000000..e46da7df --- /dev/null +++ b/modelscope/utils/config.py @@ -0,0 +1,664 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Major implementation is borrowed and modified from +# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py + +import copy +import os +import os.path as osp +import platform +import shutil +import sys +import tempfile +import types +from pathlib import Path +from types import FunctionType +from typing import Dict, Union + +import addict +import json +from yapf.yapflib.yapf_api import FormatCode + +from modelscope.utils.constant import ConfigFields, ModelFile +from modelscope.utils.import_utils import import_modules_from_file +from modelscope.utils.logger import get_logger + +logger = get_logger() + +BASE_KEY = '_base_' +DELETE_KEY = '_delete_' +DEPRECATION_KEY = '_deprecation_' +RESERVED_KEYS = ['filename', 'text', 'pretty_text'] + + +class ConfigDict(addict.Dict): + """ Dict which support get value through getattr + + Examples: + >>> cdict = ConfigDict({'a':1232}) + >>> print(cdict.a) + 1232 + """ + + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super(ConfigDict, self).__getattr__(name) + except KeyError: + ex = AttributeError(f"'{self.__class__.__name__}' object has no " + f"attribute '{name}'") + except Exception as e: + ex = e + else: + return value + raise ex + + +class Config: + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(c=[1,2,3], d='dd'))) + >>> cfg.a + 1 + >>> cfg.b + {'c': [1, 2, 3], 'd': 'dd'} + >>> cfg.b.d + 'dd' + >>> cfg = Config.from_file('configs/examples/configuration.json') + >>> cfg.filename + 'configs/examples/configuration.json' + >>> cfg.b + {'c': [1, 2, 3], 'd': 'dd'} + >>> cfg = Config.from_file('configs/examples/configuration.py') + >>> cfg.filename + "configs/examples/configuration.py" + >>> cfg = Config.from_file('configs/examples/configuration.yaml') + >>> cfg.filename + "configs/examples/configuration.yaml" + """ + + @staticmethod + def _file2dict(filename): + filename = osp.abspath(osp.expanduser(filename)) + if not osp.exists(filename): + raise ValueError(f'File does not exists {filename}') + fileExtname = osp.splitext(filename)[1] + if fileExtname not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + + with tempfile.TemporaryDirectory() as tmp_cfg_dir: + tmp_cfg_file = tempfile.NamedTemporaryFile( + dir=tmp_cfg_dir, suffix=fileExtname) + if platform.system() == 'Windows': + tmp_cfg_file.close() + tmp_cfg_name = osp.basename(tmp_cfg_file.name) + shutil.copyfile(filename, tmp_cfg_file.name) + + if filename.endswith('.py'): + module_nanme, mod = import_modules_from_file( + osp.join(tmp_cfg_dir, tmp_cfg_name)) + cfg_dict = {} + for name, value in mod.__dict__.items(): + if not name.startswith('__') and \ + not isinstance(value, types.ModuleType) and \ + not isinstance(value, types.FunctionType): + cfg_dict[name] = value + + # delete imported module + del sys.modules[module_nanme] + elif filename.endswith(('.yml', '.yaml', '.json')): + from modelscope.fileio import load + cfg_dict = load(tmp_cfg_file.name) + # close temp file + tmp_cfg_file.close() + + cfg_text = filename + '\n' + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + return cfg_dict, cfg_text + + @staticmethod + def from_file(filename): + if isinstance(filename, Path): + filename = str(filename) + cfg_dict, cfg_text = Config._file2dict(filename) + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + @staticmethod + def from_string(cfg_str, file_format): + """Generate config from config str. + + Args: + cfg_str (str): Config str. + file_format (str): Config file format corresponding to the + config str. Only py/yml/yaml/json type are supported now! + + Returns: + :obj:`Config`: Config obj. + """ + if file_format not in ['.py', '.json', '.yaml', '.yml']: + raise IOError('Only py/yml/yaml/json type are supported now!') + if file_format != '.py' and 'dict(' in cfg_str: + # check if users specify a wrong suffix for python + logger.warning( + 'Please check "file_format", the file format may be .py') + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix=file_format, + delete=False) as temp_file: + temp_file.write(cfg_str) + # on windows, previous implementation cause error + # see PR 1077 for details + cfg = Config.from_file(temp_file.name) + os.remove(temp_file.name) + return cfg + + def __init__(self, cfg_dict=None, cfg_text=None, filename=None): + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError('cfg_dict must be a dict, but ' + f'got {type(cfg_dict)}') + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f'{key} is reserved for config file') + + if isinstance(filename, Path): + filename = str(filename) + + super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) + super(Config, self).__setattr__('_filename', filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename, 'r') as f: + text = f.read() + else: + text = '' + super(Config, self).__setattr__('_text', text) + + @property + def filename(self): + return self._filename + + @property + def text(self): + return self._text + + @property + def pretty_text(self): + + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split('\n') + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = f"'{v}'" + else: + v_str = str(v) + + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + + return attr_str + + def _format_list(k, v, use_mapping=False): + # check if all items in the list are dict + if all(isinstance(_, dict) for _ in v): + v_str = '[\n' + v_str += '\n'.join( + f'dict({_indent(_format_dict(v_), indent)}),' + for v_ in v).rstrip(',') + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + ']' + else: + attr_str = _format_basic_types(k, v, use_mapping) + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= \ + (not str(key_name).isidentifier()) + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = '' + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += '{' + for idx, (k, v) in enumerate(input_dict.items()): + is_last = idx >= len(input_dict) - 1 + end = '' if outest_level or is_last else ',' + if isinstance(v, dict): + v_str = '\n' + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: dict({v_str}' + else: + attr_str = f'{str(k)}=dict({v_str}' + attr_str = _indent(attr_str, indent) + ')' + end + elif isinstance(v, list): + attr_str = _format_list(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += '\n'.join(s) + if use_mapping: + r += '}' + return r + + cfg_dict = self._cfg_dict.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + # copied from setup.cfg + yapf_style = dict( + based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) + text, _ = FormatCode(text, style_config=yapf_style, verify=True) + + return text + + def __repr__(self): + return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name): + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def __getstate__(self): + return (self._cfg_dict, self._filename, self._text) + + def __copy__(self): + cls = self.__class__ + other = cls.__new__(cls) + other.__dict__.update(self.__dict__) + + return other + + def __deepcopy__(self, memo): + cls = self.__class__ + other = cls.__new__(cls) + memo[id(self)] = other + + for key, value in self.__dict__.items(): + super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) + + return other + + def __setstate__(self, state): + _cfg_dict, _filename, _text = state + super(Config, self).__setattr__('_cfg_dict', _cfg_dict) + super(Config, self).__setattr__('_filename', _filename) + super(Config, self).__setattr__('_text', _text) + + def dump(self, file: str = None): + """Dumps config into a file or returns a string representation of the + config. + + If a file argument is given, saves the config to that file using the + format defined by the file argument extension. + + Otherwise, returns a string representing the config. The formatting of + this returned string is defined by the extension of `self.filename`. If + `self.filename` is not defined, returns a string representation of a + dict (lowercased and using ' for strings). + + Examples: + >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0), + ... item3=True, item4='test') + >>> cfg = Config(cfg_dict=cfg_dict) + >>> dump_file = "a.py" + >>> cfg.dump(dump_file) + + Args: + file (str, optional): Path of the output file where the config + will be dumped. Defaults to None. + """ + from modelscope.fileio import dump + cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() + if file is None: + if self.filename is None or self.filename.endswith('.py'): + return self.pretty_text + else: + file_format = self.filename.split('.')[-1] + return dump(cfg_dict, file_format=file_format) + elif file.endswith('.py'): + with open(file, 'w', encoding='utf-8') as f: + f.write(self.pretty_text) + else: + file_format = file.split('.')[-1] + return dump(cfg_dict, file=file, file_format=file_format) + + def merge_from_dict(self, options, allow_list_keys=True, force=True): + """Merge dict into cfg_dict. + + Merge the dict parsed by MultipleKVAction into this cfg. + + Examples: + >>> options = {'model.backbone.depth': 50, + ... 'model.backbone.with_cp':True} + >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) + >>> cfg.merge_from_dict(options) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... model=dict(backbone=dict(type='ResNet', depth=50, with_cp=True))) + + >>> # Merge list element for replace target index + >>> cfg = Config(dict(pipeline=[ + ... dict(type='Resize'), dict(type='RandomDistortion')])) + >>> options = dict(pipeline={'0': dict(type='MyResize')}) + >>> cfg.merge_from_dict(options, allow_list_keys=True) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict(pipeline=[ + ... dict(type='MyResize'), dict(type='RandomDistortion')]) + + >>> # Merge list element for replace args and add to list, only support list of type dict with key ``type``, + >>> # if you add new list element, the list does not guarantee the order, + >>> # it is only suitable for the case where the order of the list is not concerned. + >>> cfg = Config(dict(pipeline=[ + ... dict(type='Resize', size=224), dict(type='RandomDistortion')])) + >>> options = dict(pipeline=[dict(type='Resize', size=256), dict(type='RandomFlip')]) + >>> cfg.merge_from_dict(options, allow_list_keys=True) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict(pipeline=[ + ... dict(type='Resize', size=256), dict(type='RandomDistortion'), dict(type='RandomFlip')]) + + >>> # force usage + >>> options = {'model.backbone.depth': 18, + ... 'model.backbone.with_cp':True} + >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet', depth=50)))) + >>> cfg.merge_from_dict(options, force=False) + >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + >>> assert cfg_dict == dict( + ... model=dict(backbone=dict(type='ResNet', depth=50, with_cp=True))) + + Args: + options (dict): dict of configs to merge from. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in ``options`` and will replace the element of the + corresponding index in the config if the config is a list. + Or you can directly replace args for list or add new list element, + only support list of type dict with key ``type``, + but if you add new list element, the list does not guarantee the order, + It is only suitable for the case where the order of the list is not concerned. + Default: True. + force (bool): If True, existing key-value will be replaced by new given. + If False, existing key-value will not be updated. + """ + option_cfg_dict = {} + for full_key, v in options.items(): + d = option_cfg_dict + key_list = full_key.split('.') + for subkey in key_list[:-1]: + d.setdefault(subkey, ConfigDict()) + d = d[subkey] + subkey = key_list[-1] + d[subkey] = v + + cfg_dict = super(Config, self).__getattribute__('_cfg_dict') + super(Config, self).__setattr__( + '_cfg_dict', + Config._merge_a_into_b( + option_cfg_dict, + cfg_dict, + allow_list_keys=allow_list_keys, + force=force)) + + @staticmethod + def _merge_a_into_b(a, b, allow_list_keys=False, force=True): + """merge dict ``a`` into dict ``b`` (non-inplace). + + Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid + in-place modifications. + + Args: + a (dict): The source dict to be merged into ``b``. + b (dict): The origin dict to be fetch keys from ``a``. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in source ``a`` and will replace the element of the + corresponding index in b if b is a list. Default: False. + force (bool): If True, existing key-value will be replaced by new given. + If False, existing key-value will not be updated. + + Returns: + dict: The modified dict of ``b`` using ``a``. + + Examples: + # Normally merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # Delete b first and merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # b is a list + >>> Config._merge_a_into_b( + ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) + [{'a': 2}, {'b': 2}] + + # value of a and b are both list, only support list of type dict with key ``type``, + # You can directly replace args for list or add new list element, + # but if you add new list element, the list does not guarantee the order, + # it is only suitable for the case where the order of the list is not concerned. + >>> Config._merge_a_into_b( + ... {'k': [dict(a=2), dict(c=3)]}, {'k': [dict(a=1), dict(b=2)]}, True) + {'k': [dict(a=2), dict(b=2), dict(c=3)]} + + # force is False + >>> Config._merge_a_into_b( + ... dict(obj=dict(a=2, b=2)), dict(obj=dict(a=1))), True, force=False) + {'obj': {'a': 1, b=2}} + """ + b = b.copy() + for k, v in a.items(): + if allow_list_keys and k.isdigit() and isinstance(b, list): + k = int(k) + if len(b) <= k: + raise KeyError(f'Index {k} exceeds the length of list {b}') + b[k] = Config._merge_a_into_b( + v, b[k], allow_list_keys, force=force) + elif allow_list_keys and isinstance(v, list) and k in b: + if not isinstance(b[k], list): + raise ValueError( + f'type mismatch {type(v)} and {type(b[k])} between a and b for key {k}' + ) + _is_dict_with_type = True + for list_i in b[k] + v: + if not isinstance(list_i, dict) or 'type' not in list_i: + if k not in b or force: + b[k] = v + _is_dict_with_type = False + if _is_dict_with_type: + res_list = [] + added_index_bk, added_index_v = [], [] + for i, b_li in enumerate(b[k]): + for j, a_lj in enumerate(v): + if a_lj['type'] == b_li['type']: + res_list.append( + Config._merge_a_into_b( + a_lj, + b_li, + allow_list_keys, + force=force)) + added_index_v.append(j) + added_index_bk.append(i) + break + rest_bk = [ + b[k][i] for i in range(len(b[k])) + if i not in added_index_bk + ] + rest_v = [ + v[i] for i in range(len(v)) if i not in added_index_v + ] + rest = rest_bk + rest_v + res_list += [ + Config._merge_a_into_b( + rest[i], {}, allow_list_keys, force=force) + for i in range(len(rest)) + ] + b[k] = res_list + elif isinstance(v, + dict) and k in b and not v.pop(DELETE_KEY, False): + allowed_types = (dict, list) if allow_list_keys else dict + if not isinstance(b[k], allowed_types): + raise TypeError( + f'{k}={v} in child config cannot inherit from base ' + f'because {k} is a dict in the child config but is of ' + f'type {type(b[k])} in base config. You may set ' + f'`{DELETE_KEY}=True` to ignore the base config') + b[k] = Config._merge_a_into_b( + v, b[k], allow_list_keys, force=force) + else: + if k not in b or force: + b[k] = v + return b + + def to_dict(self) -> Dict: + """ Convert Config object to python dict + """ + return self._cfg_dict.to_dict() + + def to_args(self, parse_fn, use_hyphen=True): + """ Convert config obj to args using parse_fn + + Args: + parse_fn: a function object, which takes args as input, + such as ['--foo', 'FOO'] and return parsed args, an + example is given as follows + including literal blocks:: + def parse_fn(args): + parser = argparse.ArgumentParser(prog='PROG') + parser.add_argument('-x') + parser.add_argument('--foo') + return parser.parse_args(args) + use_hyphen (bool, optional): if set true, hyphen in keyname + will be converted to underscore + Return: + args: arg object parsed by argparse.ArgumentParser + """ + args = [] + for k, v in self._cfg_dict.items(): + arg_name = f'--{k}' + if use_hyphen: + arg_name = arg_name.replace('_', '-') + if isinstance(v, bool) and v: + args.append(arg_name) + elif isinstance(v, (int, str, float)): + args.append(arg_name) + args.append(str(v)) + elif isinstance(v, list): + args.append(arg_name) + assert isinstance(v, (int, str, float, bool)), 'Element type in list ' \ + f'is expected to be either int,str,float, but got type {v[0]}' + args.append(str(v)) + else: + raise ValueError( + 'type in config file which supported to be ' + 'converted to args should be either bool, ' + f'int, str, float or list of them but got type {v}') + + return parse_fn(args) + + +def check_config(cfg: Union[str, ConfigDict], is_training=False): + """ Check whether configuration file is valid, If anything wrong, exception will be raised. + + Args: + cfg (str or ConfigDict): Config file path or config object. + is_training: indicate if checking training related elements + """ + + if isinstance(cfg, str): + cfg = Config.from_file(cfg) + + def check_attr(attr_name, msg=''): + assert hasattr(cfg, attr_name), f'Attribute {attr_name} is missing from ' \ + f'{ModelFile.CONFIGURATION}. {msg}' + + check_attr(ConfigFields.framework) + check_attr(ConfigFields.task) + check_attr(ConfigFields.pipeline) + + if is_training: + check_attr(ConfigFields.model) + check_attr(ConfigFields.train) + check_attr(ConfigFields.preprocessor) + check_attr(ConfigFields.evaluation) + + +def use_task_specific_params(model, task): + """Update config with summarization specific params.""" + task_specific_params = model.config.task_specific_params + + if task_specific_params is not None: + pars = task_specific_params.get(task, {}) + logger.info(f'using task specific params for {task}: {pars}') + model.config.update(pars) + + +class JSONIteratorEncoder(json.JSONEncoder): + """Implement this method in order that supporting arbitrary iterators, it returns + a serializable object for ``obj``, or calls the base implementation + (to raise a ``TypeError``). + + """ + + def default(self, obj): + if isinstance(obj, FunctionType): + return None + try: + iterable = iter(obj) + except TypeError: + pass + else: + return list(iterable) + return json.JSONEncoder.default(self, obj) diff --git a/modelscope/utils/config_ds.py b/modelscope/utils/config_ds.py new file mode 100644 index 00000000..fce823c4 --- /dev/null +++ b/modelscope/utils/config_ds.py @@ -0,0 +1,26 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from pathlib import Path + +# Cache location +from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT + +DEFAULT_CACHE_HOME = Path.home().joinpath('.cache') +CACHE_HOME = os.getenv('CACHE_HOME', DEFAULT_CACHE_HOME) +DEFAULT_MS_CACHE_HOME = os.path.join(CACHE_HOME, 'modelscope', 'hub') +MS_CACHE_HOME = os.path.expanduser( + os.getenv('MS_CACHE_HOME', DEFAULT_MS_CACHE_HOME)) + +DEFAULT_MS_DATASETS_CACHE = os.path.join(MS_CACHE_HOME, 'datasets') +MS_DATASETS_CACHE = Path( + os.getenv('MS_DATASETS_CACHE', DEFAULT_MS_DATASETS_CACHE)) + +DOWNLOADED_DATASETS_DIR = 'downloads' +DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(MS_DATASETS_CACHE, + DOWNLOADED_DATASETS_DIR) +DOWNLOADED_DATASETS_PATH = Path( + os.getenv('DOWNLOADED_DATASETS_PATH', DEFAULT_DOWNLOADED_DATASETS_PATH)) + +HUB_DATASET_ENDPOINT = os.environ.get('HUB_DATASET_ENDPOINT', + DEFAULT_MODELSCOPE_DATA_ENDPOINT) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py new file mode 100644 index 00000000..f0a97dbd --- /dev/null +++ b/modelscope/utils/constant.py @@ -0,0 +1,380 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import enum + + +class Fields(object): + """ Names for different application fields + """ + cv = 'cv' + nlp = 'nlp' + audio = 'audio' + multi_modal = 'multi-modal' + science = 'science' + + +class CVTasks(object): + # ocr + ocr_detection = 'ocr-detection' + ocr_recognition = 'ocr-recognition' + + # human face body related + animal_recognition = 'animal-recognition' + face_detection = 'face-detection' + card_detection = 'card-detection' + face_recognition = 'face-recognition' + facial_expression_recognition = 'facial-expression-recognition' + face_2d_keypoints = 'face-2d-keypoints' + human_detection = 'human-detection' + human_object_interaction = 'human-object-interaction' + face_image_generation = 'face-image-generation' + body_2d_keypoints = 'body-2d-keypoints' + body_3d_keypoints = 'body-3d-keypoints' + hand_2d_keypoints = 'hand-2d-keypoints' + general_recognition = 'general-recognition' + human_wholebody_keypoint = 'human-wholebody-keypoint' + + image_classification = 'image-classification' + image_multilabel_classification = 'image-multilabel-classification' + image_classification_imagenet = 'image-classification-imagenet' + image_classification_dailylife = 'image-classification-dailylife' + + image_object_detection = 'image-object-detection' + video_object_detection = 'video-object-detection' + + image_segmentation = 'image-segmentation' + semantic_segmentation = 'semantic-segmentation' + portrait_matting = 'portrait-matting' + text_driven_segmentation = 'text-driven-segmentation' + shop_segmentation = 'shop-segmentation' + hand_static = 'hand-static' + face_human_hand_detection = 'face-human-hand-detection' + face_emotion = 'face-emotion' + product_segmentation = 'product-segmentation' + + crowd_counting = 'crowd-counting' + + # image editing + skin_retouching = 'skin-retouching' + image_super_resolution = 'image-super-resolution' + image_colorization = 'image-colorization' + image_color_enhancement = 'image-color-enhancement' + image_denoising = 'image-denoising' + image_portrait_enhancement = 'image-portrait-enhancement' + image_inpainting = 'image-inpainting' + + # image generation + image_to_image_translation = 'image-to-image-translation' + image_to_image_generation = 'image-to-image-generation' + image_style_transfer = 'image-style-transfer' + image_portrait_stylization = 'image-portrait-stylization' + image_body_reshaping = 'image-body-reshaping' + image_embedding = 'image-embedding' + + product_retrieval_embedding = 'product-retrieval-embedding' + + # video recognition + live_category = 'live-category' + action_recognition = 'action-recognition' + action_detection = 'action-detection' + video_category = 'video-category' + video_embedding = 'video-embedding' + virtual_try_on = 'virtual-try-on' + movie_scene_segmentation = 'movie-scene-segmentation' + + # video segmentation + referring_video_object_segmentation = 'referring-video-object-segmentation' + + # video editing + video_inpainting = 'video-inpainting' + + # reid and tracking + video_single_object_tracking = 'video-single-object-tracking' + video_summarization = 'video-summarization' + image_reid_person = 'image-reid-person' + + +class NLPTasks(object): + # nlp tasks + word_segmentation = 'word-segmentation' + part_of_speech = 'part-of-speech' + named_entity_recognition = 'named-entity-recognition' + nli = 'nli' + sentiment_classification = 'sentiment-classification' + sentiment_analysis = 'sentiment-analysis' + sentence_similarity = 'sentence-similarity' + text_classification = 'text-classification' + sentence_embedding = 'sentence-embedding' + text_ranking = 'text-ranking' + relation_extraction = 'relation-extraction' + zero_shot = 'zero-shot' + translation = 'translation' + token_classification = 'token-classification' + conversational = 'conversational' + text_generation = 'text-generation' + text2text_generation = 'text2text-generation' + task_oriented_conversation = 'task-oriented-conversation' + dialog_intent_prediction = 'dialog-intent-prediction' + dialog_state_tracking = 'dialog-state-tracking' + table_question_answering = 'table-question-answering' + fill_mask = 'fill-mask' + text_summarization = 'text-summarization' + question_answering = 'question-answering' + zero_shot_classification = 'zero-shot-classification' + backbone = 'backbone' + text_error_correction = 'text-error-correction' + faq_question_answering = 'faq-question-answering' + information_extraction = 'information-extraction' + document_segmentation = 'document-segmentation' + feature_extraction = 'feature-extraction' + + +class AudioTasks(object): + # audio tasks + auto_speech_recognition = 'auto-speech-recognition' + text_to_speech = 'text-to-speech' + speech_signal_process = 'speech-signal-process' + acoustic_echo_cancellation = 'acoustic-echo-cancellation' + acoustic_noise_suppression = 'acoustic-noise-suppression' + keyword_spotting = 'keyword-spotting' + + +class MultiModalTasks(object): + # multi-modal tasks + image_captioning = 'image-captioning' + visual_grounding = 'visual-grounding' + text_to_image_synthesis = 'text-to-image-synthesis' + multi_modal_embedding = 'multi-modal-embedding' + generative_multi_modal_embedding = 'generative-multi-modal-embedding' + multi_modal_similarity = 'multi-modal-similarity' + visual_question_answering = 'visual-question-answering' + visual_entailment = 'visual-entailment' + video_multi_modal_embedding = 'video-multi-modal-embedding' + image_text_retrieval = 'image-text-retrieval' + + +class ScienceTasks(object): + protein_structure = 'protein-structure' + + +class TasksIODescriptions(object): + image_to_image = 'image_to_image', + images_to_image = 'images_to_image', + image_to_text = 'image_to_text', + seed_to_image = 'seed_to_image', + text_to_speech = 'text_to_speech', + text_to_text = 'text_to_text', + speech_to_text = 'speech_to_text', + speech_to_speech = 'speech_to_speech' + speeches_to_speech = 'speeches_to_speech', + visual_grounding = 'visual_grounding', + visual_question_answering = 'visual_question_answering', + visual_entailment = 'visual_entailment', + generative_multi_modal_embedding = 'generative_multi_modal_embedding' + + +class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks, ScienceTasks): + """ Names for tasks supported by modelscope. + + Holds the standard task name to use for identifying different tasks. + This should be used to register models, pipelines, trainers. + """ + reverse_field_index = {} + + @staticmethod + def find_field_by_task(task_name): + if len(Tasks.reverse_field_index) == 0: + # Lazy init, not thread safe + field_dict = { + Fields.cv: [ + getattr(Tasks, attr) for attr in dir(CVTasks) + if not attr.startswith('__') + ], + Fields.nlp: [ + getattr(Tasks, attr) for attr in dir(NLPTasks) + if not attr.startswith('__') + ], + Fields.audio: [ + getattr(Tasks, attr) for attr in dir(AudioTasks) + if not attr.startswith('__') + ], + Fields.multi_modal: [ + getattr(Tasks, attr) for attr in dir(MultiModalTasks) + if not attr.startswith('__') + ], + Fields.science: [ + getattr(Tasks, attr) for attr in dir(ScienceTasks) + if not attr.startswith('__') + ], + } + + for field, tasks in field_dict.items(): + for task in tasks: + if task in Tasks.reverse_field_index: + raise ValueError(f'Duplicate task: {task}') + Tasks.reverse_field_index[task] = field + + return Tasks.reverse_field_index.get(task_name) + + +class InputFields(object): + """ Names for input data fields in the input data for pipelines + """ + img = 'img' + text = 'text' + audio = 'audio' + + +class Hubs(enum.Enum): + """ Source from which an entity (such as a Dataset or Model) is stored + """ + modelscope = 'modelscope' + huggingface = 'huggingface' + + +class DownloadMode(enum.Enum): + """ How to treat existing datasets + """ + REUSE_DATASET_IF_EXISTS = 'reuse_dataset_if_exists' + FORCE_REDOWNLOAD = 'force_redownload' + + +class DownloadChannel(enum.Enum): + """ Channels of datasets downloading for uv/pv counting. + """ + LOCAL = 'local' + DSW = 'dsw' + EAIS = 'eais' + + +class UploadMode(enum.Enum): + """ How to upload object to remote. + """ + # Upload all objects from local, existing remote objects may be overwritten. (Default) + OVERWRITE = 'overwrite' + # Upload local objects in append mode, skipping all existing remote objects. + APPEND = 'append' + + +class DatasetFormations(enum.Enum): + """ How a dataset is organized and interpreted + """ + # formation that is compatible with official huggingface dataset, which + # organizes whole dataset into one single (zip) file. + hf_compatible = 1 + # native modelscope formation that supports, among other things, + # multiple files in a dataset + native = 2 + + +DatasetMetaFormats = { + DatasetFormations.native: ['.json'], + DatasetFormations.hf_compatible: ['.py'], +} + + +class ModelFile(object): + CONFIGURATION = 'configuration.json' + README = 'README.md' + TF_SAVED_MODEL_FILE = 'saved_model.pb' + TF_GRAPH_FILE = 'tf_graph.pb' + TF_CHECKPOINT_FOLDER = 'tf_ckpts' + TF_CKPT_PREFIX = 'ckpt-' + TORCH_MODEL_FILE = 'pytorch_model.pt' + TORCH_MODEL_BIN_FILE = 'pytorch_model.bin' + VOCAB_FILE = 'vocab.txt' + ONNX_MODEL_FILE = 'model.onnx' + LABEL_MAPPING = 'label_mapping.json' + TRAIN_OUTPUT_DIR = 'output' + TS_MODEL_FILE = 'model.ts' + + +class ConfigFields(object): + """ First level keyword in configuration file + """ + framework = 'framework' + task = 'task' + pipeline = 'pipeline' + model = 'model' + dataset = 'dataset' + preprocessor = 'preprocessor' + train = 'train' + evaluation = 'evaluation' + postprocessor = 'postprocessor' + + +class ConfigKeys(object): + """Fixed keywords in configuration file""" + train = 'train' + val = 'val' + test = 'test' + + +class Requirements(object): + """Requirement names for each module + """ + protobuf = 'protobuf' + sentencepiece = 'sentencepiece' + sklearn = 'sklearn' + scipy = 'scipy' + timm = 'timm' + tokenizers = 'tokenizers' + tf = 'tf' + torch = 'torch' + + +class Frameworks(object): + tf = 'tensorflow' + torch = 'pytorch' + kaldi = 'kaldi' + + +DEFAULT_MODEL_REVISION = None +MASTER_MODEL_BRANCH = 'master' +DEFAULT_REPOSITORY_REVISION = 'master' +DEFAULT_DATASET_REVISION = 'master' +DEFAULT_DATASET_NAMESPACE = 'modelscope' + + +class ModeKeys: + TRAIN = 'train' + EVAL = 'eval' + INFERENCE = 'inference' + + +class LogKeys: + ITER = 'iter' + ITER_TIME = 'iter_time' + EPOCH = 'epoch' + LR = 'lr' # learning rate + MODE = 'mode' + DATA_LOAD_TIME = 'data_load_time' + ETA = 'eta' # estimated time of arrival + MEMORY = 'memory' + LOSS = 'loss' + + +class TrainerStages: + before_run = 'before_run' + before_train_epoch = 'before_train_epoch' + before_train_iter = 'before_train_iter' + after_train_iter = 'after_train_iter' + after_train_epoch = 'after_train_epoch' + before_val_epoch = 'before_val_epoch' + before_val_iter = 'before_val_iter' + after_val_iter = 'after_val_iter' + after_val_epoch = 'after_val_epoch' + after_run = 'after_run' + + +class ColorCodes: + MAGENTA = '\033[95m' + YELLOW = '\033[93m' + GREEN = '\033[92m' + RED = '\033[91m' + END = '\033[0m' + + +class Devices: + """device used for training and inference""" + cpu = 'cpu' + gpu = 'gpu' diff --git a/modelscope/utils/cv/__init__.py b/modelscope/utils/cv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py new file mode 100644 index 00000000..095c36ec --- /dev/null +++ b/modelscope/utils/cv/image_utils.py @@ -0,0 +1,441 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.image import load_image + + +def numpy_to_cv2img(img_array): + """to convert a np.array with shape(h, w) to cv2 img + + Args: + img_array (np.array): input data + + Returns: + cv2 img + """ + img_array = (img_array - img_array.min()) / ( + img_array.max() - img_array.min() + 1e-5) + img_array = (img_array * 255).astype(np.uint8) + img_array = cv2.applyColorMap(img_array, cv2.COLORMAP_JET) + return img_array + + +def draw_joints(image, np_kps, score, threshold=0.2): + lst_parent_ids_17 = [0, 0, 0, 1, 2, 0, 0, 5, 6, 7, 8, 5, 6, 11, 12, 13, 14] + lst_left_ids_17 = [1, 3, 5, 7, 9, 11, 13, 15] + lst_right_ids_17 = [2, 4, 6, 8, 10, 12, 14, 16] + + lst_parent_ids_15 = [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1] + lst_left_ids_15 = [2, 3, 4, 8, 9, 10] + lst_right_ids_15 = [5, 6, 7, 11, 12, 13] + + if np_kps.shape[0] == 17: + lst_parent_ids = lst_parent_ids_17 + lst_left_ids = lst_left_ids_17 + lst_right_ids = lst_right_ids_17 + + elif np_kps.shape[0] == 15: + lst_parent_ids = lst_parent_ids_15 + lst_left_ids = lst_left_ids_15 + lst_right_ids = lst_right_ids_15 + + for i in range(len(lst_parent_ids)): + pid = lst_parent_ids[i] + if i == pid: + continue + + if (score[i] < threshold or score[1] < threshold): + continue + + if i in lst_left_ids and pid in lst_left_ids: + color = (0, 255, 0) + elif i in lst_right_ids and pid in lst_right_ids: + color = (255, 0, 0) + else: + color = (0, 255, 255) + + cv2.line(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), + (int(np_kps[pid][0]), int(np_kps[pid, 1])), color, 3) + + for i in range(np_kps.shape[0]): + if score[i] < threshold: + continue + cv2.circle(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), 5, + (0, 0, 255), -1) + + +def draw_box(image, box): + cv2.rectangle(image, (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), (0, 0, 255), 2) + + +def realtime_object_detection_bbox_vis(image, bboxes): + for bbox in bboxes: + cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), + (255, 0, 0), 2) + return image + + +def draw_keypoints(output, original_image): + poses = np.array(output[OutputKeys.KEYPOINTS]) + scores = np.array(output[OutputKeys.SCORES]) + boxes = np.array(output[OutputKeys.BOXES]) + assert len(poses) == len(scores) and len(poses) == len(boxes) + image = cv2.imread(original_image, -1) + for i in range(len(poses)): + draw_box(image, np.array(boxes[i])) + draw_joints(image, np.array(poses[i]), np.array(scores[i])) + return image + + +def draw_106face_keypoints(in_path, + keypoints, + boxes, + scale=4.0, + save_path=None): + face_contour_point_index = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 + ] + left_eye_brow_point_index = [33, 34, 35, 36, 37, 38, 39, 40, 41, 33] + right_eye_brow_point_index = [42, 43, 44, 45, 46, 47, 48, 49, 50, 42] + left_eye_point_index = [66, 67, 68, 69, 70, 71, 72, 73, 66] + right_eye_point_index = [75, 76, 77, 78, 79, 80, 81, 82, 75] + nose_bridge_point_index = [51, 52, 53, 54] + nose_contour_point_index = [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65] + mouth_outer_point_index = [ + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84 + ] + mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96] + + img = cv2.imread(in_path) + + for i in range(len(boxes)): + draw_box(img, np.array(boxes[i])) + + image = cv2.resize(img, dsize=None, fx=scale, fy=scale) + + def draw_line(point_index, image, point): + for i in range(len(point_index) - 1): + cur_index = point_index[i] + next_index = point_index[i + 1] + cur_pt = (int(point[cur_index][0] * scale), + int(point[cur_index][1] * scale)) + next_pt = (int(point[next_index][0] * scale), + int(point[next_index][1] * scale)) + cv2.line(image, cur_pt, next_pt, (0, 0, 255), thickness=2) + + for i in range(len(keypoints)): + points = keypoints[i] + + draw_line(face_contour_point_index, image, points) + draw_line(left_eye_brow_point_index, image, points) + draw_line(right_eye_brow_point_index, image, points) + draw_line(left_eye_point_index, image, points) + draw_line(right_eye_point_index, image, points) + draw_line(nose_bridge_point_index, image, points) + draw_line(nose_contour_point_index, image, points) + draw_line(mouth_outer_point_index, image, points) + draw_line(mouth_inter_point_index, image, points) + + size = len(points) + for i in range(size): + x = int(points[i][0]) + y = int(points[i][1]) + cv2.putText(image, str(i), (int(x * scale), int(y * scale)), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + cv2.circle(image, (int(x * scale), int(y * scale)), 2, (0, 255, 0), + cv2.FILLED) + + if save_path is not None: + cv2.imwrite(save_path, image) + + return image + + +def draw_face_detection_no_lm_result(img_path, detection_result): + bboxes = np.array(detection_result[OutputKeys.BOXES]) + scores = np.array(detection_result[OutputKeys.SCORES]) + img = cv2.imread(img_path) + assert img is not None, f"Can't read img: {img_path}" + for i in range(len(scores)): + bbox = bboxes[i].astype(np.int32) + x1, y1, x2, y2 = bbox + score = scores[i] + cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) + cv2.putText( + img, + f'{score:.2f}', (x1, y2), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + print(f'Found {len(scores)} faces') + return img + + +def draw_facial_expression_result(img_path, facial_expression_result): + scores = facial_expression_result[OutputKeys.SCORES] + labels = facial_expression_result[OutputKeys.LABELS] + label = labels[np.argmax(scores)] + img = cv2.imread(img_path) + assert img is not None, f"Can't read img: {img_path}" + cv2.putText( + img, + 'facial expression: {}'.format(label), (10, 10), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + print('facial expression: {}'.format(label)) + return img + + +def draw_face_detection_result(img_path, detection_result): + bboxes = np.array(detection_result[OutputKeys.BOXES]) + kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) + scores = np.array(detection_result[OutputKeys.SCORES]) + img = cv2.imread(img_path) + assert img is not None, f"Can't read img: {img_path}" + for i in range(len(scores)): + bbox = bboxes[i].astype(np.int32) + kps = kpss[i].reshape(-1, 2).astype(np.int32) + score = scores[i] + x1, y1, x2, y2 = bbox + cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) + for kp in kps: + cv2.circle(img, tuple(kp), 1, (0, 0, 255), 1) + cv2.putText( + img, + f'{score:.2f}', (x1, y2), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + print(f'Found {len(scores)} faces') + return img + + +def draw_card_detection_result(img_path, detection_result): + + def warp_img(src_img, kps, ratio): + short_size = 500 + if ratio > 1: + obj_h = short_size + obj_w = int(obj_h * ratio) + else: + obj_w = short_size + obj_h = int(obj_w / ratio) + input_pts = np.float32([kps[0], kps[1], kps[2], kps[3]]) + output_pts = np.float32([[0, obj_h - 1], [0, 0], [obj_w - 1, 0], + [obj_w - 1, obj_h - 1]]) + M = cv2.getPerspectiveTransform(input_pts, output_pts) + obj_img = cv2.warpPerspective(src_img, M, (obj_w, obj_h)) + return obj_img + + bboxes = np.array(detection_result[OutputKeys.BOXES]) + kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) + scores = np.array(detection_result[OutputKeys.SCORES]) + img_list = [] + ver_col = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255)] + img = cv2.imread(img_path) + img_list += [img] + assert img is not None, f"Can't read img: {img_path}" + for i in range(len(scores)): + bbox = bboxes[i].astype(np.int32) + kps = kpss[i].reshape(-1, 2).astype(np.int32) + _w = (kps[0][0] - kps[3][0])**2 + (kps[0][1] - kps[3][1])**2 + _h = (kps[0][0] - kps[1][0])**2 + (kps[0][1] - kps[1][1])**2 + ratio = 1.59 if _w >= _h else 1 / 1.59 + card_img = warp_img(img, kps, ratio) + img_list += [card_img] + score = scores[i] + x1, y1, x2, y2 = bbox + cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 4) + for k, kp in enumerate(kps): + cv2.circle(img, tuple(kp), 1, color=ver_col[k], thickness=10) + cv2.putText( + img, + f'{score:.2f}', (x1, y2), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + return img_list + + +def created_boxed_image(image_in, box): + image = load_image(image_in) + img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), + (0, 255, 0), 3) + return img + + +def show_video_tracking_result(video_in_path, bboxes, video_save_path): + cap = cv2.VideoCapture(video_in_path) + for i in range(len(bboxes)): + box = bboxes[i] + success, frame = cap.read() + if success is False: + raise Exception(video_in_path, + ' can not be correctly decoded by OpenCV.') + if i == 0: + size = (frame.shape[1], frame.shape[0]) + fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') + video_writer = cv2.VideoWriter(video_save_path, fourcc, + cap.get(cv2.CAP_PROP_FPS), size, + True) + cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), + 5) + video_writer.write(frame) + video_writer.release + cap.release() + + +def show_video_object_detection_result(video_in_path, bboxes_list, labels_list, + video_save_path): + + PALETTE = { + 'person': [128, 0, 0], + 'bicycle': [128, 128, 0], + 'car': [64, 0, 0], + 'motorcycle': [0, 128, 128], + 'bus': [64, 128, 0], + 'truck': [192, 128, 0], + 'traffic light': [64, 0, 128], + 'stop sign': [192, 0, 128], + } + from tqdm import tqdm + import math + cap = cv2.VideoCapture(video_in_path) + with tqdm(total=len(bboxes_list)) as pbar: + pbar.set_description( + 'Writing results to video: {}'.format(video_save_path)) + for i in range(len(bboxes_list)): + bboxes = bboxes_list[i].astype(int) + labels = labels_list[i] + success, frame = cap.read() + if success is False: + raise Exception(video_in_path, + ' can not be correctly decoded by OpenCV.') + if i == 0: + size = (frame.shape[1], frame.shape[0]) + fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') + video_writer = cv2.VideoWriter(video_save_path, fourcc, + cap.get(cv2.CAP_PROP_FPS), size, + True) + + FONT_SCALE = 1e-3 # Adjust for larger font size in all images + THICKNESS_SCALE = 1e-3 # Adjust for larger thickness in all images + TEXT_Y_OFFSET_SCALE = 1e-2 # Adjust for larger Y-offset of text and bounding box + H, W, _ = frame.shape + zeros_mask = np.zeros((frame.shape)).astype(np.uint8) + for bbox, l in zip(bboxes, labels): + cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), + PALETTE[l], 1) + cv2.putText( + frame, + l, (bbox[0], bbox[1] - int(TEXT_Y_OFFSET_SCALE * H)), + fontFace=cv2.FONT_HERSHEY_TRIPLEX, + fontScale=min(H, W) * FONT_SCALE, + thickness=math.ceil(min(H, W) * THICKNESS_SCALE), + color=PALETTE[l]) + zeros_mask = cv2.rectangle( + zeros_mask, (bbox[0], bbox[1]), (bbox[2], bbox[3]), + color=PALETTE[l], + thickness=-1) + + frame = cv2.addWeighted(frame, 1., zeros_mask, .65, 0) + video_writer.write(frame) + pbar.update(1) + video_writer.release + cap.release() + + +def panoptic_seg_masks_to_image(masks): + draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) + from mmdet.core.visualization.palette import get_palette + mask_palette = get_palette('coco', 133) + + from mmdet.core.visualization.image import _get_bias_color + taken_colors = set([0, 0, 0]) + for i, mask in enumerate(masks): + color_mask = mask_palette[i] + while tuple(color_mask) in taken_colors: + color_mask = _get_bias_color(color_mask) + taken_colors.add(tuple(color_mask)) + + mask = mask.astype(bool) + draw_img[mask] = color_mask + + return draw_img + + +def semantic_seg_masks_to_image(masks): + from mmdet.core.visualization.palette import get_palette + mask_palette = get_palette('coco', 133) + + draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) + + for i, mask in enumerate(masks): + color_mask = mask_palette[i] + mask = mask.astype(bool) + draw_img[mask] = color_mask + return draw_img + + +def show_video_summarization_result(video_in_path, result, video_save_path): + frame_indexes = result[OutputKeys.OUTPUT] + cap = cv2.VideoCapture(video_in_path) + for i in range(len(frame_indexes)): + idx = frame_indexes[i] + success, frame = cap.read() + if success is False: + raise Exception(video_in_path, + ' can not be correctly decoded by OpenCV.') + if i == 0: + size = (frame.shape[1], frame.shape[0]) + fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') + video_writer = cv2.VideoWriter(video_save_path, fourcc, + cap.get(cv2.CAP_PROP_FPS), size, + True) + if idx == 1: + video_writer.write(frame) + video_writer.release() + cap.release() + + +def show_image_object_detection_auto_result(img_path, + detection_result, + save_path=None): + scores = detection_result[OutputKeys.SCORES] + labels = detection_result[OutputKeys.LABELS] + bboxes = detection_result[OutputKeys.BOXES] + img = cv2.imread(img_path) + assert img is not None, f"Can't read img: {img_path}" + + for (score, label, box) in zip(scores, labels, bboxes): + cv2.rectangle(img, (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), (0, 0, 255), 2) + cv2.putText( + img, + f'{score:.2f}', (int(box[0]), int(box[1])), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + cv2.putText( + img, + label, (int((box[0] + box[2]) * 0.5), int(box[1])), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + + if save_path is not None: + cv2.imwrite(save_path, img) + return img diff --git a/modelscope/utils/data_utils.py b/modelscope/utils/data_utils.py new file mode 100644 index 00000000..2bc88e19 --- /dev/null +++ b/modelscope/utils/data_utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections.abc import Mapping + +import torch + + +def to_device(batch, device, non_blocking=False): + """Put the data to the target cuda device just before the forward function. + Args: + batch: The batch data out of the dataloader. + device: (str | torch.device): The target device for the data. + + Returns: The data to the target device. + + """ + if isinstance(batch, dict) or isinstance(batch, Mapping): + return type(batch)({k: to_device(v, device) for k, v in batch.items()}) + elif isinstance(batch, (tuple, list)): + return type(batch)(to_device(v, device) for v in batch) + elif isinstance(batch, torch.Tensor): + return batch.to(device, non_blocking=non_blocking) + else: + return batch diff --git a/modelscope/utils/demo_utils.py b/modelscope/utils/demo_utils.py new file mode 100644 index 00000000..e57b3348 --- /dev/null +++ b/modelscope/utils/demo_utils.py @@ -0,0 +1,274 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io + +import cv2 +import json + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks, TasksIODescriptions +from modelscope.utils.service_utils import NumpyEncoder + +TASKS_INPUT_TEMPLATES = { + # vision tasks + Tasks.image_portrait_stylization: TasksIODescriptions.image_to_image, + Tasks.portrait_matting: TasksIODescriptions.image_to_image, + Tasks.skin_retouching: TasksIODescriptions.image_to_image, + Tasks.image_captioning: TasksIODescriptions.image_to_text, + Tasks.image_denoising: TasksIODescriptions.image_to_image, + Tasks.image_portrait_enhancement: TasksIODescriptions.image_to_image, + Tasks.image_super_resolution: TasksIODescriptions.image_to_image, + Tasks.image_colorization: TasksIODescriptions.image_to_image, + Tasks.image_color_enhancement: TasksIODescriptions.image_to_image, + Tasks.face_image_generation: TasksIODescriptions.seed_to_image, + Tasks.image_style_transfer: TasksIODescriptions.images_to_image, + Tasks.image_segmentation: TasksIODescriptions.image_to_text, + Tasks.image_object_detection: TasksIODescriptions.image_to_text, + + # not tested + Tasks.image_classification: TasksIODescriptions.image_to_text, + Tasks.ocr_detection: TasksIODescriptions.image_to_text, + Tasks.ocr_recognition: TasksIODescriptions.image_to_text, + Tasks.body_2d_keypoints: TasksIODescriptions.image_to_text, + + # nlp tasks + Tasks.text_classification: TasksIODescriptions.text_to_text, + Tasks.text_generation: TasksIODescriptions.text_to_text, + Tasks.word_segmentation: TasksIODescriptions.text_to_text, + Tasks.text_error_correction: TasksIODescriptions.text_to_text, + Tasks.named_entity_recognition: TasksIODescriptions.text_to_text, + Tasks.sentiment_classification: TasksIODescriptions.text_to_text, + + # audio tasks + Tasks.text_to_speech: TasksIODescriptions.text_to_speech, + Tasks.auto_speech_recognition: TasksIODescriptions.speech_to_text, + Tasks.keyword_spotting: TasksIODescriptions.speech_to_text, + Tasks.acoustic_noise_suppression: TasksIODescriptions.speech_to_speech, + Tasks.acoustic_echo_cancellation: TasksIODescriptions.speeches_to_speech, + + # multi-modal + Tasks.visual_grounding: TasksIODescriptions.visual_grounding, + Tasks.visual_question_answering: + TasksIODescriptions.visual_question_answering, + Tasks.visual_entailment: TasksIODescriptions.visual_entailment, + Tasks.generative_multi_modal_embedding: + TasksIODescriptions.generative_multi_modal_embedding, + + # new tasks + Tasks.virtual_try_on: TasksIODescriptions.images_to_image, + + # TODO(lingcai.wl): support more tasks and implement corresponding example +} + +INPUT_EXAMPLES = { + # Must align with task schema defined in the Widget section of model card= + # cv + TasksIODescriptions.image_to_image: { + 'inputs': [ + 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png' + ], + 'urlPaths': { + 'outUrls': [{ + 'outputKey': OutputKeys.OUTPUT_IMG, + 'fileType': 'png' + }] + } + }, + TasksIODescriptions.images_to_image: { + 'inputs': [ + 'https://modelscope.oss-cn-beijing.aliyuncs.com/demo/image-style-transfer/style_transfer_content.jpg', + 'https://modelscope.oss-cn-beijing.aliyuncs.com/demo/image-style-transfer/style_transfer_style.jpg' + ], + 'urlPaths': { + 'outUrls': [{ + 'outputKey': OutputKeys.OUTPUT_IMG, + 'fileType': 'png' + }] + } + }, + TasksIODescriptions.image_to_text: { + 'inputs': [ + 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png' + ], + 'urlPaths': {} + }, + # nlp + TasksIODescriptions.text_to_text: { + 'inputs': ['test'], + 'urlPaths': {} + }, + + # audio + TasksIODescriptions.speech_to_text: { + 'inputs': [ + 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' + ], + 'urlPaths': {} + }, + TasksIODescriptions.text_to_speech: { + 'inputs': ['北京今天天气怎么样'], + 'urlPaths': { + 'outUrls': [{ + 'outputKey': OutputKeys.OUTPUT_PCM, + 'fileType': 'pcm' + }] + } + }, + TasksIODescriptions.speeches_to_speech: { + 'inputs': [ + 'http://225252-file.oss-cn-hangzhou-zmf.aliyuncs.com/maas_demo/nearend_mic.wav', + 'http://225252-file.oss-cn-hangzhou-zmf.aliyuncs.com/maas_demo/nearend_speech.wav' + ], + 'urlPaths': { + 'outUrls': [{ + 'outputKey': OutputKeys.OUTPUT_PCM, + 'fileType': 'pcm' + }] + } + }, + TasksIODescriptions.speech_to_speech: { + 'inputs': [ + 'http://225252-file.oss-cn-hangzhou-zmf.aliyuncs.com/maas_demo/speech_with_noise.wav' + ], + 'urlPaths': { + 'outUrls': [{ + 'outputKey': OutputKeys.OUTPUT_PCM, + 'fileType': 'pcm' + }] + } + }, + + # multi modal + TasksIODescriptions.visual_grounding: { + 'task': + Tasks.visual_grounding, + 'inputs': [ + 'http://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/maas/visual-grounding/visual_grounding.png', + 'a blue turtle-like pokemon with round head' + ], + 'urlPaths': { + 'inUrls': [{ + 'name': 'image' + }, { + 'name': 'text' + }] + } + }, + TasksIODescriptions.visual_question_answering: { + 'task': + Tasks.visual_question_answering, + 'inputs': [ + 'http://225252-file.oss-cn-hangzhou-zmf.aliyuncs.com/maas_demo/visual_question_answering.png', + 'what is grown on the plant?' + ], + 'urlPaths': { + 'inUrls': [{ + 'name': 'image' + }, { + 'name': 'text' + }], + 'outUrls': [{ + 'outputKey': 'text' + }] + } + }, + TasksIODescriptions.visual_entailment: { + 'task': + Tasks.visual_entailment, + 'inputs': [ + 'http://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/maas/visual-entailment/visual_entailment.jpg', + 'there are two birds.', 'test' + ], + 'urlPaths': { + 'inUrls': [{ + 'name': 'image' + }, { + 'name': 'text' + }], + 'outUrls': [{}] + } + }, + TasksIODescriptions.generative_multi_modal_embedding: { + 'task': + Tasks.generative_multi_modal_embedding, + 'inputs': [ + 'http://clip-multimodal.oss-cn-beijing.aliyuncs.com/lingchen/demo/dogs.jpg', + 'dogs playing in the grass' + ], + 'urlPaths': { + 'inUrls': [{ + 'name': 'image' + }, { + 'name': 'text' + }], + 'outUrls': [{}] + } + }, +} + + +class DemoCompatibilityCheck(object): + + def compatibility_check(self): + if self.task not in TASKS_INPUT_TEMPLATES: + print('task is not supported in demo service so far') + return False + if TASKS_INPUT_TEMPLATES[self.task] not in INPUT_EXAMPLES: + print('no example input for this task') + return False + + print('testing demo: ', self.task, self.model_id) + test_pipline = pipeline(self.task, self.model_id) + req = INPUT_EXAMPLES[TASKS_INPUT_TEMPLATES[self.task]] + inputs = preprocess(req) + params = req.get('parameters', {}) + # modelscope inference + if params != {}: + output = test_pipline(inputs, **params) + else: + output = test_pipline(inputs) + json.dumps(output, cls=NumpyEncoder) + result = postprocess(req, output) + print(result) + return True + + +def preprocess(req): + in_urls = req.get('urlPaths').get('inUrls') + if len(req['inputs']) == 1: + inputs = req['inputs'][0] + else: + inputs = tuple(req['inputs']) + if in_urls is None or len(in_urls) == 0: + return inputs + + inputs_dict = {} + for i, in_url in enumerate(in_urls): + input_name = in_url.get('name') + if input_name is None or input_name == '': + return inputs + inputs_dict[input_name] = req['inputs'][i] + return inputs_dict + + +def postprocess(req, resp): + out_urls = req.get('urlPaths').get('outUrls') + if out_urls is None or len(out_urls) == 0: + return resp + new_resp = resp + if isinstance(resp, str): + new_resp = json.loads(resp) + for out_url in out_urls: + output_key = out_url['outputKey'] + file_type = out_url['fileType'] + new_resp.get(output_key) + if file_type == 'png' or file_type == 'jpg': + content = new_resp.get(output_key) + _, img_encode = cv2.imencode('.' + file_type, content) + img_bytes = img_encode.tobytes() + return type(img_bytes) + else: + out_mem_file = io.BytesIO() + out_mem_file.write(new_resp.get(output_key)) + return type(out_mem_file) diff --git a/modelscope/utils/device.py b/modelscope/utils/device.py new file mode 100644 index 00000000..83faa261 --- /dev/null +++ b/modelscope/utils/device.py @@ -0,0 +1,122 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from contextlib import contextmanager + +from modelscope.utils.constant import Devices, Frameworks +from modelscope.utils.import_utils import is_tf_available, is_torch_available +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def verify_device(device_name): + """ Verify device is valid, device should be either cpu, cuda, gpu, cuda:X or gpu:X. + + Args: + device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X + where X is the ordinal for gpu device. + + Return: + device info (tuple): device_type and device_id, if device_id is not set, will use 0 as default. + """ + err_msg = 'device should be either cpu, cuda, gpu, gpu:X or cuda:X where X is the ordinal for gpu device.' + assert device_name is not None and device_name != '', err_msg + device_name = device_name.lower() + eles = device_name.split(':') + assert len(eles) <= 2, err_msg + assert device_name is not None + assert eles[0] in ['cpu', 'cuda', 'gpu'], err_msg + device_type = eles[0] + device_id = None + if len(eles) > 1: + device_id = int(eles[1]) + if device_type == 'cuda': + device_type = Devices.gpu + if device_type == Devices.gpu and device_id is None: + device_id = 0 + return device_type, device_id + + +@contextmanager +def device_placement(framework, device_name='gpu:0'): + """ Device placement function, allow user to specify which device to place model or tensor + Args: + framework (str): tensorflow or pytorch. + device (str): gpu or cpu to use, if you want to specify certain gpu, + use gpu:$gpu_id or cuda:$gpu_id. + + Returns: + Context manager + + Examples: + + ```python + # Requests for using model on cuda:0 for gpu + with device_placement('pytorch', device='gpu:0'): + model = Model.from_pretrained(...) + ``` + """ + device_type, device_id = verify_device(device_name) + + if framework == Frameworks.tf: + import tensorflow as tf + if device_type == Devices.gpu and not tf.test.is_gpu_available(): + logger.debug( + 'tensorflow: cuda is not available, using cpu instead.') + device_type = Devices.cpu + if device_type == Devices.cpu: + with tf.device('/CPU:0'): + yield + else: + if device_type == Devices.gpu: + with tf.device(f'/device:gpu:{device_id}'): + yield + + elif framework == Frameworks.torch: + import torch + if device_type == Devices.gpu: + if torch.cuda.is_available(): + torch.cuda.set_device(f'cuda:{device_id}') + else: + logger.debug( + 'pytorch: cuda is not available, using cpu instead.') + yield + else: + yield + + +def create_device(device_name): + """ create torch device + + Args: + device_name (str): cpu, gpu, gpu:0, cuda:0 etc. + """ + import torch + device_type, device_id = verify_device(device_name) + use_cuda = False + if device_type == Devices.gpu: + use_cuda = True + if not torch.cuda.is_available(): + logger.info('cuda is not available, using cpu instead.') + use_cuda = False + + if use_cuda: + device = torch.device(f'cuda:{device_id}') + else: + device = torch.device('cpu') + + return device + + +def get_device(): + import torch + from torch import distributed as dist + if torch.cuda.is_available(): + if dist.is_available() and dist.is_initialized( + ) and 'LOCAL_RANK' in os.environ: + device_id = f"cuda:{os.environ['LOCAL_RANK']}" + else: + device_id = 'cuda:0' + else: + device_id = 'cpu' + return torch.device(device_id) diff --git a/modelscope/utils/error.py b/modelscope/utils/error.py new file mode 100644 index 00000000..a894063c --- /dev/null +++ b/modelscope/utils/error.py @@ -0,0 +1,122 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +# docstyle-ignore +AUDIO_IMPORT_ERROR = """ +Audio model import failed: {0}, if you want to use audio releated function, please execute +`pip install modelscope[audio] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html` +""" + +# docstyle-ignore +PROTOBUF_IMPORT_ERROR = """ +{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the +installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and +follow the ones that match your environment. +""" + +# docstyle-ignore +SENTENCEPIECE_IMPORT_ERROR = """ +{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the +installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones +that match your environment. +""" + +# docstyle-ignore +SKLEARN_IMPORT_ERROR = """ +{0} requires the scikit-learn library but it was not found in your environment. You can install it with: +``` +pip install -U scikit-learn +``` +In a notebook or a colab, you can install it by executing a cell with +``` +!pip install -U scikit-learn +``` +""" + +# docstyle-ignore +TENSORFLOW_IMPORT_ERROR = """ +{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the +installation page: https://www.tensorflow.org/install and follow the ones that match your environment. +""" + +# docstyle-ignore +TENSORFLOW_IMPORT_WARNING = """ +{0} requires the TensorFlow library but it was not found in your environment. +If you don't want to use them, please ignore this message +If you want to use them, please Checkout the instructions on the +installation page: https://www.tensorflow.org/install and follow the ones that match your environment. +""" + +# docstyle-ignore +TIMM_IMPORT_ERROR = """ +{0} requires the timm library but it was not found in your environment. You can install it with pip: +`pip install timm` +""" + +# docstyle-ignore +TOKENIZERS_IMPORT_ERROR = """ +{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with: +``` +pip install tokenizers +``` +In a notebook or a colab, you can install it by executing a cell with +``` +!pip install tokenizers +``` +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: +`pip install scipy` +""" + +# docstyle-ignore +OPENCV_IMPORT_ERROR = """ +{0} requires the opencv library but it was not found in your environment. You can install it with pip: +`pip install opencv-python` +""" + +PILLOW_IMPORT_ERROR = """ +{0} requires the Pillow library but it was not found in your environment. You can install it with pip: +`pip install Pillow` +""" + +GENERAL_IMPORT_ERROR = """ +{0} requires the REQ library but it was not found in your environment. You can install it with pip: +`pip install REQ` +""" + +DECORD_IMPORT_ERROR = """ +{0} requires the decord library but it was not found in your environment. You can install it with pip: +`pip install decord>=0.6.0` +""" + +# docstyle-ignore +DEEPSPEED_IMPORT_ERROR = """ +{0} requires the Deepspeed library but it was not found in your environment. Checkout the instructions on the +installation page: https://www.deepspeed.ai/tutorials/advanced-install/ and follow the ones that match your environment. +""" + +# docstyle-ignore +FAIRSEQ_IMPORT_ERROR = """ +{0} requires the fairseq library but it was not found in your environment. +You can install it with pip on linux: +`pip install fairseq` +On windows, please checkout the instructions on the +installation page: https://github.com/facebookresearch/fairseq and follow the ones that match your environment. +""" + +# docstyle-ignore +FASTTEXT_IMPORT_ERROR = """ +{0} requires the fasttext library but it was not found in your environment. +You can install it with pip on linux or mac: +`pip install fasttext` +Or you can checkout the instructions on the +installation page: https://github.com/facebookresearch/fastText and follow the ones that match your environment. +""" diff --git a/modelscope/utils/file_utils.py b/modelscope/utils/file_utils.py new file mode 100644 index 00000000..cf59dc57 --- /dev/null +++ b/modelscope/utils/file_utils.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import inspect +import os +from pathlib import Path + + +# TODO: remove this api, unify to flattened args +def func_receive_dict_inputs(func): + """to decide if a func could recieve dict inputs or not + + Args: + func (class): the target function to be inspected + + Returns: + bool: if func only has one arg ``input`` or ``inputs``, return True, else return False + """ + full_args_spec = inspect.getfullargspec(func) + varargs = full_args_spec.varargs + varkw = full_args_spec.varkw + if not (varargs is None and varkw is None): + return False + + args = [] if not full_args_spec.args else full_args_spec.args + args.pop(0) if (args and args[0] in ['self', 'cls']) else args + + if len(args) == 1 and args[0] in ['input', 'inputs']: + return True + + return False + + +def get_default_cache_dir(): + """ + default base dir: '~/.cache/modelscope' + """ + default_cache_dir = Path.home().joinpath('.cache', 'modelscope') + return default_cache_dir + + +def read_file(path): + + with open(path, 'r') as f: + text = f.read() + return text diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py new file mode 100644 index 00000000..105b3ffa --- /dev/null +++ b/modelscope/utils/hub.py @@ -0,0 +1,157 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +from typing import List, Optional, Union + +from requests import HTTPError + +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.file_download import model_file_download +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.utils.config import Config +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, + ModelFile) +from .logger import get_logger + +logger = get_logger(__name__) + + +def create_model_if_not_exist( + api, + model_id: str, + chinese_name: str, + visibility: Optional[int] = ModelVisibility.PUBLIC, + license: Optional[str] = Licenses.APACHE_V2, + revision: Optional[str] = DEFAULT_MODEL_REVISION): + exists = True + try: + api.get_model(model_id=model_id, revision=revision) + except HTTPError: + exists = False + if exists: + print(f'model {model_id} already exists, skip creation.') + return False + else: + api.create_model( + model_id=model_id, + visibility=visibility, + license=license, + chinese_name=chinese_name, + ) + print(f'model {model_id} successfully created.') + return True + + +def read_config(model_id_or_path: str, + revision: Optional[str] = DEFAULT_MODEL_REVISION): + """ Read config from hub or local path + + Args: + model_id_or_path (str): Model repo name or local directory path. + revision: revision of the model when getting from the hub + Return: + config (:obj:`Config`): config object + """ + if not os.path.exists(model_id_or_path): + local_path = model_file_download( + model_id_or_path, ModelFile.CONFIGURATION, revision=revision) + else: + local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION) + + return Config.from_file(local_path) + + +def auto_load(model: Union[str, List[str]]): + if isinstance(model, str): + if not osp.exists(model): + model = snapshot_download(model) + else: + model = [ + snapshot_download(m) if not osp.exists(m) else m for m in model + ] + + return model + + +def get_model_type(model_dir): + """Get the model type from the configuration. + + This method will try to get the model type from 'model.backbone.type', + 'model.type' or 'model.model_type' field in the configuration.json file. If + this file does not exist, the method will try to get the 'model_type' field + from the config.json. + + Args: + model_dir: The local model dir to use. @return: The model type + string, returns None if nothing is found. + """ + try: + configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION) + config_file = osp.join(model_dir, 'config.json') + if osp.isfile(configuration_file): + cfg = Config.from_file(configuration_file) + if hasattr(cfg.model, 'backbone'): + return cfg.model.backbone.type + elif hasattr(cfg.model, + 'model_type') and not hasattr(cfg.model, 'type'): + return cfg.model.model_type + else: + return cfg.model.type + elif osp.isfile(config_file): + cfg = Config.from_file(config_file) + return cfg.model_type if hasattr(cfg, 'model_type') else None + except Exception as e: + logger.error(f'parse config file failed with error: {e}') + + +def parse_label_mapping(model_dir): + """Get the label mapping from the model dir. + + This method will do: + 1. Try to read label-id mapping from the label_mapping.json + 2. Try to read label-id mapping from the configuration.json + 3. Try to read label-id mapping from the config.json + + Args: + model_dir: The local model dir to use. + + Returns: + The label2id mapping if found. + """ + import json + import os + label2id = None + label_path = os.path.join(model_dir, ModelFile.LABEL_MAPPING) + if os.path.exists(label_path): + with open(label_path) as f: + label_mapping = json.load(f) + label2id = {name: idx for name, idx in label_mapping.items()} + + if label2id is None: + config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) + config = Config.from_file(config_path) + if hasattr(config, ConfigFields.model) and hasattr( + config[ConfigFields.model], 'label2id'): + label2id = config[ConfigFields.model].label2id + elif hasattr(config, ConfigFields.model) and hasattr( + config[ConfigFields.model], 'id2label'): + id2label = config[ConfigFields.model].id2label + label2id = {label: id for id, label in id2label.items()} + elif hasattr(config, ConfigFields.preprocessor) and hasattr( + config[ConfigFields.preprocessor], 'label2id'): + label2id = config[ConfigFields.preprocessor].label2id + elif hasattr(config, ConfigFields.preprocessor) and hasattr( + config[ConfigFields.preprocessor], 'id2label'): + id2label = config[ConfigFields.preprocessor].id2label + label2id = {label: id for id, label in id2label.items()} + + config_path = os.path.join(model_dir, 'config.json') + if label2id is None and os.path.exists(config_path): + config = Config.from_file(config_path) + if hasattr(config, 'label2id'): + label2id = config.label2id + elif hasattr(config, 'id2label'): + id2label = config.id2label + label2id = {label: id for id, label in id2label.items()} + return label2id diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py new file mode 100644 index 00000000..5db5ea98 --- /dev/null +++ b/modelscope/utils/import_utils.py @@ -0,0 +1,447 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from huggingface/transformers. +import ast +import functools +import importlib +import os +import os.path as osp +import sys +from collections import OrderedDict +from functools import wraps +from importlib import import_module +from itertools import chain +from pathlib import Path +from types import ModuleType +from typing import Any + +from packaging import version + +from modelscope.utils.ast_utils import (INDEX_KEY, MODULE_KEY, REQUIREMENT_KEY, + load_index) +from modelscope.utils.error import * # noqa +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) + +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +logger = get_logger() + +AST_INDEX = None + + +def import_modules_from_file(py_file: str): + """ Import module from a certrain file + + Args: + py_file: path to a python file to be imported + + Return: + + """ + dirname, basefile = os.path.split(py_file) + if dirname == '': + dirname = Path.cwd() + module_name = osp.splitext(basefile)[0] + sys.path.insert(0, dirname) + validate_py_syntax(py_file) + mod = import_module(module_name) + sys.path.pop(0) + return module_name, mod + + +def is_method_overridden(method, base_class, derived_class): + """Check if a method of base class is overridden in derived class. + + Args: + method (str): the method name to check. + base_class (type): the class of the base class. + derived_class (type | Any): the class or instance of the derived class. + """ + assert isinstance(base_class, type), \ + "base_class doesn't accept instance, Please pass class instead." + + if not isinstance(derived_class, type): + derived_class = derived_class.__class__ + + base_method = getattr(base_class, method) + derived_method = getattr(derived_class, method) + return derived_method != base_method + + +def has_method(obj: object, method: str) -> bool: + """Check whether the object has a method. + + Args: + method (str): The method name to check. + obj (object): The object to check. + + Returns: + bool: True if the object has the method else False. + """ + return hasattr(obj, method) and callable(getattr(obj, method)) + + +def import_modules(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError( + f'custom_imports must be a list but got type {type(imports)}') + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError( + f'{imp} is of type {type(imp)} and cannot be imported.') + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + logger.warning(f'{imp} failed to import and is ignored.') + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported + + +def validate_py_syntax(filename): + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError('There are syntax errors in config ' + f'file {filename}: {e}') + + +# following code borrows implementation from huggingface/transformers +ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'}) +USE_TF = os.environ.get('USE_TF', 'AUTO').upper() +USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper() + +_torch_version = 'N/A' +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec('torch') is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version('torch') + logger.info(f'PyTorch version {_torch_version} Found.') + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info('Disabling PyTorch because USE_TF is set') + _torch_available = False + +_timm_available = importlib.util.find_spec('timm') is not None +try: + _timm_version = importlib_metadata.version('timm') + logger.debug(f'Successfully imported timm version {_timm_version}') +except importlib_metadata.PackageNotFoundError: + _timm_available = False + +_tf_version = 'N/A' +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec('tensorflow') is not None + if _tf_available: + candidates = ( + 'tensorflow', + 'tensorflow-cpu', + 'tensorflow-gpu', + 'tf-nightly', + 'tf-nightly-cpu', + 'tf-nightly-gpu', + 'intel-tensorflow', + 'intel-tensorflow-avx512', + 'tensorflow-rocm', + 'tensorflow-macos', + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse('2'): + pass + else: + logger.info(f'TensorFlow version {_tf_version} Found.') +else: + logger.info('Disabling Tensorflow because USE_TORCH is set') + _tf_available = False + + +def is_scipy_available(): + return importlib.util.find_spec('scipy') is not None + + +def is_sklearn_available(): + if importlib.util.find_spec('sklearn') is None: + return False + return is_scipy_available() and importlib.util.find_spec('sklearn.metrics') + + +def is_sentencepiece_available(): + return importlib.util.find_spec('sentencepiece') is not None + + +def is_protobuf_available(): + if importlib.util.find_spec('google') is None: + return False + return importlib.util.find_spec('google.protobuf') is not None + + +def is_tokenizers_available(): + return importlib.util.find_spec('tokenizers') is not None + + +def is_timm_available(): + return _timm_available + + +def is_torch_available(): + return _torch_available + + +def is_torch_cuda_available(): + if is_torch_available(): + import torch + + return torch.cuda.is_available() + else: + return False + + +def is_tf_available(): + return _tf_available + + +def is_opencv_available(): + return importlib.util.find_spec('cv2') is not None + + +def is_pillow_available(): + return importlib.util.find_spec('PIL.Image') is not None + + +def _is_package_available_fn(pkg_name): + return importlib.util.find_spec(pkg_name) is not None + + +def is_package_available(pkg_name): + return functools.partial(_is_package_available_fn, pkg_name) + + +def is_espnet_available(pkg_name): + return importlib.util.find_spec('espnet2') is not None \ + and importlib.util.find_spec('espnet') + + +REQUIREMENTS_MAAPING = OrderedDict([ + ('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), + ('sentencepiece', (is_sentencepiece_available, + SENTENCEPIECE_IMPORT_ERROR)), + ('sklearn', (is_sklearn_available, SKLEARN_IMPORT_ERROR)), + ('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)), + ('tensorflow', (is_tf_available, TENSORFLOW_IMPORT_ERROR)), + ('timm', (is_timm_available, TIMM_IMPORT_ERROR)), + ('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), + ('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)), + ('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)), + ('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)), + ('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)), + ('espnet2', (is_espnet_available, + GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))), + ('espnet', (is_espnet_available, + GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))), + ('easyasr', (is_package_available('easyasr'), AUDIO_IMPORT_ERROR)), + ('kwsbp', (is_package_available('kwsbp'), AUDIO_IMPORT_ERROR)), + ('decord', (is_package_available('decord'), DECORD_IMPORT_ERROR)), + ('deepspeed', (is_package_available('deepspeed'), DEEPSPEED_IMPORT_ERROR)), + ('fairseq', (is_package_available('fairseq'), FAIRSEQ_IMPORT_ERROR)), + ('fasttext', (is_package_available('fasttext'), FASTTEXT_IMPORT_ERROR)), +]) + +SYSTEM_PACKAGE = set(['os', 'sys', 'typing']) + + +def requires(obj, requirements): + if not isinstance(requirements, (list, tuple)): + requirements = [requirements] + if isinstance(obj, str): + name = obj + else: + name = obj.__name__ if hasattr(obj, + '__name__') else obj.__class__.__name__ + checks = [] + for req in requirements: + if req == '' or req in SYSTEM_PACKAGE: + continue + if req in REQUIREMENTS_MAAPING: + check = REQUIREMENTS_MAAPING[req] + else: + check_fn = is_package_available(req) + err_msg = GENERAL_IMPORT_ERROR.replace('REQ', req) + check = (check_fn, err_msg) + checks.append(check) + + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError(''.join(failed)) + + +def torch_required(func): + # Chose a different decorator name than in tests so it's clear they are not the same. + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_torch_available(): + return func(*args, **kwargs) + else: + raise ImportError(f'Method `{func.__name__}` requires PyTorch.') + + return wrapper + + +def tf_required(func): + # Chose a different decorator name than in tests so it's clear they are not the same. + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_tf_available(): + return func(*args, **kwargs) + else: + raise ImportError(f'Method `{func.__name__}` requires TF.') + + return wrapper + + +class LazyImportModule(ModuleType): + AST_INDEX = None + if AST_INDEX is None: + AST_INDEX = load_index() + + def __init__(self, + name, + module_file, + import_structure, + module_spec=None, + extra_objects=None, + try_to_pre_import=False): + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list( + chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + if try_to_pre_import: + self._try_to_import() + + def _try_to_import(self): + for sub_module in self._class_to_module.keys(): + try: + getattr(self, sub_module) + except Exception as e: + logger.warn( + f'pre load module {sub_module} error, please check {e}') + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + else: + raise AttributeError( + f'module {self.__name__} has no attribute {name}') + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str): + try: + # check requirements before module import + module_name_full = self.__name__ + '.' + module_name + if module_name_full in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]: + requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][ + module_name_full] + requires(module_name_full, requirements) + return importlib.import_module('.' + module_name, self.__name__) + except Exception as e: + raise RuntimeError( + f'Failed to import {self.__name__}.{module_name} because of the following error ' + f'(look up to see its traceback):\n{e}') from e + + def __reduce__(self): + return self.__class__, (self._name, self.__file__, + self._import_structure) + + @staticmethod + def import_module(signature): + """ import a lazy import module using signature + + Args: + signature (tuple): a tuple of str, (registry_name, registry_group_name, module_name) + """ + if signature in LazyImportModule.AST_INDEX[INDEX_KEY]: + mod_index = LazyImportModule.AST_INDEX[INDEX_KEY][signature] + module_name = mod_index[MODULE_KEY] + if module_name in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]: + requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][ + module_name] + requires(module_name, requirements) + importlib.import_module(module_name) + else: + logger.warning(f'{signature} not found in ast index file') diff --git a/modelscope/utils/json_utils.py b/modelscope/utils/json_utils.py new file mode 100644 index 00000000..c5bece23 --- /dev/null +++ b/modelscope/utils/json_utils.py @@ -0,0 +1,17 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import json +import numpy as np + + +class EnhancedEncoder(json.JSONEncoder): + """ Enhanced json encoder for not supported types """ + + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) diff --git a/modelscope/utils/logger.py b/modelscope/utils/logger.py new file mode 100644 index 00000000..994bd719 --- /dev/null +++ b/modelscope/utils/logger.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +from typing import Optional + +init_loggers = {} + + +def get_logger(log_file: Optional[str] = None, + log_level: int = logging.INFO, + file_mode: str = 'w'): + """ Get logging logger + + Args: + log_file: Log filename, if specified, file handler will be added to + logger + log_level: Logging level. + file_mode: Specifies the mode to open the file, if filename is + specified (if filemode is unspecified, it defaults to 'w'). + """ + logger_name = __name__.split('.')[0] + logger = logging.getLogger(logger_name) + + if logger_name in init_loggers: + return logger + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + # TODO @wenmeng.zwm add logger setting for distributed environment + if log_file is not None: + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + logger.setLevel(log_level) + init_loggers[logger_name] = True + + return logger diff --git a/modelscope/utils/model_tag.py b/modelscope/utils/model_tag.py new file mode 100644 index 00000000..7065e8f3 --- /dev/null +++ b/modelscope/utils/model_tag.py @@ -0,0 +1,184 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +import os + +import json +import requests + +from modelscope.version import __version__ + + +# 打标 +class ModelTag(object): + _URL = os.environ.get('MODEL_TAG_URL', None) + + # 模型测试结果 + BATCH_COMMIT_RESULT_URL = f'{_URL}/batchCommitResult' + # 测试阶段完成 + BATCH_REFRESH_STAGE_URL = f'{_URL}/batchRefreshStage' + # query_model_stage + QUERY_MODEL_STAGE_URL = f'{_URL}/queryModelStage' + + HEADER = {'Content-Type': 'application/json'} + + # 检测结果 + MODEL_SKIP = 0 + MODEL_FAIL = 1 + MODEL_PASS = 2 + + class ItemResult(object): + + def __init__(self): + self.result = 0 + self.name = '' + self.info = '' + + def to_json(self): + return { + 'name': self.name, + 'result': self.result, + 'info': self.info + } + + def __init__(self): + self.job_name = '' + self.job_id = '' + self.model = '' + self.sdk_version = '' + self.image_version = '' + self.domain = '' + self.task = '' + self.source = '' + self.stage = '' + # ItemResult list + self.item_result = [] + + # 发送请求 + def _post_request(self, url, param): + try: + logging.info(url + ' query: ' + + str(json.dumps(param, ensure_ascii=False))) + res = requests.post( + url=url, + headers=self.HEADER, + data=json.dumps(param, ensure_ascii=False).encode('utf8')) + if res.status_code == 200: + logging.info(f'{url} post结果: ' + res.text) + res_json = json.loads(res.text) + if int(res_json['errorCode']) == 200: + return res_json['content'] + else: + logging.error(res.text) + else: + logging.error(res.text) + except Exception as e: + logging.error(e) + + return None + + # 提交模型测试结果 + def batch_commit_result(self): + try: + param = { + 'sdkVersion': + self.sdk_version, + 'imageVersion': + self.image_version, + 'source': + self.source, + 'jobName': + self.job_name, + 'jobId': + self.job_id, + 'modelList': [{ + 'model': self.model, + 'domain': self.domain, + 'task': self.task, + 'itemResult': self.item_result + }] + } + return self._post_request(self.BATCH_COMMIT_RESULT_URL, param) + + except Exception as e: + logging.error(e) + + return + + # 测试阶段完成 + def batch_refresh_stage(self): + try: + param = { + 'sdkVersion': + self.sdk_version, + 'imageVersion': + self.image_version, + 'source': + self.source, + 'stage': + self.stage, + 'modelList': [{ + 'model': self.model, + 'domain': self.domain, + 'task': self.task + }] + } + return self._post_request(self.BATCH_REFRESH_STAGE_URL, param) + + except Exception as e: + logging.error(e) + + return + + # 查询模型某个阶段的最新测试结果(只返回单个结果 + def query_model_stage(self): + try: + param = { + 'sdkVersion': self.sdk_version, + 'model': self.model, + 'stage': self.stage, + 'imageVersion': self.image_version + } + return self._post_request(self.QUERY_MODEL_STAGE_URL, param) + + except Exception as e: + logging.error(e) + + return None + + # 提交模型UT测试结果 + """ + model_tag = ModelTag() + model_tag.model = "XXX" + model_tag.sdk_version = "0.3.7" + model_tag.domain = "nlp" + model_tag.task = "word-segmentation" + item = model_tag.ItemResult() + item.result = model_tag.MODEL_PASS + item.name = "ALL" + item.info = "" + model_tag.item_result.append(item.to_json()) + """ + + def commit_ut_result(self): + if self._URL is not None and self._URL != '': + self.job_name = 'UT' + self.source = 'dev' + self.stage = 'integration' + + self.batch_commit_result() + self.batch_refresh_stage() + + +def commit_model_ut_result(model_name, ut_result): + model_tag = ModelTag() + model_tag.model = model_name.replace('damo/', '') + model_tag.sdk_version = __version__ + # model_tag.domain = "" + # model_tag.task = "" + item = model_tag.ItemResult() + item.result = ut_result + item.name = 'ALL' + item.info = '' + model_tag.item_result.append(item.to_json()) + model_tag.commit_ut_result() diff --git a/modelscope/utils/multi_modal/fp16/__init__.py b/modelscope/utils/multi_modal/fp16/__init__.py new file mode 100644 index 00000000..81250858 --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .fp16 import FP16_Module, FP16_Optimizer diff --git a/modelscope/utils/multi_modal/fp16/fp16.py b/modelscope/utils/multi_modal/fp16/fp16.py new file mode 100755 index 00000000..37a80e65 --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/fp16.py @@ -0,0 +1,655 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stable version of apex FP16 Optimizer""" +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter + +from .fp16util import (master_params_to_model_params, + model_grads_to_master_grads) +from .loss_scaler import DynamicLossScaler, LossScaler + +FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_fp16(val): + """Convert fp32 `val` to fp16""" + + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, FLOAT_TYPES): + val = val.half() + return val + + return conversion_helper(val, half_conversion) + + +def fp16_to_fp32(val): + """Convert fp16 `val` to fp32""" + + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, HALF_TYPES): + val = val.float() + return val + + return conversion_helper(val, float_conversion) + + +class FP16_Module(nn.Module): + + def __init__(self, module): + super(FP16_Module, self).__init__() + self.add_module('module', module.half()) + + def forward(self, *inputs, **kwargs): + return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + + +class FP16_Optimizer(object): + """ + :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, + and manage static or dynamic loss scaling and master weights in a manner transparent to the user. + For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, + and changing the call to ``backward``. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + # Name the FP16_Optimizer instance to replace the existing optimizer + # (recommended but not required): + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + # loss.backward() becomes: + optimizer.backward(loss) + ... + + Example with dynamic loss scaling:: + + ... + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) + # optional arg to control dynamic loss scaling behavior + # dynamic_loss_args={'scale_window' : 500}) + # Usually, dynamic_loss_args is not necessary. + + Args: + init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. # noqa + static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. # noqa + dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. # noqa + dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. # noqa + verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. # noqa + + ``init_optimizer`` is expected to have been constructed in the ordinary way. + It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be + named to replace ``init_optimizer``, for two reasons: + First, it means that references to the same name + later in the file will not have to change. + Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to + modify ``init_optimizer``. If you do choose a unique name for the new + :class:`FP16_Optimizer` instance, you should only work with this new instance, + because the preexisting optimizer might no longer behave as expected. + + ``init_optimizer`` may be any Pytorch optimizer. + It may contain a mixture of fp16 and fp32 parameters organized into any number of + ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will + ingest these ``param_groups`` and remember them. + + Calls to :: + + loss.backward() + + must be replaced with :: + + optimizer.backward(loss) + + because :class:`FP16_Optimizer` requires ownership of the backward pass to implement + loss scaling and copies to master gradients. + + .. note:: + Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients + are downscaled before being applied. This means that adjusting the loss scale, or using + dynamic loss scaling, should not require retuning the learning rate or any other + hyperparameters. + + + **Advanced options** + + **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. + See docstring for :attr:`step`. + + **Gradient clipping**: Use :attr:`clip_master_grads`. + + **Multiple losses**: If your model accumulates gradients from multiple losses, + this can be made more efficient by supplying ``update_master_grads=False`` + to :attr:`backward`. See docstring for :attr:`backward`. + + **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: + + print(optimizer.loss_scale) + optimizer.loss_scale = new_loss_scale + + For static loss scaling, manually adjusting the loss scale over time is a reasonable + thing to do. During later epochs, gradients may become smaller, and a + higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss + scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting + the loss scale is not recommended. + + **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in + Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` + should still work as intended. + """ + + def __init__(self, + init_optimizer, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=False): + if not torch.cuda.is_available: + raise SystemError('Cannot use fp16 without CUDA.') + + self.verbose = verbose + + self.optimizer = init_optimizer + # init_state_dict sets up an alternative way to cast per-param state tensors. + # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. + # init_state_dict = init_optimizer.state_dict() + + self.fp16_groups = [] + self.fp32_from_fp16_groups = [] + self.fp32_from_fp32_groups = [] + for i, param_group in enumerate(self.optimizer.param_groups): + self.maybe_print( + 'FP16_Optimizer processing param group {}:'.format(i)) + fp16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_fp16_params_this_group = [] + for i, param in enumerate(param_group['params']): + if param.requires_grad: + if param.type() == 'torch.cuda.HalfTensor': + self.maybe_print( + 'FP16_Optimizer received torch.cuda.HalfTensor with {}' + .format(param.size())) + fp16_params_this_group.append(param) + master_param = param.detach().clone().float() + master_param.requires_grad = True + # Copythe model parallel flag. + master_param.model_parallel = param.model_parallel + param_group['params'][i] = master_param + fp32_from_fp16_params_this_group.append(master_param) + # Reset existing state dict key to the new master param. + # We still need to recast per-param state tensors, if any, to FP32. + if param in self.optimizer.state: + self.optimizer.state[ + master_param] = self.optimizer.state.pop(param) + elif param.type() == 'torch.cuda.FloatTensor': + self.maybe_print( + 'FP16_Optimizer received torch.cuda.FloatTensor with {}' + .format(param.size())) + fp32_params_this_group.append(param) + param_group['params'][i] = param + else: + raise TypeError( + 'Wrapped parameters must be either ' + 'torch.cuda.FloatTensor or torch.cuda.HalfTensor. ' + 'Received {}'.format(param.type())) + + self.fp16_groups.append(fp16_params_this_group) + self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors + self.optimizer.load_state_dict(self.optimizer.state_dict()) + # alternative way to cast per-param state tensors: + # self.optimizer.load_state_dict(init_state_dict) + + if dynamic_loss_scale: + self.dynamic_loss_scale = True + if dynamic_loss_args is not None: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + else: + self.loss_scaler = DynamicLossScaler() + else: + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(static_loss_scale) + + self.overflow = False + self.first_closure_call_this_step = True + + self.clip_grad_norm = nn.utils.clip_grad.clip_grad_norm_ + + def maybe_print(self, msg): + if self.verbose: + print(msg) + + def __getstate__(self): + raise RuntimeError( + 'FP16_Optimizer should be serialized using state_dict().') + + def __setstate__(self, state): + raise RuntimeError( + 'FP16_Optimizer should be deserialized using load_state_dict().') + + def zero_grad(self, set_grads_to_None=False): + """ + Zero fp32 and fp16 parameter grads. + """ + # In principle, only the .grad attributes of the model params need to be zeroed, + # because gradients are copied into the FP32 master params. However, we zero + # all gradients owned by the optimizer, just to be safe: + for group in self.optimizer.param_groups: + for p in group['params']: + if set_grads_to_None: + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + # Zero fp16 gradients owned by the model: + for fp16_group in self.fp16_groups: + for param in fp16_group: + if set_grads_to_None: + param.grad = None + else: + if param.grad is not None: + param.grad.detach_( + ) # as in torch.optim.optimizer.zero_grad() + param.grad.zero_() + + def _check_overflow(self): + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + for group in self.fp32_from_fp32_groups: + for param in group: + params.append(param) + self.overflow = self.loss_scaler.has_overflow(params) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + def _master_params_to_model_params(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp16_group, fp32_from_fp16_group) + + def _model_params_to_master_params(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp32_from_fp16_group, fp16_group) + + # To consider: Integrate distributed with this wrapper by registering a hook on each variable + # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. + def _model_grads_to_master_grads(self): + for fp16_group, fp32_from_fp16_group in zip( + self.fp16_groups, self.fp32_from_fp16_groups): + model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) + + def _downscale_master(self): + if self.loss_scale != 1.0: + for group in self.optimizer.param_groups: + for param in group['params']: + if param.grad is not None: + param.grad.data.mul_(1. / self.loss_scale) + + def clip_master_grads(self, max_norm, norm_type=2): + """ + Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the current fp32 gradients (viewed as a single vector). + + .. warning:: + Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). # noqa + """ + if not self.overflow: + fp32_params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + fp32_params.append(param) + return self.clip_grad_norm(fp32_params, max_norm, norm_type) + else: + return -1 + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + state_dict = {} + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict[ + 'first_closure_call_this_step'] = self.first_closure_call_this_step + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups + return state_dict + + def load_state_dict(self, state_dict): + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + self.first_closure_call_this_step = state_dict[ + 'first_closure_call_this_step'] + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # At this point, the optimizer's references to the model's fp32 parameters are up to date. + # The optimizer's hyperparameters and internal buffers are also up to date. + # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still + # out of date. There are two options. + # 1: Refresh the master params from the model's fp16 params. + # This requires less storage but incurs precision loss. + # 2: Save and restore the fp32 master copies separately. + # We choose option 2. + # + # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device + # of their associated parameters, because it's possible those buffers might not exist yet in + # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been + # constructed in the same way as the one whose state_dict we are loading, the same master params + # are guaranteed to exist, so we can just copy_() from the saved master params. + for current_group, saved_group in zip(self.fp32_from_fp16_groups, + state_dict['fp32_from_fp16']): + for current, saved in zip(current_group, saved_group): + current.data.copy_(saved.data) + + def step(self, closure=None): # could add clip option. + """ + If no closure is supplied, :attr:`step` should be called after + ``fp16_optimizer_obj.backward(loss)``. + :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to + :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params + originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run + another forward pass using their model. + + If a closure is supplied, :attr:`step` may be called without a prior call to + :attr:`backward(loss)`. + This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. + However, the user should take care that any ``loss.backward()`` call within the closure + has been replaced by ``fp16_optimizer_obj.backward(loss)``. + + Args: + closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. # noqa + + Example with closure:: + + # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an + # existing pytorch optimizer. + for input, target in dataset: + def closure(): + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + # loss.backward() becomes: + optimizer.backward(loss) + return loss + optimizer.step(closure) + + .. warning:: + Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. + + .. _`ordinary Pytorch optimizer use`: + http://pytorch.org/docs/master/optim.html#optimizer-step-closure + """ + + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self.maybe_print( + 'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}' + .format(scale, self.loss_scale)) + return + + if closure is not None: + retval = self._step_with_closure(closure) + else: + retval = self.optimizer.step() + + self._master_params_to_model_params() + + return retval + + def _step_with_closure(self, closure): + + def wrapped_closure(): + # helpful for debugging + # print("Calling wrapped_closure, first_closure_call_this_step = {}" + # .format(self.first_closure_call_this_step)) + if self.first_closure_call_this_step: + # We expect that the fp16 params are initially fresh on entering self.step(), + # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() + # is called within self.optimizer.step(). + self.first_closure_call_this_step = False + else: + # If self.optimizer.step() internally calls wrapped_closure more than once, + # it may update the fp32 params after each call. However, self.optimizer + # doesn't know about the fp16 params at all. If the fp32 params get updated, + # we can't rely on self.optimizer to refresh the fp16 params. We need + # to handle that manually: + self._master_params_to_model_params() + # Our API expects the user to give us ownership of the backward() call by + # replacing all calls to loss.backward() with optimizer.backward(loss). + # This requirement holds whether or not the call to backward() is made within a closure. + # If the user is properly calling optimizer.backward(loss) within "closure," + # calling closure() here will give the fp32 master params fresh gradients + # for the optimizer to play with, so all wrapped_closure needs to do is call + # closure() and return the loss. + temp_loss = closure() + while (self.overflow): + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + self.maybe_print( + 'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, ' + 'reducing to {}'.format(scale, self.loss_scale)) + temp_loss = closure() + return temp_loss + + retval = self.optimizer.step(wrapped_closure) + + self.first_closure_call_this_step = True + + return retval + + def backward(self, loss, update_master_grads=True, retain_graph=False): + """ + :attr:`backward` performs the following conceptual steps: + + 1. fp32_loss = loss.float() (see first Note below) + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). # noqa + 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. # noqa + 5. Finally, master grads are divided by loss_scale. + + In this way, after :attr:`backward`, the master params have fresh gradients, + and :attr:`step` may be called. + + .. note:: + :attr:`backward` internally converts the loss to fp32 before applying the loss scale. + This provides some additional safety against overflow if the user has supplied an + fp16 loss value. + However, for maximum overflow safety, the user should + compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to + :attr:`backward`. + + .. warning:: + The gradients found in a model's leaves after the call to + :attr:`backward` should not be regarded as valid in general, + because it's possible + they have been scaled (and in the case of dynamic loss scaling, + the scale factor may change over time). + If the user wants to inspect gradients after a call to :attr:`backward`, + only the master gradients should be regarded as valid. These can be retrieved via + :attr:`inspect_master_grad_data()`. + + Args: + loss: The loss output by the user's model. loss may be either float or half (but see first Note above). + update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. # noqa + retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). # noqa + + Example:: + + # Ordinary operation: + optimizer.backward(loss) + + # Naive operation with multiple losses (technically valid, but less efficient): + # fp32 grads will be correct after the second call, but + # the first call incurs an unnecessary fp16->fp32 grad copy. + optimizer.backward(loss1) + optimizer.backward(loss2) + + # More efficient way to handle multiple losses: + # The fp16->fp32 grad copy is delayed until fp16 grads from all + # losses have been accumulated. + optimizer.backward(loss1, update_master_grads=False) + optimizer.backward(loss2, update_master_grads=False) + optimizer.update_master_grads() + """ + # To consider: try multiple backward passes using retain_grad=True to find + # a loss scale that works. After you find a loss scale that works, do a final dummy + # backward pass with retain_graph=False to tear down the graph. Doing this would avoid + # discarding the iteration, but probably wouldn't improve overall efficiency. + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + if update_master_grads: + self.update_master_grads() + + def update_master_grads(self): + """ + Copy the ``.grad`` attribute from stored references to fp16 parameters to + the ``.grad`` attribute of the fp32 master parameters that are directly + updated by the optimizer. :attr:`update_master_grads` only needs to be called if + ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. + """ + if self.dynamic_loss_scale: + self._check_overflow() + if self.overflow: return # noqa + self._model_grads_to_master_grads() + self._downscale_master() + + def inspect_master_grad_data(self): + """ + When running with :class:`FP16_Optimizer`, + ``.grad`` attributes of a model's fp16 leaves should not be + regarded as truthful, because they might be scaled. + After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, + the fp32 master params' ``.grad`` + attributes will contain valid gradients properly divided by the loss scale. However, + because :class:`FP16_Optimizer` flattens some parameters, accessing them may be + nonintuitive. :attr:`inspect_master_grad_data` + allows those gradients to be viewed with shapes corresponding to their associated model leaves. + + Returns: + List of lists (one list for each parameter group). The list for each parameter group + is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. + """ + if self.overflow: + print( + 'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. ' + 'Gradients are currently invalid (may be inf, nan, or stale). Returning None.' + ) + return None + else: + # The optimizer owns only references to master params. + master_grads_data = [] + for param_group in self.optimizer.param_groups: + master_grads_this_group = [] + for param in param_group['params']: + if param.grad is not None: + master_grads_this_group.append(param.grad.data) + else: + master_grads_this_group.append(None) + master_grads_data.append(master_grads_this_group) + return master_grads_data + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) diff --git a/modelscope/utils/multi_modal/fp16/fp16util.py b/modelscope/utils/multi_modal/fp16/fp16util.py new file mode 100644 index 00000000..29595a6c --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/fp16util.py @@ -0,0 +1,216 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Variable + + +class tofp16(nn.Module): + """ + Utility module that implements:: + + def forward(self, input): + return input.half() + """ + + def __init__(self): + super(tofp16, self).__init__() + + def forward(self, input): + return input.half() + + +def BN_convert_float(module): + """ + Utility function for network_to_half(). + + Retained for legacy purposes. + """ + if isinstance( + module, + torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: + module.float() + for child in module.children(): + BN_convert_float(child) + return module + + +def network_to_half(network): + """ + Convert model to half precision in a batchnorm-safe way. + + Retained for legacy purposes. It is recommended to use FP16Model. + """ + return nn.Sequential(tofp16(), BN_convert_float(network.half())) + + +def convert_module(module, dtype): + """ + Converts a module's immediate parameters and buffers to dtype. + """ + for param in module.parameters(recurse=False): + if param is not None: + if param.data.dtype.is_floating_point: + param.data = param.data.to(dtype=dtype) + if param._grad is not None and param._grad.data.dtype.is_floating_point: + param._grad.data = param._grad.data.to(dtype=dtype) + + for buf in module.buffers(recurse=False): + if buf is not None and buf.data.dtype.is_floating_point: + buf.data = buf.data.to(dtype=dtype) + + +def convert_network(network, dtype): + """ + Converts a network's parameters and buffers to dtype. + """ + for module in network.modules(): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm + ) and module.affine is True: + continue + convert_module(module, dtype) + return network + + +class FP16Model(nn.Module): + """ + Convert model to half precision in a batchnorm-safe way. + """ + + def __init__(self, network): + super(FP16Model, self).__init__() + self.network = convert_network(network, dtype=torch.half) + + def forward(self, *inputs): + inputs = tuple(t.half() for t in inputs) + return self.network(*inputs) + + +def backwards_debug_hook(grad): + raise RuntimeError( + 'master_params recieved a gradient in the backward pass!') + + +def prep_param_lists(model, flat_master=False): + """ + Creates a list of FP32 master parameters for a given model, as in + `Training Neural Networks with Mixed Precision: Real Examples`_. + + Args: + model (torch.nn.Module): Existing Pytorch model + flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. # noqa + Returns: + A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. # noqa + + Example:: + + model_params, master_params = prep_param_lists(model) + + .. warning:: + Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. # noqa + + .. _`Training Neural Networks with Mixed Precision: Real Examples`: + http://on-demand.gputechconf.com/gtc/2018/video/S81012/ + """ + model_params = [ + param for param in model.parameters() if param.requires_grad + ] + + if flat_master: + # Give the user some more useful error messages + try: + # flatten_dense_tensors returns a contiguous flat array. + # http://pytorch.org/docs/master/_modules/torch/_utils.html + master_params = _flatten_dense_tensors( + [param.data for param in model_params]).float() + except: # noqa + print( + 'Error in prep_param_lists: model may contain a mixture of parameters ' + 'of different types. Use flat_master=False, or use F16_Optimizer.' + ) + raise + master_params = torch.nn.Parameter(master_params) + master_params.requires_grad = True + # master_params.register_hook(backwards_debug_hook) + if master_params.grad is None: + master_params.grad = master_params.new(*master_params.size()) + return model_params, [master_params] + else: + master_params = [ + param.clone().float().detach() for param in model_params + ] + for param in master_params: + param.requires_grad = True + return model_params, master_params + + +def model_grads_to_master_grads(model_params, + master_params, + flat_master=False): + """ + Copy model gradients to master gradients. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. # noqa + """ + if flat_master: + # The flattening may incur one more deep copy than is necessary. + master_params[0].grad.data.copy_( + _flatten_dense_tensors([p.grad.data for p in model_params])) + else: + for model, master in zip(model_params, master_params): + if model.grad is not None: + if master.grad is None: + master.grad = Variable( + master.data.new(*master.data.size())) + master.grad.data.copy_(model.grad.data) + else: + master.grad = None + + +def master_params_to_model_params(model_params, + master_params, + flat_master=False): + """ + Copy master parameters to model parameters. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. # noqa + """ + if flat_master: + for model, master in zip( + model_params, + _unflatten_dense_tensors(master_params[0].data, model_params)): + model.data.copy_(master) + else: + for model, master in zip(model_params, master_params): + model.data.copy_(master.data) + + +# Backward compatibility fixes + + +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) diff --git a/modelscope/utils/multi_modal/fp16/loss_scaler.py b/modelscope/utils/multi_modal/fp16/loss_scaler.py new file mode 100755 index 00000000..fc55a4ed --- /dev/null +++ b/modelscope/utils/multi_modal/fp16/loss_scaler.py @@ -0,0 +1,237 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +# item() is a recent addition, so this helps with backward compatibility. +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + + +class LossScaler: + """ + Class that manages a static loss scale. This class is intended to interact with + :class:`FP16_Optimizer`, and should not be directly manipulated by the user. + + Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to + :class:`FP16_Optimizer`'s constructor. + + Args: + scale (float, optional, default=1.0): The loss scale. + """ + + def __init__(self, scale=1): + self.cur_scale = scale + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + return False + + def update_scale(self, overflow): + pass + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss * self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + + +class DynamicLossScaler: + """ + Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` + indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of + :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` + operates, because the default options can be changed using the + the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. + + Loss scaling is designed to combat the problem of underflowing gradients encountered at long + times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are + encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has + occurred. + :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, + and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients detected, + :class:`DynamicLossScaler` increases the loss scale once more. + In this way :class:`DynamicLossScaler` attempts to "ride the edge" of + always using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` + scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. # noqa + scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. # noqa + """ + + def __init__(self, + init_scale=2**32, + scale_factor=2., + scale_window=1000, + min_scale=1, + delayed_shift=1, + consecutive_hysteresis=False): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + self.min_scale = min_scale + self.delayed_shift = delayed_shift + self.cur_hysteresis = delayed_shift + self.consecutive_hysteresis = consecutive_hysteresis + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params): + for p in params: + if p.grad is not None and DynamicLossScaler._has_inf_or_nan( + p.grad.data): + return True + + return False + + def has_overflow(self, params): + overflow = self.has_overflow_serial(params) + overflow_gpu = torch.cuda.ByteTensor([overflow]) + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if 'value cannot be converted' not in instance.args[0]: + raise + return True + else: + if cpu_sum == float( + 'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + # `overflow` is boolean indicating whether the gradient overflowed + def update_scale(self, overflow): + + if not hasattr(self, 'min_scale'): + self.min_scale = 1 + if not hasattr(self, 'delayed_shift'): + self.delayed_shift = 1 + if not hasattr(self, 'cur_hysteresis'): + self.cur_hysteresis = 1 + if not hasattr(self, 'consecutive_hysteresis'): + self.consecutive_hysteresis = True + if overflow: + # self.cur_scale /= self.scale_factor + if self.delayed_shift == 1 or self.cur_hysteresis == 1: + self.cur_scale = max(self.cur_scale / self.scale_factor, + self.min_scale) + else: + self.cur_hysteresis -= 1 + self.last_overflow_iter = self.cur_iter + else: + if self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + if (self.cur_iter + - self.last_overflow_iter) % self.scale_window == 0: + if not self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss * self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + + +############################################################## +# Example usage below here -- assuming it's in a separate file +############################################################## +""" +TO-DO separate out into an example. +if __name__ == "__main__": + import torch + from torch.autograd import Variable + from dynamic_loss_scaler import DynamicLossScaler + + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 64, 1000, 100, 10 + + # Create random Tensors to hold inputs and outputs, and wrap them in Variables. + x = Variable(torch.randn(N, D_in), requires_grad=False) + y = Variable(torch.randn(N, D_out), requires_grad=False) + + w1 = Variable(torch.randn(D_in, H), requires_grad=True) + w2 = Variable(torch.randn(H, D_out), requires_grad=True) + parameters = [w1, w2] + + learning_rate = 1e-6 + optimizer = torch.optim.SGD(parameters, lr=learning_rate) + loss_scaler = DynamicLossScaler() + + for t in range(500): + y_pred = x.mm(w1).clamp(min=0).mm(w2) + loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale + print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) + print('Iter {} scaled loss: {}'.format(t, loss.data[0])) + print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) + + # Run backprop + optimizer.zero_grad() + loss.backward() + + # Check for overflow + has_overflow = DynamicLossScaler.has_overflow(parameters) + + # If no overflow, unscale grad and update as usual + if not has_overflow: + for param in parameters: + param.grad.data.mul_(1. / loss_scaler.loss_scale) + optimizer.step() + # Otherwise, don't do anything -- ie, skip iteration + else: + print('OVERFLOW!') + + # Update loss scale for next iteration + loss_scaler.update_scale(has_overflow) + +""" diff --git a/modelscope/utils/nlp/__init__.py b/modelscope/utils/nlp/__init__.py new file mode 100644 index 00000000..62c0b888 --- /dev/null +++ b/modelscope/utils/nlp/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .utils import import_external_nltk_data + +else: + _import_structure = { + 'utils': ['import_external_nltk_data'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/utils/nlp/distributed.py b/modelscope/utils/nlp/distributed.py new file mode 100755 index 00000000..53332c0f --- /dev/null +++ b/modelscope/utils/nlp/distributed.py @@ -0,0 +1,133 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.distributed as dist +from megatron import mpu +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Variable +from torch.nn.modules import Module + +from modelscope.utils.torch_utils import init_dist + + +def initialize_distributed(rank, mpu, world_size, model_parallel_size, + master_ip, master_port): + """Initialize torch.distributed.""" + # Manually set the device ids. + device = rank % torch.cuda.device_count() + torch.cuda.set_device(device) + # Call the init process + init_method = 'tcp://' + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend='nccl', + world_size=world_size, + rank=rank, + init_method=init_method) + # Set the model-parallel communicators. + mpu.initialize_model_parallel(model_parallel_size) + + +def normal_init_method(mean, std): + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=mean, std=std) + + return init_ + + +def scaled_init_method(mean, std, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = std / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=mean, std=std) + + return init_ + + +class DistributedDataParallel(Module): + + def __init__(self, module): + super(DistributedDataParallel, self).__init__() + self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + self.module = module + self.data_parallel_group = mpu.get_data_parallel_group() + src_rank = mpu.get_model_parallel_rank() + for p in self.module.parameters(): + if torch.is_tensor(p): + dist.broadcast(p, src_rank, group=self.data_parallel_group) + + def allreduce_params(reduce_after=True, + no_scale=False, + fp32_allreduce=False): + if (self.needs_reduction): + self.needs_reduction = False + buckets = {} + for name, param in self.module.named_parameters(): + if param.requires_grad and param.grad is not None: + tp = (param.data.type()) + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + if self.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print( + 'WARNING: gloo dist backend for half parameters may be extremely slow.', + 'It is recommended to use the NCCL backend in this case.' + ) + self.warn_on_half = False + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + if fp32_allreduce: + coalesced = coalesced.float() + if not no_scale and not reduce_after: + coalesced /= dist.get_world_size( + group=self.data_parallel_group) + dist.all_reduce(coalesced, group=self.data_parallel_group) + torch.cuda.synchronize() + if not no_scale and reduce_after: + coalesced /= dist.get_world_size( + group=self.data_parallel_group) + for buf, synced in zip( + grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + self.hook_handles = [] + self.hooks = [] + for param in list(self.module.parameters()): + + def allreduce_hook(*unused): + Variable._execution_engine.queue_callback(allreduce_params) + + self.allreduce_params = allreduce_params + + def forward(self, *inputs, **kwargs): + self.needs_reduction = True + return self.module(*inputs, **kwargs) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + sd = self.module.state_dict(destination, prefix, keep_vars) + + return sd + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) diff --git a/modelscope/utils/nlp/load_checkpoint.py b/modelscope/utils/nlp/load_checkpoint.py new file mode 100755 index 00000000..6534e18d --- /dev/null +++ b/modelscope/utils/nlp/load_checkpoint.py @@ -0,0 +1,117 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + + +def load_checkpoint(model, + load_dir, + tag, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True): + r"""Load training checkpoint + + Arguments: + load_dir: Required. Directory to load the checkpoint from + tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step. + load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and + checkpoint match. + load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. + Ex. ADAM's momentum and variance + load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. + Return: + load_path: Path of the loaded checkpoint. None if loading the checkpoint failed + client_state: State dictionary used for loading required training states in the client code. + """ + + load_path, client_states = _load_checkpoint( + model, + load_dir, + tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states) + + if load_optimizer_states: + if model.zero_optimization() and load_path is not None: + model._load_zero_checkpoint( + load_dir, tag, load_optimizer_states=load_optimizer_states) + + return load_path, client_states + + +def _get_ckpt_name(mpu, checkpoints_path, tag): + mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() + ckpt_name = os.path.join( + checkpoints_path, str(tag), + 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') + return ckpt_name + + +def pre_load(mpu, load_dir, tag=''): + load_path = _get_ckpt_name(mpu, load_dir, tag) + checkpoint = torch.load( + load_path, map_location=lambda storage, loc: storage) + return checkpoint['module'] + + +def _load_checkpoint(model, + load_dir, + tag, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True): + + load_path = model._get_ckpt_name(load_dir, tag) + + if not os.path.exists(load_path): + return None, None + + checkpoint = torch.load( + load_path, map_location=lambda storage, loc: storage) + + model.load_module_state_dict( + state_dict=checkpoint['module'], strict=load_module_strict) + if not model.zero_optimization() and load_optimizer_states: + if model.fp16_enabled(): + model.optimizer.load_state_dict( + checkpoint['optimizer'], + load_optimizer_states=load_optimizer_states) + elif load_optimizer_states: + model.optimizer.load_state_dict(checkpoint['optimizer']) + + if load_lr_scheduler_states and model.lr_scheduler is not None: + model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + + model.csr_tensor_module_names = checkpoint['csr_tensor_module_names'] + model.global_steps = checkpoint['global_steps'] + model.global_samples = checkpoint.get( + 'global_samples', model.global_steps * model.train_batch_size()) + model.skipped_steps = checkpoint['skipped_steps'] + model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] + model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] + deepspeed_states = [ + 'module', 'optimizer', 'lr_scheduler', 'csr_tensor_module_names', + 'skipped_steps', 'global_steps', 'dp_world_size', 'mp_world_size' + ] + client_state = { + key: value + for key, value in checkpoint.items() if key not in deepspeed_states + } + + return load_path, client_state diff --git a/modelscope/utils/nlp/space/__init__.py b/modelscope/utils/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/utils/nlp/space/args.py b/modelscope/utils/nlp/space/args.py new file mode 100644 index 00000000..c92401c5 --- /dev/null +++ b/modelscope/utils/nlp/space/args.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import argparse + +import json + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Unsupported value encountered.') + + +class HParams(dict): + """ Hyper-parameters class + + Store hyper-parameters in training / infer / ... scripts. + """ + + def __getattr__(self, name): + if name in self.keys(): + return self[name] + for v in self.values(): + if isinstance(v, HParams): + if name in v: + return v[name] + raise AttributeError(f"'HParams' object has no attribute '{name}'") + + def __setattr__(self, name, value): + self[name] = value + + def save(self, filename): + with open(filename, 'w', encoding='utf-8') as fp: + json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) + + def load(self, filename): + with open(filename, 'r', encoding='utf-8') as fp: + params_dict = json.load(fp) + for k, v in params_dict.items(): + if isinstance(v, dict): + self[k].update(HParams(v)) + else: + self[k] = v + + +def parse_args(parser): + """ Parse hyper-parameters from cmdline. """ + parsed = parser.parse_args() + args = HParams() + optional_args = parser._action_groups[1] + for action in optional_args._group_actions[1:]: + arg_name = action.dest + args[arg_name] = getattr(parsed, arg_name) + for group in parser._action_groups[2:]: + group_args = HParams() + for action in group._group_actions: + arg_name = action.dest + group_args[arg_name] = getattr(parsed, arg_name) + if len(group_args) > 0: + args[group.title] = group_args + return args diff --git a/modelscope/utils/nlp/space/clean_dataset.py b/modelscope/utils/nlp/space/clean_dataset.py new file mode 100644 index 00000000..2c971b10 --- /dev/null +++ b/modelscope/utils/nlp/space/clean_dataset.py @@ -0,0 +1,335 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import re + +from . import ontology + + +def clean_text_split_dot(text): + text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', + text) # 'abc.xyz' -> 'abc . xyz' + text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' + return text + + +def clean_text(data_dir, text): + text = text.strip() + text = text.lower() + text = text.replace(u'’', "'") + text = text.replace(u'‘', "'") + text = text.replace(';', ',') + text = text.replace('"', ' ') + text = text.replace('/', ' and ') + text = text.replace("don't", "do n't") + text = clean_time(text) + baddata = { + r'c\.b (\d), (\d) ([a-z])\.([a-z])': r'cb\1\2\3\4', + 'c.b. 1 7 d.y': 'cb17dy', + 'c.b.1 7 d.y': 'cb17dy', + 'c.b 25, 9 a.q': 'cb259aq', + 'isc.b 25, 9 a.q': 'is cb259aq', + 'c.b2, 1 u.f': 'cb21uf', + 'c.b 1,2 q.a': 'cb12qa', + '0-122-336-5664': '01223365664', + 'postcodecb21rs': 'postcode cb21rs', + r'i\.d': 'id', + ' i d ': 'id', + 'Telephone:01223358966': 'Telephone: 01223358966', + 'depature': 'departure', + 'depearting': 'departing', + '-type': ' type', + r'b[\s]?&[\s]?b': 'bed and breakfast', + 'b and b': 'bed and breakfast', + r'guesthouse[s]?': 'guest house', + r'swimmingpool[s]?': 'swimming pool', + "wo n\'t": 'will not', + " \'d ": ' would ', + " \'m ": ' am ', + " \'re' ": ' are ', + " \'ll' ": ' will ', + " \'ve ": ' have ', + r'^\'': '', + r'\'$': '', + } + for tmpl, good in baddata.items(): + text = re.sub(tmpl, good, text) + + text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', + text) # 'abc.xyz' -> 'abc . xyz' + text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' + + with open(os.path.join(data_dir, 'mapping.pair'), 'r') as fin: + for line in fin.readlines(): + fromx, tox = line.replace('\n', '').split('\t') + text = ' ' + text + ' ' + text = text.replace(' ' + fromx + ' ', ' ' + tox + ' ')[1:-1] + + return text + + +def clean_time(utter): + utter = re.sub(r'(\d+) ([ap]\.?m)', lambda x: x.group(1) + x.group(2), + utter) # 9 am -> 9am + utter = re.sub(r'((?3'} + else: + nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'} + if vector[:4] == [0, 0, 0, 0]: + report = '' + else: + num = vector.index(1) + report = domain + ': ' + nummap[num] + '; ' + + if vector[-2] == 0 and vector[-1] == 1: + report += 'booking: ok' + if vector[-2] == 1 and vector[-1] == 0: + report += 'booking: unable' + + return report + + def queryJsons(self, + domain, + constraints, + exactly_match=True, + return_name=False): + """Returns the list of entities for a given domain + based on the annotation of the belief state + constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'} + """ + # query the db + if domain == 'taxi': + return [{ + 'taxi_colors': + random.choice(self.dbs[domain]['taxi_colors']), + 'taxi_types': + random.choice(self.dbs[domain]['taxi_types']), + 'taxi_phone': [random.randint(1, 9) for _ in range(10)] + }] + if domain == 'police': + return self.dbs['police'] + if domain == 'hospital': + if constraints.get('department'): + for entry in self.dbs['hospital']: + if entry.get('department') == constraints.get( + 'department'): + return [entry] + else: + return [] + + valid_cons = False + for v in constraints.values(): + if v not in ['not mentioned', '']: + valid_cons = True + if not valid_cons: + return [] + + match_result = [] + + if 'name' in constraints: + for db_ent in self.dbs[domain]: + if 'name' in db_ent: + cons = constraints['name'] + dbn = db_ent['name'] + if cons == dbn: + db_ent = db_ent if not return_name else db_ent['name'] + match_result.append(db_ent) + return match_result + + for db_ent in self.dbs[domain]: + match = True + for s, v in constraints.items(): + if s == 'name': + continue + if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ + (domain == 'restaurant' and s in ['day', 'time']): + # These inform slots belong to "book info",which do not exist in DB + # "book" is according to the user goal,not DB + continue + + skip_case = { + "don't care": 1, + "do n't care": 1, + 'dont care': 1, + 'not mentioned': 1, + 'dontcare': 1, + '': 1 + } + if skip_case.get(v): + continue + + if s not in db_ent: + # logging.warning('Searching warning: slot %s not in %s db'%(s, domain)) + match = False + break + + # v = 'guesthouse' if v == 'guest house' else v + # v = 'swimmingpool' if v == 'swimming pool' else v + v = 'yes' if v == 'free' else v + + if s in ['arrive', 'leave']: + try: + h, m = v.split( + ':' + ) # raise error if time value is not xx:xx format + v = int(h) * 60 + int(m) + except Exception: + match = False + break + time = int(db_ent[s].split(':')[0]) * 60 + int( + db_ent[s].split(':')[1]) + if s == 'arrive' and v > time: + match = False + if s == 'leave' and v < time: + match = False + else: + if exactly_match and v != db_ent[s]: + match = False + break + elif v not in db_ent[s]: + match = False + break + + if match: + match_result.append(db_ent) + + if not return_name: + return match_result + else: + if domain == 'train': + match_result = [e['id'] for e in match_result] + else: + match_result = [e['name'] for e in match_result] + return match_result + + def querySQL(self, domain, constraints): + if not self.sql_dbs: + for dom in db_domains: + db = 'db/{}-dbase.db'.format(dom) + conn = sqlite3.connect(db) + c = conn.cursor() + self.sql_dbs[dom] = c + + sql_query = 'select * from {}'.format(domain) + + flag = True + for key, val in constraints.items(): + if val == '' \ + or val == 'dontcare' \ + or val == 'not mentioned' \ + or val == "don't care" \ + or val == 'dont care' \ + or val == "do n't care": + pass + else: + if flag: + sql_query += ' where ' + val2 = val.replace("'", "''") + # val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'" + else: + sql_query += r' ' + key + '=' + r"'" + val2 + r"'" + flag = False + else: + val2 = val.replace("'", "''") + # val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'" + else: + sql_query += r' and ' + key + '=' + r"'" + val2 + r"'" + + try: # "select * from attraction where name = 'queens college'" + print(sql_query) + return self.sql_dbs[domain].execute(sql_query).fetchall() + except Exception: + return [] # TODO test it + + +if __name__ == '__main__': + dbPATHs = { + 'attraction': 'db/attraction_db_processed.json', + 'hospital': 'db/hospital_db_processed.json', + 'hotel': 'db/hotel_db_processed.json', + 'police': 'db/police_db_processed.json', + 'restaurant': 'db/restaurant_db_processed.json', + 'taxi': 'db/taxi_db_processed.json', + 'train': 'db/train_db_processed.json', + } + db = MultiWozDB(dbPATHs) + while True: + constraints = {} + inp = input( + 'input belief state in fomat: domain-slot1=value1;slot2=value2...\n' + ) + domain, cons = inp.split('-') + for sv in cons.split(';'): + s, v = sv.split('=') + constraints[s] = v + # res = db.querySQL(domain, constraints) + res = db.queryJsons(domain, constraints, return_name=True) + report = [] + reidx = { + 'hotel': 8, + 'restaurant': 6, + 'attraction': 5, + 'train': 1, + } + print(constraints) + print(res) + print('count:', len(res), '\nnames:', report) diff --git a/modelscope/utils/nlp/space/ontology.py b/modelscope/utils/nlp/space/ontology.py new file mode 100644 index 00000000..c55d12e1 --- /dev/null +++ b/modelscope/utils/nlp/space/ontology.py @@ -0,0 +1,206 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +all_domains = [ + 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' +] +all_domains_with_bracket = ['[{}]'.format(item) for item in all_domains] +db_domains = ['restaurant', 'hotel', 'attraction', 'train'] +placeholder_tokens = [ + '', '', '', '', '', '', '', + '', '', '', '', '', '', + '', '', '' +] + +normlize_slot_names = { + 'car type': 'car', + 'entrance fee': 'price', + 'duration': 'time', + 'leaveat': 'leave', + 'arriveby': 'arrive', + 'trainid': 'id' +} + +requestable_slots = { + 'taxi': ['car', 'phone'], + 'police': ['postcode', 'address', 'phone'], + 'hospital': ['address', 'phone', 'postcode'], + 'hotel': [ + 'address', 'postcode', 'internet', 'phone', 'parking', 'type', + 'pricerange', 'stars', 'area', 'reference' + ], + 'attraction': + ['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'], + 'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'], + 'restaurant': [ + 'phone', 'postcode', 'address', 'pricerange', 'food', 'area', + 'reference' + ] +} +all_reqslot = [ + 'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type', + 'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave', + 'price', 'arrive', 'id' +] + +informable_slots = { + 'taxi': ['leave', 'destination', 'departure', 'arrive'], + 'police': [], + 'hospital': ['department'], + 'hotel': [ + 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', + 'area', 'stars', 'name' + ], + 'attraction': ['area', 'type', 'name'], + 'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'], + 'restaurant': + ['food', 'pricerange', 'area', 'name', 'time', 'day', 'people'] +} +all_infslot = [ + 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', + 'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive', + 'department', 'food', 'time' +] + +all_slots = all_reqslot + [ + 'stay', 'day', 'people', 'name', 'destination', 'departure', 'department' +] +get_slot = {} +for s in all_slots: + get_slot[s] = 1 + +# mapping slots in dialogue act to original goal slot names +da_abbr_to_slot_name = { + 'addr': 'address', + 'fee': 'price', + 'post': 'postcode', + 'ref': 'reference', + 'ticket': 'price', + 'depart': 'departure', + 'dest': 'destination', +} + +dialog_acts = { + 'restaurant': [ + 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', + 'offerbooked', 'nobook' + ], + 'hotel': [ + 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', + 'offerbooked', 'nobook' + ], + 'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'], + 'train': + ['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'], + 'taxi': ['inform', 'request'], + 'police': ['inform', 'request'], + 'hospital': ['inform', 'request'], + # 'booking': ['book', 'inform', 'nobook', 'request'], + 'general': ['bye', 'greet', 'reqmore', 'welcome'], +} +all_acts = [] +for acts in dialog_acts.values(): + for act in acts: + if act not in all_acts: + all_acts.append(act) + +dialog_act_params = { + 'inform': all_slots + ['choice', 'open'], + 'request': all_infslot + ['choice', 'price'], + 'nooffer': all_slots + ['choice'], + 'recommend': all_reqslot + ['choice', 'open'], + 'select': all_slots + ['choice'], + # 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], + 'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], + 'offerbook': all_slots + ['choice'], + 'offerbooked': all_slots + ['choice'], + 'reqmore': [], + 'welcome': [], + 'bye': [], + 'greet': [], +} + +dialog_act_all_slots = all_slots + ['choice', 'open'] + +# special slot tokens in belief span +# no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] +slot_name_to_slot_token = {} + +# eos tokens definition +eos_tokens = { + 'user': '', + 'user_delex': '', + 'resp': '', + 'resp_gen': '', + 'pv_resp': '', + 'bspn': '', + 'bspn_gen': '', + 'pv_bspn': '', + 'bsdx': '', + 'bsdx_gen': '', + 'pv_bsdx': '', + 'qspn': '', + 'qspn_gen': '', + 'pv_qspn': '', + 'aspn': '', + 'aspn_gen': '', + 'pv_aspn': '', + 'dspn': '', + 'dspn_gen': '', + 'pv_dspn': '' +} + +# sos tokens definition +sos_tokens = { + 'user': '', + 'user_delex': '', + 'resp': '', + 'resp_gen': '', + 'pv_resp': '', + 'bspn': '', + 'bspn_gen': '', + 'pv_bspn': '', + 'bsdx': '', + 'bsdx_gen': '', + 'pv_bsdx': '', + 'qspn': '', + 'qspn_gen': '', + 'pv_qspn': '', + 'aspn': '', + 'aspn_gen': '', + 'pv_aspn': '', + 'dspn': '', + 'dspn_gen': '', + 'pv_dspn': '' +} + +# db tokens definition +db_tokens = [ + '', '', '[book_nores]', '[book_fail]', '[book_success]', + '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' +] + + +# understand tokens definition +def get_understand_tokens(prompt_num_for_understand): + understand_tokens = [] + for i in range(prompt_num_for_understand): + understand_tokens.append(f'') + return understand_tokens + + +# policy tokens definition +def get_policy_tokens(prompt_num_for_policy): + policy_tokens = [] + for i in range(prompt_num_for_policy): + policy_tokens.append(f'') + return policy_tokens + + +# all special tokens definition +def get_special_tokens(other_tokens): + special_tokens = [ + '', '', '', '', '', '', + '', '', '', '', '', '', + '', '', '', '' + ] + db_tokens + other_tokens + return special_tokens diff --git a/modelscope/utils/nlp/space/scores.py b/modelscope/utils/nlp/space/scores.py new file mode 100644 index 00000000..eb6dd41c --- /dev/null +++ b/modelscope/utils/nlp/space/scores.py @@ -0,0 +1,9 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + + +def hierarchical_set_score(frame1, frame2): + # deal with empty frame + if not (frame1 and frame2): + return 0. + pass + return 0. diff --git a/modelscope/utils/nlp/space/utils.py b/modelscope/utils/nlp/space/utils.py new file mode 100644 index 00000000..56e67671 --- /dev/null +++ b/modelscope/utils/nlp/space/utils.py @@ -0,0 +1,194 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +from collections import OrderedDict + +import json +import numpy as np + +from modelscope.utils.logger import get_logger +from . import ontology + +logger = get_logger() + + +def max_lens(X): + lens = [len(X)] + while isinstance(X[0], list): + lens.append(max(map(len, X))) + X = [x for xs in X for x in xs] + return lens + + +def list2np(X: object, padding: object = 0, dtype: object = 'int64') -> object: + shape = max_lens(X) + ret = np.full(shape, padding, dtype=np.int32) + + if len(shape) == 1: + ret = np.array(X) + elif len(shape) == 2: + for i, x in enumerate(X): + ret[i, :len(x)] = np.array(x) + elif len(shape) == 3: + for i, xs in enumerate(X): + for j, x in enumerate(xs): + ret[i, j, :len(x)] = np.array(x) + return ret.astype(dtype) + + +def clean_replace(s, r, t, forward=True, backward=False): + + def clean_replace_single(s, r, t, forward, backward, sidx=0): + # idx = s[sidx:].find(r) + idx = s.find(r) + if idx == -1: + return s, -1 + idx_r = idx + len(r) + if backward: + while idx > 0 and s[idx - 1]: + idx -= 1 + elif idx > 0 and s[idx - 1] != ' ': + return s, -1 + + if forward: + while \ + idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): + idx_r += 1 + elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): + return s, -1 + return s[:idx] + t + s[idx_r:], idx_r + + sidx = 0 + while sidx != -1: + s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) + return s + + +def py2np(list): + return np.array(list) + + +def write_dict(fn, dic): + with open(fn, 'w') as f: + json.dump(dic, f, indent=2) + + +def f1_score(label_list, pred_list): + tp = len([t for t in pred_list if t in label_list]) + fp = max(0, len(pred_list) - tp) + fn = max(0, len(label_list) - tp) + precision = tp / (tp + fp + 1e-10) + recall = tp / (tp + fn + 1e-10) + f1 = 2 * precision * recall / (precision + recall + 1e-10) + return f1 + + +class MultiWOZVocab(object): + + def __init__(self, vocab_size=0): + """ + vocab for multiwoz dataset + """ + self.vocab_size = vocab_size + self.vocab_size_oov = 0 # get after construction + self._idx2word = {} # word + oov + self._word2idx = {} # word + self._freq_dict = {} # word + oov + for w in [ + '[PAD]', '', '[UNK]', '', '', '', + '', '', '', '', '' + ]: + self._absolute_add_word(w) + + def _absolute_add_word(self, w): + idx = len(self._idx2word) + self._idx2word[idx] = w + self._word2idx[w] = idx + + def add_word(self, word): + if word not in self._freq_dict: + self._freq_dict[word] = 0 + self._freq_dict[word] += 1 + + def has_word(self, word): + return self._freq_dict.get(word) + + def _add_to_vocab(self, word): + if word not in self._word2idx: + idx = len(self._idx2word) + self._idx2word[idx] = word + self._word2idx[word] = idx + + def construct(self): + freq_dict_sorted = sorted( + self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) + logger.info('Vocabulary size including oov: %d' % + (len(freq_dict_sorted) + len(self._idx2word))) + if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size: + logging.warning( + 'actual label set smaller than that configured: {}/{}'.format( + len(freq_dict_sorted) + len(self._idx2word), + self.vocab_size)) + for word in ontology.all_domains + ['general']: + word = '[' + word + ']' + self._add_to_vocab(word) + for word in ontology.all_acts: + word = '[' + word + ']' + self._add_to_vocab(word) + for word in ontology.all_slots: + self._add_to_vocab(word) + for word in freq_dict_sorted: + if word.startswith('[value_') and word.endswith(']'): + self._add_to_vocab(word) + for word in freq_dict_sorted: + self._add_to_vocab(word) + self.vocab_size_oov = len(self._idx2word) + + def load_vocab(self, vocab_path): + self._freq_dict = json.loads( + open(vocab_path + '.freq.json', 'r').read()) + self._word2idx = json.loads( + open(vocab_path + '.word2idx.json', 'r').read()) + self._idx2word = {} + for w, idx in self._word2idx.items(): + self._idx2word[idx] = w + self.vocab_size_oov = len(self._idx2word) + logger.info('vocab file loaded from "' + vocab_path + '"') + logger.info('Vocabulary size including oov: %d' % + (self.vocab_size_oov)) + + def save_vocab(self, vocab_path): + _freq_dict = OrderedDict( + sorted( + self._freq_dict.items(), key=lambda kv: kv[1], reverse=True)) + write_dict(vocab_path + '.word2idx.json', self._word2idx) + write_dict(vocab_path + '.freq.json', _freq_dict) + + def encode(self, word, include_oov=True): + if include_oov: + if self._word2idx.get(word, None) is None: + raise ValueError( + 'Unknown word: %s. Vocabulary should include oovs here.' + % word) + return self._word2idx[word] + else: + word = '' if word not in self._word2idx else word + return self._word2idx[word] + + def sentence_encode(self, word_list): + return [self.encode(_) for _ in word_list] + + def oov_idx_map(self, idx): + return 2 if idx > self.vocab_size else idx + + def sentence_oov_map(self, index_list): + return [self.oov_idx_map(_) for _ in index_list] + + def decode(self, idx, indicate_oov=False): + if not self._idx2word.get(idx): + raise ValueError( + 'Error idx: %d. Vocabulary should include oovs here.' % idx) + if not indicate_oov or idx < self.vocab_size: + return self._idx2word[idx] + else: + return self._idx2word[idx] + '(o)' diff --git a/modelscope/utils/nlp/space/utils_dst.py b/modelscope/utils/nlp/space/utils_dst.py new file mode 100644 index 00000000..6277172e --- /dev/null +++ b/modelscope/utils/nlp/space/utils_dst.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List + +from modelscope.outputs import OutputKeys +from modelscope.pipelines.nlp import DialogStateTrackingPipeline + + +def tracking_and_print_dialog_states( + test_case, pipelines: List[DialogStateTrackingPipeline]): + import json + pipelines_len = len(pipelines) + history_states = [{}] + utter = {} + for step, item in enumerate(test_case): + utter.update(item) + result = pipelines[step % pipelines_len]({ + 'utter': + utter, + 'history_states': + history_states + }) + print(json.dumps(result)) + + history_states.extend([result[OutputKeys.OUTPUT], {}]) + + +def batch_to_device(batch, device): + batch_on_device = [] + for element in batch: + if isinstance(element, dict): + batch_on_device.append( + {k: v.to(device) + for k, v in element.items()}) + else: + batch_on_device.append(element.to(device)) + return tuple(batch_on_device) diff --git a/modelscope/utils/nlp/space_T_en/__init__.py b/modelscope/utils/nlp/space_T_en/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/utils/nlp/space_T_en/utils.py b/modelscope/utils/nlp/space_T_en/utils.py new file mode 100644 index 00000000..d884c241 --- /dev/null +++ b/modelscope/utils/nlp/space_T_en/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import List + +from modelscope.outputs import OutputKeys +from modelscope.pipelines.nlp import ConversationalTextToSqlPipeline + + +def text2sql_tracking_and_print_results( + test_case, pipelines: List[ConversationalTextToSqlPipeline]): + for p in pipelines: + last_sql, history = '', [] + for item in test_case['utterance']: + case = { + 'utterance': item, + 'history': history, + 'last_sql': last_sql, + 'database_id': test_case['database_id'], + 'local_db_path': test_case['local_db_path'] + } + results = p(case) + print({'question': item}) + print(results) + last_sql = results[OutputKeys.OUTPUT][OutputKeys.TEXT] + history.append(item) diff --git a/modelscope/utils/nlp/utils.py b/modelscope/utils/nlp/utils.py new file mode 100644 index 00000000..13a21480 --- /dev/null +++ b/modelscope/utils/nlp/utils.py @@ -0,0 +1,20 @@ +import os.path as osp + + +def import_external_nltk_data(nltk_data_dir, package_name): + """import external nltk_data, and extract nltk zip package. + + Args: + nltk_data_dir (str): external nltk_data dir path, eg. /home/xx/nltk_data + package_name (str): nltk package name, eg. tokenizers/punkt + """ + import nltk + nltk.data.path.append(nltk_data_dir) + + filepath = osp.join(nltk_data_dir, package_name + '.zip') + zippath = osp.join(nltk_data_dir, package_name) + packagepath = osp.dirname(zippath) + if not osp.exists(zippath): + import zipfile + with zipfile.ZipFile(filepath) as zf: + zf.extractall(osp.join(packagepath)) diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py new file mode 100644 index 00000000..5284aa43 --- /dev/null +++ b/modelscope/utils/registry.py @@ -0,0 +1,214 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import importlib +import inspect +from typing import List, Tuple, Union + +from modelscope.utils.logger import get_logger + +TYPE_NAME = 'type' +default_group = 'default' +logger = get_logger() +AST_INDEX = None + + +class Registry(object): + """ Registry which support registering modules and group them by a keyname + + If group name is not provided, modules will be registered to default group. + """ + + def __init__(self, name: str): + self._name = name + self._modules = {default_group: {}} + + def __repr__(self): + format_str = self.__class__.__name__ + f' ({self._name})\n' + for group_name, group in self._modules.items(): + format_str += f'group_name={group_name}, '\ + f'modules={list(group.keys())}\n' + + return format_str + + @property + def name(self): + return self._name + + @property + def modules(self): + return self._modules + + def list(self): + """ logging the list of module in current registry + """ + for group_name, group in self._modules.items(): + logger.info(f'group_name={group_name}') + for m in group.keys(): + logger.info(f'\t{m}') + logger.info('') + + def get(self, module_key, group_key=default_group): + if group_key not in self._modules: + return None + else: + return self._modules[group_key].get(module_key, None) + + def _register_module(self, + group_key=default_group, + module_name=None, + module_cls=None, + force=False): + assert isinstance(group_key, + str), 'group_key is required and must be str' + + if group_key not in self._modules: + self._modules[group_key] = dict() + + if not inspect.isclass(module_cls): + raise TypeError(f'module is not a class type: {type(module_cls)}') + + if module_name is None: + module_name = module_cls.__name__ + + if module_name in self._modules[group_key] and not force: + raise KeyError(f'{module_name} is already registered in ' + f'{self._name}[{group_key}]') + self._modules[group_key][module_name] = module_cls + module_cls.group_key = group_key + + def register_module(self, + group_key: str = default_group, + module_name: str = None, + module_cls: type = None, + force=False): + """ Register module + + Example: + >>> models = Registry('models') + >>> @models.register_module('image-classification', 'SwinT') + >>> class SwinTransformer: + >>> pass + + >>> @models.register_module('SwinDefault') + >>> class SwinTransformerDefaultGroup: + >>> pass + + >>> class SwinTransformer2: + >>> pass + >>> MODELS.register_module('image-classification', + module_name='SwinT2', + module_cls=SwinTransformer2) + + Args: + group_key: Group name of which module will be registered, + default group name is 'default' + module_name: Module name + module_cls: Module class object + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + + """ + if not (module_name is None or isinstance(module_name, str)): + raise TypeError(f'module_name must be either of None, str,' + f'got {type(module_name)}') + if module_cls is not None: + self._register_module( + group_key=group_key, + module_name=module_name, + module_cls=module_cls, + force=force) + return module_cls + + # if module_cls is None, should return a decorator function + def _register(module_cls): + self._register_module( + group_key=group_key, + module_name=module_name, + module_cls=module_cls, + force=force) + return module_cls + + return _register + + +def build_from_cfg(cfg, + registry: Registry, + group_key: str = default_group, + default_args: dict = None) -> object: + """Build a module from config dict when it is a class configuration, or + call a function from config dict when it is a function configuration. + + Example: + >>> models = Registry('models') + >>> @models.register_module('image-classification', 'SwinT') + >>> class SwinTransformer: + >>> pass + >>> swint = build_from_cfg(dict(type='SwinT'), MODELS, + >>> 'image-classification') + >>> # Returns an instantiated object + >>> + >>> @MODELS.register_module() + >>> def swin_transformer(): + >>> pass + >>> = build_from_cfg(dict(type='swin_transformer'), MODELS) + >>> # Return a result of the calling function + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + group_key (str, optional): The name of registry group from which + module should be searched. + default_args (dict, optional): Default initialization arguments. + type_name (str, optional): The name of the type in the config. + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if TYPE_NAME not in cfg: + if default_args is None or TYPE_NAME not in default_args: + raise KeyError( + f'`cfg` or `default_args` must contain the key "{TYPE_NAME}", ' + f'but got {cfg}\n{default_args}') + if not isinstance(registry, Registry): + raise TypeError('registry must be an modelscope.Registry object, ' + f'but got {type(registry)}') + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError('default_args must be a dict or None, ' + f'but got {type(default_args)}') + + # dynamic load installation requirements for this module + from modelscope.utils.import_utils import LazyImportModule + sig = (registry.name.upper(), group_key, cfg['type']) + LazyImportModule.import_module(sig) + + args = cfg.copy() + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + if group_key is None: + group_key = default_group + + obj_type = args.pop(TYPE_NAME) + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type, group_key=group_key) + if obj_cls is None: + raise KeyError( + f'{obj_type} is not in the {registry.name}' + f' registry group {group_key}. Please make' + f' sure the correct version of ModelScope library is used.') + obj_cls.group_key = group_key + elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + try: + if hasattr(obj_cls, '_instantiate'): + return obj_cls._instantiate(**args) + else: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f'{obj_cls.__name__}: {e}') diff --git a/modelscope/utils/regress_test_utils.py b/modelscope/utils/regress_test_utils.py new file mode 100644 index 00000000..58b5b1a3 --- /dev/null +++ b/modelscope/utils/regress_test_utils.py @@ -0,0 +1,779 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import contextlib +import hashlib +import os +import pickle +import random +import re +import shutil +import tempfile +from collections import OrderedDict +from collections.abc import Mapping +from pathlib import Path +from types import FunctionType +from typing import Any, Dict, Union + +import json +import numpy as np +import torch +import torch.optim +from torch import nn + +from modelscope.utils.service_utils import NumpyEncoder + + +class RegressTool: + """This class is used to stop inference/training results from changing by some unaware affections by unittests. + + Firstly, run a baseline test to create a result file, then changes can be observed between + the latest version and the baseline file. + """ + + def __init__(self, + baseline: bool = None, + store_func: FunctionType = None, + load_func: FunctionType = None): + """A func to store the baseline file and a func to load the baseline file. + """ + self.baseline = baseline + self.store_func = store_func + self.load_func = load_func + print(f'Current working dir is: {Path.cwd()}') + + def store(self, local, remote): + if self.store_func is not None: + self.store_func(local, remote) + else: + path = os.path.abspath( + os.path.join(Path.cwd(), 'data', 'test', 'regression')) + os.makedirs(path, exist_ok=True) + shutil.copy(local, os.path.join(path, remote)) + + def load(self, local, remote): + if self.load_func is not None: + self.load_func(local, remote) + else: + path = os.path.abspath( + os.path.join(Path.cwd(), 'data', 'test', 'regression')) + baseline = os.path.join(path, remote) + if not os.path.exists(baseline): + raise ValueError(f'base line file {baseline} not exist') + print( + f'local file found:{baseline}, md5:{hashlib.md5(open(baseline,"rb").read()).hexdigest()}' + ) + if os.path.exists(local): + os.remove(local) + os.symlink(baseline, local, target_is_directory=False) + + @contextlib.contextmanager + def monitor_module_single_forward(self, + module: nn.Module, + file_name: str, + compare_fn=None, + **kwargs): + """Monitor a pytorch module in a single forward. + + Args: + module: A torch module + file_name: The file_name to store or load file + compare_fn: A custom fn used to compare the results manually. + + >>> def compare_fn(v1, v2, key, type): + >>> return None + + v1 is the baseline value + v2 is the value of current version + key is the key of submodules + type is in one of 'input', 'output' + + kwargs: + atol: The absolute gap between two np arrays. + rtol: The relative gap between two np arrays. + """ + baseline = os.getenv('REGRESSION_BASELINE') + if baseline is None or self.baseline is None: + yield + return + + baseline = self.baseline + io_json = {} + absolute_path = f'./{file_name}.bin' + if not isinstance(module, nn.Module): + assert hasattr(module, 'model') + module = module.model + + hack_forward(module, file_name, io_json) + intercept_module(module, io_json) + yield + hack_forward(module, None, None, restore=True) + intercept_module(module, None, restore=True) + if baseline: + with open(absolute_path, 'wb') as f: + pickle.dump(io_json, f) + self.store(absolute_path, f'{file_name}.bin') + os.remove(absolute_path) + else: + name = os.path.basename(absolute_path) + baseline = os.path.join(tempfile.gettempdir(), name) + self.load(baseline, name) + with open(baseline, 'rb') as f: + base = pickle.load(f) + + print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}') + print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') + if not compare_io_and_print(base, io_json, compare_fn, **kwargs): + raise ValueError('Result not match!') + + @contextlib.contextmanager + def monitor_module_train(self, + trainer: Union[Dict, Any], + file_name, + level='config', + compare_fn=None, + ignore_keys=None, + compare_random=True, + reset_dropout=True, + lazy_stop_callback=None, + **kwargs): + """Monitor a pytorch module's backward data and cfg data within a step of the optimizer. + + This is usually useful when you try to change some dangerous code + which has the risk of affecting the training loop. + + Args: + trainer: A dict or an object contains the model/optimizer/lr_scheduler + file_name: The file_name to store or load file + level: The regression level. + 'strict' for matching every single tensor. + Please make sure the parameters of head are fixed + and the drop-out rate is zero. + 'config' for matching the initial config, like cfg file, optimizer param_groups, + lr_scheduler params and the random seed. + 'metric' for compare the best metrics in the evaluation loop. + compare_fn: A custom fn used to compare the results manually. + ignore_keys: The keys to ignore of the named_parameters. + compare_random: If to compare random setttings, default True. + reset_dropout: Reset all dropout modules to 0.0. + lazy_stop_callback: A callback passed in, when the moniting is over, this callback will be called. + kwargs: + atol: The absolute gap between two np arrays. + rtol: The relative gap between two np arrays. + + >>> def compare_fn(v1, v2, key, type): + >>> return None + + v1 is the baseline value + v2 is the value of current version + key is the key of modules/parameters + type is in one of 'input', 'output', 'backward', 'optimizer', 'lr_scheduler', 'cfg', 'state' + """ + baseline = os.getenv('REGRESSION_BASELINE') + if baseline is None or self.baseline is None: + yield + return + + baseline = self.baseline + + io_json = {} + bw_json = {} + absolute_path = f'./{file_name}.bin' + + if level == 'strict': + print( + "[Important] The level of regression is 'strict', please make sure your model's parameters are " + 'fixed and all drop-out rates have been set to zero.') + + assert hasattr( + trainer, 'model') or 'model' in trainer, 'model must be in trainer' + module = trainer['model'] if isinstance(trainer, + dict) else trainer.model + if not isinstance(module, nn.Module): + assert hasattr(module, 'model') + module = module.model + + assert hasattr( + trainer, 'optimizer' + ) or 'optimizer' in trainer, 'optimizer must be in trainer' + assert hasattr( + trainer, 'lr_scheduler' + ) or 'lr_scheduler' in trainer, 'lr_scheduler must be in trainer' + optimizer: torch.optim.Optimizer = trainer['optimizer'] if isinstance( + trainer, dict) else trainer.optimizer + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = trainer['lr_scheduler'] if isinstance(trainer, dict) \ + else trainer.lr_scheduler + torch_state = numpify_tensor_nested(torch.get_rng_state()) + np_state = np.random.get_state() + random_seed = random.getstate() + seed = trainer._seed if hasattr( + trainer, + '_seed') else trainer.seed if hasattr(trainer, 'seed') else None + + if reset_dropout: + with torch.no_grad(): + + def reinit_dropout(_module): + for name, submodule in _module.named_children(): + if isinstance(submodule, torch.nn.Dropout): + setattr(_module, name, torch.nn.Dropout(0.)) + else: + reinit_dropout(submodule) + + reinit_dropout(module) + + if level == 'strict': + hack_forward(module, file_name, io_json) + intercept_module(module, io_json) + hack_backward( + module, optimizer, bw_json, lazy_stop_callback=lazy_stop_callback) + yield + hack_backward(module, optimizer, None, restore=True) + if level == 'strict': + hack_forward(module, None, None, restore=True) + intercept_module(module, None, restore=True) + + optimizer_dict = optimizer.state_dict() + optimizer_dict.pop('state', None) + summary = { + 'forward': io_json, + 'backward': bw_json, + 'optimizer': { + 'type': optimizer.__class__.__name__, + 'defaults': optimizer.defaults, + 'state_dict': optimizer_dict + }, + 'lr_scheduler': { + 'type': lr_scheduler.__class__.__name__, + 'state_dict': lr_scheduler.state_dict() + }, + 'cfg': trainer.cfg.to_dict() if hasattr(trainer, 'cfg') else None, + 'state': { + 'torch_state': torch_state, + 'np_state': np_state, + 'random_seed': random_seed, + 'seed': seed, + } + } + + if baseline: + with open(absolute_path, 'wb') as f: + pickle.dump(summary, f) + self.store(absolute_path, f'{file_name}.bin') + os.remove(absolute_path) + else: + name = os.path.basename(absolute_path) + baseline = os.path.join(tempfile.gettempdir(), name) + self.load(baseline, name) + with open(baseline, 'rb') as f: + baseline_json = pickle.load(f) + + if level == 'strict' and not compare_io_and_print( + baseline_json['forward'], io_json, compare_fn, **kwargs): + raise RuntimeError('Forward not match!') + if not compare_backward_and_print( + baseline_json['backward'], + bw_json, + compare_fn=compare_fn, + ignore_keys=ignore_keys, + level=level, + **kwargs): + raise RuntimeError('Backward not match!') + cfg_opt1 = { + 'optimizer': baseline_json['optimizer'], + 'lr_scheduler': baseline_json['lr_scheduler'], + 'cfg': baseline_json['cfg'], + 'state': None if not compare_random else baseline_json['state'] + } + cfg_opt2 = { + 'optimizer': summary['optimizer'], + 'lr_scheduler': summary['lr_scheduler'], + 'cfg': summary['cfg'], + 'state': None if not compare_random else summary['state'] + } + if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn, + **kwargs): + raise RuntimeError('Cfg or optimizers not match!') + + +class MsRegressTool(RegressTool): + + class EarlyStopError(Exception): + pass + + @contextlib.contextmanager + def monitor_ms_train(self, + trainer, + file_name, + level='config', + compare_fn=None, + ignore_keys=None, + compare_random=True, + lazy_stop_callback=None, + **kwargs): + + if lazy_stop_callback is None: + + def lazy_stop_callback(): + + from modelscope.trainers.hooks.hook import Hook, Priority + + class EarlyStopHook(Hook): + PRIORITY = Priority.VERY_LOW + + def after_iter(self, trainer): + raise MsRegressTool.EarlyStopError('Test finished.') + + trainer.register_hook(EarlyStopHook()) + + def _train_loop(trainer, *args_train, **kwargs_train): + with self.monitor_module_train( + trainer, + file_name, + level, + compare_fn=compare_fn, + ignore_keys=ignore_keys, + compare_random=compare_random, + lazy_stop_callback=lazy_stop_callback, + **kwargs): + try: + return trainer.train_loop_origin(*args_train, + **kwargs_train) + except MsRegressTool.EarlyStopError: + pass + + trainer.train_loop_origin, trainer.train_loop = \ + trainer.train_loop, type(trainer.train_loop)(_train_loop, trainer) + yield + + +def compare_module(module1: nn.Module, module2: nn.Module): + for p1, p2 in zip(module1.parameters(), module2.parameters()): + if p1.data.ne(p2.data).sum() > 0: + return False + return True + + +def numpify_tensor_nested(tensors, reduction=None, clip_value=10000): + try: + from modelscope.outputs import ModelOutputBase + except ImportError: + ModelOutputBase = dict + "Numpify `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (Mapping, ModelOutputBase)): + return OrderedDict({ + k: numpify_tensor_nested(t, reduction, clip_value) + for k, t in tensors.items() + }) + if isinstance(tensors, list): + return list( + numpify_tensor_nested(t, reduction, clip_value) for t in tensors) + if isinstance(tensors, tuple): + return tuple( + numpify_tensor_nested(t, reduction, clip_value) for t in tensors) + if isinstance(tensors, torch.Tensor): + t: np.ndarray = tensors.cpu().numpy() + if clip_value is not None: + t = np.where(t > clip_value, clip_value, t) + t = np.where(t < -clip_value, -clip_value, t) + if reduction == 'sum': + return t.sum(dtype=np.float) + elif reduction == 'mean': + return t.mean(dtype=np.float) + return t + return tensors + + +def detach_tensor_nested(tensors): + try: + from modelscope.outputs import ModelOutputBase + except ImportError: + ModelOutputBase = dict + "Detach `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (Mapping, ModelOutputBase)): + return OrderedDict( + {k: detach_tensor_nested(t) + for k, t in tensors.items()}) + if isinstance(tensors, list): + return list(detach_tensor_nested(t) for t in tensors) + if isinstance(tensors, tuple): + return tuple(detach_tensor_nested(t) for t in tensors) + if isinstance(tensors, torch.Tensor): + return tensors.detach() + return tensors + + +def hack_forward(module: nn.Module, + name, + io_json, + restore=False, + keep_tensors=False): + + def _forward(self, *args, **kwargs): + ret = self.forward_origin(*args, **kwargs) + if keep_tensors: + args = numpify_tensor_nested(detach_tensor_nested(args)) + kwargs = numpify_tensor_nested(detach_tensor_nested(kwargs)) + output = numpify_tensor_nested(detach_tensor_nested(ret)) + else: + args = { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(args), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(args), reduction='mean'), + } + kwargs = { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(kwargs), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(kwargs), reduction='mean'), + } + output = { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(ret), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(ret), reduction='mean'), + } + + io_json[name] = { + 'input': { + 'args': args, + 'kwargs': kwargs, + }, + 'output': output, + } + return ret + + if not restore and not hasattr(module, 'forward_origin'): + module.forward_origin, module.forward = module.forward, type( + module.forward)(_forward, module) + if restore and hasattr(module, 'forward_origin'): + module.forward = module.forward_origin + del module.forward_origin + + +def hack_backward(module: nn.Module, + optimizer, + io_json, + restore=False, + lazy_stop_callback=None): + + def _step(self, *args, **kwargs): + for name, param in module.named_parameters(): + io_json[name] = { + 'data': { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(param.data), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(param.data), reduction='mean'), + }, + 'grad': { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(param.grad), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(param.grad), reduction='mean'), + } + } + ret = self.step_origin(*args, **kwargs) + for name, param in module.named_parameters(): + io_json[name]['data_after'] = { + 'sum': + numpify_tensor_nested( + detach_tensor_nested(param.data), reduction='sum'), + 'mean': + numpify_tensor_nested( + detach_tensor_nested(param.data), reduction='mean'), + } + if lazy_stop_callback is not None: + lazy_stop_callback() + return ret + + if not restore and not hasattr(optimizer, 'step_origin'): + optimizer.step_origin, optimizer.step = optimizer.step, type( + optimizer.state_dict)(_step, optimizer) + if restore and hasattr(optimizer, 'step_origin'): + optimizer.step = optimizer.step_origin + del optimizer.step_origin + + +def intercept_module(module: nn.Module, + io_json, + parent_name=None, + restore=False): + for name, module in module.named_children(): + full_name = parent_name + '.' + name if parent_name is not None else name + hack_forward(module, full_name, io_json, restore) + intercept_module(module, io_json, full_name, restore) + + +def compare_arguments_nested(print_content, + arg1, + arg2, + rtol=1.e-3, + atol=1.e-8): + type1 = type(arg1) + type2 = type(arg2) + if type1.__name__ != type2.__name__: + if print_content is not None: + print( + f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}' + ) + return False + + if arg1 is None: + return True + elif isinstance(arg1, (int, str, bool, np.bool, np.integer, np.str)): + if arg1 != arg2: + if print_content is not None: + print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') + return False + return True + elif isinstance(arg1, (float, np.floating)): + if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True): + if print_content is not None: + print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') + return False + return True + elif isinstance(arg1, (tuple, list)): + if len(arg1) != len(arg2): + if print_content is not None: + print( + f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}' + ) + return False + if not all([ + compare_arguments_nested( + None, sub_arg1, sub_arg2, rtol=rtol, atol=atol) + for sub_arg1, sub_arg2 in zip(arg1, arg2) + ]): + if print_content is not None: + print(f'{print_content}') + return False + return True + elif isinstance(arg1, Mapping): + keys1 = arg1.keys() + keys2 = arg2.keys() + if len(keys1) != len(keys2): + if print_content is not None: + print( + f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}' + ) + return False + if len(set(keys1) - set(keys2)) > 0: + if print_content is not None: + print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') + return False + if not all([ + compare_arguments_nested( + None, arg1[key], arg2[key], rtol=rtol, atol=atol) + for key in keys1 + ]): + if print_content is not None: + print(f'{print_content}') + return False + return True + elif isinstance(arg1, np.ndarray): + arg1 = np.where(np.equal(arg1, None), np.NaN, + arg1).astype(dtype=np.float) + arg2 = np.where(np.equal(arg2, None), np.NaN, + arg2).astype(dtype=np.float) + if not all( + np.isclose(arg1, arg2, rtol=rtol, atol=atol, + equal_nan=True).flatten()): + if print_content is not None: + print(f'{print_content}') + return False + return True + else: + raise ValueError(f'type not supported: {type1}') + + +def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs): + if compare_fn is None: + + def compare_fn(*args, **kwargs): + return None + + keys1 = set(baseline_json.keys()) + keys2 = set(io_json.keys()) + added = keys1 - keys2 + removed = keys2 - keys1 + print(f'unmatched keys: {added}, {removed}') + shared_keys = keys1.intersection(keys2) + match = True + for key in shared_keys: + v1 = baseline_json[key] + v2 = io_json[key] + + v1input = numpify_tensor_nested(v1['input']) + v2input = numpify_tensor_nested(v2['input']) + res = compare_fn(v1input, v2input, key, 'input') + if res is not None: + print( + f'input of {key} compared with user compare_fn with result:{res}\n' + ) + match = match and res + else: + match = compare_arguments_nested( + f'unmatched module {key} input args', v1input['args'], + v2input['args'], **kwargs) and match + match = compare_arguments_nested( + f'unmatched module {key} input kwargs', v1input['kwargs'], + v2input['kwargs'], **kwargs) and match + v1output = numpify_tensor_nested(v1['output']) + v2output = numpify_tensor_nested(v2['output']) + res = compare_fn(v1output, v2output, key, 'output') + if res is not None: + print( + f'output of {key} compared with user compare_fn with result:{res}\n' + ) + match = match and res + else: + match = compare_arguments_nested( + f'unmatched module {key} outputs', + arg1=v1output, + arg2=v2output, + **kwargs) and match + return match + + +def compare_backward_and_print(baseline_json, + bw_json, + level, + ignore_keys=None, + compare_fn=None, + **kwargs): + if compare_fn is None: + + def compare_fn(*args, **kwargs): + return None + + keys1 = set(baseline_json.keys()) + keys2 = set(bw_json.keys()) + added = keys1 - keys2 + removed = keys2 - keys1 + print(f'unmatched backward keys: {added}, {removed}') + shared_keys = keys1.intersection(keys2) + match = True + for key in shared_keys: + if ignore_keys is not None and key in ignore_keys: + continue + + res = compare_fn(baseline_json[key], bw_json[key], key, 'backward') + if res is not None: + print(f'backward data of {key} compared with ' + f'user compare_fn with result:{res}\n') + match = match and res + else: + data1, grad1, data_after1 = baseline_json[key][ + 'data'], baseline_json[key]['grad'], baseline_json[key][ + 'data_after'] + data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ + 'grad'], bw_json[key]['data_after'] + match = compare_arguments_nested( + f'unmatched module {key} tensor data', + arg1=data1, + arg2=data2, + **kwargs) and match + if level == 'strict': + match = compare_arguments_nested( + f'unmatched module {key} grad data', + arg1=grad1, + arg2=grad2, + **kwargs) and match + match = compare_arguments_nested( + f'unmatched module {key} data after step', data_after1, + data_after2, **kwargs) and match + return match + + +def compare_cfg_and_optimizers(baseline_json, + cfg_json, + compare_fn=None, + **kwargs): + if compare_fn is None: + + def compare_fn(*args, **kwargs): + return None + + optimizer1, lr_scheduler1, cfg1, state1 = baseline_json[ + 'optimizer'], baseline_json['lr_scheduler'], baseline_json[ + 'cfg'], baseline_json['state'] + optimizer2, lr_scheduler2, cfg2, state2 = cfg_json['optimizer'], cfg_json[ + 'lr_scheduler'], cfg_json['cfg'], baseline_json['state'] + + match = True + res = compare_fn(optimizer1, optimizer2, None, 'optimizer') + if res is not None: + print(f'optimizer compared with user compare_fn with result:{res}\n') + match = match and res + else: + if optimizer1['type'] != optimizer2['type']: + print( + f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" + ) + match = compare_arguments_nested( + 'unmatched optimizer defaults', optimizer1['defaults'], + optimizer2['defaults'], **kwargs) and match + match = compare_arguments_nested( + 'unmatched optimizer state_dict', optimizer1['state_dict'], + optimizer2['state_dict'], **kwargs) and match + + res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') + if res is not None: + print( + f'lr_scheduler compared with user compare_fn with result:{res}\n') + match = match and res + else: + if lr_scheduler1['type'] != lr_scheduler2['type']: + print( + f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" + ) + match = compare_arguments_nested( + 'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'], + lr_scheduler2['state_dict'], **kwargs) and match + + res = compare_fn(cfg1, cfg2, None, 'cfg') + if res is not None: + print(f'cfg compared with user compare_fn with result:{res}\n') + match = match and res + else: + match = compare_arguments_nested( + 'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match + + res = compare_fn(state1, state2, None, 'state') + if res is not None: + print( + f'random state compared with user compare_fn with result:{res}\n') + match = match and res + else: + match = compare_arguments_nested('unmatched random state', state1, + state2, **kwargs) and match + + return match + + +class IgnoreKeyFn: + + def __init__(self, keys): + if isinstance(keys, str): + keys = [keys] + self.keys = keys if isinstance(keys, list) else [] + + def __call__(self, v1output, v2output, key, type): + if key == 'encoder.encoder.layer.0.intermediate.intermediate_act_fn': + print() + for _key in self.keys: + pattern = re.compile(_key) + if key is not None and pattern.fullmatch(key): + return True + return None diff --git a/modelscope/utils/service_utils.py b/modelscope/utils/service_utils.py new file mode 100644 index 00000000..29c111f8 --- /dev/null +++ b/modelscope/utils/service_utils.py @@ -0,0 +1,179 @@ +import base64 +import mimetypes +from io import BytesIO + +import json +import numpy as np +import requests +from PIL import Image + +from modelscope.outputs import TASK_OUTPUTS, OutputKeys +from modelscope.pipeline_inputs import TASK_INPUTS, InputType +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks, TasksIODescriptions + + +# service data decoder func decodes data from network and convert it to pipeline's input +# for example +def ExampleDecoder(data): + # Assuming the pipeline inputs is a dict contains an image and a text, + # to decode the data from network we decode the image as base64 + data_json = json.loads(data) + # data: {"image": "xxxxxxxx=="(base64 str), "text": "a question"} + # pipeline(inputs) as follows: + # pipeline({'image': image, 'text': text}) + inputs = { + 'image': decode_base64_to_image(data_json.get('image')), + 'text': data_json.get('text') + } + return inputs + + +# service data encoder func encodes data from pipeline outputs and convert to network response (such as json) +# for example +def ExampleEncoder(data): + # Assuming the pipeline outputs is a dict contains an image and a text, + # and transmit it through network, this func encode image to base64 and dumps into json + # data (for e.g. python dict): + # {"image": a numpy array represents a image, "text": "output"} + image = data['image'] + text = data['text'] + data = {'image': encode_array_to_img_base64(image), 'text': text} + return json.dumps(data, cls=NumpyEncoder) + + +CustomEncoder = { + # Tasks.visual_question_answering: ExampleEncoder +} + +CustomDecoder = { + # Tasks.visual_question_answering: ExampleDecoder +} + + +class NumpyEncoder(json.JSONEncoder): + + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + + if isinstance(obj, np.floating): + return float(obj) + + if isinstance(obj, np.integer): + return int(obj) + + return json.JSONEncoder.default(self, obj) + + +def get_extension(encoding): + encoding = encoding.replace('audio/wav', 'audio/x-wav') + tp = mimetypes.guess_type(encoding)[0] + if tp == 'audio/flac': # flac is not supported by mimetypes + return 'flac' + extension = mimetypes.guess_extension(tp) + if extension is not None and extension.startswith('.'): + extension = extension[1:] + return extension + + +def get_mimetype(filename): + mimetype = mimetypes.guess_type(filename)[0] + if mimetype is not None: + mimetype = mimetype.replace('x-wav', 'wav').replace('x-flac', 'flac') + return mimetype + + +def decode_base64_to_binary(encoding): + extension = get_extension(encoding) + data = encoding.split(',')[1] + return base64.b64decode(data), extension + + +def decode_base64_to_image(encoding): + content = encoding.split(';')[1] + image_encoded = content.split(',')[1] + return Image.open(BytesIO(base64.b64decode(image_encoded))) + + +def encode_array_to_img_base64(image_array): + with BytesIO() as output_bytes: + pil_image = Image.fromarray(image_array.astype(np.uint8)) + pil_image.save(output_bytes, 'PNG') + bytes_data = output_bytes.getvalue() + base64_str = str(base64.b64encode(bytes_data), 'utf-8') + return 'data:image/png;base64,' + base64_str + + +def encode_pcm_to_base64(bytes_data): + from scipy.io.wavfile import write + with BytesIO() as out_mem_file: + write(out_mem_file, 16000, bytes_data) + base64_str = str(base64.b64encode(out_mem_file.getvalue()), 'utf-8') + return 'data:audio/pcm;base64,' + base64_str + + +def encode_url_to_base64(url): + encoded_string = base64.b64encode(requests.get(url).content) + base64_str = str(encoded_string, 'utf-8') + mimetype = get_mimetype(url) + return ('data:' + (mimetype if mimetype is not None else '') + ';base64,' + + base64_str) + + +def encode_file_to_base64(f): + with open(f, 'rb') as file: + encoded_string = base64.b64encode(file.read()) + base64_str = str(encoded_string, 'utf-8') + mimetype = get_mimetype(f) + return ('data:' + (mimetype if mimetype is not None else '') + + ';base64,' + base64_str) + + +def encode_url_or_file_to_base64(path): + try: + requests.get(path) + return encode_url_to_base64(path) + except (requests.exceptions.MissingSchema, + requests.exceptions.InvalidSchema): + return encode_file_to_base64(path) + + +def service_data_decoder(task, data): + if CustomDecoder.get(task) is not None: + return CustomDecoder[task](data) + input_type = TASK_INPUTS[task] + input_data = data.decode('utf-8') + if input_type == InputType.IMAGE: + return decode_base64_to_image(input_data) + elif input_type == InputType.AUDIO: + return decode_base64_to_binary(input_data)[0] + elif input_type == InputType.TEXT: + return input_data + elif isinstance(input_type, dict): + input_data = {} + for key, val in input_type.items(): + if val == InputType.IMAGE: + input_data[key] = decode_base64_to_image(data[key]) + elif val == InputType.AUDIO: + input_data[key] = decode_base64_to_binary(data[key])[0] + elif val == InputType.TEXT: + input_data[key] = data[key] + + return input_data + + +def service_data_encoder(task, data): + if CustomEncoder.get(task) is not None: + return CustomEncoder[task](data) + output_keys = TASK_OUTPUTS[task] + result = data + for output_key in output_keys: + if output_key == OutputKeys.OUTPUT_IMG: + result[OutputKeys.OUTPUT_IMG] = encode_array_to_img_base64( + data[OutputKeys.OUTPUT_IMG][..., ::-1]) + elif output_key == OutputKeys.OUTPUT_PCM: + result[OutputKeys.OUTPUT_PCM] = encode_pcm_to_base64( + data[OutputKeys.OUTPUT_PCM]) + result = bytes(json.dumps(result, cls=NumpyEncoder), encoding='utf8') + return result diff --git a/modelscope/utils/tensor_utils.py b/modelscope/utils/tensor_utils.py new file mode 100644 index 00000000..8f580d19 --- /dev/null +++ b/modelscope/utils/tensor_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from huggingface/transformers. +from collections.abc import Mapping + + +def torch_nested_numpify(tensors): + """ Numpify nested torch tensors. + + NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. + + Args: + tensors: Nested torch tensors. + + Returns: + The numpify tensors. + """ + + import torch + "Numpify `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(torch_nested_numpify(t) for t in tensors) + if isinstance(tensors, Mapping): + # return dict + return {k: torch_nested_numpify(t) for k, t in tensors.items()} + if isinstance(tensors, torch.Tensor): + t = tensors.cpu() + return t.numpy() + return tensors + + +def torch_nested_detach(tensors): + """ Detach nested torch tensors. + + NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. + + Args: + tensors: Nested torch tensors. + + Returns: + The detached tensors. + """ + + import torch + "Detach `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(torch_nested_detach(t) for t in tensors) + if isinstance(tensors, Mapping): + return {k: torch_nested_detach(t) for k, t in tensors.items()} + if isinstance(tensors, torch.Tensor): + return tensors.detach() + return tensors diff --git a/modelscope/utils/test_utils.py b/modelscope/utils/test_utils.py new file mode 100644 index 00000000..5109db11 --- /dev/null +++ b/modelscope/utils/test_utils.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python +# Copyright (c) Alibaba, Inc. and its affiliates. + +import copy +import os +import pickle +import shutil +import socket +import subprocess +import sys +import tarfile +import tempfile +import unittest +from collections import OrderedDict + +import requests +import torch +from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE +from torch.utils.data import Dataset + +from .torch_utils import _find_free_port + +TEST_LEVEL = 2 +TEST_LEVEL_STR = 'TEST_LEVEL' + + +def test_level(): + global TEST_LEVEL + if TEST_LEVEL_STR in os.environ: + TEST_LEVEL = int(os.environ[TEST_LEVEL_STR]) + + return TEST_LEVEL + + +def require_tf(test_case): + if not TF_AVAILABLE: + test_case = unittest.skip('test requires TensorFlow')(test_case) + return test_case + + +def require_torch(test_case): + if not TORCH_AVAILABLE: + test_case = unittest.skip('test requires PyTorch')(test_case) + return test_case + + +def set_test_level(level: int): + global TEST_LEVEL + TEST_LEVEL = level + + +class DummyTorchDataset(Dataset): + + def __init__(self, feat, label, num) -> None: + self.feat = feat + self.label = label + self.num = num + + def __getitem__(self, index): + return { + 'feat': torch.Tensor(self.feat), + 'labels': torch.Tensor(self.label) + } + + def __len__(self): + return self.num + + +def create_dummy_test_dataset(feat, label, num): + return DummyTorchDataset(feat, label, num) + + +def download_and_untar(fpath, furl, dst) -> str: + if not os.path.exists(fpath): + r = requests.get(furl) + with open(fpath, 'wb') as f: + f.write(r.content) + + file_name = os.path.basename(fpath) + root_dir = os.path.dirname(fpath) + target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0] + target_dir_path = os.path.join(root_dir, target_dir_name) + + # untar the file + t = tarfile.open(fpath) + t.extractall(path=dst) + + return target_dir_path + + +def get_case_model_info(): + status_code, result = subprocess.getstatusoutput( + 'grep -rn "damo/" tests/ | grep -v ".pyc" | grep -v "Binary file" | grep -v run.py ' + ) + lines = result.split('\n') + test_cases = OrderedDict() + model_cases = OrderedDict() + for line in lines: + # "tests/msdatasets/test_ms_dataset.py:92: model_id = 'damo/bert-base-sst2'" + line = line.strip() + elements = line.split(':') + test_file = elements[0] + model_pos = line.find('damo') + left_quote = line[model_pos - 1] + rquote_idx = line.rfind(left_quote) + model_name = line[model_pos:rquote_idx] + if test_file not in test_cases: + test_cases[test_file] = set() + model_info = test_cases[test_file] + model_info.add(model_name) + + if model_name not in model_cases: + model_cases[model_name] = set() + case_info = model_cases[model_name] + case_info.add( + test_file.replace('tests/', '').replace('.py', + '').replace('/', '.')) + + return model_cases + + +_DIST_SCRIPT_TEMPLATE = """ +import ast +import argparse +import pickle +import torch +from torch import distributed as dist +from modelscope.utils.torch_utils import get_dist_info +import {} + +parser = argparse.ArgumentParser() +parser.add_argument('--save_all_ranks', type=ast.literal_eval, help='save all ranks results') +parser.add_argument('--save_file', type=str, help='save file') +parser.add_argument('--local_rank', type=int, default=0) +args = parser.parse_args() + + +def main(): + results = {}.{}({}) # module.func(params) + if args.save_all_ranks: + save_file = args.save_file + str(dist.get_rank()) + with open(save_file, 'wb') as f: + pickle.dump(results, f) + else: + rank, _ = get_dist_info() + if rank == 0: + with open(args.save_file, 'wb') as f: + pickle.dump(results, f) + + +if __name__ == '__main__': + main() +""" + + +class DistributedTestCase(unittest.TestCase): + """Distributed TestCase for test function with distributed mode. + Examples: + import torch + from torch import distributed as dist + from modelscope.utils.torch_utils import init_dist + + def _test_func(*args, **kwargs): + init_dist(launcher='pytorch') + rank = dist.get_rank() + if rank == 0: + value = torch.tensor(1.0).cuda() + else: + value = torch.tensor(2.0).cuda() + dist.all_reduce(value) + return value.cpu().numpy() + + class DistTest(DistributedTestCase): + def test_function_dist(self): + args = () # args should be python builtin type + kwargs = {} # kwargs should be python builtin type + self.start( + _test_func, + num_gpus=2, + assert_callback=lambda x: self.assertEqual(x, 3.0), + *args, + **kwargs, + ) + """ + + def _start(self, + dist_start_cmd, + func, + num_gpus, + assert_callback=None, + save_all_ranks=False, + *args, + **kwargs): + script_path = func.__code__.co_filename + script_dir, script_name = os.path.split(script_path) + script_name = os.path.splitext(script_name)[0] + func_name = func.__qualname__ + + func_params = [] + for arg in args: + if isinstance(arg, str): + arg = ('\'{}\''.format(arg)) + func_params.append(str(arg)) + + for k, v in kwargs.items(): + if isinstance(v, str): + v = ('\'{}\''.format(v)) + func_params.append('{}={}'.format(k, v)) + + func_params = ','.join(func_params).strip(',') + + tmp_run_file = tempfile.NamedTemporaryFile(suffix='.py').name + tmp_res_file = tempfile.NamedTemporaryFile(suffix='.pkl').name + + with open(tmp_run_file, 'w') as f: + print('save temporary run file to : {}'.format(tmp_run_file)) + print('save results to : {}'.format(tmp_res_file)) + run_file_content = _DIST_SCRIPT_TEMPLATE.format( + script_name, script_name, func_name, func_params) + f.write(run_file_content) + + tmp_res_files = [] + if save_all_ranks: + for i in range(num_gpus): + tmp_res_files.append(tmp_res_file + str(i)) + else: + tmp_res_files = [tmp_res_file] + self.addCleanup(self.clean_tmp, [tmp_run_file] + tmp_res_files) + + tmp_env = copy.deepcopy(os.environ) + tmp_env['PYTHONPATH'] = ':'.join( + (tmp_env.get('PYTHONPATH', ''), script_dir)).lstrip(':') + script_params = '--save_all_ranks=%s --save_file=%s' % (save_all_ranks, + tmp_res_file) + script_cmd = '%s %s %s' % (dist_start_cmd, tmp_run_file, script_params) + print('script command: %s' % script_cmd) + res = subprocess.call(script_cmd, shell=True, env=tmp_env) + + script_res = [] + for res_file in tmp_res_files: + with open(res_file, 'rb') as f: + script_res.append(pickle.load(f)) + if not save_all_ranks: + script_res = script_res[0] + + if assert_callback: + assert_callback(script_res) + + self.assertEqual( + res, + 0, + msg='The test function ``{}`` in ``{}`` run failed!'.format( + func_name, script_name)) + + return script_res + + def start(self, + func, + num_gpus, + assert_callback=None, + save_all_ranks=False, + *args, + **kwargs): + ip = socket.gethostbyname(socket.gethostname()) + dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d --master_addr=\'%s\' --master_port=%s' % ( + sys.executable, num_gpus, ip, _find_free_port()) + + return self._start( + dist_start_cmd=dist_start_cmd, + func=func, + num_gpus=num_gpus, + assert_callback=assert_callback, + save_all_ranks=save_all_ranks, + *args, + **kwargs) + + def clean_tmp(self, tmp_file_list): + for file in tmp_file_list: + if os.path.exists(file): + if os.path.isdir(file): + shutil.rmtree(file) + else: + os.remove(file) diff --git a/modelscope/utils/torch_utils.py b/modelscope/utils/torch_utils.py new file mode 100644 index 00000000..74d9bb7b --- /dev/null +++ b/modelscope/utils/torch_utils.py @@ -0,0 +1,202 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Following code is partialy borrowed from openmmlab/mmcv +import functools +import os +import pickle +import random +import socket +import subprocess +import tempfile +from typing import Callable, List, Optional, Tuple + +import numpy as np +import torch +import torch.multiprocessing as mp +from torch import distributed as dist + + +def _find_free_port() -> str: + # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(('', 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def _is_free_port(port: int) -> bool: + ips = socket.gethostbyname_ex(socket.gethostname())[-1] + ips.append('localhost') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return all(s.connect_ex((ip, port)) != 0 for ip in ips) + + +def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend: str, **kwargs) -> None: + # rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_mpi(backend: str, **kwargs) -> None: + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + torch.cuda.set_device(local_rank) + if 'MASTER_PORT' not in os.environ: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + if 'MASTER_ADDR' not in os.environ: + raise KeyError('The environment variable MASTER_ADDR is not set') + os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] + os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None: + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # if torch.distributed default port(29500) is available + # then use it, else find a free port + if _is_free_port(29500): + os.environ['MASTER_PORT'] = '29500' + else: + os.environ['MASTER_PORT'] = str(_find_free_port()) + # use MASTER_ADDR in the environment variable if it already exists + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info() -> Tuple[int, int]: + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def get_local_rank(): + return int(os.environ.get('LOCAL_RANK', 0)) + + +def is_master(): + rank, _ = get_dist_info() + return rank == 0 + + +def master_only(func: Callable) -> Callable: + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def make_tmp_dir(): + """Make sure each rank has the same temporary directory on the distributed mode. + """ + rank, world_size = get_dist_info() + if world_size <= 1: + return tempfile.mkdtemp() + + tmpdir = None + if rank == 0: + tmpdir = tempfile.mkdtemp() + + dist.barrier() + tmpdir = broadcast(tmpdir, 0) + + return tmpdir + + +def broadcast(inputs, src): + """ + Broadcasts the inputs to all ranks. + + Arguments: + inputs : Any objects that can be serialized by pickle. + src (int): Source rank. + Returns: + Each rank returns the same value as src. + """ + rank, _ = get_dist_info() + shape_tensor = torch.tensor([0], device='cuda') + + if rank == src: + inputs_tensor = torch.tensor( + bytearray(pickle.dumps(inputs)), dtype=torch.uint8, device='cuda') + shape_tensor = torch.tensor(inputs_tensor.shape, device='cuda') + + dist.barrier() + dist.broadcast(shape_tensor, src) + + if rank != src: + inputs_tensor = torch.full((shape_tensor.item(), ), + 0, + dtype=torch.uint8, + device='cuda') + + dist.barrier() + dist.broadcast(inputs_tensor, src) + + return pickle.loads(inputs_tensor.cpu().numpy().tobytes()) + + +def set_random_seed(seed): + if seed is not None and seed >= 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + else: + raise ValueError( + f'Random seed should be positive, current seed is {seed}') + + +def set_random_seed_mpu(seed): + from megatron import mpu + set_random_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) diff --git a/modelscope/utils/trie.py b/modelscope/utils/trie.py new file mode 100644 index 00000000..77f7e971 --- /dev/null +++ b/modelscope/utils/trie.py @@ -0,0 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from collections import defaultdict + + +class TreeNode: + + def __init__(self): + self.child = defaultdict(TreeNode) + + +class Trie: + + def __init__(self, eos): + self.root = TreeNode() + self.eos = eos + + def insert(self, word): + cur = self.root + for c in word: + cur = cur.child[c] + + def get_next_layer(self, word): + cur = self.root + for c in word: + cur = cur.child.get(c) + if cur is None: + return [self.eos] + return list(cur.child.keys()) diff --git a/modelscope/utils/type_assert.py b/modelscope/utils/type_assert.py new file mode 100644 index 00000000..f732a81a --- /dev/null +++ b/modelscope/utils/type_assert.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from functools import wraps +from inspect import signature + + +def type_assert(*ty_args, **ty_kwargs): + """a decorator which is used to check the types of arguments in a function or class + Examples: + >>> @type_assert(str) + ... def main(a: str, b: list): + ... print(a, b) + >>> main(1) + Argument a must be a str + + >>> @type_assert(str, (int, str)) + ... def main(a: str, b: int | str): + ... print(a, b) + >>> main('1', [1]) + Argument b must be (, ) + + >>> @type_assert(str, (int, str)) + ... class A: + ... def __init__(self, a: str, b: int | str) + ... print(a, b) + >>> a = A('1', [1]) + Argument b must be (, ) + """ + + def decorate(func): + # If in optimized mode, disable type checking + if not __debug__: + return func + + # Map function argument names to supplied types + sig = signature(func) + bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments + + @wraps(func) + def wrapper(*args, **kwargs): + bound_values = sig.bind(*args, **kwargs) + # Enforce type assertions across supplied arguments + for name, value in bound_values.arguments.items(): + if name in bound_types: + if not isinstance(value, bound_types[name]): + raise TypeError('Argument {} must be {}'.format( + name, bound_types[name])) + return func(*args, **kwargs) + + return wrapper + + return decorate diff --git a/modelscope/version.py b/modelscope/version.py new file mode 100644 index 00000000..ca813cc0 --- /dev/null +++ b/modelscope/version.py @@ -0,0 +1,5 @@ +# Make sure to modify __release_datetime__ to release time when making official release. +__version__ = '1.0.0' +# default release datetime for branches under active development is set +# to be a time far-far-away-into-the-future +__release_datetime__ = '2099-10-13 08:56:12' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..0832e6ab --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +-r requirements/framework.txt diff --git a/requirements/audio.txt b/requirements/audio.txt new file mode 100644 index 00000000..bef32121 --- /dev/null +++ b/requirements/audio.txt @@ -0,0 +1,27 @@ +easyasr>=0.0.2 +espnet==202204 +h5py +inflect +keras +kwsbp>=0.0.2 +librosa +lxml +matplotlib +MinDAEC +nara_wpe +nltk +# tensorflow 1.15 requires numpy<=1.18 +numpy<=1.18 +# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. +protobuf>3,<3.21.0 +ptflops +py_sound_connect>=0.1 +pytorch_wavelets +PyWavelets>=1.0.0 +scikit-learn +SoundFile>0.10 +sox +torchaudio +tqdm +ttsfrd>=0.0.3 +unidecode diff --git a/requirements/cv.txt b/requirements/cv.txt new file mode 100644 index 00000000..f29b296b --- /dev/null +++ b/requirements/cv.txt @@ -0,0 +1,34 @@ +albumentations>=1.0.3 +av>=9.2.0 +easydict +fairscale>=0.4.1 +fastai>=1.0.51 +ffmpeg>=1.4 +ffmpeg-python>=0.2.0 +ftfy +imageio>=2.9.0 +imageio-ffmpeg>=0.4.2 +imgaug>=0.4.0 +kornia>=0.5.0 +lmdb +lpips +ml_collections +mmcls>=0.21.0 +mmdet>=2.25.0 +moviepy>=1.0.3 +networkx>=2.5 +numba +onnxruntime>=1.10 +pai-easycv>=0.6.3.9 +pandas +psutil +regex +scikit-image>=0.19.3 +scikit-learn>=0.20.1 +shapely +shotdetect_scenedetect_lgss +tensorflow-estimator>=1.15.1 +tf_slim +timm>=0.4.9 +torchmetrics>=0.6.2 +torchvision diff --git a/requirements/docs.txt b/requirements/docs.txt new file mode 100644 index 00000000..f51d1565 --- /dev/null +++ b/requirements/docs.txt @@ -0,0 +1,7 @@ +docutils>=0.16.0 +myst_parser +recommonmark +sphinx>=4.0.2 +sphinx-book-theme +sphinx-copybutton +sphinx_markdown_tables diff --git a/requirements/framework.txt b/requirements/framework.txt new file mode 100644 index 00000000..a86c0cc5 --- /dev/null +++ b/requirements/framework.txt @@ -0,0 +1,22 @@ +addict +attrs +# version beyond 2.5.2 introduces compatbility issue and is being resolved +datasets<=2.5.2 +easydict +einops +filelock>=3.3.0 +gast>=0.2.2 +jsonplus +numpy +opencv-python +oss2 +Pillow>=6.2.0 +# for pyarrow 9.0.0 event_loop core dump +pyarrow>=6.0.0,!=9.0.0 +pyyaml +requests +scipy +setuptools +tensorboard +tqdm>=4.64.0 +yapf diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt new file mode 100644 index 00000000..31e9601d --- /dev/null +++ b/requirements/multi-modal.txt @@ -0,0 +1,17 @@ +ftfy>=6.0.3 +ofa>=0.0.2 +pycocoevalcap>=1.2 +pycocotools>=2.0.4 +# compatible with taming-transformers-rom1504 +pytorch_lightning<=1.7.7 +# rough-score was just recently updated from 0.0.4 to 0.0.7 +# which introduced compatability issues that are being investigated +rouge_score<=0.0.4 +sacrebleu +taming-transformers-rom1504 +timm +tokenizers +torchvision +transformers>=4.12.0 +unicodedata2 +zhconv diff --git a/requirements/nlp.txt b/requirements/nlp.txt new file mode 100644 index 00000000..433f70f7 --- /dev/null +++ b/requirements/nlp.txt @@ -0,0 +1,24 @@ +boto3 +en_core_web_sm>=2.3.5 +filelock +ftfy +jieba>=0.42.1 +matplotlib +nltk +pai-easynlp +pandas +# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. +protobuf>=3.19.0,<3.21.0 +pythainlp +pyvi +regex +sacremoses>=0.0.41 +scikit_learn +sentencepiece +seqeval +spacy>=2.3.5 +subword_nmt>=0.3.8 +termcolor +text2sql_lgesql +tokenizers +transformers>=4.12.0 diff --git a/requirements/science.txt b/requirements/science.txt new file mode 100644 index 00000000..c30ff644 --- /dev/null +++ b/requirements/science.txt @@ -0,0 +1,8 @@ +biopython +iopath +ipdb +lmdb +ml_collections +scipy +tensorboardX +tokenizers diff --git a/requirements/tensorflow1x.txt b/requirements/tensorflow1x.txt new file mode 100644 index 00000000..b139efe1 --- /dev/null +++ b/requirements/tensorflow1x.txt @@ -0,0 +1 @@ +numpy==1.18.5 diff --git a/requirements/tests.txt b/requirements/tests.txt new file mode 100644 index 00000000..5ec4df7e --- /dev/null +++ b/requirements/tests.txt @@ -0,0 +1,5 @@ +expecttest +flake8 +isort>=4.3.21 +pre-commit +yapf==0.30.0 # use fix version to ensure consistent auto-styling diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..3dc64f86 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,25 @@ +[isort] +line_length = 79 +multi_line_output = 0 +known_standard_library = setuptools +known_first_party = modelscope +known_third_party = json,yaml +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY + +[yapf] +BASED_ON_STYLE = pep8 +BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true +SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true +SPLIT_BEFORE_ARITHMETIC_OPERATOR = true + +[codespell] +skip = *.ipynb +quiet-level = 3 +ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids + +[flake8] +max-line-length = 120 +select = B,C,E,F,P,T4,W,B9 +ignore = F401,F405,F821,W503,E251 +exclude = docs/src,*.pyi,.git diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..eff2f8ba --- /dev/null +++ b/setup.py @@ -0,0 +1,216 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# !/usr/bin/env python +import os +import shutil +import subprocess +from setuptools import find_packages, setup + +from modelscope.utils.constant import Fields + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +version_file = 'modelscope/version.py' + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen( + cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + assert os.path.exists('.git'), '.git directory does not exist' + sha = get_git_hash()[:7] + return sha + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def parse_requirements(fname='requirements.txt', with_version=True): + """ + Parse the package dependencies listed in a requirements file but strips + specific versioning information. + + Args: + fname (str): path to requirements file + with_version (bool, default=False): if True include version specs + + Returns: + List[str]: list of requirements items + + CommandLine: + python -c "import setup; print(setup.parse_requirements())" + """ + import re + import sys + from os.path import exists + require_fpath = fname + + def parse_line(line): + """ + Parse information from a line in a requirements text file + """ + if line.startswith('-r '): + # Allow specifying requirements in other files + target = line.split(' ')[1] + for info in parse_require_file(target): + yield info + else: + info = {'line': line} + if line.startswith('-e '): + info['package'] = line.split('#egg=')[1] + else: + # Remove versioning from the package + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info['package'] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ';' in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, + rest.split(';')) + info['platform_deps'] = platform_deps + else: + version = rest # NOQA + info['version'] = (op, version) + yield info + + def parse_require_file(fpath): + with open(fpath, 'r') as f: + for line in f.readlines(): + line = line.strip() + if line.startswith('http'): + print('skip http requirements %s' % line) + continue + if line and not line.startswith('#') and not line.startswith( + '--'): + for info in parse_line(line): + yield info + elif line and line.startswith('--find-links'): + eles = line.split() + for e in eles: + e = e.strip() + if 'http' in e: + info = dict(dependency_links=e) + yield info + + def gen_packages_items(): + items = [] + deps_link = [] + if exists(require_fpath): + for info in parse_require_file(require_fpath): + if 'dependency_links' not in info: + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + # apparently package_deps are broken in 3.4 + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(';' + platform_deps) + item = ''.join(parts) + items.append(item) + else: + deps_link.append(info['dependency_links']) + return items, deps_link + + return gen_packages_items() + + +def pack_resource(): + # pack resource such as configs and tools + root_dir = 'package/' + if os.path.isdir(root_dir): + shutil.rmtree(root_dir) + os.makedirs(root_dir) + + proj_dir = root_dir + 'modelscope/' + shutil.copytree('./modelscope', proj_dir) + shutil.copytree('./configs', proj_dir + 'configs') + shutil.copytree('./requirements', 'package/requirements') + shutil.copy('./requirements.txt', 'package/requirements.txt') + shutil.copy('./MANIFEST.in', 'package/MANIFEST.in') + shutil.copy('./README.md', 'package/README.md') + + +if __name__ == '__main__': + # write_version_py() + pack_resource() + os.chdir('package') + install_requires, deps_link = parse_requirements('requirements.txt') + extra_requires = {} + all_requires = [] + for field in dir(Fields): + if field.startswith('_'): + continue + field = getattr(Fields, field) + extra_requires[field], _ = parse_requirements( + f'requirements/{field}.txt') + + # skip audio requirements due to its hard dependency which + # result in mac/windows compatibility problems + if field != Fields.audio: + all_requires.append(extra_requires[field]) + + extra_requires['all'] = all_requires + + setup( + name='modelscope', + version=get_version(), + description='', + long_description=readme(), + long_description_content_type='text/markdown', + author='Alibaba ModelScope team', + author_email='modelscope@list.alibaba-inc.com', + keywords='', + url='TBD', + packages=find_packages(exclude=('configs', 'tools', 'demo')), + include_package_data=True, + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + ], + license='Apache License 2.0', + tests_require=parse_requirements('requirements/tests.txt'), + install_requires=install_requires, + extras_require=extra_requires, + dependency_links=deps_link, + zip_safe=False) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/export/__init__.py b/tests/export/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/export/test_export_sbert_sequence_classification.py b/tests/export/test_export_sbert_sequence_classification.py new file mode 100644 index 00000000..7533732d --- /dev/null +++ b/tests/export/test_export_sbert_sequence_classification.py @@ -0,0 +1,70 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +from collections import OrderedDict + +from modelscope.exporters import Exporter, TorchModelExporter +from modelscope.models import Model +from modelscope.utils.test_utils import test_level + + +class TestExportSbertSequenceClassification(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_export_sbert_sequence_classification(self): + model = Model.from_pretrained(self.model_id) + print( + Exporter.from_model(model).export_onnx( + shape=(2, 256), output_dir=self.tmp_dir)) + print( + TorchModelExporter.from_model(model).export_torch_script( + shape=(2, 256), output_dir=self.tmp_dir)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_export_outer_module(self): + from transformers import BertForSequenceClassification, BertTokenizerFast + model = BertForSequenceClassification.from_pretrained( + 'bert-base-uncased') + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + dummy_inputs = tokenizer( + tokenizer.unk_token, + padding='max_length', + max_length=256, + return_tensors='pt') + dynamic_axis = {0: 'batch', 1: 'sequence'} + inputs = OrderedDict([ + ('input_ids', dynamic_axis), + ('attention_mask', dynamic_axis), + ('token_type_ids', dynamic_axis), + ]) + outputs = OrderedDict({'logits': {0: 'batch'}}) + output_files = TorchModelExporter().export_onnx( + model=model, + dummy_inputs=dummy_inputs, + inputs=inputs, + outputs=outputs, + output_dir='/tmp') + print(output_files) + output_files = TorchModelExporter().export_torch_script( + model=model, + dummy_inputs=dummy_inputs, + output_dir='/tmp', + strict=False) + print(output_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/fileio/__init__.py b/tests/fileio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fileio/test_file.py b/tests/fileio/test_file.py new file mode 100644 index 00000000..ded8ece7 --- /dev/null +++ b/tests/fileio/test_file.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import unittest + +from requests import HTTPError + +from modelscope.fileio.file import File, HTTPStorage, LocalStorage + + +class FileTest(unittest.TestCase): + + def test_local_storage(self): + storage = LocalStorage() + temp_name = tempfile.gettempdir() + '/' + next( + tempfile._get_candidate_names()) + binary_content = b'12345' + storage.write(binary_content, temp_name) + self.assertEqual(binary_content, storage.read(temp_name)) + + content = '12345' + storage.write_text(content, temp_name) + self.assertEqual(content, storage.read_text(temp_name)) + + os.remove(temp_name) + + def test_http_storage(self): + storage = HTTPStorage() + url = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/texts/data.txt' + content = 'this is test data' + self.assertEqual(content.encode('utf8'), storage.read(url)) + self.assertEqual(content, storage.read_text(url)) + + with storage.as_local_path(url) as local_file: + with open(local_file, 'r') as infile: + self.assertEqual(content, infile.read()) + + with self.assertRaises(NotImplementedError): + storage.write('dfad', url) + + with self.assertRaises(HTTPError): + storage.read(url + 'df') + + def test_file(self): + url = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/texts/data.txt' + content = 'this is test data' + self.assertEqual(content.encode('utf8'), File.read(url)) + + with File.as_local_path(url) as local_file: + with open(local_file, 'r') as infile: + self.assertEqual(content, infile.read()) + + with self.assertRaises(NotImplementedError): + File.write('dfad', url) + + with self.assertRaises(HTTPError): + File.read(url + 'df') + + temp_name = tempfile.gettempdir() + '/' + next( + tempfile._get_candidate_names()) + binary_content = b'12345' + File.write(binary_content, temp_name) + self.assertEqual(binary_content, File.read(temp_name)) + os.remove(temp_name) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/fileio/test_io.py b/tests/fileio/test_io.py new file mode 100644 index 00000000..0a80d3f7 --- /dev/null +++ b/tests/fileio/test_io.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import tempfile +import unittest + +from modelscope.fileio.io import dump, dumps, load + + +class FileIOTest(unittest.TestCase): + + def test_format(self, format='json'): + obj = [1, 2, 3, 'str', {'model': 'resnet'}] + result_str = dumps(obj, format) + temp_name = tempfile.gettempdir() + '/' + next( + tempfile._get_candidate_names()) + '.' + format + dump(obj, temp_name) + obj_load = load(temp_name) + self.assertEqual(obj_load, obj) + with open(temp_name, 'r') as infile: + self.assertEqual(result_str, infile.read()) + + with self.assertRaises(TypeError): + obj_load = load(temp_name + 's') + + with self.assertRaises(TypeError): + dump(obj, temp_name + 's') + + def test_yaml(self): + self.test_format('yaml') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/__init__.py b/tests/hub/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/hub/test_download_dataset.py b/tests/hub/test_download_dataset.py new file mode 100644 index 00000000..29b5d1ab --- /dev/null +++ b/tests/hub/test_download_dataset.py @@ -0,0 +1,709 @@ +import unittest + +from modelscope.msdatasets import MsDataset +from modelscope.utils.test_utils import test_level + + +class DownloadDatasetTest(unittest.TestCase): + + def setUp(self): + self.subset_count = 10 + + def download_subset(self, dataset, subset_name): + dataset = MsDataset.load(dataset, subset_name=subset_name) + if isinstance(dataset, MsDataset): + lens = len(dataset) + print(f'dataset {subset_name} len: {lens}') + self.assertTrue(lens > 0) + else: + assert isinstance(dataset, dict) + lens = {key: len(subset) for key, subset in dataset.items()} + print(f'dataset {subset_name} len: {lens}') + self.assertTrue(all([_len > 0 for _len in lens.values()])) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_glue(self): + subset = [ + 'cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', 'mnli_mismatched', + 'mnli_matched', 'qnli', 'rte', 'wnli', 'ax' + ] + for subset_name in subset: + self.download_subset('glue', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_super_glue(self): + subset = [ + 'boolq', 'cb', 'copa', 'multirc', 'record', 'rte', 'wic', 'wsc', + 'wsc.fixed', 'axb', 'axg' + ] + for subset_name in subset: + self.download_subset('super_glue', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_nllb(self): + subset = [ + 'ace_Latn-ban_Latn', 'ace_Latn-bjn_Latn', 'ace_Latn-bug_Latn', + 'ace_Latn-ceb_Latn', 'ace_Latn-eng_Latn', 'ace_Latn-fij_Latn', + 'ace_Latn-ilo_Latn', 'ace_Latn-jav_Latn', 'ace_Latn-min_Latn', + 'ace_Latn-mri_Latn', 'ace_Latn-pag_Latn', 'ace_Latn-plt_Latn', + 'ace_Latn-smo_Latn', 'ace_Latn-sun_Latn', 'ace_Latn-war_Latn', + 'afr_Latn-aka_Latn', 'afr_Latn-amh_Ethi', 'afr_Latn-bam_Latn', + 'afr_Latn-bem_Latn', 'afr_Latn-cjk_Latn', 'afr_Latn-dik_Latn', + 'afr_Latn-dyu_Latn', 'afr_Latn-eng_Latn', 'afr_Latn-ewe_Latn', + 'afr_Latn-fon_Latn', 'afr_Latn-fra_Latn', 'afr_Latn-fuv_Latn', + 'afr_Latn-gaz_Latn', 'afr_Latn-hau_Latn', 'afr_Latn-ibo_Latn', + 'afr_Latn-kam_Latn', 'afr_Latn-kik_Latn', 'afr_Latn-kin_Latn', + 'afr_Latn-kmb_Latn', 'afr_Latn-knc_Arab', 'afr_Latn-knc_Latn', + 'afr_Latn-kon_Latn', 'afr_Latn-lin_Latn', 'afr_Latn-lua_Latn', + 'afr_Latn-lug_Latn', 'afr_Latn-luo_Latn', 'afr_Latn-nso_Latn', + 'afr_Latn-nus_Latn', 'afr_Latn-nya_Latn', 'afr_Latn-run_Latn', + 'afr_Latn-sna_Latn', 'afr_Latn-som_Latn', 'afr_Latn-sot_Latn', + 'afr_Latn-ssw_Latn', 'afr_Latn-swh_Latn', 'afr_Latn-tir_Ethi', + 'afr_Latn-tsn_Latn', 'afr_Latn-tso_Latn', 'afr_Latn-tum_Latn', + 'afr_Latn-twi_Latn', 'afr_Latn-umb_Latn', 'afr_Latn-wol_Latn', + 'afr_Latn-xho_Latn', 'afr_Latn-yor_Latn', 'afr_Latn-zul_Latn', + 'aka_Latn-amh_Ethi', 'aka_Latn-bam_Latn', 'aka_Latn-bem_Latn', + 'aka_Latn-cjk_Latn', 'aka_Latn-dik_Latn', 'aka_Latn-dyu_Latn', + 'aka_Latn-eng_Latn', 'aka_Latn-ewe_Latn', 'aka_Latn-fon_Latn', + 'aka_Latn-fra_Latn', 'aka_Latn-fuv_Latn', 'aka_Latn-gaz_Latn', + 'aka_Latn-hau_Latn', 'aka_Latn-ibo_Latn', 'aka_Latn-kam_Latn', + 'aka_Latn-kik_Latn', 'aka_Latn-kin_Latn', 'aka_Latn-kmb_Latn', + 'aka_Latn-knc_Arab', 'aka_Latn-knc_Latn', 'aka_Latn-kon_Latn', + 'aka_Latn-lin_Latn', 'aka_Latn-lua_Latn', 'aka_Latn-lug_Latn', + 'aka_Latn-luo_Latn', 'aka_Latn-nso_Latn', 'aka_Latn-nus_Latn', + 'aka_Latn-nya_Latn', 'aka_Latn-run_Latn', 'aka_Latn-sna_Latn', + 'aka_Latn-som_Latn', 'aka_Latn-sot_Latn', 'aka_Latn-ssw_Latn', + 'aka_Latn-swh_Latn', 'aka_Latn-tir_Ethi', 'aka_Latn-tsn_Latn', + 'aka_Latn-tso_Latn', 'aka_Latn-tum_Latn', 'aka_Latn-twi_Latn', + 'aka_Latn-umb_Latn', 'aka_Latn-wol_Latn', 'aka_Latn-xho_Latn', + 'aka_Latn-yor_Latn', 'aka_Latn-zul_Latn', 'amh_Ethi-bam_Latn', + 'amh_Ethi-bem_Latn', 'amh_Ethi-cjk_Latn', 'amh_Ethi-dik_Latn', + 'amh_Ethi-dyu_Latn', 'amh_Ethi-eng_Latn', 'amh_Ethi-ewe_Latn', + 'amh_Ethi-fon_Latn', 'amh_Ethi-fra_Latn', 'amh_Ethi-fuv_Latn', + 'amh_Ethi-gaz_Latn', 'amh_Ethi-hau_Latn', 'amh_Ethi-ibo_Latn', + 'amh_Ethi-kam_Latn', 'amh_Ethi-kik_Latn', 'amh_Ethi-kin_Latn', + 'amh_Ethi-kmb_Latn', 'amh_Ethi-knc_Arab', 'amh_Ethi-knc_Latn', + 'amh_Ethi-kon_Latn', 'amh_Ethi-lin_Latn', 'amh_Ethi-lua_Latn', + 'amh_Ethi-lug_Latn', 'amh_Ethi-luo_Latn', 'amh_Ethi-nso_Latn', + 'amh_Ethi-nus_Latn', 'amh_Ethi-nya_Latn', 'amh_Ethi-run_Latn', + 'amh_Ethi-sna_Latn', 'amh_Ethi-som_Latn', 'amh_Ethi-sot_Latn', + 'amh_Ethi-ssw_Latn', 'amh_Ethi-swh_Latn', 'amh_Ethi-tir_Ethi', + 'amh_Ethi-tsn_Latn', 'amh_Ethi-tso_Latn', 'amh_Ethi-tum_Latn', + 'amh_Ethi-twi_Latn', 'amh_Ethi-umb_Latn', 'amh_Ethi-wol_Latn', + 'amh_Ethi-xho_Latn', 'amh_Ethi-yor_Latn', 'amh_Ethi-zul_Latn', + 'arb_Arab-ckb_Arab', 'arb_Arab-crh_Latn', 'arb_Arab-dik_Latn', + 'arb_Arab-diq_Latn', 'arb_Arab-fuv_Latn', 'arb_Arab-kmr_Latn', + 'arb_Arab-knc_Latn', 'arb_Arab-nus_Latn', 'arb_Arab-som_Latn', + 'arb_Arab-tat_Cyrl', 'arb_Arab-tzm_Tfng', 'arb_Arab-urd_Arab', + 'arb_Arab-wol_Latn', 'asm_Beng-awa_Deva', 'asm_Beng-ben_Beng', + 'asm_Beng-bho_Deva', 'asm_Beng-eng_Latn', 'asm_Beng-guj_Gujr', + 'asm_Beng-hin_Deva', 'asm_Beng-hne_Deva', 'asm_Beng-kan_Knda', + 'asm_Beng-kas_Arab', 'asm_Beng-kas_Deva', 'asm_Beng-mag_Deva', + 'asm_Beng-mai_Deva', 'asm_Beng-mal_Mlym', 'asm_Beng-mar_Deva', + 'asm_Beng-npi_Deva', 'asm_Beng-ory_Orya', 'asm_Beng-pan_Guru', + 'asm_Beng-san_Deva', 'asm_Beng-sat_Beng', 'asm_Beng-sin_Sinh', + 'asm_Beng-snd_Arab', 'asm_Beng-tam_Taml', 'asm_Beng-tel_Telu', + 'asm_Beng-urd_Arab', 'awa_Deva-ben_Beng', 'awa_Deva-bho_Deva', + 'awa_Deva-eng_Latn', 'awa_Deva-guj_Gujr', 'awa_Deva-hin_Deva', + 'awa_Deva-hne_Deva', 'awa_Deva-kan_Knda', 'awa_Deva-kas_Arab', + 'awa_Deva-kas_Deva', 'awa_Deva-mag_Deva', 'awa_Deva-mai_Deva', + 'awa_Deva-mal_Mlym', 'awa_Deva-mar_Deva', 'awa_Deva-npi_Deva', + 'awa_Deva-ory_Orya', 'awa_Deva-pan_Guru', 'awa_Deva-san_Deva', + 'awa_Deva-sat_Beng', 'awa_Deva-sin_Sinh', 'awa_Deva-snd_Arab', + 'awa_Deva-tam_Taml', 'awa_Deva-tel_Telu', 'awa_Deva-urd_Arab', + 'ayr_Latn-eng_Latn', 'ayr_Latn-spa_Latn', 'azb_Arab-eng_Latn', + 'azj_Latn-eng_Latn', 'azj_Latn-rus_Cyrl', 'bak_Cyrl-crh_Latn', + 'bak_Cyrl-eng_Latn', 'bak_Cyrl-kir_Cyrl', 'bak_Cyrl-rus_Cyrl', + 'bak_Cyrl-tat_Cyrl', 'bak_Cyrl-tuk_Latn', 'bak_Cyrl-uig_Arab', + 'bak_Cyrl-uzn_Latn', 'bam_Latn-bem_Latn', 'bam_Latn-cjk_Latn', + 'bam_Latn-dik_Latn', 'bam_Latn-dyu_Latn', 'bam_Latn-eng_Latn', + 'bam_Latn-ewe_Latn', 'bam_Latn-fon_Latn', 'bam_Latn-fra_Latn', + 'bam_Latn-fuv_Latn', 'bam_Latn-gaz_Latn', 'bam_Latn-hau_Latn', + 'bam_Latn-ibo_Latn', 'bam_Latn-kam_Latn', 'bam_Latn-kik_Latn', + 'bam_Latn-kin_Latn', 'bam_Latn-kmb_Latn', 'bam_Latn-knc_Arab', + 'bam_Latn-knc_Latn', 'bam_Latn-kon_Latn', 'bam_Latn-lin_Latn', + 'bam_Latn-lua_Latn', 'bam_Latn-lug_Latn', 'bam_Latn-luo_Latn', + 'bam_Latn-nso_Latn', 'bam_Latn-nus_Latn', 'bam_Latn-nya_Latn', + 'bam_Latn-run_Latn', 'bam_Latn-sna_Latn', 'bam_Latn-som_Latn', + 'bam_Latn-sot_Latn', 'bam_Latn-ssw_Latn', 'bam_Latn-swh_Latn', + 'bam_Latn-tir_Ethi', 'bam_Latn-tsn_Latn', 'bam_Latn-tso_Latn', + 'bam_Latn-tum_Latn', 'bam_Latn-twi_Latn', 'bam_Latn-umb_Latn', + 'bam_Latn-wol_Latn', 'bam_Latn-xho_Latn', 'bam_Latn-yor_Latn', + 'bam_Latn-zul_Latn', 'ban_Latn-bjn_Latn', 'ban_Latn-bug_Latn', + 'ban_Latn-ceb_Latn', 'ban_Latn-eng_Latn', 'ban_Latn-fij_Latn', + 'ban_Latn-ilo_Latn', 'ban_Latn-jav_Latn', 'ban_Latn-min_Latn', + 'ban_Latn-mri_Latn', 'ban_Latn-pag_Latn', 'ban_Latn-plt_Latn', + 'ban_Latn-smo_Latn', 'ban_Latn-sun_Latn', 'ban_Latn-war_Latn', + 'bel_Cyrl-eng_Latn', 'bel_Cyrl-rus_Cyrl', 'bem_Latn-cjk_Latn', + 'bem_Latn-dik_Latn', 'bem_Latn-dyu_Latn', 'bem_Latn-eng_Latn', + 'bem_Latn-ewe_Latn', 'bem_Latn-fon_Latn', 'bem_Latn-fra_Latn', + 'bem_Latn-fuv_Latn', 'bem_Latn-gaz_Latn', 'bem_Latn-hau_Latn', + 'bem_Latn-ibo_Latn', 'bem_Latn-kam_Latn', 'bem_Latn-kik_Latn', + 'bem_Latn-kin_Latn', 'bem_Latn-kmb_Latn', 'bem_Latn-knc_Arab', + 'bem_Latn-knc_Latn', 'bem_Latn-kon_Latn', 'bem_Latn-lin_Latn', + 'bem_Latn-lua_Latn', 'bem_Latn-lug_Latn', 'bem_Latn-luo_Latn', + 'bem_Latn-nso_Latn', 'bem_Latn-nus_Latn', 'bem_Latn-nya_Latn', + 'bem_Latn-run_Latn', 'bem_Latn-sna_Latn', 'bem_Latn-som_Latn', + 'bem_Latn-sot_Latn', 'bem_Latn-ssw_Latn', 'bem_Latn-swh_Latn', + 'bem_Latn-tir_Ethi', 'bem_Latn-tsn_Latn', 'bem_Latn-tso_Latn', + 'bem_Latn-tum_Latn', 'bem_Latn-twi_Latn', 'bem_Latn-umb_Latn', + 'bem_Latn-wol_Latn', 'bem_Latn-xho_Latn', 'bem_Latn-yor_Latn', + 'bem_Latn-zul_Latn', 'ben_Beng-bho_Deva', 'ben_Beng-eng_Latn', + 'ben_Beng-guj_Gujr', 'ben_Beng-hin_Deva', 'ben_Beng-hne_Deva', + 'ben_Beng-kan_Knda', 'ben_Beng-kas_Arab', 'ben_Beng-kas_Deva', + 'ben_Beng-mag_Deva', 'ben_Beng-mai_Deva', 'ben_Beng-mal_Mlym', + 'ben_Beng-mar_Deva', 'ben_Beng-npi_Deva', 'ben_Beng-ory_Orya', + 'ben_Beng-pan_Guru', 'ben_Beng-pbt_Arab', 'ben_Beng-san_Deva', + 'ben_Beng-sat_Beng', 'ben_Beng-sin_Sinh', 'ben_Beng-snd_Arab', + 'ben_Beng-tam_Taml', 'ben_Beng-tel_Telu', 'ben_Beng-urd_Arab', + 'bho_Deva-eng_Latn', 'bho_Deva-guj_Gujr', 'bho_Deva-hin_Deva', + 'bho_Deva-hne_Deva', 'bho_Deva-kan_Knda', 'bho_Deva-kas_Arab', + 'bho_Deva-kas_Deva', 'bho_Deva-mag_Deva', 'bho_Deva-mai_Deva', + 'bho_Deva-mal_Mlym', 'bho_Deva-mar_Deva', 'bho_Deva-npi_Deva', + 'bho_Deva-ory_Orya', 'bho_Deva-pan_Guru', 'bho_Deva-san_Deva', + 'bho_Deva-sat_Beng', 'bho_Deva-sin_Sinh', 'bho_Deva-snd_Arab', + 'bho_Deva-tam_Taml', 'bho_Deva-tel_Telu', 'bho_Deva-urd_Arab', + 'bjn_Latn-bug_Latn', 'bjn_Latn-ceb_Latn', 'bjn_Latn-eng_Latn', + 'bjn_Latn-fij_Latn', 'bjn_Latn-ilo_Latn', 'bjn_Latn-ind_Latn', + 'bjn_Latn-jav_Latn', 'bjn_Latn-min_Latn', 'bjn_Latn-mri_Latn', + 'bjn_Latn-pag_Latn', 'bjn_Latn-plt_Latn', 'bjn_Latn-smo_Latn', + 'bjn_Latn-sun_Latn', 'bjn_Latn-war_Latn', 'bod_Tibt-eng_Latn', + 'bos_Latn-eng_Latn', 'bug_Latn-ceb_Latn', 'bug_Latn-eng_Latn', + 'bug_Latn-fij_Latn', 'bug_Latn-ilo_Latn', 'bug_Latn-jav_Latn', + 'bug_Latn-min_Latn', 'bug_Latn-mri_Latn', 'bug_Latn-pag_Latn', + 'bug_Latn-plt_Latn', 'bug_Latn-smo_Latn', 'bug_Latn-sun_Latn', + 'bug_Latn-war_Latn', 'ceb_Latn-eng_Latn', 'ceb_Latn-fij_Latn', + 'ceb_Latn-ilo_Latn', 'ceb_Latn-jav_Latn', 'ceb_Latn-min_Latn', + 'ceb_Latn-mri_Latn', 'ceb_Latn-pag_Latn', 'ceb_Latn-plt_Latn', + 'ceb_Latn-smo_Latn', 'ceb_Latn-sun_Latn', 'ceb_Latn-war_Latn', + 'cjk_Latn-dik_Latn', 'cjk_Latn-dyu_Latn', 'cjk_Latn-eng_Latn', + 'cjk_Latn-ewe_Latn', 'cjk_Latn-fon_Latn', 'cjk_Latn-fra_Latn', + 'cjk_Latn-fuv_Latn', 'cjk_Latn-gaz_Latn', 'cjk_Latn-hau_Latn', + 'cjk_Latn-ibo_Latn', 'cjk_Latn-kam_Latn', 'cjk_Latn-kik_Latn', + 'cjk_Latn-kin_Latn', 'cjk_Latn-kmb_Latn', 'cjk_Latn-knc_Arab', + 'cjk_Latn-knc_Latn', 'cjk_Latn-kon_Latn', 'cjk_Latn-lin_Latn', + 'cjk_Latn-lua_Latn', 'cjk_Latn-lug_Latn', 'cjk_Latn-luo_Latn', + 'cjk_Latn-nso_Latn', 'cjk_Latn-nus_Latn', 'cjk_Latn-nya_Latn', + 'cjk_Latn-por_Latn', 'cjk_Latn-run_Latn', 'cjk_Latn-sna_Latn', + 'cjk_Latn-som_Latn', 'cjk_Latn-sot_Latn', 'cjk_Latn-ssw_Latn', + 'cjk_Latn-swh_Latn', 'cjk_Latn-tir_Ethi', 'cjk_Latn-tsn_Latn', + 'cjk_Latn-tso_Latn', 'cjk_Latn-tum_Latn', 'cjk_Latn-twi_Latn', + 'cjk_Latn-umb_Latn', 'cjk_Latn-wol_Latn', 'cjk_Latn-xho_Latn', + 'cjk_Latn-yor_Latn', 'cjk_Latn-zul_Latn', 'ckb_Arab-diq_Latn', + 'ckb_Arab-eng_Latn', 'ckb_Arab-kmr_Latn', 'ckb_Arab-pbt_Arab', + 'ckb_Arab-prs_Arab', 'ckb_Arab-tgk_Cyrl', 'crh_Latn-eng_Latn', + 'crh_Latn-kir_Cyrl', 'crh_Latn-rus_Cyrl', 'crh_Latn-tat_Cyrl', + 'crh_Latn-tuk_Latn', 'crh_Latn-uig_Arab', 'crh_Latn-uzn_Latn', + 'cym_Latn-eng_Latn', 'dik_Latn-dyu_Latn', 'dik_Latn-eng_Latn', + 'dik_Latn-ewe_Latn', 'dik_Latn-fon_Latn', 'dik_Latn-fra_Latn', + 'dik_Latn-fuv_Latn', 'dik_Latn-gaz_Latn', 'dik_Latn-hau_Latn', + 'dik_Latn-ibo_Latn', 'dik_Latn-kam_Latn', 'dik_Latn-kik_Latn', + 'dik_Latn-kin_Latn', 'dik_Latn-kmb_Latn', 'dik_Latn-knc_Arab', + 'dik_Latn-knc_Latn', 'dik_Latn-kon_Latn', 'dik_Latn-lin_Latn', + 'dik_Latn-lua_Latn', 'dik_Latn-lug_Latn', 'dik_Latn-luo_Latn', + 'dik_Latn-nso_Latn', 'dik_Latn-nus_Latn', 'dik_Latn-nya_Latn', + 'dik_Latn-run_Latn', 'dik_Latn-sna_Latn', 'dik_Latn-som_Latn', + 'dik_Latn-sot_Latn', 'dik_Latn-ssw_Latn', 'dik_Latn-swh_Latn', + 'dik_Latn-tir_Ethi', 'dik_Latn-tsn_Latn', 'dik_Latn-tso_Latn', + 'dik_Latn-tum_Latn', 'dik_Latn-twi_Latn', 'dik_Latn-umb_Latn', + 'dik_Latn-wol_Latn', 'dik_Latn-xho_Latn', 'dik_Latn-yor_Latn', + 'dik_Latn-zul_Latn', 'diq_Latn-eng_Latn', 'diq_Latn-kmr_Latn', + 'diq_Latn-pbt_Arab', 'diq_Latn-prs_Arab', 'diq_Latn-tgk_Cyrl', + 'dyu_Latn-eng_Latn', 'dyu_Latn-ewe_Latn', 'dyu_Latn-fon_Latn', + 'dyu_Latn-fra_Latn', 'dyu_Latn-fuv_Latn', 'dyu_Latn-gaz_Latn', + 'dyu_Latn-hau_Latn', 'dyu_Latn-ibo_Latn', 'dyu_Latn-kam_Latn', + 'dyu_Latn-kik_Latn', 'dyu_Latn-kin_Latn', 'dyu_Latn-kmb_Latn', + 'dyu_Latn-knc_Arab', 'dyu_Latn-knc_Latn', 'dyu_Latn-kon_Latn', + 'dyu_Latn-lin_Latn', 'dyu_Latn-lua_Latn', 'dyu_Latn-lug_Latn', + 'dyu_Latn-luo_Latn', 'dyu_Latn-nso_Latn', 'dyu_Latn-nus_Latn', + 'dyu_Latn-nya_Latn', 'dyu_Latn-run_Latn', 'dyu_Latn-sna_Latn', + 'dyu_Latn-som_Latn', 'dyu_Latn-sot_Latn', 'dyu_Latn-ssw_Latn', + 'dyu_Latn-swh_Latn', 'dyu_Latn-tir_Ethi', 'dyu_Latn-tsn_Latn', + 'dyu_Latn-tso_Latn', 'dyu_Latn-tum_Latn', 'dyu_Latn-twi_Latn', + 'dyu_Latn-umb_Latn', 'dyu_Latn-wol_Latn', 'dyu_Latn-xho_Latn', + 'dyu_Latn-yor_Latn', 'dyu_Latn-zul_Latn', 'dzo_Tibt-eng_Latn', + 'eng_Latn-als_Latn', 'eng_Latn-epo_Latn', 'eng_Latn-ewe_Latn', + 'eng_Latn-fao_Latn', 'eng_Latn-fij_Latn', 'eng_Latn-fon_Latn', + 'eng_Latn-fur_Latn', 'eng_Latn-fuv_Latn', 'eng_Latn-gaz_Latn', + 'eng_Latn-gla_Latn', 'eng_Latn-gle_Latn', 'eng_Latn-grn_Latn', + 'eng_Latn-guj_Gujr', 'eng_Latn-hat_Latn', 'eng_Latn-hau_Latn', + 'eng_Latn-hin_Deva', 'eng_Latn-hne_Deva', 'eng_Latn-hye_Armn', + 'eng_Latn-ibo_Latn', 'eng_Latn-ilo_Latn', 'eng_Latn-jav_Latn', + 'eng_Latn-kab_Latn', 'eng_Latn-kac_Latn', 'eng_Latn-kam_Latn', + 'eng_Latn-kan_Knda', 'eng_Latn-kas_Arab', 'eng_Latn-kas_Deva', + 'eng_Latn-kat_Geor', 'eng_Latn-kaz_Cyrl', 'eng_Latn-kbp_Latn', + 'eng_Latn-kea_Latn', 'eng_Latn-khk_Cyrl', 'eng_Latn-khm_Khmr', + 'eng_Latn-kik_Latn', 'eng_Latn-kin_Latn', 'eng_Latn-kir_Cyrl', + 'eng_Latn-kmb_Latn', 'eng_Latn-kmr_Latn', 'eng_Latn-knc_Arab', + 'eng_Latn-knc_Latn', 'eng_Latn-kon_Latn', 'eng_Latn-lao_Laoo', + 'eng_Latn-lij_Latn', 'eng_Latn-lim_Latn', 'eng_Latn-lin_Latn', + 'eng_Latn-lmo_Latn', 'eng_Latn-ltg_Latn', 'eng_Latn-ltz_Latn', + 'eng_Latn-lua_Latn', 'eng_Latn-lug_Latn', 'eng_Latn-luo_Latn', + 'eng_Latn-lus_Latn', 'eng_Latn-mag_Deva', 'eng_Latn-mai_Deva', + 'eng_Latn-mal_Mlym', 'eng_Latn-mar_Deva', 'eng_Latn-min_Latn', + 'eng_Latn-mlt_Latn', 'eng_Latn-mni_Beng', 'eng_Latn-mos_Latn', + 'eng_Latn-mri_Latn', 'eng_Latn-mya_Mymr', 'eng_Latn-npi_Deva', + 'eng_Latn-nso_Latn', 'eng_Latn-nus_Latn', 'eng_Latn-nya_Latn', + 'eng_Latn-ory_Orya', 'eng_Latn-pag_Latn', 'eng_Latn-pan_Guru', + 'eng_Latn-pap_Latn', 'eng_Latn-pbt_Arab', 'eng_Latn-plt_Latn', + 'eng_Latn-prs_Arab', 'eng_Latn-quy_Latn', 'eng_Latn-run_Latn', + 'eng_Latn-sag_Latn', 'eng_Latn-san_Deva', 'eng_Latn-sat_Beng', + 'eng_Latn-scn_Latn', 'eng_Latn-shn_Mymr', 'eng_Latn-sin_Sinh', + 'eng_Latn-smo_Latn', 'eng_Latn-sna_Latn', 'eng_Latn-snd_Arab', + 'eng_Latn-som_Latn', 'eng_Latn-sot_Latn', 'eng_Latn-srd_Latn', + 'eng_Latn-ssw_Latn', 'eng_Latn-sun_Latn', 'eng_Latn-swh_Latn', + 'eng_Latn-szl_Latn', 'eng_Latn-tam_Taml', 'eng_Latn-taq_Latn', + 'eng_Latn-tat_Cyrl', 'eng_Latn-tel_Telu', 'eng_Latn-tgk_Cyrl', + 'eng_Latn-tgl_Latn', 'eng_Latn-tir_Ethi', 'eng_Latn-tpi_Latn', + 'eng_Latn-tsn_Latn', 'eng_Latn-tso_Latn', 'eng_Latn-tuk_Latn', + 'eng_Latn-tum_Latn', 'eng_Latn-twi_Latn', 'eng_Latn-tzm_Tfng', + 'eng_Latn-uig_Arab', 'eng_Latn-umb_Latn', 'eng_Latn-urd_Arab', + 'eng_Latn-uzn_Latn', 'eng_Latn-vec_Latn', 'eng_Latn-war_Latn', + 'eng_Latn-wol_Latn', 'eng_Latn-xho_Latn', 'eng_Latn-ydd_Hebr', + 'eng_Latn-yor_Latn', 'eng_Latn-zho_Hant', 'eng_Latn-zsm_Latn', + 'eng_Latn-zul_Latn', 'epo_Latn-fra_Latn', 'ewe_Latn-fon_Latn', + 'ewe_Latn-fra_Latn', 'ewe_Latn-fuv_Latn', 'ewe_Latn-gaz_Latn', + 'ewe_Latn-hau_Latn', 'ewe_Latn-ibo_Latn', 'ewe_Latn-kam_Latn', + 'ewe_Latn-kik_Latn', 'ewe_Latn-kin_Latn', 'ewe_Latn-kmb_Latn', + 'ewe_Latn-knc_Arab', 'ewe_Latn-knc_Latn', 'ewe_Latn-kon_Latn', + 'ewe_Latn-lin_Latn', 'ewe_Latn-lua_Latn', 'ewe_Latn-lug_Latn', + 'ewe_Latn-luo_Latn', 'ewe_Latn-nso_Latn', 'ewe_Latn-nus_Latn', + 'ewe_Latn-nya_Latn', 'ewe_Latn-run_Latn', 'ewe_Latn-sna_Latn', + 'ewe_Latn-som_Latn', 'ewe_Latn-sot_Latn', 'ewe_Latn-ssw_Latn', + 'ewe_Latn-swh_Latn', 'ewe_Latn-tir_Ethi', 'ewe_Latn-tsn_Latn', + 'ewe_Latn-tso_Latn', 'ewe_Latn-tum_Latn', 'ewe_Latn-twi_Latn', + 'ewe_Latn-umb_Latn', 'ewe_Latn-wol_Latn', 'ewe_Latn-xho_Latn', + 'ewe_Latn-yor_Latn', 'ewe_Latn-zul_Latn', 'fij_Latn-hin_Deva', + 'fij_Latn-ilo_Latn', 'fij_Latn-jav_Latn', 'fij_Latn-min_Latn', + 'fij_Latn-mri_Latn', 'fij_Latn-pag_Latn', 'fij_Latn-plt_Latn', + 'fij_Latn-smo_Latn', 'fij_Latn-sun_Latn', 'fij_Latn-war_Latn', + 'fon_Latn-fra_Latn', 'fon_Latn-fuv_Latn', 'fon_Latn-gaz_Latn', + 'fon_Latn-hau_Latn', 'fon_Latn-ibo_Latn', 'fon_Latn-kam_Latn', + 'fon_Latn-kik_Latn', 'fon_Latn-kin_Latn', 'fon_Latn-kmb_Latn', + 'fon_Latn-knc_Arab', 'fon_Latn-knc_Latn', 'fon_Latn-kon_Latn', + 'fon_Latn-lin_Latn', 'fon_Latn-lua_Latn', 'fon_Latn-lug_Latn', + 'fon_Latn-luo_Latn', 'fon_Latn-nso_Latn', 'fon_Latn-nus_Latn', + 'fon_Latn-nya_Latn', 'fon_Latn-run_Latn', 'fon_Latn-sna_Latn', + 'fon_Latn-som_Latn', 'fon_Latn-sot_Latn', 'fon_Latn-ssw_Latn', + 'fon_Latn-swh_Latn', 'fon_Latn-tir_Ethi', 'fon_Latn-tsn_Latn', + 'fon_Latn-tso_Latn', 'fon_Latn-tum_Latn', 'fon_Latn-twi_Latn', + 'fon_Latn-umb_Latn', 'fon_Latn-wol_Latn', 'fon_Latn-xho_Latn', + 'fon_Latn-yor_Latn', 'fon_Latn-zul_Latn', 'fra_Latn-fuv_Latn', + 'fra_Latn-gaz_Latn', 'fra_Latn-glg_Latn', 'fra_Latn-hat_Latn', + 'fra_Latn-hau_Latn', 'fra_Latn-ibo_Latn', 'fra_Latn-kab_Latn', + 'fra_Latn-kam_Latn', 'fra_Latn-kik_Latn', 'fra_Latn-kin_Latn', + 'fra_Latn-kmb_Latn', 'fra_Latn-knc_Arab', 'fra_Latn-knc_Latn', + 'fra_Latn-kon_Latn', 'fra_Latn-lin_Latn', 'fra_Latn-ltz_Latn', + 'fra_Latn-lua_Latn', 'fra_Latn-lug_Latn', 'fra_Latn-luo_Latn', + 'fra_Latn-nso_Latn', 'fra_Latn-nus_Latn', 'fra_Latn-nya_Latn', + 'fra_Latn-oci_Latn', 'fra_Latn-plt_Latn', 'fra_Latn-run_Latn', + 'fra_Latn-sag_Latn', 'fra_Latn-scn_Latn', 'fra_Latn-sna_Latn', + 'fra_Latn-som_Latn', 'fra_Latn-sot_Latn', 'fra_Latn-ssw_Latn', + 'fra_Latn-swh_Latn', 'fra_Latn-tir_Ethi', 'fra_Latn-tsn_Latn', + 'fra_Latn-tso_Latn', 'fra_Latn-tum_Latn', 'fra_Latn-twi_Latn', + 'fra_Latn-tzm_Tfng', 'fra_Latn-umb_Latn', 'fra_Latn-wol_Latn', + 'fra_Latn-xho_Latn', 'fra_Latn-yor_Latn', 'fra_Latn-zul_Latn', + 'fuv_Latn-gaz_Latn', 'fuv_Latn-hau_Latn', 'fuv_Latn-ibo_Latn', + 'fuv_Latn-kam_Latn', 'fuv_Latn-kik_Latn', 'fuv_Latn-kin_Latn', + 'fuv_Latn-kmb_Latn', 'fuv_Latn-knc_Arab', 'fuv_Latn-knc_Latn', + 'fuv_Latn-kon_Latn', 'fuv_Latn-lin_Latn', 'fuv_Latn-lua_Latn', + 'fuv_Latn-lug_Latn', 'fuv_Latn-luo_Latn', 'fuv_Latn-nso_Latn', + 'fuv_Latn-nus_Latn', 'fuv_Latn-nya_Latn', 'fuv_Latn-run_Latn', + 'fuv_Latn-sna_Latn', 'fuv_Latn-som_Latn', 'fuv_Latn-sot_Latn', + 'fuv_Latn-ssw_Latn', 'fuv_Latn-swh_Latn', 'fuv_Latn-tir_Ethi', + 'fuv_Latn-tsn_Latn', 'fuv_Latn-tso_Latn', 'fuv_Latn-tum_Latn', + 'fuv_Latn-twi_Latn', 'fuv_Latn-umb_Latn', 'fuv_Latn-wol_Latn', + 'fuv_Latn-xho_Latn', 'fuv_Latn-yor_Latn', 'fuv_Latn-zul_Latn', + 'gaz_Latn-run_Latn', 'gaz_Latn-sna_Latn', 'gaz_Latn-som_Latn', + 'gaz_Latn-sot_Latn', 'gaz_Latn-ssw_Latn', 'gaz_Latn-swh_Latn', + 'gaz_Latn-tir_Ethi', 'gaz_Latn-tsn_Latn', 'gaz_Latn-tso_Latn', + 'gaz_Latn-tum_Latn', 'gaz_Latn-twi_Latn', 'gaz_Latn-umb_Latn', + 'gaz_Latn-wol_Latn', 'gaz_Latn-xho_Latn', 'gaz_Latn-yor_Latn', + 'gaz_Latn-zul_Latn', 'glg_Latn-por_Latn', 'grn_Latn-por_Latn', + 'guj_Gujr-hin_Deva', 'guj_Gujr-hne_Deva', 'guj_Gujr-kan_Knda', + 'guj_Gujr-kas_Arab', 'guj_Gujr-kas_Deva', 'guj_Gujr-mag_Deva', + 'guj_Gujr-mai_Deva', 'guj_Gujr-mal_Mlym', 'guj_Gujr-mar_Deva', + 'guj_Gujr-npi_Deva', 'guj_Gujr-ory_Orya', 'guj_Gujr-pan_Guru', + 'guj_Gujr-san_Deva', 'guj_Gujr-sat_Beng', 'guj_Gujr-sin_Sinh', + 'guj_Gujr-snd_Arab', 'guj_Gujr-tam_Taml', 'guj_Gujr-tel_Telu', + 'guj_Gujr-urd_Arab', 'hau_Latn-gaz_Latn', 'hau_Latn-ibo_Latn', + 'hau_Latn-kam_Latn', 'hau_Latn-kik_Latn', 'hau_Latn-kin_Latn', + 'hau_Latn-kmb_Latn', 'hau_Latn-knc_Arab', 'hau_Latn-knc_Latn', + 'hau_Latn-kon_Latn', 'hau_Latn-lin_Latn', 'hau_Latn-lua_Latn', + 'hau_Latn-lug_Latn', 'hau_Latn-luo_Latn', 'hau_Latn-nso_Latn', + 'hau_Latn-nus_Latn', 'hau_Latn-nya_Latn', 'hau_Latn-run_Latn', + 'hau_Latn-sna_Latn', 'hau_Latn-som_Latn', 'hau_Latn-sot_Latn', + 'hau_Latn-ssw_Latn', 'hau_Latn-swh_Latn', 'hau_Latn-tir_Ethi', + 'hau_Latn-tsn_Latn', 'hau_Latn-tso_Latn', 'hau_Latn-tum_Latn', + 'hau_Latn-twi_Latn', 'hau_Latn-umb_Latn', 'hau_Latn-wol_Latn', + 'hau_Latn-xho_Latn', 'hau_Latn-yor_Latn', 'hau_Latn-zul_Latn', + 'hin_Deva-hne_Deva', 'hin_Deva-kan_Knda', 'hin_Deva-kas_Arab', + 'hin_Deva-kas_Deva', 'hin_Deva-mag_Deva', 'hin_Deva-mai_Deva', + 'hin_Deva-mal_Mlym', 'hin_Deva-mar_Deva', 'hin_Deva-npi_Deva', + 'hin_Deva-ory_Orya', 'hin_Deva-pan_Guru', 'hin_Deva-pbt_Arab', + 'hin_Deva-san_Deva', 'hin_Deva-sat_Beng', 'hin_Deva-sin_Sinh', + 'hin_Deva-snd_Arab', 'hin_Deva-tam_Taml', 'hin_Deva-tel_Telu', + 'hin_Deva-urd_Arab', 'hne_Deva-kan_Knda', 'hne_Deva-kas_Arab', + 'hne_Deva-kas_Deva', 'hne_Deva-mag_Deva', 'hne_Deva-mai_Deva', + 'hne_Deva-mal_Mlym', 'hne_Deva-mar_Deva', 'hne_Deva-npi_Deva', + 'hne_Deva-ory_Orya', 'hne_Deva-pan_Guru', 'hne_Deva-san_Deva', + 'hne_Deva-sat_Beng', 'hne_Deva-sin_Sinh', 'hne_Deva-snd_Arab', + 'hne_Deva-tam_Taml', 'hne_Deva-tel_Telu', 'hne_Deva-urd_Arab', + 'hye_Armn-rus_Cyrl', 'ibo_Latn-gaz_Latn', 'ibo_Latn-kam_Latn', + 'ibo_Latn-kik_Latn', 'ibo_Latn-kin_Latn', 'ibo_Latn-kmb_Latn', + 'ibo_Latn-knc_Arab', 'ibo_Latn-knc_Latn', 'ibo_Latn-kon_Latn', + 'ibo_Latn-lin_Latn', 'ibo_Latn-lua_Latn', 'ibo_Latn-lug_Latn', + 'ibo_Latn-luo_Latn', 'ibo_Latn-nso_Latn', 'ibo_Latn-nus_Latn', + 'ibo_Latn-nya_Latn', 'ibo_Latn-run_Latn', 'ibo_Latn-sna_Latn', + 'ibo_Latn-som_Latn', 'ibo_Latn-sot_Latn', 'ibo_Latn-ssw_Latn', + 'ibo_Latn-swh_Latn', 'ibo_Latn-tir_Ethi', 'ibo_Latn-tsn_Latn', + 'ibo_Latn-tso_Latn', 'ibo_Latn-tum_Latn', 'ibo_Latn-twi_Latn', + 'ibo_Latn-umb_Latn', 'ibo_Latn-wol_Latn', 'ibo_Latn-xho_Latn', + 'ibo_Latn-yor_Latn', 'ibo_Latn-zul_Latn', 'ilo_Latn-jav_Latn', + 'ilo_Latn-min_Latn', 'ilo_Latn-mri_Latn', 'ilo_Latn-pag_Latn', + 'ilo_Latn-plt_Latn', 'ilo_Latn-smo_Latn', 'ilo_Latn-sun_Latn', + 'ilo_Latn-war_Latn', 'ind_Latn-ace_Latn', 'ind_Latn-ban_Latn', + 'ind_Latn-jav_Latn', 'ind_Latn-khm_Khmr', 'ind_Latn-lao_Laoo', + 'ind_Latn-min_Latn', 'ind_Latn-mya_Mymr', 'ind_Latn-shn_Mymr', + 'ind_Latn-sun_Latn', 'jav_Latn-min_Latn', 'jav_Latn-mri_Latn', + 'jav_Latn-pag_Latn', 'jav_Latn-plt_Latn', 'jav_Latn-smo_Latn', + 'jav_Latn-sun_Latn', 'jav_Latn-war_Latn', 'kam_Latn-gaz_Latn', + 'kam_Latn-kik_Latn', 'kam_Latn-kin_Latn', 'kam_Latn-kmb_Latn', + 'kam_Latn-knc_Arab', 'kam_Latn-knc_Latn', 'kam_Latn-kon_Latn', + 'kam_Latn-lin_Latn', 'kam_Latn-lua_Latn', 'kam_Latn-lug_Latn', + 'kam_Latn-luo_Latn', 'kam_Latn-nso_Latn', 'kam_Latn-nus_Latn', + 'kam_Latn-nya_Latn', 'kam_Latn-run_Latn', 'kam_Latn-sna_Latn', + 'kam_Latn-som_Latn', 'kam_Latn-sot_Latn', 'kam_Latn-ssw_Latn', + 'kam_Latn-swh_Latn', 'kam_Latn-tir_Ethi', 'kam_Latn-tsn_Latn', + 'kam_Latn-tso_Latn', 'kam_Latn-tum_Latn', 'kam_Latn-twi_Latn', + 'kam_Latn-umb_Latn', 'kam_Latn-wol_Latn', 'kam_Latn-xho_Latn', + 'kam_Latn-yor_Latn', 'kam_Latn-zul_Latn', 'kan_Knda-kas_Arab', + 'kan_Knda-kas_Deva', 'kan_Knda-mag_Deva', 'kan_Knda-mai_Deva', + 'kan_Knda-mal_Mlym', 'kan_Knda-mar_Deva', 'kan_Knda-npi_Deva', + 'kan_Knda-ory_Orya', 'kan_Knda-pan_Guru', 'kan_Knda-san_Deva', + 'kan_Knda-sat_Beng', 'kan_Knda-sin_Sinh', 'kan_Knda-snd_Arab', + 'kan_Knda-tam_Taml', 'kan_Knda-tel_Telu', 'kan_Knda-urd_Arab', + 'kas_Arab-kas_Deva', 'kas_Arab-mag_Deva', 'kas_Arab-mai_Deva', + 'kas_Arab-mal_Mlym', 'kas_Arab-mar_Deva', 'kas_Arab-npi_Deva', + 'kas_Arab-ory_Orya', 'kas_Arab-pan_Guru', 'kas_Arab-san_Deva', + 'kas_Arab-sat_Beng', 'kas_Arab-sin_Sinh', 'kas_Arab-snd_Arab', + 'kas_Arab-tam_Taml', 'kas_Arab-tel_Telu', 'kas_Arab-urd_Arab', + 'kas_Deva-mag_Deva', 'kas_Deva-mai_Deva', 'kas_Deva-mal_Mlym', + 'kas_Deva-mar_Deva', 'kas_Deva-npi_Deva', 'kas_Deva-ory_Orya', + 'kas_Deva-pan_Guru', 'kas_Deva-san_Deva', 'kas_Deva-sat_Beng', + 'kas_Deva-sin_Sinh', 'kas_Deva-snd_Arab', 'kas_Deva-tam_Taml', + 'kas_Deva-tel_Telu', 'kas_Deva-urd_Arab', 'kat_Geor-rus_Cyrl', + 'kea_Latn-por_Latn', 'kik_Latn-gaz_Latn', 'kik_Latn-kin_Latn', + 'kik_Latn-kmb_Latn', 'kik_Latn-kon_Latn', 'kik_Latn-lin_Latn', + 'kik_Latn-lua_Latn', 'kik_Latn-lug_Latn', 'kik_Latn-luo_Latn', + 'kik_Latn-nso_Latn', 'kik_Latn-nus_Latn', 'kik_Latn-nya_Latn', + 'kik_Latn-run_Latn', 'kik_Latn-sna_Latn', 'kik_Latn-som_Latn', + 'kik_Latn-sot_Latn', 'kik_Latn-ssw_Latn', 'kik_Latn-swh_Latn', + 'kik_Latn-tir_Ethi', 'kik_Latn-tsn_Latn', 'kik_Latn-tso_Latn', + 'kik_Latn-tum_Latn', 'kik_Latn-twi_Latn', 'kik_Latn-umb_Latn', + 'kik_Latn-wol_Latn', 'kik_Latn-xho_Latn', 'kik_Latn-yor_Latn', + 'kik_Latn-zul_Latn', 'kin_Latn-gaz_Latn', 'kin_Latn-kmb_Latn', + 'kin_Latn-kon_Latn', 'kin_Latn-lin_Latn', 'kin_Latn-lua_Latn', + 'kin_Latn-lug_Latn', 'kin_Latn-luo_Latn', 'kin_Latn-nso_Latn', + 'kin_Latn-nus_Latn', 'kin_Latn-nya_Latn', 'kin_Latn-run_Latn', + 'kin_Latn-sna_Latn', 'kin_Latn-som_Latn', 'kin_Latn-sot_Latn', + 'kin_Latn-ssw_Latn', 'kin_Latn-swh_Latn', 'kin_Latn-tir_Ethi', + 'kin_Latn-tsn_Latn', 'kin_Latn-tso_Latn', 'kin_Latn-tum_Latn', + 'kin_Latn-twi_Latn', 'kin_Latn-umb_Latn', 'kin_Latn-wol_Latn', + 'kin_Latn-xho_Latn', 'kin_Latn-yor_Latn', 'kin_Latn-zul_Latn', + 'kir_Cyrl-rus_Cyrl', 'kir_Cyrl-tat_Cyrl', 'kir_Cyrl-tuk_Latn', + 'kir_Cyrl-uig_Arab', 'kir_Cyrl-uzn_Latn', 'kmb_Latn-gaz_Latn', + 'kmb_Latn-kon_Latn', 'kmb_Latn-lin_Latn', 'kmb_Latn-lua_Latn', + 'kmb_Latn-lug_Latn', 'kmb_Latn-luo_Latn', 'kmb_Latn-nso_Latn', + 'kmb_Latn-nus_Latn', 'kmb_Latn-nya_Latn', 'kmb_Latn-por_Latn', + 'kmb_Latn-run_Latn', 'kmb_Latn-sna_Latn', 'kmb_Latn-som_Latn', + 'kmb_Latn-sot_Latn', 'kmb_Latn-ssw_Latn', 'kmb_Latn-swh_Latn', + 'kmb_Latn-tir_Ethi', 'kmb_Latn-tsn_Latn', 'kmb_Latn-tso_Latn', + 'kmb_Latn-tum_Latn', 'kmb_Latn-twi_Latn', 'kmb_Latn-umb_Latn', + 'kmb_Latn-wol_Latn', 'kmb_Latn-xho_Latn', 'kmb_Latn-yor_Latn', + 'kmb_Latn-zul_Latn', 'kmr_Latn-pbt_Arab', 'kmr_Latn-prs_Arab', + 'kmr_Latn-tgk_Cyrl', 'knc_Arab-gaz_Latn', 'knc_Arab-kik_Latn', + 'knc_Arab-kin_Latn', 'knc_Arab-kmb_Latn', 'knc_Arab-knc_Latn', + 'knc_Arab-kon_Latn', 'knc_Arab-lin_Latn', 'knc_Arab-lua_Latn', + 'knc_Arab-lug_Latn', 'knc_Arab-luo_Latn', 'knc_Arab-nso_Latn', + 'knc_Arab-nus_Latn', 'knc_Arab-nya_Latn', 'knc_Arab-run_Latn', + 'knc_Arab-sna_Latn', 'knc_Arab-som_Latn', 'knc_Arab-sot_Latn', + 'knc_Arab-ssw_Latn', 'knc_Arab-swh_Latn', 'knc_Arab-tir_Ethi', + 'knc_Arab-tsn_Latn', 'knc_Arab-tso_Latn', 'knc_Arab-tum_Latn', + 'knc_Arab-twi_Latn', 'knc_Arab-umb_Latn', 'knc_Arab-wol_Latn', + 'knc_Arab-xho_Latn', 'knc_Arab-yor_Latn', 'knc_Arab-zul_Latn', + 'knc_Latn-gaz_Latn', 'knc_Latn-kik_Latn', 'knc_Latn-kin_Latn', + 'knc_Latn-kmb_Latn', 'knc_Latn-kon_Latn', 'knc_Latn-lin_Latn', + 'knc_Latn-lua_Latn', 'knc_Latn-lug_Latn', 'knc_Latn-luo_Latn', + 'knc_Latn-nso_Latn', 'knc_Latn-nus_Latn', 'knc_Latn-nya_Latn', + 'knc_Latn-run_Latn', 'knc_Latn-sna_Latn', 'knc_Latn-som_Latn', + 'knc_Latn-sot_Latn', 'knc_Latn-ssw_Latn', 'knc_Latn-swh_Latn', + 'knc_Latn-tir_Ethi', 'knc_Latn-tsn_Latn', 'knc_Latn-tso_Latn', + 'knc_Latn-tum_Latn', 'knc_Latn-twi_Latn', 'knc_Latn-umb_Latn', + 'knc_Latn-wol_Latn', 'knc_Latn-xho_Latn', 'knc_Latn-yor_Latn', + 'knc_Latn-zul_Latn', 'kon_Latn-gaz_Latn', 'kon_Latn-lin_Latn', + 'kon_Latn-lua_Latn', 'kon_Latn-lug_Latn', 'kon_Latn-luo_Latn', + 'kon_Latn-nso_Latn', 'kon_Latn-nus_Latn', 'kon_Latn-nya_Latn', + 'kon_Latn-run_Latn', 'kon_Latn-sna_Latn', 'kon_Latn-som_Latn', + 'kon_Latn-sot_Latn', 'kon_Latn-ssw_Latn', 'kon_Latn-swh_Latn', + 'kon_Latn-tir_Ethi', 'kon_Latn-tsn_Latn', 'kon_Latn-tso_Latn', + 'kon_Latn-tum_Latn', 'kon_Latn-twi_Latn', 'kon_Latn-umb_Latn', + 'kon_Latn-wol_Latn', 'kon_Latn-xho_Latn', 'kon_Latn-yor_Latn', + 'kon_Latn-zul_Latn', 'lao_Laoo-rus_Cyrl', 'lin_Latn-gaz_Latn', + 'lin_Latn-lua_Latn', 'lin_Latn-lug_Latn', 'lin_Latn-luo_Latn', + 'lin_Latn-nso_Latn', 'lin_Latn-nus_Latn', 'lin_Latn-nya_Latn', + 'lin_Latn-run_Latn', 'lin_Latn-sna_Latn', 'lin_Latn-som_Latn', + 'lin_Latn-sot_Latn', 'lin_Latn-ssw_Latn', 'lin_Latn-swh_Latn', + 'lin_Latn-tir_Ethi', 'lin_Latn-tsn_Latn', 'lin_Latn-tso_Latn', + 'lin_Latn-tum_Latn', 'lin_Latn-twi_Latn', 'lin_Latn-umb_Latn', + 'lin_Latn-wol_Latn', 'lin_Latn-xho_Latn', 'lin_Latn-yor_Latn', + 'lin_Latn-zul_Latn', 'ltg_Latn-rus_Cyrl', 'lua_Latn-gaz_Latn', + 'lua_Latn-lug_Latn', 'lua_Latn-luo_Latn', 'lua_Latn-nso_Latn', + 'lua_Latn-nus_Latn', 'lua_Latn-nya_Latn', 'lua_Latn-run_Latn', + 'lua_Latn-sna_Latn', 'lua_Latn-som_Latn', 'lua_Latn-sot_Latn', + 'lua_Latn-ssw_Latn', 'lua_Latn-swh_Latn', 'lua_Latn-tir_Ethi', + 'lua_Latn-tsn_Latn', 'lua_Latn-tso_Latn', 'lua_Latn-tum_Latn', + 'lua_Latn-twi_Latn', 'lua_Latn-umb_Latn', 'lua_Latn-wol_Latn', + 'lua_Latn-xho_Latn', 'lua_Latn-yor_Latn', 'lua_Latn-zul_Latn', + 'lug_Latn-gaz_Latn', 'lug_Latn-luo_Latn', 'lug_Latn-nso_Latn', + 'lug_Latn-nus_Latn', 'lug_Latn-nya_Latn', 'lug_Latn-run_Latn', + 'lug_Latn-sna_Latn', 'lug_Latn-som_Latn', 'lug_Latn-sot_Latn', + 'lug_Latn-ssw_Latn', 'lug_Latn-swh_Latn', 'lug_Latn-tir_Ethi', + 'lug_Latn-tsn_Latn', 'lug_Latn-tso_Latn', 'lug_Latn-tum_Latn', + 'lug_Latn-twi_Latn', 'lug_Latn-umb_Latn', 'lug_Latn-wol_Latn', + 'lug_Latn-xho_Latn', 'lug_Latn-yor_Latn', 'lug_Latn-zul_Latn', + 'luo_Latn-gaz_Latn', 'luo_Latn-nso_Latn', 'luo_Latn-nus_Latn', + 'luo_Latn-nya_Latn', 'luo_Latn-run_Latn', 'luo_Latn-sna_Latn', + 'luo_Latn-som_Latn', 'luo_Latn-sot_Latn', 'luo_Latn-ssw_Latn', + 'luo_Latn-swh_Latn', 'luo_Latn-tir_Ethi', 'luo_Latn-tsn_Latn', + 'luo_Latn-tso_Latn', 'luo_Latn-tum_Latn', 'luo_Latn-twi_Latn', + 'luo_Latn-umb_Latn', 'luo_Latn-wol_Latn', 'luo_Latn-xho_Latn', + 'luo_Latn-yor_Latn', 'luo_Latn-zul_Latn', 'mag_Deva-mai_Deva', + 'mag_Deva-mal_Mlym', 'mag_Deva-mar_Deva', 'mag_Deva-npi_Deva', + 'mag_Deva-ory_Orya', 'mag_Deva-pan_Guru', 'mag_Deva-san_Deva', + 'mag_Deva-sat_Beng', 'mag_Deva-sin_Sinh', 'mag_Deva-snd_Arab', + 'mag_Deva-tam_Taml', 'mag_Deva-tel_Telu', 'mag_Deva-urd_Arab', + 'mai_Deva-mal_Mlym', 'mai_Deva-mar_Deva', 'mai_Deva-npi_Deva', + 'mai_Deva-ory_Orya', 'mai_Deva-pan_Guru', 'mai_Deva-san_Deva', + 'mai_Deva-sat_Beng', 'mai_Deva-sin_Sinh', 'mai_Deva-snd_Arab', + 'mai_Deva-tam_Taml', 'mai_Deva-tel_Telu', 'mai_Deva-urd_Arab', + 'mal_Mlym-mar_Deva', 'mal_Mlym-npi_Deva', 'mal_Mlym-ory_Orya', + 'mal_Mlym-pan_Guru', 'mal_Mlym-san_Deva', 'mal_Mlym-sat_Beng', + 'mal_Mlym-sin_Sinh', 'mal_Mlym-snd_Arab', 'mal_Mlym-tam_Taml', + 'mal_Mlym-tel_Telu', 'mal_Mlym-urd_Arab', 'mar_Deva-npi_Deva', + 'mar_Deva-ory_Orya', 'mar_Deva-pan_Guru', 'mar_Deva-san_Deva', + 'mar_Deva-sat_Beng', 'mar_Deva-sin_Sinh', 'mar_Deva-snd_Arab', + 'mar_Deva-tam_Taml', 'mar_Deva-tel_Telu', 'mar_Deva-urd_Arab', + 'min_Latn-mri_Latn', 'min_Latn-pag_Latn', 'min_Latn-plt_Latn', + 'min_Latn-smo_Latn', 'min_Latn-sun_Latn', 'min_Latn-war_Latn', + 'mri_Latn-pag_Latn', 'mri_Latn-smo_Latn', 'mri_Latn-sun_Latn', + 'mri_Latn-war_Latn', 'npi_Deva-ory_Orya', 'npi_Deva-pan_Guru', + 'npi_Deva-san_Deva', 'npi_Deva-sat_Beng', 'npi_Deva-sin_Sinh', + 'npi_Deva-snd_Arab', 'npi_Deva-tam_Taml', 'npi_Deva-tel_Telu', + 'npi_Deva-urd_Arab', 'nso_Latn-gaz_Latn', 'nso_Latn-nus_Latn', + 'nso_Latn-nya_Latn', 'nso_Latn-run_Latn', 'nso_Latn-sna_Latn', + 'nso_Latn-som_Latn', 'nso_Latn-sot_Latn', 'nso_Latn-ssw_Latn', + 'nso_Latn-swh_Latn', 'nso_Latn-tir_Ethi', 'nso_Latn-tsn_Latn', + 'nso_Latn-tso_Latn', 'nso_Latn-tum_Latn', 'nso_Latn-twi_Latn', + 'nso_Latn-umb_Latn', 'nso_Latn-wol_Latn', 'nso_Latn-xho_Latn', + 'nso_Latn-yor_Latn', 'nso_Latn-zul_Latn', 'nus_Latn-gaz_Latn', + 'nus_Latn-nya_Latn', 'nus_Latn-run_Latn', 'nus_Latn-sna_Latn', + 'nus_Latn-som_Latn', 'nus_Latn-sot_Latn', 'nus_Latn-ssw_Latn', + 'nus_Latn-swh_Latn', 'nus_Latn-tir_Ethi', 'nus_Latn-tsn_Latn', + 'nus_Latn-tso_Latn', 'nus_Latn-tum_Latn', 'nus_Latn-twi_Latn', + 'nus_Latn-umb_Latn', 'nus_Latn-wol_Latn', 'nus_Latn-xho_Latn', + 'nus_Latn-yor_Latn', 'nus_Latn-zul_Latn', 'nya_Latn-gaz_Latn', + 'nya_Latn-run_Latn', 'nya_Latn-sna_Latn', 'nya_Latn-som_Latn', + 'nya_Latn-sot_Latn', 'nya_Latn-ssw_Latn', 'nya_Latn-swh_Latn', + 'nya_Latn-tir_Ethi', 'nya_Latn-tsn_Latn', 'nya_Latn-tso_Latn', + 'nya_Latn-tum_Latn', 'nya_Latn-twi_Latn', 'nya_Latn-umb_Latn', + 'nya_Latn-wol_Latn', 'nya_Latn-xho_Latn', 'nya_Latn-yor_Latn', + 'nya_Latn-zul_Latn', 'oci_Latn-por_Latn', 'ory_Orya-pan_Guru', + 'ory_Orya-san_Deva', 'ory_Orya-sat_Beng', 'ory_Orya-sin_Sinh', + 'ory_Orya-snd_Arab', 'ory_Orya-tam_Taml', 'ory_Orya-tel_Telu', + 'ory_Orya-urd_Arab', 'pag_Latn-smo_Latn', 'pag_Latn-sun_Latn', + 'pan_Guru-san_Deva', 'pan_Guru-sat_Beng', 'pan_Guru-sin_Sinh', + 'pan_Guru-snd_Arab', 'pan_Guru-tam_Taml', 'pan_Guru-tel_Telu', + 'pan_Guru-urd_Arab', 'pbt_Arab-tam_Taml', 'pbt_Arab-tgk_Cyrl', + 'plt_Latn-mri_Latn', 'plt_Latn-pag_Latn', 'plt_Latn-smo_Latn', + 'plt_Latn-sun_Latn', 'plt_Latn-war_Latn', 'por_Latn-ayr_Latn', + 'por_Latn-quy_Latn', 'prs_Arab-pbt_Arab', 'prs_Arab-tgk_Cyrl', + 'quy_Latn-spa_Latn', 'run_Latn-sna_Latn', 'run_Latn-som_Latn', + 'run_Latn-sot_Latn', 'run_Latn-ssw_Latn', 'run_Latn-swh_Latn', + 'run_Latn-tir_Ethi', 'run_Latn-tsn_Latn', 'run_Latn-tso_Latn', + 'run_Latn-tum_Latn', 'run_Latn-twi_Latn', 'run_Latn-umb_Latn', + 'run_Latn-wol_Latn', 'run_Latn-xho_Latn', 'run_Latn-yor_Latn', + 'run_Latn-zul_Latn', 'rus_Cyrl-tat_Cyrl', 'rus_Cyrl-tgk_Cyrl', + 'san_Deva-sat_Beng', 'san_Deva-sin_Sinh', 'san_Deva-snd_Arab', + 'san_Deva-tam_Taml', 'san_Deva-tel_Telu', 'san_Deva-urd_Arab', + 'sat_Beng-sin_Sinh', 'sat_Beng-snd_Arab', 'sat_Beng-tam_Taml', + 'sat_Beng-tel_Telu', 'sat_Beng-urd_Arab', 'sin_Sinh-snd_Arab', + 'sin_Sinh-tam_Taml', 'sin_Sinh-tel_Telu', 'sin_Sinh-urd_Arab', + 'smo_Latn-sun_Latn', 'smo_Latn-war_Latn', 'sna_Latn-som_Latn', + 'sna_Latn-sot_Latn', 'sna_Latn-ssw_Latn', 'sna_Latn-swh_Latn', + 'sna_Latn-tir_Ethi', 'sna_Latn-tsn_Latn', 'sna_Latn-tso_Latn', + 'sna_Latn-tum_Latn', 'sna_Latn-twi_Latn', 'sna_Latn-umb_Latn', + 'sna_Latn-wol_Latn', 'sna_Latn-xho_Latn', 'sna_Latn-yor_Latn', + 'sna_Latn-zul_Latn', 'snd_Arab-tam_Taml', 'snd_Arab-tel_Telu', + 'snd_Arab-urd_Arab', 'som_Latn-sot_Latn', 'som_Latn-ssw_Latn', + 'som_Latn-swh_Latn', 'som_Latn-tir_Ethi', 'som_Latn-tsn_Latn', + 'som_Latn-tso_Latn', 'som_Latn-tum_Latn', 'som_Latn-twi_Latn', + 'som_Latn-umb_Latn', 'som_Latn-wol_Latn', 'som_Latn-xho_Latn', + 'som_Latn-yor_Latn', 'som_Latn-zul_Latn', 'sot_Latn-ssw_Latn', + 'sot_Latn-swh_Latn', 'sot_Latn-tir_Ethi', 'sot_Latn-tsn_Latn', + 'sot_Latn-tso_Latn', 'sot_Latn-tum_Latn', 'sot_Latn-twi_Latn', + 'sot_Latn-umb_Latn', 'sot_Latn-wol_Latn', 'sot_Latn-xho_Latn', + 'sot_Latn-yor_Latn', 'sot_Latn-zul_Latn', 'ssw_Latn-swh_Latn', + 'ssw_Latn-tir_Ethi', 'ssw_Latn-tsn_Latn', 'ssw_Latn-tso_Latn', + 'ssw_Latn-tum_Latn', 'ssw_Latn-twi_Latn', 'ssw_Latn-umb_Latn', + 'ssw_Latn-wol_Latn', 'ssw_Latn-xho_Latn', 'ssw_Latn-yor_Latn', + 'ssw_Latn-zul_Latn', 'sun_Latn-war_Latn', 'swh_Latn-tir_Ethi', + 'swh_Latn-tsn_Latn', 'swh_Latn-tso_Latn', 'swh_Latn-tum_Latn', + 'swh_Latn-twi_Latn', 'swh_Latn-umb_Latn', 'swh_Latn-wol_Latn', + 'swh_Latn-xho_Latn', 'swh_Latn-yor_Latn', 'swh_Latn-zul_Latn', + 'tam_Taml-tel_Telu', 'tam_Taml-urd_Arab', 'tat_Cyrl-tuk_Latn', + 'tat_Cyrl-uig_Arab', 'tat_Cyrl-uzn_Latn', 'tel_Telu-urd_Arab', + 'tir_Ethi-tsn_Latn', 'tir_Ethi-tso_Latn', 'tir_Ethi-tum_Latn', + 'tir_Ethi-twi_Latn', 'tir_Ethi-umb_Latn', 'tir_Ethi-wol_Latn', + 'tir_Ethi-xho_Latn', 'tir_Ethi-yor_Latn', 'tir_Ethi-zul_Latn', + 'tsn_Latn-tso_Latn', 'tsn_Latn-tum_Latn', 'tsn_Latn-twi_Latn', + 'tsn_Latn-umb_Latn', 'tsn_Latn-wol_Latn', 'tsn_Latn-xho_Latn', + 'tsn_Latn-yor_Latn', 'tsn_Latn-zul_Latn', 'tso_Latn-tum_Latn', + 'tso_Latn-twi_Latn', 'tso_Latn-umb_Latn', 'tso_Latn-wol_Latn', + 'tso_Latn-xho_Latn', 'tso_Latn-yor_Latn', 'tso_Latn-zul_Latn', + 'tuk_Latn-uig_Arab', 'tuk_Latn-uzn_Latn', 'tum_Latn-twi_Latn', + 'tum_Latn-umb_Latn', 'tum_Latn-wol_Latn', 'tum_Latn-xho_Latn', + 'tum_Latn-yor_Latn', 'tum_Latn-zul_Latn', 'twi_Latn-umb_Latn', + 'twi_Latn-wol_Latn', 'twi_Latn-xho_Latn', 'twi_Latn-yor_Latn', + 'twi_Latn-zul_Latn', 'uig_Arab-uzn_Latn', 'umb_Latn-wol_Latn', + 'umb_Latn-xho_Latn', 'umb_Latn-yor_Latn', 'umb_Latn-zul_Latn', + 'wol_Latn-xho_Latn', 'wol_Latn-yor_Latn', 'wol_Latn-zul_Latn', + 'xho_Latn-yor_Latn', 'xho_Latn-zul_Latn', 'yor_Latn-zul_Latn' + ] + subset = subset[:self.subset_count] + for subset_name in subset: + self.download_subset('nllb', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_universal_dependencies(self): + subset = [ + 'af_afribooms', 'akk_pisandub', 'akk_riao', 'aqz_tudet', 'sq_tsa', + 'am_att', 'grc_perseus', 'grc_proiel', 'apu_ufpa', 'ar_nyuad', + 'ar_padt', 'ar_pud', 'hy_armtdp', 'aii_as', 'bm_crb', 'eu_bdt', + 'be_hse', 'bho_bhtb', 'br_keb', 'bg_btb', 'bxr_bdt', 'yue_hk', + 'ca_ancora', 'zh_cfl', 'zh_gsd', 'zh_gsdsimp', 'zh_hk', 'zh_pud', + 'ckt_hse', 'lzh_kyoto', 'cop_scriptorium', 'hr_set', 'cs_cac', + 'cs_cltt', 'cs_fictree', 'cs_pdt', 'cs_pud', 'da_ddt', 'nl_alpino', + 'nl_lassysmall', 'en_esl', 'en_ewt', 'en_gum', 'en_gumreddit', + 'en_lines', 'en_partut', 'en_pronouns', 'en_pud', 'myv_jr', + 'et_edt', 'et_ewt', 'fo_farpahc', 'fo_oft', 'fi_ftb', 'fi_ood', + 'fi_pud', 'fi_tdt', 'fr_fqb', 'fr_ftb', 'fr_gsd', 'fr_partut', + 'fr_pud', 'fr_sequoia', 'fr_spoken', 'gl_ctg', 'gl_treegal', + 'de_gsd', 'de_hdt', 'de_lit', 'de_pud', 'got_proiel', 'el_gdt', + 'he_htb', 'qhe_hiencs', 'hi_hdtb', 'hi_pud', 'hu_szeged', + 'is_icepahc', 'is_pud', 'id_csui', 'id_gsd', 'id_pud', 'ga_idt', + 'it_isdt', 'it_partut', 'it_postwita', 'it_pud', 'it_twittiro', + 'it_vit', 'ja_bccwj', 'ja_gsd', 'ja_modern', 'ja_pud', 'krl_kkpp', + 'kk_ktb', 'kfm_aha', 'koi_uh', 'kpv_ikdp', 'kpv_lattice', 'ko_gsd', + 'ko_kaist', 'ko_pud', 'kmr_mg', 'la_ittb', 'la_llct', 'la_perseus', + 'la_proiel', 'lv_lvtb', 'lt_alksnis', 'lt_hse', 'olo_kkpp', + 'mt_mudt', 'gv_cadhan', 'mr_ufal', 'gun_dooley', 'gun_thomas', + 'mdf_jr', 'myu_tudet', 'pcm_nsc', 'nyq_aha', 'sme_giella', + 'no_bokmaal', 'no_nynorsk', 'no_nynorsklia', 'cu_proiel', + 'fro_srcmf', 'orv_rnc', 'orv_torot', 'otk_tonqq', 'fa_perdt', + 'fa_seraji', 'pl_lfg', 'pl_pdb', 'pl_pud', 'pt_bosque', 'pt_gsd', + 'pt_pud', 'ro_nonstandard', 'ro_rrt', 'ro_simonero', 'ru_gsd', + 'ru_pud', 'ru_syntagrus', 'ru_taiga', 'sa_ufal', 'sa_vedic', + 'gd_arcosg', 'sr_set', 'sms_giellagas', 'sk_snk', 'sl_ssj', + 'sl_sst', 'soj_aha', 'ajp_madar', 'es_ancora', 'es_gsd', 'es_pud', + 'swl_sslc', 'sv_lines', 'sv_pud', 'sv_talbanken', 'gsw_uzh', + 'tl_trg', 'tl_ugnayan', 'ta_mwtt', 'ta_ttb', 'te_mtg', 'th_pud', + 'tpn_tudet', 'qtd_sagt', 'tr_boun', 'tr_gb', 'tr_imst', 'tr_pud', + 'uk_iu', 'hsb_ufal', 'ur_udtb', 'ug_udt', 'vi_vtb', 'wbp_ufal', + 'cy_ccg', 'wo_wtb', 'yo_ytb' + ] + subset = subset[:self.subset_count] + for subset_name in subset: + self.download_subset('universal_dependencies', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_imdb(self): + dataset = MsDataset.load('imdb') + if isinstance(dataset, MsDataset): + lens = len(dataset) + print(f'dataset imdb len: {lens}') + self.assertTrue(lens > 0) + else: + assert isinstance(dataset, dict) + lens = {key: len(subset) for key, subset in dataset.items()} + print(f'dataset imdb len: {lens}') + self.assertTrue(all([_len > 0 for _len in lens.values()])) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_clue(self): + subset = [ + 'afqmc', 'tnews', 'iflytek', 'cmnli', 'cluewsc2020', 'csl', + 'cmrc2018', 'drcd', 'chid', 'c3', 'ocnli', 'diagnostics' + ] + for subset_name in subset: + self.download_subset('clue', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_wikitext(self): + subset = [ + 'wikitext-103-v1', 'wikitext-2-v1', 'wikitext-103-raw-v1', + 'wikitext-2-raw-v1' + ] + for subset_name in subset: + self.download_subset('wikitext', subset_name) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_download_xnli(self): + subset = [ + 'XNLI', 'tydiqa', 'SQuAD', 'PAN-X.af', 'PAN-X.ar', 'PAN-X.bg', + 'PAN-X.bn', 'PAN-X.de', 'PAN-X.el', 'PAN-X.en', 'PAN-X.es', + 'PAN-X.et', 'PAN-X.eu', 'PAN-X.fa', 'PAN-X.fi', 'PAN-X.fr', + 'PAN-X.he', 'PAN-X.hi', 'PAN-X.hu', 'PAN-X.id', 'PAN-X.it', + 'PAN-X.ja', 'PAN-X.jv', 'PAN-X.ka', 'PAN-X.kk', 'PAN-X.ko', + 'PAN-X.ml', 'PAN-X.mr', 'PAN-X.ms', 'PAN-X.my', 'PAN-X.nl', + 'PAN-X.pt', 'PAN-X.ru', 'PAN-X.sw', 'PAN-X.ta', 'PAN-X.te', + 'PAN-X.th', 'PAN-X.tl', 'PAN-X.tr', 'PAN-X.ur', 'PAN-X.vi', + 'PAN-X.yo', 'PAN-X.zh', 'MLQA.ar.ar', 'MLQA.ar.de', 'MLQA.ar.vi', + 'MLQA.ar.zh', 'MLQA.ar.en', 'MLQA.ar.es', 'MLQA.ar.hi', + 'MLQA.de.ar', 'MLQA.de.de', 'MLQA.de.vi', 'MLQA.de.zh', + 'MLQA.de.en', 'MLQA.de.es', 'MLQA.de.hi', 'MLQA.vi.ar', + 'MLQA.vi.de', 'MLQA.vi.vi', 'MLQA.vi.zh', 'MLQA.vi.en', + 'MLQA.vi.es', 'MLQA.vi.hi', 'MLQA.zh.ar', 'MLQA.zh.de', + 'MLQA.zh.vi', 'MLQA.zh.zh', 'MLQA.zh.en', 'MLQA.zh.es', + 'MLQA.zh.hi', 'MLQA.en.ar', 'MLQA.en.de', 'MLQA.en.vi', + 'MLQA.en.zh', 'MLQA.en.en', 'MLQA.en.es', 'MLQA.en.hi', + 'MLQA.es.ar', 'MLQA.es.de', 'MLQA.es.vi', 'MLQA.es.zh', + 'MLQA.es.en', 'MLQA.es.es', 'MLQA.es.hi', 'MLQA.hi.ar', + 'MLQA.hi.de', 'MLQA.hi.vi', 'MLQA.hi.zh', 'MLQA.hi.en', + 'MLQA.hi.es', 'MLQA.hi.hi', 'XQuAD.ar', 'XQuAD.de', 'XQuAD.vi', + 'XQuAD.zh', 'XQuAD.en', 'XQuAD.es', 'XQuAD.hi', 'XQuAD.el', + 'XQuAD.ru', 'XQuAD.th', 'XQuAD.tr', 'bucc18.de', 'bucc18.fr', + 'bucc18.zh', 'bucc18.ru', 'PAWS-X.de', 'PAWS-X.en', 'PAWS-X.es', + 'PAWS-X.fr', 'PAWS-X.ja', 'PAWS-X.ko', 'PAWS-X.zh', 'tatoeba.afr', + 'tatoeba.ara', 'tatoeba.ben', 'tatoeba.bul', 'tatoeba.deu', + 'tatoeba.cmn', 'tatoeba.ell', 'tatoeba.est', 'tatoeba.eus', + 'tatoeba.fin', 'tatoeba.fra', 'tatoeba.heb', 'tatoeba.hin', + 'tatoeba.hun', 'tatoeba.ind', 'tatoeba.ita', 'tatoeba.jav', + 'tatoeba.jpn', 'tatoeba.kat', 'tatoeba.kaz', 'tatoeba.kor', + 'tatoeba.mal', 'tatoeba.mar', 'tatoeba.nld', 'tatoeba.pes', + 'tatoeba.por', 'tatoeba.rus', 'tatoeba.spa', 'tatoeba.swh', + 'tatoeba.tam', 'tatoeba.tel', 'tatoeba.tgl', 'tatoeba.tha', + 'tatoeba.tur', 'tatoeba.urd', 'tatoeba.vie', 'udpos.Afrikaans', + 'udpos.Arabic', 'udpos.Basque', 'udpos.Bulgarian', 'udpos.Dutch', + 'udpos.English', 'udpos.Estonian', 'udpos.Finnish', 'udpos.French', + 'udpos.German', 'udpos.Greek', 'udpos.Hebrew', 'udpos.Hindi', + 'udpos.Hungarian', 'udpos.Indonesian', 'udpos.Italian', + 'udpos.Japanese', 'udpos.Kazakh', 'udpos.Korean', 'udpos.Chinese', + 'udpos.Marathi', 'udpos.Persian', 'udpos.Portuguese', + 'udpos.Russian', 'udpos.Spanish', 'udpos.Tagalog', 'udpos.Tamil', + 'udpos.Telugu', 'udpos.Thai', 'udpos.Turkish', 'udpos.Urdu', + 'udpos.Vietnamese', 'udpos.Yoruba' + ] + subset = subset[:self.subset_count] + for subset_name in subset: + self.download_subset('xtreme', subset_name) diff --git a/tests/hub/test_hub_examples.py b/tests/hub/test_hub_examples.py new file mode 100644 index 00000000..d1f7594e --- /dev/null +++ b/tests/hub/test_hub_examples.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.hub.api import HubApi +from modelscope.utils.hub import create_model_if_not_exist + +# note this is temporary before official account management is ready +YOUR_ACCESS_TOKEN = 'token' + + +class HubExampleTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + self.api.login(YOUR_ACCESS_TOKEN) + + @unittest.skip('to be used for local test only') + def test_example_model_creation(self): + # ATTENTION:change to proper model names before use + model_name = 'cv_unet_person-image-cartoon_compound-models' + model_chinese_name = '达摩卡通化模型' + model_org = 'damo' + model_id = '%s/%s' % (model_org, model_name) + created = create_model_if_not_exist(self.api, model_id, + model_chinese_name) + if not created: + print('!! NOT created since model already exists !!') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py new file mode 100644 index 00000000..5b6e957d --- /dev/null +++ b/tests/hub/test_hub_operation.py @@ -0,0 +1,151 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import unittest +import uuid +from shutil import rmtree + +import requests + +from modelscope.hub.api import HubApi, ModelScopeConfig +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.file_download import model_file_download +from modelscope.hub.repository import Repository +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.utils.constant import ModelFile +from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, + TEST_MODEL_ORG) + +DEFAULT_GIT_PATH = 'git' + +download_model_file_name = 'test.bin' + + +class HubOperationTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + self.api.login(TEST_ACCESS_TOKEN1) + self.model_name = 'op-%s' % (uuid.uuid4().hex) + self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' + self.api.create_model( + model_id=self.model_id, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2, + chinese_name=TEST_MODEL_CHINESE_NAME, + ) + + def tearDown(self): + self.api.delete_model(model_id=self.model_id) + + def prepare_case(self): + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + repo = Repository(self.model_dir, clone_from=self.model_id) + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name)) + repo.push('add model') + repo.tag_and_push(self.revision, 'Test revision') + + def test_model_repo_creation(self): + # change to proper model names before use + try: + info = self.api.get_model(model_id=self.model_id) + assert info['Name'] == self.model_name + except KeyError as ke: + if ke.args[0] == 'name': + print(f'model {self.model_name} already exists, ignore') + else: + raise + + def test_download_single_file(self): + self.prepare_case() + downloaded_file = model_file_download( + model_id=self.model_id, + file_path=download_model_file_name, + revision=self.revision) + assert os.path.exists(downloaded_file) + mdtime1 = os.path.getmtime(downloaded_file) + # download again + downloaded_file = model_file_download( + model_id=self.model_id, file_path=download_model_file_name) + mdtime2 = os.path.getmtime(downloaded_file) + assert mdtime1 == mdtime2 + + def test_snapshot_download(self): + self.prepare_case() + snapshot_path = snapshot_download(model_id=self.model_id) + downloaded_file_path = os.path.join(snapshot_path, + download_model_file_name) + assert os.path.exists(downloaded_file_path) + mdtime1 = os.path.getmtime(downloaded_file_path) + # download again + snapshot_path = snapshot_download( + model_id=self.model_id, revision=self.revision) + mdtime2 = os.path.getmtime(downloaded_file_path) + assert mdtime1 == mdtime2 + + def test_download_public_without_login(self): + try: + self.prepare_case() + rmtree(ModelScopeConfig.path_credential) + snapshot_path = snapshot_download( + model_id=self.model_id, revision=self.revision) + downloaded_file_path = os.path.join(snapshot_path, + download_model_file_name) + assert os.path.exists(downloaded_file_path) + temporary_dir = tempfile.mkdtemp() + downloaded_file = model_file_download( + model_id=self.model_id, + file_path=download_model_file_name, + revision=self.revision, + cache_dir=temporary_dir) + assert os.path.exists(downloaded_file) + finally: + self.api.login(TEST_ACCESS_TOKEN1) + + def test_snapshot_delete_download_cache_file(self): + self.prepare_case() + snapshot_path = snapshot_download( + model_id=self.model_id, revision=self.revision) + downloaded_file_path = os.path.join(snapshot_path, + download_model_file_name) + assert os.path.exists(downloaded_file_path) + os.remove(downloaded_file_path) + # download again in cache + file_download_path = model_file_download( + model_id=self.model_id, + file_path=ModelFile.README, + revision=self.revision) + assert os.path.exists(file_download_path) + # deleted file need download again + file_download_path = model_file_download( + model_id=self.model_id, + file_path=download_model_file_name, + revision=self.revision) + assert os.path.exists(file_download_path) + + def test_snapshot_download_default_revision(self): + pass # TOTO + + def test_file_download_default_revision(self): + pass # TODO + + def get_model_download_times(self): + url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads' + cookies = ModelScopeConfig.get_cookies() + r = requests.get(url, cookies=cookies) + if r.status_code == 200: + return r.json()['Data']['Downloads'] + else: + r.raise_for_status() + return None + + def test_list_model(self): + data = self.api.list_models(TEST_MODEL_ORG) + assert len(data['Models']) >= 1 + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_private_files.py b/tests/hub/test_hub_private_files.py new file mode 100644 index 00000000..73c4cca3 --- /dev/null +++ b/tests/hub/test_hub_private_files.py @@ -0,0 +1,121 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import unittest +import uuid + +from requests.exceptions import HTTPError + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.errors import GitError +from modelscope.hub.file_download import model_file_download +from modelscope.hub.repository import Repository +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.utils.constant import ModelFile +from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, + TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, + delete_credential) + +download_model_file_name = 'test.bin' + + +class HubPrivateFileDownloadTest(unittest.TestCase): + + def setUp(self): + self.old_cwd = os.getcwd() + self.api = HubApi() + self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) + self.model_name = 'pf-%s' % (uuid.uuid4().hex) + self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' + self.api.create_model( + model_id=self.model_id, + visibility=ModelVisibility.PRIVATE, + license=Licenses.APACHE_V2, + chinese_name=TEST_MODEL_CHINESE_NAME, + ) + + def prepare_case(self): + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + repo = Repository(self.model_dir, clone_from=self.model_id) + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name)) + repo.push('add model') + repo.tag_and_push(self.revision, 'Test revision') + + def tearDown(self): + # credential may deleted or switch login name, we need re-login here + # to ensure the temporary model is deleted. + self.api.login(TEST_ACCESS_TOKEN1) + os.chdir(self.old_cwd) + self.api.delete_model(model_id=self.model_id) + + def test_snapshot_download_private_model(self): + self.prepare_case() + snapshot_path = snapshot_download(self.model_id, self.revision) + assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) + + def test_snapshot_download_private_model_no_permission(self): + self.prepare_case() + self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) + with self.assertRaises(HTTPError): + snapshot_download(self.model_id, self.revision) + + def test_snapshot_download_private_model_without_login(self): + self.prepare_case() + delete_credential() + with self.assertRaises(HTTPError): + snapshot_download(self.model_id, self.revision) + + def test_download_file_private_model(self): + self.prepare_case() + file_path = model_file_download(self.model_id, ModelFile.README, + self.revision) + assert os.path.exists(file_path) + + def test_download_file_private_model_no_permission(self): + self.prepare_case() + self.token, _ = self.api.login(TEST_ACCESS_TOKEN2) + with self.assertRaises(HTTPError): + model_file_download(self.model_id, ModelFile.README, self.revision) + + def test_download_file_private_model_without_login(self): + self.prepare_case() + delete_credential() + with self.assertRaises(HTTPError): + model_file_download(self.model_id, ModelFile.README, self.revision) + + def test_snapshot_download_local_only(self): + self.prepare_case() + with self.assertRaises(ValueError): + snapshot_download( + self.model_id, self.revision, local_files_only=True) + snapshot_path = snapshot_download(self.model_id, self.revision) + assert os.path.exists(os.path.join(snapshot_path, ModelFile.README)) + snapshot_path = snapshot_download( + self.model_id, self.revision, local_files_only=True) + assert os.path.exists(snapshot_path) + + def test_file_download_local_only(self): + self.prepare_case() + with self.assertRaises(ValueError): + model_file_download( + self.model_id, + ModelFile.README, + self.revision, + local_files_only=True) + file_path = model_file_download(self.model_id, ModelFile.README, + self.revision) + assert os.path.exists(file_path) + file_path = model_file_download( + self.model_id, + ModelFile.README, + revision=self.revision, + local_files_only=True) + assert os.path.exists(file_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_private_repository.py b/tests/hub/test_hub_private_repository.py new file mode 100644 index 00000000..271a715c --- /dev/null +++ b/tests/hub/test_hub_private_repository.py @@ -0,0 +1,82 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import unittest +import uuid + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.errors import GitError +from modelscope.hub.repository import Repository +from modelscope.utils.constant import ModelFile +from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2, + TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG, + delete_credential) + +DEFAULT_GIT_PATH = 'git' + + +class HubPrivateRepositoryTest(unittest.TestCase): + + def setUp(self): + self.old_cwd = os.getcwd() + self.api = HubApi() + self.token, _ = self.api.login(TEST_ACCESS_TOKEN1) + self.model_name = 'pr-%s' % (uuid.uuid4().hex) + self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.api.create_model( + model_id=self.model_id, + visibility=ModelVisibility.PRIVATE, + license=Licenses.APACHE_V2, + chinese_name=TEST_MODEL_CHINESE_NAME, + ) + + def tearDown(self): + self.api.login(TEST_ACCESS_TOKEN1) + os.chdir(self.old_cwd) + self.api.delete_model(model_id=self.model_id) + + def test_clone_private_repo_no_permission(self): + token, _ = self.api.login(TEST_ACCESS_TOKEN2) + temporary_dir = tempfile.mkdtemp() + local_dir = os.path.join(temporary_dir, self.model_name) + with self.assertRaises(GitError) as cm: + Repository(local_dir, clone_from=self.model_id, auth_token=token) + + print(cm.exception) + assert not os.path.exists(os.path.join(local_dir, ModelFile.README)) + + def test_clone_private_repo_has_permission(self): + temporary_dir = tempfile.mkdtemp() + local_dir = os.path.join(temporary_dir, self.model_name) + Repository(local_dir, clone_from=self.model_id, auth_token=self.token) + assert os.path.exists(os.path.join(local_dir, ModelFile.README)) + + def test_initlize_repo_multiple_times(self): + temporary_dir = tempfile.mkdtemp() + local_dir = os.path.join(temporary_dir, self.model_name) + repo1 = Repository( + local_dir, clone_from=self.model_id, auth_token=self.token) + print(repo1.model_dir) + assert os.path.exists(os.path.join(local_dir, ModelFile.README)) + repo2 = Repository( + local_dir, clone_from=self.model_id, + auth_token=self.token) # skip clone + print(repo2.model_dir) + assert repo1.model_dir == repo2.model_dir + + def test_clone_private_model_without_token(self): + delete_credential() + temporary_dir = tempfile.mkdtemp() + local_dir = os.path.join(temporary_dir, self.model_name) + with self.assertRaises(GitError) as cm: + Repository(local_dir, clone_from=self.model_id) + + print(cm.exception) + assert not os.path.exists(os.path.join(local_dir, ModelFile.README)) + + self.api.login(TEST_ACCESS_TOKEN1) # re-login for delete + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_repository.py b/tests/hub/test_hub_repository.py new file mode 100644 index 00000000..850d5840 --- /dev/null +++ b/tests/hub/test_hub_repository.py @@ -0,0 +1,84 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import uuid +from os.path import expanduser + +from requests import delete + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.errors import NotExistError +from modelscope.hub.file_download import model_file_download +from modelscope.hub.git import GitCommandWrapper +from modelscope.hub.repository import Repository +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, + TEST_MODEL_ORG, delete_credential) + +logger = get_logger() +logger.setLevel('DEBUG') +DEFAULT_GIT_PATH = 'git' +download_model_file_name = 'test.bin' + + +class HubRepositoryTest(unittest.TestCase): + + def setUp(self): + self.old_cwd = os.getcwd() + self.api = HubApi() + self.api.login(TEST_ACCESS_TOKEN1) + self.model_name = 'repo-%s' % (uuid.uuid4().hex) + self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' + self.api.create_model( + model_id=self.model_id, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2, + chinese_name=TEST_MODEL_CHINESE_NAME, + ) + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + + def tearDown(self): + os.chdir(self.old_cwd) + self.api.delete_model(model_id=self.model_id) + + def test_clone_repo(self): + Repository(self.model_dir, clone_from=self.model_id) + assert os.path.exists(os.path.join(self.model_dir, ModelFile.README)) + + def test_clone_public_model_without_token(self): + delete_credential() + Repository(self.model_dir, clone_from=self.model_id) + assert os.path.exists(os.path.join(self.model_dir, ModelFile.README)) + self.api.login(TEST_ACCESS_TOKEN1) # re-login for delete + + def test_push_all(self): + repo = Repository(self.model_dir, clone_from=self.model_id) + assert os.path.exists(os.path.join(self.model_dir, ModelFile.README)) + os.chdir(self.model_dir) + lfs_file1 = 'test1.bin' + lfs_file2 = 'test2.bin' + os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py')) + os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py')) + os.system("echo 'lfs'>%s" % os.path.join(self.model_dir, lfs_file1)) + os.system("echo 'lfs2'>%s" % os.path.join(self.model_dir, lfs_file2)) + repo.push('test') + repo.tag_and_push(self.revision, 'Test revision') + add1 = model_file_download(self.model_id, 'add1.py', self.revision) + assert os.path.exists(add1) + add2 = model_file_download(self.model_id, 'add2.py', self.revision) + assert os.path.exists(add2) + # check lfs files. + git_wrapper = GitCommandWrapper() + lfs_files = git_wrapper.list_lfs_files(self.model_dir) + assert lfs_file1 in lfs_files + assert lfs_file2 in lfs_files + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_revision.py b/tests/hub/test_hub_revision.py new file mode 100644 index 00000000..13ec1c9a --- /dev/null +++ b/tests/hub/test_hub_revision.py @@ -0,0 +1,145 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import unittest +import uuid +from datetime import datetime + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.errors import NotExistError, NoValidRevisionError +from modelscope.hub.file_download import model_file_download +from modelscope.hub.repository import Repository +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, + TEST_MODEL_ORG) + +logger = get_logger() +logger.setLevel('DEBUG') +download_model_file_name = 'test.bin' +download_model_file_name2 = 'test2.bin' + + +class HubRevisionTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + self.api.login(TEST_ACCESS_TOKEN1) + self.model_name = 'rv-%s' % (uuid.uuid4().hex) + self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' + self.revision2 = 'v0.2_test_revision' + self.api.create_model( + model_id=self.model_id, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2, + chinese_name=TEST_MODEL_CHINESE_NAME, + ) + + def tearDown(self): + self.api.delete_model(model_id=self.model_id) + + def prepare_repo_data(self): + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + self.repo = Repository(self.model_dir, clone_from=self.model_id) + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name)) + self.repo.push('add model') + self.repo.tag_and_push(self.revision, 'Test revision') + + def test_no_tag(self): + with self.assertRaises(NoValidRevisionError): + snapshot_download(self.model_id, None) + + with self.assertRaises(NoValidRevisionError): + model_file_download(self.model_id, ModelFile.README) + + def test_with_only_one_tag(self): + self.prepare_repo_data() + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, ModelFile.README, cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + + def add_new_file_and_tag(self): + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name2)) + self.repo.push('add new file') + self.repo.tag_and_push(self.revision2, 'Test revision') + + def test_snapshot_download_different_revision(self): + self.prepare_repo_data() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time stamp: %s' % t1) + snapshot_path = snapshot_download(self.model_id, self.revision) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + self.add_new_file_and_tag() + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=self.revision, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert not os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + + def test_file_download_different_revision(self): + self.prepare_repo_data() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time stamp: %s' % t1) + file_path = model_file_download(self.model_id, + download_model_file_name, + self.revision) + assert os.path.exists(file_path) + self.add_new_file_and_tag() + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=self.revision, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + with self.assertRaises(NotExistError): + model_file_download( + self.model_id, + download_model_file_name2, + revision=self.revision, + cache_dir=temp_cache_dir) + + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=self.revision2, + cache_dir=temp_cache_dir) + print('Downloaded file path: %s' % file_path) + assert os.path.exists(file_path) + file_path = model_file_download( + self.model_id, + download_model_file_name2, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_revision_release_mode.py b/tests/hub/test_hub_revision_release_mode.py new file mode 100644 index 00000000..73a0625e --- /dev/null +++ b/tests/hub/test_hub_revision_release_mode.py @@ -0,0 +1,270 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile +import time +import unittest +import uuid +from datetime import datetime +from unittest import mock + +from modelscope import version +from modelscope.hub.api import HubApi +from modelscope.hub.constants import (MODELSCOPE_SDK_DEBUG, Licenses, + ModelVisibility) +from modelscope.hub.errors import NotExistError +from modelscope.hub.file_download import model_file_download +from modelscope.hub.repository import Repository +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.utils.logger import get_logger +from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME, + TEST_MODEL_ORG) + +logger = get_logger() +logger.setLevel('DEBUG') +download_model_file_name = 'test.bin' +download_model_file_name2 = 'test2.bin' + + +class HubRevisionTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + self.api.login(TEST_ACCESS_TOKEN1) + self.model_name = 'rvr-%s' % (uuid.uuid4().hex) + self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + self.revision = 'v0.1_test_revision' + self.revision2 = 'v0.2_test_revision' + self.api.create_model( + model_id=self.model_id, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2, + chinese_name=TEST_MODEL_CHINESE_NAME, + ) + names_to_remove = {MODELSCOPE_SDK_DEBUG} + self.modified_environ = { + k: v + for k, v in os.environ.items() if k not in names_to_remove + } + + def tearDown(self): + self.api.delete_model(model_id=self.model_id) + + def prepare_repo_data(self): + temporary_dir = tempfile.mkdtemp() + self.model_dir = os.path.join(temporary_dir, self.model_name) + self.repo = Repository(self.model_dir, clone_from=self.model_id) + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name)) + self.repo.push('add model') + + def prepare_repo_data_and_tag(self): + self.prepare_repo_data() + self.repo.tag_and_push(self.revision, 'Test revision') + + def add_new_file_and_tag_to_repo(self): + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name2)) + self.repo.push('add new file') + self.repo.tag_and_push(self.revision2, 'Test revision') + + def add_new_file_and_branch_to_repo(self, branch_name): + os.system("echo 'testtest'>%s" + % os.path.join(self.model_dir, download_model_file_name2)) + self.repo.push('add new file', remote_branch=branch_name) + + def test_dev_mode_default_master(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data() # no tag, default get master + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + + def test_dev_mode_specify_branch(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data() # no tag, default get master + branch_name = 'test' + self.add_new_file_and_branch_to_repo(branch_name) + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=branch_name, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=branch_name, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + + def test_snapshot_download_revision(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data_and_tag() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time: %s' % t1) + time.sleep(10) + self.add_new_file_and_tag_to_repo() + t2 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('Second time: %s' % t2) + # set + release_datetime_backup = version.__release_datetime__ + logger.info('Origin __release_datetime__: %s' + % version.__release_datetime__) + try: + logger.info('Setting __release_datetime__ to: %s' % t1) + version.__release_datetime__ = t1 + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert not os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + version.__release_datetime__ = t2 + logger.info('Setting __release_datetime__ to: %s' % t2) + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + finally: + version.__release_datetime__ = release_datetime_backup + + def test_snapshot_download_revision_user_set_revision(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data_and_tag() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time: %s' % t1) + time.sleep(10) + self.add_new_file_and_tag_to_repo() + t2 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('Secnod time: %s' % t2) + # set + release_datetime_backup = version.__release_datetime__ + logger.info('Origin __release_datetime__: %s' + % version.__release_datetime__) + try: + logger.info('Setting __release_datetime__ to: %s' % t1) + version.__release_datetime__ = t1 + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=self.revision, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert not os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + finally: + version.__release_datetime__ = release_datetime_backup + + def test_file_download_revision(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data_and_tag() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time stamp: %s' % t1) + time.sleep(10) + self.add_new_file_and_tag_to_repo() + t2 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('Second time: %s' % t2) + release_datetime_backup = version.__release_datetime__ + logger.info('Origin __release_datetime__: %s' + % version.__release_datetime__) + try: + version.__release_datetime__ = t1 + logger.info('Setting __release_datetime__ to: %s' % t1) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + with self.assertRaises(NotExistError): + model_file_download( + self.model_id, + download_model_file_name2, + cache_dir=temp_cache_dir) + version.__release_datetime__ = t2 + logger.info('Setting __release_datetime__ to: %s' % t2) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + file_path = model_file_download( + self.model_id, + download_model_file_name2, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + finally: + version.__release_datetime__ = release_datetime_backup + + def test_file_download_revision_user_set_revision(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data_and_tag() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time stamp: %s' % t1) + time.sleep(10) + self.add_new_file_and_tag_to_repo() + t2 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('Second time: %s' % t2) + release_datetime_backup = version.__release_datetime__ + logger.info('Origin __release_datetime__: %s' + % version.__release_datetime__) + try: + version.__release_datetime__ = t1 + logger.info('Setting __release_datetime__ to: %s' % t1) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=self.revision, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + with self.assertRaises(NotExistError): + model_file_download( + self.model_id, + download_model_file_name2, + revision=self.revision, + cache_dir=temp_cache_dir) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + file_path = model_file_download( + self.model_id, + download_model_file_name2, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + finally: + version.__release_datetime__ = release_datetime_backup + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_hub_upload.py b/tests/hub/test_hub_upload.py new file mode 100644 index 00000000..835aa62b --- /dev/null +++ b/tests/hub/test_hub_upload.py @@ -0,0 +1,156 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import uuid + +from modelscope.hub.api import HubApi +from modelscope.hub.constants import Licenses, ModelVisibility +from modelscope.hub.errors import GitError, HTTPError, NotLoginException +from modelscope.hub.repository import Repository +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level +from .test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG, delete_credential + +logger = get_logger() + + +class HubUploadTest(unittest.TestCase): + + def setUp(self): + logger.info('SetUp') + self.api = HubApi() + self.user = TEST_MODEL_ORG + logger.info(self.user) + self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload', + uuid.uuid4().hex) + logger.info('create %s' % self.create_model_name) + temporary_dir = tempfile.mkdtemp() + self.work_dir = temporary_dir + self.model_dir = os.path.join(temporary_dir, self.create_model_name) + self.finetune_path = os.path.join(self.work_dir, 'finetune_path') + self.repo_path = os.path.join(self.work_dir, 'repo_path') + os.mkdir(self.finetune_path) + os.system("echo '{}'>%s" + % os.path.join(self.finetune_path, ModelFile.CONFIGURATION)) + + def tearDown(self): + logger.info('TearDown') + shutil.rmtree(self.model_dir, ignore_errors=True) + try: + self.api.delete_model(model_id=self.create_model_name) + except Exception: + pass + + def test_upload_exits_repo_master(self): + logger.info('basic test for upload!') + self.api.login(TEST_ACCESS_TOKEN1) + self.api.create_model( + model_id=self.create_model_name, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + os.system("echo '111'>%s" + % os.path.join(self.finetune_path, 'add1.py')) + self.api.push_model( + model_id=self.create_model_name, model_dir=self.finetune_path) + Repository(model_dir=self.repo_path, clone_from=self.create_model_name) + assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + os.system("echo '222'>%s" + % os.path.join(self.finetune_path, 'add2.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version1') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version1') + assert os.path.exists(os.path.join(self.repo_path, 'add2.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + os.system("echo '333'>%s" + % os.path.join(self.finetune_path, 'add3.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version2', + commit_message='add add3.py') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version2') + assert os.path.exists(os.path.join(self.repo_path, 'add2.py')) + assert os.path.exists(os.path.join(self.repo_path, 'add3.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + add4_path = os.path.join(self.finetune_path, 'temp') + os.mkdir(add4_path) + os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version1') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version1') + assert os.path.exists(os.path.join(add4_path, 'add4.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + assert os.path.exists(os.path.join(self.finetune_path, 'add3.py')) + os.remove(os.path.join(self.finetune_path, 'add3.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_revision/version1') + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_revision/version1') + assert not os.path.exists(os.path.join(self.repo_path, 'add3.py')) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_upload_non_exists_repo(self): + logger.info('test upload non exists repo!') + self.api.login(TEST_ACCESS_TOKEN1) + os.system("echo '111'>%s" + % os.path.join(self.finetune_path, 'add1.py')) + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + revision='new_model_new_revision', + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + Repository( + model_dir=self.repo_path, + clone_from=self.create_model_name, + revision='new_model_new_revision') + assert os.path.exists(os.path.join(self.repo_path, 'add1.py')) + shutil.rmtree(self.repo_path, ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_upload_without_token(self): + logger.info('test upload without login!') + self.api.login(TEST_ACCESS_TOKEN1) + delete_credential() + with self.assertRaises(NotLoginException): + self.api.push_model( + model_id=self.create_model_name, + model_dir=self.finetune_path, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_upload_invalid_repo(self): + logger.info('test upload to invalid repo!') + self.api.login(TEST_ACCESS_TOKEN1) + with self.assertRaises((HTTPError, GitError)): + self.api.push_model( + model_id='%s/%s' % ('speech_tts', 'invalid_model_test'), + model_dir=self.finetune_path, + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/hub/test_utils.py b/tests/hub/test_utils.py new file mode 100644 index 00000000..3d312dc0 --- /dev/null +++ b/tests/hub/test_utils.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +from codecs import ignore_errors +from os.path import expanduser + +from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH + +# for user citest and sdkdev +TEST_ACCESS_TOKEN1 = os.environ['TEST_ACCESS_TOKEN_CITEST'] +TEST_ACCESS_TOKEN2 = os.environ['TEST_ACCESS_TOKEN_SDKDEV'] + +TEST_MODEL_CHINESE_NAME = '内部测试模型' +TEST_MODEL_ORG = 'citest' + + +def delete_credential(): + path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) + shutil.rmtree(path_credential, ignore_errors=True) diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/metrics/test_text_classification_metrics.py b/tests/metrics/test_text_classification_metrics.py new file mode 100644 index 00000000..d0a4cee1 --- /dev/null +++ b/tests/metrics/test_text_classification_metrics.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.metrics.sequence_classification_metric import \ + SequenceClassificationMetric +from modelscope.utils.test_utils import test_level + + +class TestTextClsMetrics(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_value(self): + metric = SequenceClassificationMetric() + outputs = { + 'logits': + np.array([[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0], + [2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7], + [2.0, 1.0, 0.5], [2.4, 1.5, 0.5]]) + } + inputs = {'labels': np.array([0, 1, 2, 2, 0, 1, 2, 2])} + metric.add(outputs, inputs) + ret = metric.evaluate() + self.assertTrue(np.isclose(ret['f1'], 0.5)) + self.assertTrue(np.isclose(ret['accuracy'], 0.5)) + print(ret) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/metrics/test_token_classification_metrics.py b/tests/metrics/test_token_classification_metrics.py new file mode 100644 index 00000000..b249b227 --- /dev/null +++ b/tests/metrics/test_token_classification_metrics.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.metrics.token_classification_metric import \ + TokenClassificationMetric +from modelscope.utils.test_utils import test_level + + +class TestTokenClsMetrics(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_value(self): + metric = TokenClassificationMetric() + + class Trainer: + pass + + metric.trainer = Trainer() + metric.trainer.label2id = { + 'B-obj': 0, + 'I-obj': 1, + 'O': 2, + } + + outputs = { + 'logits': + np.array([[[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0], + [2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7], + [2.0, 1.0, 0.5], [2.4, 1.5, 0.5]]]) + } + inputs = {'labels': np.array([[0, 1, 2, 2, 0, 1, 2, 2]])} + metric.add(outputs, inputs) + ret = metric.evaluate() + self.assertTrue(np.isclose(ret['precision'], 0.25)) + self.assertTrue(np.isclose(ret['recall'], 0.5)) + self.assertTrue(np.isclose(ret['accuracy'], 0.5)) + print(ret) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/test_base_torch.py b/tests/models/test_base_torch.py new file mode 100644 index 00000000..c147259b --- /dev/null +++ b/tests/models/test_base_torch.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.base import TorchModel + + +class TorchBaseTest(unittest.TestCase): + + def test_custom_model(self): + + class MyTorchModel(TorchModel): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, input): + x = F.relu(self.conv1(input)) + return F.relu(self.conv2(x)) + + model = MyTorchModel() + model.train() + model.eval() + out = model.forward(torch.rand(1, 1, 10, 10)) + self.assertEqual((1, 20, 2, 2), out.shape) + + def test_custom_model_with_postprocess(self): + add_bias = 200 + + class MyTorchModel(TorchModel): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, input): + x = F.relu(self.conv1(input)) + return F.relu(self.conv2(x)) + + def postprocess(self, x): + return x + add_bias + + model = MyTorchModel() + model.train() + model.eval() + out = model(torch.rand(1, 1, 10, 10)) + self.assertEqual((1, 20, 2, 2), out.shape) + self.assertTrue(np.all(out.detach().numpy() > (add_bias - 10))) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/models/test_deberta_v2_backbone.py b/tests/models/test_deberta_v2_backbone.py new file mode 100644 index 00000000..706b18f8 --- /dev/null +++ b/tests/models/test_deberta_v2_backbone.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.models import Model +from modelscope.models.nlp.deberta_v2 import (DebertaV2ForMaskedLM, + DebertaV2Model) +from modelscope.utils.constant import Tasks + + +class DebertaV2BackboneTest(unittest.TestCase): + + def test_load_model(self): + model = Model.from_pretrained( + 'damo/nlp_debertav2_fill-mask_chinese-lite') + self.assertTrue(model.__class__ == DebertaV2ForMaskedLM) + model = Model.from_pretrained( + 'damo/nlp_debertav2_fill-mask_chinese-lite', task=Tasks.backbone) + self.assertTrue(model.__class__ == DebertaV2Model) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/msdatasets/__init__.py b/tests/msdatasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/msdatasets/test_dataset_delete.py b/tests/msdatasets/test_dataset_delete.py new file mode 100644 index 00000000..8b3c2426 --- /dev/null +++ b/tests/msdatasets/test_dataset_delete.py @@ -0,0 +1,112 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile + +from modelscope.msdatasets import MsDataset +from modelscope.utils import logger as logging +from modelscope.utils.test_utils import test_level + +logger = logging.get_logger(__name__) + +KEY_EXTRACTED = 'extracted' +EXPECTED_MSG = 'success' + + +class DatasetDeleteTest(unittest.TestCase): + + def setUp(self): + self.old_dir = os.getcwd() + self.dataset_name = 'small_coco_for_test' + self.dataset_file_name = self.dataset_name + self.prepared_dataset_name = 'pets_small' + self.token = os.getenv('TEST_UPLOAD_MS_TOKEN') + error_msg = 'The modelscope token can not be empty, please set env variable: TEST_UPLOAD_MS_TOKEN' + self.assertIsNotNone(self.token, msg=error_msg) + from modelscope.hub.api import HubApi + from modelscope.hub.api import ModelScopeConfig + self.api = HubApi() + self.api.login(self.token) + + # get user info + self.namespace, _ = ModelScopeConfig.get_user_info() + + self.temp_dir = tempfile.mkdtemp() + self.test_work_dir = os.path.join(self.temp_dir, self.dataset_name) + if not os.path.exists(self.test_work_dir): + os.makedirs(self.test_work_dir) + + def tearDown(self): + os.chdir(self.old_dir) + shutil.rmtree(self.temp_dir, ignore_errors=True) + logger.info( + f'Temporary directory {self.temp_dir} successfully removed!') + + @staticmethod + def get_raw_downloaded_file_path(extracted_path): + raw_downloaded_file_path = '' + raw_data_dir = os.path.abspath( + os.path.join(extracted_path, '../../..')) + for root, dirs, files in os.walk(raw_data_dir): + if KEY_EXTRACTED in dirs: + for file in files: + curr_file_path = os.path.join(root, file) + if zipfile.is_zipfile(curr_file_path): + raw_downloaded_file_path = curr_file_path + return raw_downloaded_file_path + + def upload_test_file(self): + # Get the prepared data from hub, using default modelscope namespace + ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') + config_res = ms_ds_train._hf_ds.config_kwargs + extracted_path = config_res.get('split_config').get('train') + raw_zipfile_path = self.get_raw_downloaded_file_path(extracted_path) + + object_name = self.dataset_file_name + '_for_del.zip' + MsDataset.upload( + object_name=object_name, + local_file_path=raw_zipfile_path, + dataset_name=self.dataset_name, + namespace=self.namespace) + + return object_name + + def upload_test_dir(self): + ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') + config_train = ms_ds_train._hf_ds.config_kwargs + extracted_path_train = config_train.get('split_config').get('train') + + object_name = 'train_for_del' + MsDataset.upload( + object_name=object_name, + local_file_path=os.path.join(extracted_path_train, + 'Pets/images/train'), + dataset_name=self.dataset_name, + namespace=self.namespace) + + return object_name + '/' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_delete_object(self): + + # upload prepared data + file_name = self.upload_test_file() + dir_name = self.upload_test_dir() + + # delete object + del_file_msg = MsDataset.delete( + object_name=file_name, + dataset_name=self.dataset_name, + namespace=self.namespace) + del_dir_msg = MsDataset.delete( + object_name=dir_name, + dataset_name=self.dataset_name, + namespace=self.namespace) + + assert all([del_file_msg == EXPECTED_MSG, del_dir_msg == EXPECTED_MSG]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/msdatasets/test_dataset_upload.py b/tests/msdatasets/test_dataset_upload.py new file mode 100644 index 00000000..d91f24d7 --- /dev/null +++ b/tests/msdatasets/test_dataset_upload.py @@ -0,0 +1,137 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile + +from modelscope.msdatasets import MsDataset +from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects +from modelscope.utils import logger as logging +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DownloadMode, + ModelFile) +from modelscope.utils.test_utils import test_level + +logger = logging.get_logger(__name__) + +KEY_EXTRACTED = 'extracted' + + +class DatasetUploadTest(unittest.TestCase): + + def setUp(self): + self.old_dir = os.getcwd() + self.dataset_name = 'small_coco_for_test' + self.dataset_file_name = self.dataset_name + self.prepared_dataset_name = 'pets_small' + self.token = os.getenv('TEST_UPLOAD_MS_TOKEN') + error_msg = 'The modelscope token can not be empty, please set env variable: TEST_UPLOAD_MS_TOKEN' + self.assertIsNotNone(self.token, msg=error_msg) + from modelscope.hub.api import HubApi + from modelscope.hub.api import ModelScopeConfig + self.api = HubApi() + self.api.login(self.token) + + # get user info + self.namespace, _ = ModelScopeConfig.get_user_info() + + self.temp_dir = tempfile.mkdtemp() + self.test_work_dir = os.path.join(self.temp_dir, self.dataset_name) + self.test_meta_dir = os.path.join(self.test_work_dir, 'meta') + if not os.path.exists(self.test_work_dir): + os.makedirs(self.test_work_dir) + + def tearDown(self): + os.chdir(self.old_dir) + shutil.rmtree(self.temp_dir, ignore_errors=True) + logger.info( + f'Temporary directory {self.temp_dir} successfully removed!') + + @staticmethod + def get_raw_downloaded_file_path(extracted_path): + raw_downloaded_file_path = '' + raw_data_dir = os.path.abspath( + os.path.join(extracted_path, '../../..')) + for root, dirs, files in os.walk(raw_data_dir): + if KEY_EXTRACTED in dirs: + for file in files: + curr_file_path = os.path.join(root, file) + if zipfile.is_zipfile(curr_file_path): + raw_downloaded_file_path = curr_file_path + return raw_downloaded_file_path + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_upload(self): + # Get the prepared data from hub, using default modelscope namespace + ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') + config_res = ms_ds_train._hf_ds.config_kwargs + extracted_path = config_res.get('split_config').get('train') + raw_zipfile_path = self.get_raw_downloaded_file_path(extracted_path) + + MsDataset.upload( + object_name=self.dataset_file_name + '.zip', + local_file_path=raw_zipfile_path, + dataset_name=self.dataset_name, + namespace=self.namespace) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_upload_dir(self): + ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train') + config_train = ms_ds_train._hf_ds.config_kwargs + extracted_path_train = config_train.get('split_config').get('train') + + MsDataset.upload( + object_name='train', + local_file_path=os.path.join(extracted_path_train, + 'Pets/images/train'), + dataset_name=self.dataset_name, + namespace=self.namespace) + MsDataset.upload( + object_name='val', + local_file_path=os.path.join(extracted_path_train, + 'Pets/images/val'), + dataset_name=self.dataset_name, + namespace=self.namespace) + + objects = list_dataset_objects( + hub_api=self.api, + max_limit=-1, + is_recursive=True, + dataset_name=self.dataset_name, + namespace=self.namespace, + version=DEFAULT_DATASET_REVISION) + + logger.info(f'{len(objects)} objects have been uploaded: {objects}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_download_dir(self): + test_ds = MsDataset.load( + self.dataset_name, + namespace=self.namespace, + download_mode=DownloadMode.FORCE_REDOWNLOAD) + assert test_ds.config_kwargs['split_config'].values() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_clone_meta(self): + MsDataset.clone_meta( + dataset_work_dir=self.test_meta_dir, + dataset_id=os.path.join(self.namespace, self.dataset_name)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_upload_meta(self): + # Clone dataset meta repo first. + MsDataset.clone_meta( + dataset_work_dir=self.test_meta_dir, + dataset_id=os.path.join(self.namespace, self.dataset_name)) + + with open(os.path.join(self.test_meta_dir, ModelFile.README), + 'a') as f: + f.write('\nThis is a line for unit test.') + + MsDataset.upload_meta( + dataset_work_dir=self.test_meta_dir, + commit_message='Update for unit test.') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py new file mode 100644 index 00000000..dff411f6 --- /dev/null +++ b/tests/msdatasets/test_ms_dataset.py @@ -0,0 +1,142 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.models import Model +from modelscope.msdatasets import MsDataset +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.preprocessors.base import Preprocessor +from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DownloadMode +from modelscope.utils.test_utils import require_tf, require_torch, test_level + + +class ImgPreprocessor(Preprocessor): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.path_field = kwargs.pop('image_path', 'image_path') + self.width = kwargs.pop('width', 'width') + self.height = kwargs.pop('height', 'width') + + def __call__(self, data): + import cv2 + image_path = data.get(self.path_field) + if not image_path: + return None + img = cv2.imread(image_path) + return { + 'image': + cv2.resize(img, + (data.get(self.height, 128), data.get(self.width, 128))) + } + + +class MsDatasetTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_movie_scene_seg_toydata(self): + ms_ds_train = MsDataset.load('movie_scene_seg_toydata', split='train') + print(ms_ds_train._hf_ds.config_kwargs) + assert next(iter(ms_ds_train.config_kwargs['split_config'].values())) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_coco(self): + ms_ds_train = MsDataset.load( + 'pets_small', + namespace=DEFAULT_DATASET_NAMESPACE, + download_mode=DownloadMode.FORCE_REDOWNLOAD, + split='train') + print(ms_ds_train.config_kwargs) + assert next(iter(ms_ds_train.config_kwargs['split_config'].values())) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ms_csv_basic(self): + ms_ds_train = MsDataset.load( + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(5)) + print(next(iter(ms_ds_train))) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ds_basic(self): + ms_ds_full = MsDataset.load( + 'xcopa', subset_name='translation-et', namespace='damotest') + ms_ds = MsDataset.load( + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test') + print(next(iter(ms_ds_full['test']))) + print(next(iter(ms_ds))) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @require_torch + def test_to_torch_dataset_text(self): + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' + nlp_model = Model.from_pretrained(model_id) + preprocessor = SequenceClassificationPreprocessor( + nlp_model.model_dir, + first_sequence='premise', + second_sequence=None, + padding='max_length') + ms_ds_train = MsDataset.load( + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test') + pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor) + import torch + dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5) + print(next(iter(dataloader))) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @require_tf + def test_to_tf_dataset_text(self): + import tensorflow as tf + tf.compat.v1.enable_eager_execution() + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' + nlp_model = Model.from_pretrained(model_id) + preprocessor = SequenceClassificationPreprocessor( + nlp_model.model_dir, + first_sequence='premise', + second_sequence=None) + ms_ds_train = MsDataset.load( + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test') + tf_dataset = ms_ds_train.to_tf_dataset( + batch_size=5, + shuffle=True, + preprocessors=preprocessor, + drop_remainder=True) + print(next(iter(tf_dataset))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @require_torch + def test_to_torch_dataset_img(self): + ms_image_train = MsDataset.load( + 'fixtures_image_utils', namespace='damotest', split='test') + pt_dataset = ms_image_train.to_torch_dataset( + preprocessors=ImgPreprocessor(image_path='file')) + import torch + dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5) + print(next(iter(dataloader))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @require_tf + def test_to_tf_dataset_img(self): + import tensorflow as tf + tf.compat.v1.enable_eager_execution() + ms_image_train = MsDataset.load( + 'fixtures_image_utils', namespace='damotest', split='test') + tf_dataset = ms_image_train.to_tf_dataset( + batch_size=5, + shuffle=True, + preprocessors=ImgPreprocessor(image_path='file'), + drop_remainder=True, + ) + print(next(iter(tf_dataset))) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/outputs/__init__.py b/tests/outputs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/outputs/test_model_outputs.py b/tests/outputs/test_model_outputs.py new file mode 100644 index 00000000..311ce201 --- /dev/null +++ b/tests/outputs/test_model_outputs.py @@ -0,0 +1,31 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import torch + +from modelscope.outputs import TextClassificationModelOutput +from modelscope.utils.test_utils import test_level + + +class TestModelOutput(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_model_outputs(self): + outputs = TextClassificationModelOutput(logits=torch.Tensor([1])) + self.assertEqual(outputs['logits'], torch.Tensor([1])) + self.assertEqual(outputs[0], torch.Tensor([1])) + self.assertEqual(outputs.logits, torch.Tensor([1])) + outputs.loss = torch.Tensor([2]) + logits, loss = outputs + self.assertEqual(logits, torch.Tensor([1])) + self.assertTrue(loss is not None) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/__init__.py b/tests/pipelines/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pipelines/easycv_pipelines/__init__.py b/tests/pipelines/easycv_pipelines/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py b/tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py new file mode 100644 index 00000000..5f6dac4b --- /dev/null +++ b/tests/pipelines/easycv_pipelines/test_segmentation_pipeline.py @@ -0,0 +1,88 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest +from distutils.version import LooseVersion + +import cv2 +import easycv +import numpy as np +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import semantic_seg_masks_to_image +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class EasyCVSegmentationPipelineTest(unittest.TestCase, + DemoCompatibilityCheck): + img_path = 'data/test/images/image_segmentation.jpg' + + def setUp(self) -> None: + self.task = Tasks.image_segmentation + self.model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k' + + def _internal_test_(self, model_id): + semantic_seg = pipeline(task=Tasks.image_segmentation, model=model_id) + outputs = semantic_seg(self.img_path) + + draw_img = semantic_seg_masks_to_image(outputs[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test ' + model_id + ' DONE') + + def _internal_test_batch_(self, model_id, num_samples=2, batch_size=2): + # TODO: support in the future + img = np.asarray(Image.open(self.img_path)) + num_samples = num_samples + batch_size = batch_size + semantic_seg = pipeline( + task=Tasks.image_segmentation, + model=model_id, + batch_size=batch_size) + outputs = semantic_seg([self.img_path] * num_samples) + + self.assertEqual(semantic_seg.predict_op.batch_size, batch_size) + self.assertEqual(len(outputs), num_samples) + + for output in outputs: + self.assertListEqual( + list(img.shape)[:2], list(output['seg_pred'].shape)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_segformer_b0(self): + model_id = 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k' + self._internal_test_(model_id) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_segformer_b1(self): + model_id = 'damo/cv_segformer-b1_image_semantic-segmentation_coco-stuff164k' + self._internal_test_(model_id) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_segformer_b2(self): + model_id = 'damo/cv_segformer-b2_image_semantic-segmentation_coco-stuff164k' + self._internal_test_(model_id) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_segformer_b3(self): + model_id = 'damo/cv_segformer-b3_image_semantic-segmentation_coco-stuff164k' + self._internal_test_(model_id) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_segformer_b4(self): + model_id = 'damo/cv_segformer-b4_image_semantic-segmentation_coco-stuff164k' + self._internal_test_(model_id) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_segformer_b5(self): + model_id = 'damo/cv_segformer-b5_image_semantic-segmentation_coco-stuff164k' + self._internal_test_(model_id) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/nlp/test_faq.py b/tests/pipelines/nlp/test_faq.py new file mode 100644 index 00000000..8bac55d4 --- /dev/null +++ b/tests/pipelines/nlp/test_faq.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForFaqRanking, SbertForFaqRetrieval +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import FaqPipeline +from modelscope.preprocessors import FaqPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class FaqTest(unittest.TestCase): + model_id = '/Users/tanfan/Desktop/Workdir/Gitlab/maas/MaaS-lib/.faq_test_model' + param = { + 'query_set': ['明天星期几', '今天星期六', '今天星期六'], + 'support_set': [{ + 'text': '今天星期六', + 'label': 'label0' + }, { + 'text': '明天星期几', + 'label': 'label1' + }] + } + + # @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + # def test_run_with_direct_file_download(self): + # cache_path = self.model_id # snapshot_download(self.model_id) + # preprocessor = FaqPreprocessor(cache_path) + # model = SbertForFaq(cache_path) + # pipeline_ins = FaqPipeline(model, preprocessor=preprocessor) + # + # result = pipeline_ins(self.param) + # print(result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = FaqPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.faq, model=model, preprocessor=preprocessor) + result = pipeline_ins(self.param) + print(result) + + # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + # def test_run_with_model_name(self): + # pipeline_ins = pipeline(task=Tasks.faq, model=self.model_id) + # result = pipeline_ins(self.param) + # print(result) + + # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + # def test_run_with_default_model(self): + # pipeline_ins = pipeline(task=Tasks.faq) + # print(pipeline_ins(self.param)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_action_detection.py b/tests/pipelines/test_action_detection.py new file mode 100644 index 00000000..ae7e60b1 --- /dev/null +++ b/tests/pipelines/test_action_detection.py @@ -0,0 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ActionDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.action_detection + self.model_id = 'damo/cv_ResNetC3D_action-detection_detection2d' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + action_detection_pipline = pipeline(self.task, model=self.model_id) + result = action_detection_pipline( + 'data/test/videos/action_detection_test_video.mp4') + print('action detection results:', result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_action_recognition.py b/tests/pipelines/test_action_recognition.py new file mode 100644 index 00000000..292eb238 --- /dev/null +++ b/tests/pipelines/test_action_recognition.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ActionRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.action_recognition + self.model_id = 'damo/cv_TAdaConv_action-recognition' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + recognition_pipeline = pipeline(self.task, self.model_id) + result = recognition_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + + print(f'recognition output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + recognition_pipeline = pipeline(self.task) + result = recognition_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + + print(f'recognition output: {result}.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_pst(self): + pst_recognition_pipeline = pipeline( + self.task, model='damo/cv_pathshift_action-recognition') + result = pst_recognition_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + print('pst recognition results:', result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_animal_recognition.py b/tests/pipelines/test_animal_recognition.py new file mode 100644 index 00000000..eb9f92e6 --- /dev/null +++ b/tests/pipelines/test_animal_recognition.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class AnimalRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.animal_recognition + self.model_id = 'damo/cv_resnest101_animal_recognition' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + animal_recognition = pipeline( + Tasks.animal_recognition, model=self.model_id) + result = animal_recognition('data/test/images/dogs.jpg') + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_automatic_post_editing.py b/tests/pipelines/test_automatic_post_editing.py new file mode 100644 index 00000000..da09851c --- /dev/null +++ b/tests/pipelines/test_automatic_post_editing.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class AutomaticPostEditingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.translation + self.model_id = 'damo/nlp_automatic_post_editing_for_translation_en2de' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2de(self): + inputs = 'Simultaneously, the Legion took part to the pacification of Algeria, plagued by various tribal ' \ + 'rebellions and razzias.\005Gleichzeitig nahm die Legion an der Befriedung Algeriens teil, die von ' \ + 'verschiedenen Stammesaufständen und Rasias heimgesucht wurde.' + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_automatic_speech_recognition.py b/tests/pipelines/test_automatic_speech_recognition.py new file mode 100644 index 00000000..b6532868 --- /dev/null +++ b/tests/pipelines/test_automatic_speech_recognition.py @@ -0,0 +1,403 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest +from typing import Any, Dict, Union + +import numpy as np +import soundfile + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ColorCodes, Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import download_and_untar, test_level + +logger = get_logger() + +WAV_FILE = 'data/test/audios/asr_example.wav' +URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' + +LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz' +LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz' + +TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz' +TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz' + + +class AutomaticSpeechRecognitionTest(unittest.TestCase, + DemoCompatibilityCheck): + action_info = { + 'test_run_with_wav_pytorch': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_pcm_pytorch': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_wav_tf': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_pcm_tf': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_url_pytorch': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_url_tf': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, + 'test_run_with_wav_dataset_pytorch': { + 'checking_item': OutputKeys.TEXT, + 'example': 'dataset_example' + }, + 'test_run_with_wav_dataset_tf': { + 'checking_item': OutputKeys.TEXT, + 'example': 'dataset_example' + }, + 'dataset_example': { + 'Wrd': 49532, # the number of words + 'Snt': 5000, # the number of sentences + 'Corr': 47276, # the number of correct words + 'Ins': 49, # the number of insert words + 'Del': 152, # the number of delete words + 'Sub': 2207, # the number of substitution words + 'wrong_words': 2408, # the number of wrong words + 'wrong_sentences': 1598, # the number of wrong sentences + 'Err': 4.86, # WER/CER + 'S.Err': 31.96 # SER + }, + 'wav_example': { + 'text': '每一天都要快乐喔' + } + } + + all_models_info = [ + { + 'model_id': + 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': 'damo/speech_paraformer_asr_nat-aishell1-pytorch', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': 'damo/speech_paraformer_asr_nat-aishell2-pytorch', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1', + 'wav_path': 'data/test/audios/asr_example_8K.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_8K.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_8K.wav' + }, + { + 'model_id': + 'damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_cn_en.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_cn_en.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' + }, + { + 'model_id': + 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example.wav' + }, + { + 'model_id': + 'damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_8K.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_en.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_en.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_ru.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_ru.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_es.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_es.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_ko.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_ko.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_ja.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_ja.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online', + 'wav_path': 'data/test/audios/asr_example_id.wav' + }, + { + 'model_id': + 'damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline', + 'wav_path': 'data/test/audios/asr_example_id.wav' + }, + ] + + def setUp(self) -> None: + self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch' + self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1' + # this temporary workspace dir will store waveform files + self.workspace = os.path.join(os.getcwd(), '.tmp') + self.task = Tasks.auto_speech_recognition + if not os.path.exists(self.workspace): + os.mkdir(self.workspace) + + def tearDown(self) -> None: + # remove workspace dir (.tmp) + shutil.rmtree(self.workspace, ignore_errors=True) + + def run_pipeline(self, + model_id: str, + audio_in: Union[str, bytes], + sr: int = None) -> Dict[str, Any]: + inference_16k_pipline = pipeline( + task=Tasks.auto_speech_recognition, model=model_id) + + rec_result = inference_16k_pipline(audio_in, audio_fs=sr) + + return rec_result + + def log_error(self, functions: str, result: Dict[str, Any]) -> None: + logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' + + ColorCodes.END) + logger.error( + ColorCodes.MAGENTA + functions + ' correct result example:' + + ColorCodes.YELLOW + + str(self.action_info[self.action_info[functions]['example']]) + + ColorCodes.END) + + raise ValueError('asr result is mismatched') + + def check_result(self, functions: str, result: Dict[str, Any]) -> None: + if result.__contains__(self.action_info[functions]['checking_item']): + logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' + + ColorCodes.END) + logger.info( + ColorCodes.YELLOW + + str(result[self.action_info[functions]['checking_item']]) + + ColorCodes.END) + else: + self.log_error(functions, result) + + def wav2bytes(self, wav_file): + audio, fs = soundfile.read(wav_file) + + # float32 -> int16 + audio = np.asarray(audio) + dtype = np.dtype('int16') + i = np.iinfo(dtype) + abs_max = 2**(i.bits - 1) + offset = i.min + abs_max + audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) + + # int16(PCM_16) -> byte + audio = audio.tobytes() + return audio, fs + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_pcm(self): + """run with wav data + """ + + logger.info('Run ASR test with wav data (tensorflow)...') + + audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE)) + + rec_result = self.run_pipeline( + model_id=self.am_tf_model_id, audio_in=audio, sr=sr) + self.check_result('test_run_with_pcm_tf', rec_result) + + logger.info('Run ASR test with wav data (pytorch)...') + + rec_result = self.run_pipeline( + model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr) + self.check_result('test_run_with_pcm_pytorch', rec_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav(self): + """run with single waveform file + """ + + logger.info('Run ASR test with waveform file (tensorflow)...') + + wav_file_path = os.path.join(os.getcwd(), WAV_FILE) + + rec_result = self.run_pipeline( + model_id=self.am_tf_model_id, audio_in=wav_file_path) + self.check_result('test_run_with_wav_tf', rec_result) + + logger.info('Run ASR test with waveform file (pytorch)...') + + rec_result = self.run_pipeline( + model_id=self.am_pytorch_model_id, audio_in=wav_file_path) + self.check_result('test_run_with_wav_pytorch', rec_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_url(self): + """run with single url file + """ + + logger.info('Run ASR test with url file (tensorflow)...') + + rec_result = self.run_pipeline( + model_id=self.am_tf_model_id, audio_in=URL_FILE) + self.check_result('test_run_with_url_tf', rec_result) + + logger.info('Run ASR test with url file (pytorch)...') + + rec_result = self.run_pipeline( + model_id=self.am_pytorch_model_id, audio_in=URL_FILE) + self.check_result('test_run_with_url_pytorch', rec_result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_wav_dataset_pytorch(self): + """run with datasets, and audio format is waveform + datasets directory: + + wav + test # testsets + xx.wav + ... + dev # devsets + yy.wav + ... + train # trainsets + zz.wav + ... + transcript + data.text # hypothesis text + """ + + logger.info('Downloading waveform testsets file ...') + + dataset_path = download_and_untar( + os.path.join(self.workspace, LITTLE_TESTSETS_FILE), + LITTLE_TESTSETS_URL, self.workspace) + dataset_path = os.path.join(dataset_path, 'wav', 'test') + + logger.info('Run ASR test with waveform dataset (tensorflow)...') + + rec_result = self.run_pipeline( + model_id=self.am_tf_model_id, audio_in=dataset_path) + self.check_result('test_run_with_wav_dataset_tf', rec_result) + + logger.info('Run ASR test with waveform dataset (pytorch)...') + + rec_result = self.run_pipeline( + model_id=self.am_pytorch_model_id, audio_in=dataset_path) + self.check_result('test_run_with_wav_dataset_pytorch', rec_result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_all_models(self): + """run with all models + """ + + logger.info('Run ASR test with all models') + + for item in self.all_models_info: + model_id = item['model_id'] + wav_path = item['wav_path'] + rec_result = self.run_pipeline( + model_id=model_id, audio_in=wav_path) + if rec_result.__contains__(OutputKeys.TEXT): + logger.info(ColorCodes.MAGENTA + str(item['model_id']) + ' ' + + ColorCodes.YELLOW + + str(rec_result[OutputKeys.TEXT]) + + ColorCodes.END) + else: + logger.info(ColorCodes.MAGENTA + str(rec_result) + + ColorCodes.END) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_base.py b/tests/pipelines/test_base.py new file mode 100644 index 00000000..b60813c8 --- /dev/null +++ b/tests/pipelines/test_base.py @@ -0,0 +1,95 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest +from typing import Any, Dict, Union + +import numpy as np +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import Pipeline, pipeline +from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info +from modelscope.utils.logger import get_logger + +logger = get_logger() + +Input = Union[str, 'PIL.Image', 'numpy.ndarray'] + + +class CustomPipelineTest(unittest.TestCase): + + def test_abstract(self): + + @PIPELINES.register_module() + class CustomPipeline1(Pipeline): + + def __init__(self, + config_file: str = None, + model=None, + preprocessor=None, + **kwargs): + super().__init__(config_file, model, preprocessor, **kwargs) + + with self.assertRaises(TypeError): + CustomPipeline1() + + def test_custom(self): + dummy_task = 'dummy-task' + + @PIPELINES.register_module( + group_key=dummy_task, module_name='custom-image') + class CustomImagePipeline(Pipeline): + + def __init__(self, + config_file: str = None, + model=None, + preprocessor=None, + **kwargs): + super().__init__(config_file, model, preprocessor, **kwargs) + + def preprocess(self, input: Union[str, + 'PIL.Image']) -> Dict[str, Any]: + """ Provide default implementation based on preprocess_cfg and user can reimplement it + + """ + if not isinstance(input, Image.Image): + from modelscope.preprocessors import load_image + data_dict = {'img': load_image(input), 'url': input} + else: + data_dict = {'img': input} + return data_dict + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ Provide default implementation using self.model and user can reimplement it + """ + outputs = {} + if 'url' in inputs: + outputs['filename'] = inputs['url'] + img = inputs['img'] + new_image = img.resize((img.width // 2, img.height // 2)) + outputs[OutputKeys.OUTPUT_IMG] = np.array(new_image) + return outputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + self.assertTrue('custom-image' in PIPELINES.modules[dummy_task]) + add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) + pipe = pipeline(task=dummy_task, pipeline_name='custom-image') + pipe2 = pipeline(dummy_task) + self.assertTrue(type(pipe) is type(pipe2)) + + img_url = 'data/test/images/dogs.jpg' + output = pipe(img_url) + self.assertEqual(output['filename'], img_url) + self.assertEqual(output[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3)) + + outputs = pipe([img_url for i in range(4)]) + self.assertEqual(len(outputs), 4) + for out in outputs: + self.assertEqual(out['filename'], img_url) + self.assertEqual(out[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_body_2d_keypoints.py b/tests/pipelines/test_body_2d_keypoints.py new file mode 100644 index 00000000..5d90cbf0 --- /dev/null +++ b/tests/pipelines/test_body_2d_keypoints.py @@ -0,0 +1,43 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +from PIL import Image + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_keypoints +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class Body2DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.body_2d_keypoints + self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' + self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' + + def pipeline_inference(self, pipeline: Pipeline, pipeline_input): + output = pipeline(pipeline_input) + image = draw_keypoints(output, self.test_image) + cv2.imwrite('pose_keypoint.jpg', image) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub_with_image_file(self): + body_2d_keypoints = pipeline(self.task, model=self.model_id) + self.pipeline_inference(body_2d_keypoints, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_with_image_input(self): + body_2d_keypoints = pipeline(self.task, model=self.model_id) + self.pipeline_inference(body_2d_keypoints, Image.open(self.test_image)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_body_3d_keypoints.py b/tests/pipelines/test_body_3d_keypoints.py new file mode 100644 index 00000000..6e671d2e --- /dev/null +++ b/tests/pipelines/test_body_3d_keypoints.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +import numpy as np + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_canonical_body-3d-keypoints_video' + self.test_video = 'data/test/videos/Walking.54138969.mp4' + self.task = Tasks.body_3d_keypoints + + def pipeline_inference(self, pipeline: Pipeline, pipeline_input): + output = pipeline(pipeline_input, output_video='./result.mp4') + poses = np.array(output[OutputKeys.KEYPOINTS]) + print(f'result 3d points shape {poses.shape}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_with_video_file(self): + body_3d_keypoints = pipeline( + Tasks.body_3d_keypoints, model=self.model_id) + pipeline_input = self.test_video + self.pipeline_inference( + body_3d_keypoints, pipeline_input=pipeline_input) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_with_video_stream(self): + body_3d_keypoints = pipeline(Tasks.body_3d_keypoints) + cap = cv2.VideoCapture(self.test_video) + if not cap.isOpened(): + raise Exception('modelscope error: %s cannot be decoded by OpenCV.' + % (self.test_video)) + pipeline_input = self.test_video + self.pipeline_inference( + body_3d_keypoints, pipeline_input=pipeline_input) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_builder.py b/tests/pipelines/test_builder.py new file mode 100644 index 00000000..6caa2cb1 --- /dev/null +++ b/tests/pipelines/test_builder.py @@ -0,0 +1,86 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import unittest +from typing import Any, Dict, List, Union + +from modelscope.fileio import io +from modelscope.models.base import Model +from modelscope.pipelines import Pipeline, pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import (ConfigFields, Frameworks, ModelFile, + Tasks) +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + group_key=Tasks.image_classification, module_name='custom_single_model') +class CustomSingleModelPipeline(Pipeline): + + def __init__(self, + config_file: str = None, + model: List[Union[str, Model]] = None, + preprocessor=None, + **kwargs): + super().__init__(config_file, model, preprocessor, **kwargs) + assert isinstance(model, str), 'model is not str' + print(model) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return super().postprocess(inputs) + + +@PIPELINES.register_module( + group_key=Tasks.image_classification, module_name='model1_model2') +class CustomMultiModelPipeline(Pipeline): + + def __init__(self, + config_file: str = None, + model: List[Union[str, Model]] = None, + preprocessor=None, + **kwargs): + super().__init__(config_file, model, preprocessor, **kwargs) + assert isinstance(model, list), 'model is not list' + for m in model: + assert isinstance(m, str), 'submodel is not str' + print(m) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return super().postprocess(inputs) + + +class PipelineInterfaceTest(unittest.TestCase): + + def prepare_dir(self, dirname, pipeline_name): + if not os.path.exists(dirname): + os.makedirs(dirname) + cfg_file = os.path.join(dirname, ModelFile.CONFIGURATION) + cfg = { + ConfigFields.framework: Frameworks.torch, + ConfigFields.task: Tasks.image_classification, + ConfigFields.pipeline: { + 'type': pipeline_name, + } + } + io.dump(cfg, cfg_file) + + def setUp(self) -> None: + self.prepare_dir('/tmp/custom_single_model', 'custom_single_model') + self.prepare_dir('/tmp/model1', 'model1_model2') + self.prepare_dir('/tmp/model2', 'model1_model2') + + def test_single_model(self): + pipe = pipeline( + Tasks.image_classification, model='/tmp/custom_single_model') + assert isinstance(pipe, CustomSingleModelPipeline) + + def test_multi_model(self): + pipe = pipeline( + Tasks.image_classification, model=['/tmp/model1', '/tmp/model2']) + assert isinstance(pipe, CustomMultiModelPipeline) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_card_detection.py b/tests/pipelines/test_card_detection.py new file mode 100644 index 00000000..d913f494 --- /dev/null +++ b/tests/pipelines/test_card_detection.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_card_detection_result +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class CardDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.card_detection + self.model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' + + def show_result(self, img_path, detection_result): + img_list = draw_card_detection_result(img_path, detection_result) + for i, img in enumerate(img_list): + if i == 0: + cv2.imwrite('result.jpg', img_list[0]) + print( + f'Found {len(img_list)-1} cards, output written to {osp.abspath("result.jpg")}' + ) + else: + cv2.imwrite(f'card_{i}.jpg', img_list[i]) + save_path = osp.abspath(f'card_{i}.jpg') + print(f'detect card_{i}: {save_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_dataset(self): + input_location = ['data/test/images/card_detection.jpg'] + + dataset = MsDataset.load(input_location, target='image') + card_detection = pipeline(Tasks.card_detection, model=self.model_id) + # note that for dataset output, the inference-output is a Generator that can be iterated. + result = card_detection(dataset) + result = next(result) + self.show_result(input_location[0], result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + card_detection = pipeline(Tasks.card_detection, model=self.model_id) + img_path = 'data/test/images/card_detection.jpg' + + result = card_detection(img_path) + self.show_result(img_path, result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + card_detection = pipeline(Tasks.card_detection) + img_path = 'data/test/images/card_detection.jpg' + result = card_detection(img_path) + self.show_result(img_path, result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_cmdssl_video_embedding.py b/tests/pipelines/test_cmdssl_video_embedding.py new file mode 100644 index 00000000..5807c075 --- /dev/null +++ b/tests/pipelines/test_cmdssl_video_embedding.py @@ -0,0 +1,31 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +# !/usr/bin/env python +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class CMDSSLVideoEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.video_embedding + self.model_id = 'damo/cv_r2p1d_video_embedding' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + videossl_pipeline = pipeline(task=self.task, model=self.model_id) + result = videossl_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + + print(f'video embedding output: {result}.') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_conversational_text_to_sql.py b/tests/pipelines/test_conversational_text_to_sql.py new file mode 100644 index 00000000..17fffcaf --- /dev/null +++ b/tests/pipelines/test_conversational_text_to_sql.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import StarForTextToSql +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import ConversationalTextToSqlPipeline +from modelscope.preprocessors import ConversationalTextToSqlPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.nlp.space_T_en.utils import \ + text2sql_tracking_and_print_results +from modelscope.utils.test_utils import test_level + + +class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.table_question_answering + self.model_id = 'damo/nlp_star_conversational-text-to-sql' + + model_id = 'damo/nlp_star_conversational-text-to-sql' + test_case = { + 'database_id': + 'employee_hire_evaluation', + 'local_db_path': + None, + 'utterance': [ + "I'd like to see Shop names.", 'Which of these are hiring?', + 'Which shop is hiring the highest number of employees? | do you want the name of the shop ? | Yes' + ] + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + preprocessor = ConversationalTextToSqlPreprocessor( + model_dir=cache_path, + database_id=self.test_case['database_id'], + db_content=True) + model = StarForTextToSql( + model_dir=cache_path, config=preprocessor.config) + + pipelines = [ + ConversationalTextToSqlPipeline( + model=model, preprocessor=preprocessor), + pipeline(task=self.task, model=model, preprocessor=preprocessor) + ] + text2sql_tracking_and_print_results(self.test_case, pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = ConversationalTextToSqlPreprocessor( + model_dir=model.model_dir) + + pipelines = [ + ConversationalTextToSqlPipeline( + model=model, preprocessor=preprocessor), + pipeline(task=self.task, model=model, preprocessor=preprocessor) + ] + text2sql_tracking_and_print_results(self.test_case, pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipelines = [pipeline(task=self.task, model=self.model_id)] + text2sql_tracking_and_print_results(self.test_case, pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_crowd_counting.py b/tests/pipelines/test_crowd_counting.py new file mode 100644 index 00000000..4e15cfca --- /dev/null +++ b/tests/pipelines/test_crowd_counting.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import numpy_to_cv2img +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class CrowdCountingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.input_location = 'data/test/images/crowd_counting.jpg' + self.model_id = 'damo/cv_hrnet_crowd-counting_dcanet' + self.task = Tasks.crowd_counting + + def save_result(self, result): + print('scores:', result[OutputKeys.SCORES]) + vis_img = result[OutputKeys.OUTPUT_IMG] + vis_img = numpy_to_cv2img(vis_img) + cv2.imwrite('result.jpg', vis_img) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_crowd_counting(self): + crowd_counting = pipeline(task=self.task, model=self.model_id) + result = crowd_counting(self.input_location) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_crowd_counting_with_image(self): + crowd_counting = pipeline(task=self.task, model=self.model_id) + img = Image.open(self.input_location) + result = crowd_counting(img) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_crowd_counting_with_default_task(self): + crowd_counting = pipeline(self.task) + result = crowd_counting(self.input_location) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_csanmt_translation.py b/tests/pipelines/test_csanmt_translation.py new file mode 100644 index 00000000..83827813 --- /dev/null +++ b/tests/pipelines/test_csanmt_translation.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TranslationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.translation + self.model_id = 'damo/nlp_csanmt_translation_zh2en' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_zh2en(self): + inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2zh(self): + model_id = 'damo/nlp_csanmt_translation_en2zh' + inputs = 'Elon Musk, co-founder and chief executive officer of Tesla Motors.' + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2fr(self): + model_id = 'damo/nlp_csanmt_translation_en2fr' + inputs = 'When I was in my 20s, I saw my very first psychotherapy client.' + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_fr2en(self): + model_id = 'damo/nlp_csanmt_translation_fr2en' + inputs = "Quand j'avais la vingtaine, j'ai vu mes tout premiers clients comme psychothérapeute." + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' + pipeline_ins = pipeline(self.task) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_deberta_tasks.py b/tests/pipelines/test_deberta_tasks.py new file mode 100644 index 00000000..549d2cb3 --- /dev/null +++ b/tests/pipelines/test_deberta_tasks.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import DebertaV2ForMaskedLM +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import FillMaskPipeline +from modelscope.preprocessors import NLPPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class DeBERTaV2TaskTest(unittest.TestCase): + model_id_deberta = 'damo/nlp_debertav2_fill-mask_chinese-lite' + + ori_text = '你师父差得动你,你师父可差不动我。' + test_input = '你师父差得动你,你师父可[MASK]不动我。' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id_deberta) + preprocessor = NLPPreprocessor( + model_dir, first_sequence='sentence', second_sequence=None) + model = DebertaV2ForMaskedLM.from_pretrained(model_dir) + pipeline1 = FillMaskPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.fill_mask, model=model, preprocessor=preprocessor) + ori_text = self.ori_text + test_input = self.test_input + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' + f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + # sbert + print(self.model_id_deberta) + model = Model.from_pretrained(self.model_id_deberta) + preprocessor = NLPPreprocessor( + model.model_dir, first_sequence='sentence', second_sequence=None) + pipeline_ins = pipeline( + task=Tasks.fill_mask, model=model, preprocessor=preprocessor) + print( + f'\nori_text: {self.ori_text}\ninput: {self.test_input}\npipeline: ' + f'{pipeline_ins(self.test_input)}\n') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.fill_mask, model=self.model_id_deberta) + ori_text = self.ori_text + test_input = self.test_input + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' + f'{pipeline_ins(test_input)}\n') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_dialog_intent_prediction.py b/tests/pipelines/test_dialog_intent_prediction.py new file mode 100644 index 00000000..2ee46388 --- /dev/null +++ b/tests/pipelines/test_dialog_intent_prediction.py @@ -0,0 +1,77 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SpaceForDialogIntent +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import DialogIntentPredictionPipeline +from modelscope.preprocessors import DialogIntentPredictionPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class DialogIntentPredictionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.task_oriented_conversation + self.model_id = 'damo/nlp_space_dialog-intent-prediction' + + test_case = [ + 'How do I locate my card?', + 'I still have not received my new card, I ordered over a week ago.' + ] + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) + model = SpaceForDialogIntent( + model_dir=cache_path, + text_field=preprocessor.text_field, + config=preprocessor.config) + + pipelines = [ + DialogIntentPredictionPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.task_oriented_conversation, + model=model, + preprocessor=preprocessor) + ] + + for my_pipeline, item in list(zip(pipelines, self.test_case)): + print(my_pipeline(item)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = DialogIntentPredictionPreprocessor( + model_dir=model.model_dir) + + pipelines = [ + DialogIntentPredictionPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.task_oriented_conversation, + model=model, + preprocessor=preprocessor) + ] + + for my_pipeline, item in list(zip(pipelines, self.test_case)): + print(my_pipeline(item)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipelines = [pipeline(task=self.task, model=self.model_id)] + for my_pipeline, item in list(zip(pipelines, self.test_case)): + print(my_pipeline(item)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_dialog_modeling.py b/tests/pipelines/test_dialog_modeling.py new file mode 100644 index 00000000..6b6259ce --- /dev/null +++ b/tests/pipelines/test_dialog_modeling.py @@ -0,0 +1,157 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest +from typing import List + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SpaceForDialogModeling +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import DialogModelingPipeline +from modelscope.preprocessors import DialogModelingPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class DialogModelingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.task_oriented_conversation + self.model_id = 'damo/nlp_space_dialog-modeling' + + test_case = { + 'sng0073': { + 'goal': { + 'taxi': { + 'info': { + 'leaveat': '17:15', + 'destination': 'pizza hut fen ditton', + 'departure': "saint john's college" + }, + 'reqt': ['car', 'phone'], + 'fail_info': {} + } + }, + 'log': [{ + 'user': + "i would like a taxi from saint john 's college to pizza hut fen ditton .", + 'user_delex': + 'i would like a taxi from [value_departure] to [value_destination] .', + 'resp': + 'what time do you want to leave and what time do you want to arrive by ?', + 'sys': + 'what time do you want to leave and what time do you want to arrive by ?', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college", + 'cons_delex': '[taxi] destination departure', + 'sys_act': '[taxi] [request] leave arrive', + 'turn_num': 0, + 'turn_domain': '[taxi]' + }, { + 'user': 'i want to leave after 17:15 .', + 'user_delex': 'i want to leave after [value_leave] .', + 'resp': + 'booking completed ! your taxi will be [value_car] contact number is [value_phone]', + 'sys': + 'booking completed ! your taxi will be blue honda contact number is 07218068540', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", + 'cons_delex': '[taxi] destination departure leave', + 'sys_act': '[taxi] [inform] car phone', + 'turn_num': 1, + 'turn_domain': '[taxi]' + }, { + 'user': 'thank you for all the help ! i appreciate it .', + 'user_delex': 'thank you for all the help ! i appreciate it .', + 'resp': + 'you are welcome . is there anything else i can help you with today ?', + 'sys': + 'you are welcome . is there anything else i can help you with today ?', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", + 'cons_delex': '[taxi] destination departure leave', + 'sys_act': '[general] [reqmore]', + 'turn_num': 2, + 'turn_domain': '[general]' + }, { + 'user': 'no , i am all set . have a nice day . bye .', + 'user_delex': 'no , i am all set . have a nice day . bye .', + 'resp': 'you too ! thank you', + 'sys': 'you too ! thank you', + 'pointer': '0,0,0,0,0,0', + 'match': '', + 'constraint': + "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", + 'cons_delex': '[taxi] destination departure leave', + 'sys_act': '[general] [bye]', + 'turn_num': 3, + 'turn_domain': '[general]' + }] + } + } + + def generate_and_print_dialog_response( + self, pipelines: List[DialogModelingPipeline]): + + result = {} + pipeline_len = len(pipelines) + for step, item in enumerate(self.test_case['sng0073']['log']): + user = item['user'] + print('user: {}'.format(user)) + + result = pipelines[step % pipeline_len]({ + 'user_input': user, + 'history': result + }) + print('response : {}'.format(result[OutputKeys.OUTPUT])) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + + cache_path = snapshot_download(self.model_id) + + preprocessor = DialogModelingPreprocessor(model_dir=cache_path) + model = SpaceForDialogModeling( + model_dir=cache_path, + text_field=preprocessor.text_field, + config=preprocessor.config) + pipelines = [ + DialogModelingPipeline(model=model, preprocessor=preprocessor) + ] + self.generate_and_print_dialog_response(pipelines) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) + + pipelines = [ + DialogModelingPipeline(model=model, preprocessor=preprocessor) + ] + + self.generate_and_print_dialog_response(pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipelines = [pipeline(task=self.task, model=self.model_id)] + self.generate_and_print_dialog_response(pipelines) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipelines = [pipeline(task=self.task)] + self.generate_and_print_dialog_response(pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_dialog_state_tracking.py b/tests/pipelines/test_dialog_state_tracking.py new file mode 100644 index 00000000..6cdd5ee7 --- /dev/null +++ b/tests/pipelines/test_dialog_state_tracking.py @@ -0,0 +1,128 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SpaceForDST +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import DialogStateTrackingPipeline +from modelscope.preprocessors import DialogStateTrackingPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.nlp.space.utils_dst import \ + tracking_and_print_dialog_states +from modelscope.utils.test_utils import test_level + + +class DialogStateTrackingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.task_oriented_conversation + self.model_id = 'damo/nlp_space_dialog-state-tracking' + + test_case = [{ + 'User-1': + 'Hi, I\'m looking for a train that is going to cambridge and arriving there by 20:45, ' + 'is there anything like that?' + }, { + 'System-1': + 'There are over 1,000 trains like that. Where will you be departing from?', + 'Dialog_Act-1': { + 'Train-Inform': [['Choice', 'over 1'], ['Choice', '000']], + 'Train-Request': [['Depart', '?']] + }, + 'User-2': 'I am departing from birmingham new street.' + }, { + 'System-2': 'Can you confirm your desired travel day?', + 'Dialog_Act-2': { + 'Train-Request': [['Day', '?']] + }, + 'User-3': 'I would like to leave on wednesday' + }, { + 'System-3': + 'I show a train leaving birmingham new street at 17:40 and arriving at 20:23 on Wednesday. ' + 'Will this work for you?', + 'Dialog_Act-3': { + 'Train-Inform': [['Arrive', '20:23'], ['Leave', '17:40'], + ['Day', 'Wednesday'], + ['Depart', 'birmingham new street']] + }, + 'User-4': + 'That will, yes. Please make a booking for 5 people please.', + }, { + 'System-4': + 'I\'ve booked your train tickets, and your reference number is A9NHSO9Y.', + 'Dialog_Act-4': { + 'Train-OfferBooked': [['Ref', 'A9NHSO9Y']] + }, + 'User-5': + 'Thanks so much. I would also need a place to say. ' + 'I am looking for something with 4 stars and has free wifi.' + }, { + 'System-5': + 'How about the cambridge belfry? ' + 'It has all the attributes you requested and a great name! ' + 'Maybe even a real belfry?', + 'Dialog_Act-5': { + 'Hotel-Recommend': [['Name', 'the cambridge belfry']] + }, + 'User-6': + 'That sounds great, could you make a booking for me please?', + }, { + 'System-6': + 'What day would you like your booking for?', + 'Dialog_Act-6': { + 'Booking-Request': [['Day', '?']] + }, + 'User-7': + 'Please book it for Wednesday for 5 people and 5 nights, please.', + }, { + 'System-7': 'Booking was successful. Reference number is : 5NAWGJDC.', + 'Dialog_Act-7': { + 'Booking-Book': [['Ref', '5NAWGJDC']] + }, + 'User-8': 'Thank you, goodbye', + }] + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + + model = SpaceForDST.from_pretrained(cache_path) + preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) + pipelines = [ + DialogStateTrackingPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.task_oriented_conversation, + model=model, + preprocessor=preprocessor) + ] + tracking_and_print_dialog_states(self.test_case, pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + + preprocessor = DialogStateTrackingPreprocessor( + model_dir=model.model_dir) + pipelines = [ + DialogStateTrackingPipeline( + model=model, preprocessor=preprocessor), + pipeline(task=self.task, model=model, preprocessor=preprocessor) + ] + + tracking_and_print_dialog_states(self.test_case, pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipelines = [pipeline(task=self.task, model=self.model_id)] + tracking_and_print_dialog_states(self.test_case, pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_document_segmentation.py b/tests/pipelines/test_document_segmentation.py new file mode 100644 index 00000000..b4406fef --- /dev/null +++ b/tests/pipelines/test_document_segmentation.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest +from typing import Any, Dict + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class DocumentSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.document_segmentation + self.model_id = 'damo/nlp_bert_document-segmentation_chinese-base' + + model_id = 'damo/nlp_bert_document-segmentation_chinese-base' + eng_model_id = 'damo/nlp_bert_document-segmentation_english-base' + sentences = '近年来,随着端到端语音识别的流行,基于Transformer结构的语音识别系统逐渐成为了主流。然而,由于Transformer是一种自回归模型,需要逐个生成目标文字,计算复杂度随着目标文字数量线性增加,限制了其在工业生产中的应用。针对Transoformer模型自回归生成文字的低计算效率缺陷,学术界提出了非自回归模型来并行的输出目标文字。根据生成目标文字时,迭代轮数,非自回归模型分为:多轮迭代式与单轮迭代非自回归模型。其中实用的是基于单轮迭代的非自回归模型。对于单轮非自回归模型,现有工作往往聚焦于如何更加准确的预测目标文字个数,如CTC-enhanced采用CTC预测输出文字个数,尽管如此,考虑到现实应用中,语速、口音、静音以及噪声等因素的影响,如何准确的预测目标文字个数以及抽取目标文字对应的声学隐变量仍然是一个比较大的挑战;另外一方面,我们通过对比自回归模型与单轮非自回归模型在工业大数据上的错误类型(如下图所示,AR与vanilla NAR),发现,相比于自回归模型,非自回归模型,在预测目标文字个数方面差距较小,但是替换错误显著的增加,我们认为这是由于单轮非自回归模型中条件独立假设导致的语义信息丢失。于此同时,目前非自回归模型主要停留在学术验证阶段,还没有工业大数据上的相关实验与结论。' # noqa * + sentences_1 = '移动端语音唤醒模型,检测关键词为“小云小云”。模型主体为4层FSMN结构,使用CTC训练准则,参数量750K,适用于移动端设备运行。模型输入为Fbank特征,输出为基于char建模的中文全集token预测,测试工具根据每一帧的预测数据进行后处理得到输入音频的实时检测结果。模型训练采用“basetrain + finetune”的模式,basetrain过程使用大量内部移动端数据,在此基础上,使用1万条设备端录制安静场景“小云小云”数据进行微调,得到最终面向业务的模型。后续用户可在basetrain模型基础上,使用其他关键词数据进行微调,得到新的语音唤醒模型,但暂时未开放模型finetune功能。' # noqa * + eng_sentences = 'The Saint Alexander Nevsky Church was established in 1936 by Archbishop Vitaly (Maximenko) () on a tract of land donated by Yulia Martinovna Plavskaya.The initial chapel, dedicated to the memory of the great prince St. Alexander Nevsky (1220–1263), was blessed in May, 1936.The church building was subsequently expanded three times.In 1987, ground was cleared for the construction of the new church and on September 12, 1989, on the Feast Day of St. Alexander Nevsky, the cornerstone was laid and the relics of St. Herman of Alaska placed in the foundation.The imposing edifice, completed in 1997, is the work of Nikolaus Karsanov, architect and Protopresbyter Valery Lukianov, engineer.Funds were raised through donations.The Great blessing of the cathedral took place on October 18, 1997 with seven bishops, headed by Metropolitan Vitaly Ustinov, and 36 priests and deacons officiating, some 800 faithful attended the festivity.The old church was rededicated to Our Lady of Tikhvin.Metropolitan Hilarion (Kapral) announced, that cathedral will officially become the episcopal See of the Ruling Bishop of the Eastern American Diocese and the administrative center of the Diocese on September 12, 2014.At present the parish serves the spiritual needs of 300 members.The parochial school instructs over 90 boys and girls in religion, Russian language and history.The school meets every Saturday.The choir is directed by Andrew Burbelo.The sisterhood attends to the needs of the church and a church council acts in the administration of the community.The cathedral is decorated by frescoes in the Byzantine style.The iconography project was fulfilled by Father Andrew Erastov and his students from 1995 until 2001.' # noqa * + + def run_pipeline(self, model_id: str, documents: str) -> Dict[str, Any]: + p = pipeline(task=self.task, model=model_id) + result = p(documents=documents) + return result + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_document(self): + logger.info('Run document segmentation with one document ...') + + result = self.run_pipeline( + model_id=self.model_id, documents=self.sentences) + print(result[OutputKeys.TEXT]) + + result = self.run_pipeline( + model_id=self.eng_model_id, documents=self.eng_sentences) + print(result[OutputKeys.TEXT]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_documents(self): + logger.info('Run document segmentation with many documents ...') + + result = self.run_pipeline( + model_id=self.model_id, + documents=[self.sentences, self.sentences_1]) + + documents_list = result[OutputKeys.TEXT] + for document in documents_list: + print(document) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_domain_classification.py b/tests/pipelines/test_domain_classification.py new file mode 100644 index 00000000..8e5bfa7f --- /dev/null +++ b/tests/pipelines/test_domain_classification.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class DomainClassificationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_classification + self.model_id = 'damo/nlp_domain_classification_chinese' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_zh_domain(self): + inputs = '通过这种方式产生的离子吸收大地水分之后,可以通过潮解作用,将活性电解离子有效释放到周围土壤中,使接地极成为一个离子发生装置,' \ + '从而改善周边土质使之达到接地要求。' + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_zh_style(self): + model_id = 'damo/nlp_style_classification_chinese' + inputs = '通过这种方式产生的离子吸收大地水分之后,可以通过潮解作用,将活性电解离子有效释放到周围土壤中,使接地极成为一个离子发生装置,' \ + '从而改善周边土质使之达到接地要求。' + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en_style(self): + model_id = 'damo/nlp_style_classification_english' + inputs = 'High Power 11.1V 5200mAh Lipo Battery For RC Car Robot Airplanes ' \ + 'Helicopter RC Drone Parts 3s Lithium battery 11.1v Battery' + pipeline_ins = pipeline(self.task, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_2d_keypoints.py b/tests/pipelines/test_face_2d_keypoints.py new file mode 100644 index 00000000..7ccc8a59 --- /dev/null +++ b/tests/pipelines/test_face_2d_keypoints.py @@ -0,0 +1,41 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_106face_keypoints +from modelscope.utils.test_utils import test_level + + +class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_face_2d_keypoints(self): + img_path = 'data/test/images/face_detection.png' + model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' + + face_2d_keypoints_align = pipeline( + task=Tasks.face_2d_keypoints, model=model_id) + output = face_2d_keypoints_align(img_path) + + output_keypoints = output[OutputKeys.KEYPOINTS] + output_poses = output[OutputKeys.POSES] + output_boxes = output[OutputKeys.BOXES] + + draw_106face_keypoints( + img_path, + output_keypoints, + output_boxes, + scale=2, + save_path='face_keypoints.jpg') + + for idx in range(len(output_keypoints)): + self.assertEqual(output_keypoints[idx].shape[0], 106) + self.assertEqual(output_keypoints[idx].shape[1], 2) + self.assertEqual(output_poses[idx].shape[0], 3) + self.assertEqual(output_boxes[idx].shape[0], 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_detection.py b/tests/pipelines/test_face_detection.py new file mode 100644 index 00000000..db513a80 --- /dev/null +++ b/tests/pipelines/test_face_detection.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_result +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.face_detection + self.model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_dataset(self): + input_location = ['data/test/images/face_detection2.jpeg'] + + dataset = MsDataset.load(input_location, target='image') + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + # note that for dataset output, the inference-output is a Generator that can be iterated. + result = face_detection(dataset) + result = next(result) + self.show_result(input_location[0], result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/face_detection2.jpeg' + + result = face_detection(img_path) + self.show_result(img_path, result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_detection = pipeline(Tasks.face_detection) + img_path = 'data/test/images/face_detection2.jpeg' + result = face_detection(img_path) + self.show_result(img_path, result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_emotion.py b/tests/pipelines/test_face_emotion.py new file mode 100644 index 00000000..96fe51a7 --- /dev/null +++ b/tests/pipelines/test_face_emotion.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class FaceEmotionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model = 'damo/cv_face-emotion' + self.img = {'img_path': 'data/test/images/face_emotion.jpg'} + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + print(result) + + @unittest.skip('skip since the model is set to private for now') + def test_run_modelhub(self): + face_emotion = pipeline(Tasks.face_emotion, model=self.model) + self.pipeline_inference(face_emotion, self.img) + + @unittest.skip('skip since the model is set to private for now') + def test_run_modelhub_default_model(self): + face_emotion = pipeline(Tasks.face_emotion) + self.pipeline_inference(face_emotion, self.img) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_human_hand_detection.py b/tests/pipelines/test_face_human_hand_detection.py new file mode 100644 index 00000000..7aaa67e7 --- /dev/null +++ b/tests/pipelines/test_face_human_hand_detection.py @@ -0,0 +1,38 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class FaceHumanHandTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_nanodet_face-human-hand-detection' + self.input = { + 'input_path': 'data/test/images/face_human_hand_detection.jpg', + } + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + logger.info(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_human_hand_detection = pipeline( + Tasks.face_human_hand_detection, model=self.model_id) + self.pipeline_inference(face_human_hand_detection, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_human_hand_detection = pipeline(Tasks.face_human_hand_detection) + self.pipeline_inference(face_human_hand_detection, self.input) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_image_generation.py b/tests/pipelines/test_face_image_generation.py new file mode 100644 index 00000000..21d8e835 --- /dev/null +++ b/tests/pipelines/test_face_image_generation.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class FaceGenerationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.face_image_generation + self.model_id = 'damo/cv_gan_face-image-generation' + + def pipeline_inference(self, pipeline: Pipeline, seed: int): + result = pipeline(seed) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + seed = 10 + face_generation = pipeline( + self.task, + model=self.model_id, + ) + self.pipeline_inference(face_generation, seed) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + seed = 10 + face_generation = pipeline(self.task) + self.pipeline_inference(face_generation, seed) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_face_recognition.py b/tests/pipelines/test_face_recognition.py new file mode 100644 index 00000000..d3451f5d --- /dev/null +++ b/tests/pipelines/test_face_recognition.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import numpy as np + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class FaceRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.face_recognition + self.model_id = 'damo/cv_ir101_facerecognition_cfglint' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_face_compare(self): + img1 = 'data/test/images/face_recognition_1.png' + img2 = 'data/test/images/face_recognition_2.png' + + face_recognition = pipeline( + Tasks.face_recognition, model=self.model_id) + emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING] + emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING] + sim = np.dot(emb1[0], emb2[0]) + print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_facial_expression_recognition.py b/tests/pipelines/test_facial_expression_recognition.py new file mode 100644 index 00000000..f5151bef --- /dev/null +++ b/tests/pipelines/test_facial_expression_recognition.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 +import numpy as np + +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_facial_expression_result +from modelscope.utils.test_utils import test_level + + +class FacialExpressionRecognitionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_vgg19_facial-expression-recognition_fer' + + def show_result(self, img_path, facial_expression_result): + img = draw_facial_expression_result(img_path, facial_expression_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skip('skip since the model is set to private for now') + def test_run_modelhub(self): + fer = pipeline( + Tasks.facial_expression_recognition, model=self.model_id) + img_path = 'data/test/images/facial_expression_recognition.jpg' + result = fer(img_path) + self.show_result(img_path, result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_faq_question_answering.py b/tests/pipelines/test_faq_question_answering.py new file mode 100644 index 00000000..2f66f516 --- /dev/null +++ b/tests/pipelines/test_faq_question_answering.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import numpy as np + +from modelscope.hub.api import HubApi +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForFaqQuestionAnswering +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import FaqQuestionAnsweringPipeline +from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class FaqQuestionAnsweringTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.faq_question_answering + self.model_id = 'damo/nlp_structbert_faq-question-answering_chinese-base' + + param = { + 'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'], + 'support_set': [{ + 'text': '卖品代金券怎么用', + 'label': '6527856' + }, { + 'text': '怎么使用优惠券', + 'label': '6527856' + }, { + 'text': '这个可以一起领吗', + 'label': '1000012000' + }, { + 'text': '付款时送的优惠券哪里领', + 'label': '1000012000' + }, { + 'text': '购物等级怎么长', + 'label': '13421097' + }, { + 'text': '购物等级二心', + 'label': '13421097' + }] + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_direct_file_download(self): + cache_path = snapshot_download(self.model_id) + preprocessor = FaqQuestionAnsweringPreprocessor.from_pretrained( + cache_path) + model = SbertForFaqQuestionAnswering.from_pretrained(cache_path) + pipeline_ins = FaqQuestionAnsweringPipeline( + model, preprocessor=preprocessor) + result = pipeline_ins(self.param) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = FaqQuestionAnsweringPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.faq_question_answering, + model=model, + preprocessor=preprocessor) + result = pipeline_ins(self.param) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.faq_question_answering, model=self.model_id) + result = pipeline_ins(self.param) + print(result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.faq_question_answering) + print(pipeline_ins(self.param, max_seq_length=20)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_sentence_embedding(self): + pipeline_ins = pipeline(task=Tasks.faq_question_answering) + sentence_vec = pipeline_ins.get_sentence_embedding( + ['今天星期六', '明天星期几明天星期几']) + print(np.shape(sentence_vec)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_feature_extraction.py b/tests/pipelines/test_feature_extraction.py new file mode 100644 index 00000000..39291e76 --- /dev/null +++ b/tests/pipelines/test_feature_extraction.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import numpy as np + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import FeatureExtractionModel +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import FeatureExtractionPipeline +from modelscope.preprocessors import NLPPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class FeatureExtractionTaskModelTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.feature_extraction + self.model_id = 'damo/pert_feature-extraction_base-test' + + sentence1 = '测试embedding' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_direct_file_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = NLPPreprocessor(cache_path, padding=False) + model = FeatureExtractionModel.from_pretrained(self.model_id) + pipeline1 = FeatureExtractionPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.feature_extraction, model=model, preprocessor=tokenizer) + result = pipeline1(input=self.sentence1) + + print(f'sentence1: {self.sentence1}\n' + f'pipeline1:{np.shape(result[OutputKeys.TEXT_EMBEDDING])}') + result = pipeline2(input=self.sentence1) + print(f'sentence1: {self.sentence1}\n' + f'pipeline1: {np.shape(result[OutputKeys.TEXT_EMBEDDING])}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = NLPPreprocessor(model.model_dir, padding=False) + pipeline_ins = pipeline( + task=Tasks.feature_extraction, model=model, preprocessor=tokenizer) + result = pipeline_ins(input=self.sentence1) + print(np.shape(result[OutputKeys.TEXT_EMBEDDING])) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.feature_extraction, model=self.model_id) + result = pipeline_ins(input=self.sentence1) + print(np.shape(result[OutputKeys.TEXT_EMBEDDING])) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.feature_extraction) + result = pipeline_ins(input=self.sentence1) + print(np.shape(result[OutputKeys.TEXT_EMBEDDING])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_fill_mask.py b/tests/pipelines/test_fill_mask.py new file mode 100644 index 00000000..64833026 --- /dev/null +++ b/tests/pipelines/test_fill_mask.py @@ -0,0 +1,176 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from regex import R + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForMaskedLM, VecoForMaskedLM +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import FillMaskPipeline +from modelscope.preprocessors import NLPPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool +from modelscope.utils.test_utils import test_level + + +class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.fill_mask + self.model_id = 'damo/nlp_veco_fill-mask-large' + + model_id_sbert = { + 'zh': 'damo/nlp_structbert_fill-mask_chinese-large', + 'en': 'damo/nlp_structbert_fill-mask_english-large' + } + model_id_veco = 'damo/nlp_veco_fill-mask-large' + model_id_bert = 'damo/nlp_bert_fill-mask_chinese-base' + + ori_texts = { + 'zh': + '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。' + '你师父差得动你,你师父可差不动我。', + 'en': + 'Everything in what you call reality is really just a reflection of your ' + 'consciousness. Your whole universe is just a mirror reflection of your story.' + } + + test_inputs = { + 'zh': + '段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你' + '师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。', + 'en': + 'Everything in [MASK] you call reality is really [MASK] a reflection of your ' + '[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.' + } + regress_tool = MsRegressTool(baseline=False) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + # sbert + for language in ['zh']: + model_dir = snapshot_download(self.model_id_sbert[language]) + preprocessor = NLPPreprocessor( + model_dir, first_sequence='sentence', second_sequence=None) + model = SbertForMaskedLM.from_pretrained(model_dir) + pipeline1 = FillMaskPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.fill_mask, model=model, preprocessor=preprocessor) + ori_text = self.ori_texts[language] + test_input = self.test_inputs[language] + print( + f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' + f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n' + ) + + # veco + model_dir = snapshot_download(self.model_id_veco) + preprocessor = NLPPreprocessor( + model_dir, first_sequence='sentence', second_sequence=None) + model = VecoForMaskedLM.from_pretrained(model_dir) + pipeline1 = FillMaskPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.fill_mask, model=model, preprocessor=preprocessor) + for language in ['zh', 'en']: + ori_text = self.ori_texts[language] + test_input = self.test_inputs[language].replace('[MASK]', '') + print( + f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' + f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n' + ) + + # bert + language = 'zh' + model_dir = snapshot_download(self.model_id_bert) + preprocessor = NLPPreprocessor( + model_dir, first_sequence='sentence', second_sequence=None) + model = Model.from_pretrained(model_dir) + pipeline1 = FillMaskPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.fill_mask, model=model, preprocessor=preprocessor) + ori_text = self.ori_texts[language] + test_input = self.test_inputs[language] + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' + f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + + # sbert + for language in ['zh']: + print(self.model_id_sbert[language]) + model = Model.from_pretrained(self.model_id_sbert[language]) + preprocessor = NLPPreprocessor( + model.model_dir, + first_sequence='sentence', + second_sequence=None) + pipeline_ins = pipeline( + task=Tasks.fill_mask, model=model, preprocessor=preprocessor) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, + f'fill_mask_sbert_{language}', + compare_fn=IgnoreKeyFn('.*intermediate_act_fn')): + print( + f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' + f'{pipeline_ins(self.test_inputs[language])}\n') + + # veco + model = Model.from_pretrained(self.model_id_veco) + preprocessor = NLPPreprocessor( + model.model_dir, first_sequence='sentence', second_sequence=None) + pipeline_ins = pipeline( + Tasks.fill_mask, model=model, preprocessor=preprocessor) + for language in ['zh', 'en']: + ori_text = self.ori_texts[language] + test_input = self.test_inputs[language].replace('[MASK]', '') + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, + f'fill_mask_veco_{language}', + compare_fn=IgnoreKeyFn('.*intermediate_act_fn')): + print( + f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' + f'{pipeline_ins(test_input)}\n') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + # veco + pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_veco) + for language in ['zh', 'en']: + ori_text = self.ori_texts[language] + test_input = self.test_inputs[language].replace('[MASK]', '') + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' + f'{pipeline_ins(test_input)}\n') + + # structBert + language = 'zh' + pipeline_ins = pipeline( + task=Tasks.fill_mask, model=self.model_id_sbert[language]) + print( + f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' + f'{pipeline_ins(self.test_inputs[language])}\n') + + # Bert + language = 'zh' + pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_bert) + print( + f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' + f'{pipeline_ins(self.test_inputs[language])}\n') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.fill_mask) + language = 'en' + ori_text = self.ori_texts[language] + test_input = self.test_inputs[language].replace('[MASK]', '') + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' + f'{pipeline_ins(test_input)}\n') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_fill_mask_ponet.py b/tests/pipelines/test_fill_mask_ponet.py new file mode 100644 index 00000000..707cc201 --- /dev/null +++ b/tests/pipelines/test_fill_mask_ponet.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.metainfo import Pipelines +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class FillMaskPonetTest(unittest.TestCase): + model_id_ponet = { + 'zh': 'damo/nlp_ponet_fill-mask_chinese-base', + 'en': 'damo/nlp_ponet_fill-mask_english-base' + } + + ori_texts = { + 'zh': + '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。' + '你师父差得动你,你师父可差不动我。', + 'en': + 'Everything in what you call reality is really just a reflection of your ' + 'consciousness. Your whole universe is just a mirror reflection of your story.' + } + + test_inputs = { + 'zh': + '段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你' + '师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。', + 'en': + 'Everything in [MASK] you call reality is really [MASK] a reflection of your ' + '[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.' + } + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_ponet_model(self): + for language in ['zh', 'en']: + ori_text = self.ori_texts[language] + test_input = self.test_inputs[language] + + pipeline_ins = pipeline( + task=Tasks.fill_mask, model=self.model_id_ponet[language]) + + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' + f'{pipeline_ins(test_input)}\n') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_general_image_classification.py b/tests/pipelines/test_general_image_classification.py new file mode 100644 index 00000000..d5357f02 --- /dev/null +++ b/tests/pipelines/test_general_image_classification.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class GeneralImageClassificationTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_classification + self.model_id = 'damo/cv_vit-base_image-classification_Dailylife-labels' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_ImageNet(self): + general_image_classification = pipeline( + Tasks.image_classification, + model='damo/cv_vit-base_image-classification_ImageNet-labels') + result = general_image_classification('data/test/images/bird.JPEG') + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_Dailylife(self): + general_image_classification = pipeline( + Tasks.image_classification, + model='damo/cv_vit-base_image-classification_Dailylife-labels') + result = general_image_classification('data/test/images/bird.JPEG') + print(result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_Dailylife_default(self): + general_image_classification = pipeline(Tasks.image_classification) + result = general_image_classification('data/test/images/bird.JPEG') + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_general_recognition.py b/tests/pipelines/test_general_recognition.py new file mode 100644 index 00000000..ba713bbe --- /dev/null +++ b/tests/pipelines/test_general_recognition.py @@ -0,0 +1,31 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class GeneralRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.general_recognition + self.model_id = 'damo/cv_resnest101_general_recognition' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + general_recognition = pipeline( + Tasks.general_recognition, + model='damo/cv_resnest101_general_recognition') + result = general_recognition('data/test/images/dogs.jpg') + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_generative_multi_modal_embedding.py b/tests/pipelines/test_generative_multi_modal_embedding.py new file mode 100644 index 00000000..7061d736 --- /dev/null +++ b/tests/pipelines/test_generative_multi_modal_embedding.py @@ -0,0 +1,77 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import unittest + +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class GEMMMultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.generative_multi_modal_embedding + self.model_id = 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' + + test_input = { + 'image': 'data/test/images/generative_multimodal.jpg', + 'text': + 'interior design of modern living room with fireplace in a new house', + 'captioning': False + } + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + generative_multi_modal_embedding_pipeline = pipeline( + Tasks.generative_multi_modal_embedding, model=self.model_id) + output = generative_multi_modal_embedding_pipeline(self.test_input) + print(output) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding) + output = generative_multi_modal_embedding_pipeline(self.test_input) + print(output) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding, model=model) + output = generative_multi_modal_embedding_pipeline(self.test_input) + print(output) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_output_captioning(self): + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding, model=self.model_id) + test_input = {'image': self.test_input['image'], 'captioning': True} + output = generative_multi_modal_embedding_pipeline(test_input) + print(output) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_output_only_image(self): + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding, model=self.model_id) + test_input = {'image': self.test_input['image'], 'captioning': False} + output = generative_multi_modal_embedding_pipeline(test_input) + print(output) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_output_only_text(self): + generative_multi_modal_embedding_pipeline = pipeline( + task=Tasks.generative_multi_modal_embedding, model=self.model_id) + test_input = {'text': self.test_input['text']} + output = generative_multi_modal_embedding_pipeline(test_input) + print(output) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_gpt3_text_generation.py b/tests/pipelines/test_gpt3_text_generation.py new file mode 100644 index 00000000..674e95bb --- /dev/null +++ b/tests/pipelines/test_gpt3_text_generation.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextGPT3GenerationTest(unittest.TestCase): + + def setUp(self) -> None: + # please make sure this local path exists. + self.model_id_1_3B = 'damo/nlp_gpt3_text-generation_1.3B' + self.model_id_2_7B = 'damo/nlp_gpt3_text-generation_2.7B' + self.model_id_13B = 'damo/nlp_gpt3_text-generation_13B' + self.model_dir_13B = snapshot_download(self.model_id_13B) + self.input = '好的' + + @unittest.skip('distributed gpt3 1.3B, skipped') + def test_gpt3_1_3B(self): + pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B) + print(pipe(self.input)) + + @unittest.skip('distributed gpt3 2.7B, skipped') + def test_gpt3_2_7B(self): + pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B) + print(pipe(self.input)) + + @unittest.skip('distributed gpt3 13B, skipped') + def test_gpt3_13B(self): + """ The model can be downloaded from the link on + TODO: add gpt3 checkpoint link + After downloading, you should have a gpt3 model structure like this: + nlp_gpt3_text-generation_13B + |_ config.json + |_ configuration.json + |_ tokenizer.json + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + """ + pipe = pipeline(Tasks.text_generation, model=self.model_dir_13B) + print(pipe(self.input)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_hand_2d_keypoints.py b/tests/pipelines/test_hand_2d_keypoints.py new file mode 100644 index 00000000..43b569d0 --- /dev/null +++ b/tests/pipelines/test_hand_2d_keypoints.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class Hand2DKeypointsPipelineTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_hand_2d_keypoints(self): + img_path = 'data/test/images/hand_keypoints.jpg' + model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody' + + hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints, model=model_id) + results = hand_keypoint(img_path) + + self.assertIn(OutputKeys.KEYPOINTS, results.keys()) + self.assertIn(OutputKeys.BOXES, results.keys()) + self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21) + self.assertEqual(results[OutputKeys.KEYPOINTS].shape[2], 3) + self.assertEqual(results[OutputKeys.BOXES].shape[1], 4) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_hand_2d_keypoints_with_default_model(self): + img_path = 'data/test/images/hand_keypoints.jpg' + + hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints) + results = hand_keypoint(img_path) + self.assertIn(OutputKeys.KEYPOINTS, results.keys()) + self.assertIn(OutputKeys.BOXES, results.keys()) + self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21) + self.assertEqual(results[OutputKeys.KEYPOINTS].shape[2], 3) + self.assertEqual(results[OutputKeys.BOXES].shape[1], 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_hand_static.py b/tests/pipelines/test_hand_static.py new file mode 100644 index 00000000..37181899 --- /dev/null +++ b/tests/pipelines/test_hand_static.py @@ -0,0 +1,32 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class HandStaticTest(unittest.TestCase): + + def setUp(self) -> None: + self.model = 'damo/cv_mobileface_hand-static' + self.input = {'img_path': 'data/test/images/hand_static.jpg'} + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + hand_static = pipeline(Tasks.hand_static, model=self.model) + self.pipeline_inference(hand_static, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + hand_static = pipeline(Tasks.hand_static) + self.pipeline_inference(hand_static, self.input) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_hicossl_video_embedding.py b/tests/pipelines/test_hicossl_video_embedding.py new file mode 100644 index 00000000..8a7de1fa --- /dev/null +++ b/tests/pipelines/test_hicossl_video_embedding.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# !/usr/bin/env python +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class HICOSSLVideoEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.video_embedding + self.model_id = 'damo/cv_s3dg_video-embedding' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + videossl_pipeline = pipeline( + Tasks.video_embedding, model=self.model_id) + result = videossl_pipeline( + 'data/test/videos/action_recognition_test_video.mp4') + + print(f'video embedding output: {result}.') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_human_wholebody_keypoint.py b/tests/pipelines/test_human_wholebody_keypoint.py new file mode 100644 index 00000000..7c5946cc --- /dev/null +++ b/tests/pipelines/test_human_wholebody_keypoint.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_human_wholebody_keypoint(self): + img_path = 'data/test/images/keypoints_detect/img_test_wholebody.jpg' + model_id = 'damo/cv_hrnetw48_human-wholebody-keypoint_image' + + human_wholebody_keypoint_pipeline = pipeline( + task=Tasks.human_wholebody_keypoint, model=model_id) + output = human_wholebody_keypoint_pipeline(img_path) + + output_keypoints = output[OutputKeys.KEYPOINTS] + output_pose = output[OutputKeys.BOXES] + + human_wholebody_keypoint_pipeline.predict_op.show_result( + img_path, + output_keypoints, + output_pose, + scale=1, + save_path='human_wholebody_keypoint_ret.jpg') + + for keypoint in output_keypoints: + self.assertEqual(keypoint.shape[0], 133) + for box in output_pose: + self.assertEqual(box.shape[0], 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image2image_generation.py b/tests/pipelines/test_image2image_generation.py new file mode 100644 index 00000000..116cef76 --- /dev/null +++ b/tests/pipelines/test_image2image_generation.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from torchvision.utils import save_image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class Image2ImageGenerationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub(self): + r"""We provide two generation modes, i.e., Similar Image Generation and Interpolation. + You can pass the following parameters for different mode. + 1. Similar Image Generation Mode: + 2. Interpolation Mode: + """ + img2img_gen_pipeline = pipeline( + Tasks.image_to_image_generation, + model='damo/cv_latent_diffusion_image2image_generate') + + # Similar Image Generation mode + result1 = img2img_gen_pipeline('data/test/images/img2img_input.jpg') + # Interpolation Mode + result2 = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', + 'data/test/images/img2img_style.jpg')) + save_image( + result1[OutputKeys.OUTPUT_IMG].clamp(-1, 1), + 'result1.jpg', + range=(-1, 1), + normalize=True, + nrow=4) + save_image( + result2[OutputKeys.OUTPUT_IMG].clamp(-1, 1), + 'result2.jpg', + range=(-1, 1), + normalize=True, + nrow=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image2image_translation.py b/tests/pipelines/test_image2image_translation.py new file mode 100644 index 00000000..a1cdb957 --- /dev/null +++ b/tests/pipelines/test_image2image_translation.py @@ -0,0 +1,34 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class Image2ImageTranslationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub(self): + r"""We provide three translation modes, i.e., uncropping, colorization and combination. + You can pass the following parameters for different mode. + 1. Uncropping Mode: + result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 'left', 0, 'result.jpg')) + 2. Colorization Mode: + result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 1, 'result.jpg')) + 3. Combination Mode: + just like the following code. + """ + img2img_gen_pipeline = pipeline( + Tasks.image_to_image_translation, + model='damo/cv_latent_diffusion_image2image_translation') + result = img2img_gen_pipeline( + ('data/test/images/img2img_input_mask.png', + 'data/test/images/img2img_input_masked_img.png', 2, + 'result.jpg')) # combination mode + + print(f'output: {result}.') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_body_reshaping.py b/tests/pipelines/test_image_body_reshaping.py new file mode 100644 index 00000000..e1955e94 --- /dev/null +++ b/tests/pipelines/test_image_body_reshaping.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageBodyReshapingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_body_reshaping + self.model_id = 'damo/cv_flow-based-body-reshaping_damo' + self.test_image = 'data/test/images/image_body_reshaping.jpg' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + if result is not None: + cv2.imwrite('result_bodyreshaping.png', + result[OutputKeys.OUTPUT_IMG]) + print( + f'Output written to {osp.abspath("result_body_reshaping.png")}' + ) + else: + raise Exception('Testing failed: invalid output') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id) + image_body_reshaping = pipeline( + Tasks.image_body_reshaping, model=model_dir) + self.pipeline_inference(image_body_reshaping, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + image_body_reshaping = pipeline( + Tasks.image_body_reshaping, model=self.model_id) + self.pipeline_inference(image_body_reshaping, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + image_body_reshaping = pipeline(Tasks.image_body_reshaping) + self.pipeline_inference(image_body_reshaping, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_color_enhance.py b/tests/pipelines/test_image_color_enhance.py new file mode 100644 index 00000000..7c3ae8c0 --- /dev/null +++ b/tests/pipelines/test_image_color_enhance.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageColorEnhanceTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_csrnet_image-color-enhance-models' + self.task = Tasks.image_color_enhancement + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + img_color_enhance = pipeline( + Tasks.image_color_enhancement, model=self.model_id) + self.pipeline_inference(img_color_enhance, + 'data/test/images/image_color_enhance.png') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + img_color_enhance = pipeline(Tasks.image_color_enhancement) + self.pipeline_inference(img_color_enhance, + 'data/test/images/image_color_enhance.png') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_colorization.py b/tests/pipelines/test_image_colorization.py new file mode 100644 index 00000000..547fce89 --- /dev/null +++ b/tests/pipelines/test_image_colorization.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageColorizationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_unet_image-colorization' + self.test_image = 'data/test/images/marilyn_monroe_4.jpg' + self.task = Tasks.image_colorization + + def pipeline_inference(self, pipeline: Pipeline, test_image: str): + result = pipeline(test_image) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + image_colorization = pipeline( + Tasks.image_colorization, model=self.model_id) + + self.pipeline_inference(image_colorization, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + image_colorization = pipeline(Tasks.image_colorization) + self.pipeline_inference(image_colorization, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_denoise.py b/tests/pipelines/test_image_denoise.py new file mode 100644 index 00000000..d95dd343 --- /dev/null +++ b/tests/pipelines/test_image_denoise.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.cv import ImageDenoisePipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_denoising + self.model_id = 'damo/cv_nafnet_image-denoise_sidd' + + demo_image_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/noisy-demo-0.png' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + pipeline = ImageDenoisePipeline(cache_path) + pipeline.group_key = self.task + denoise_img = pipeline( + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] + print('pipeline: the shape of output_img is {}x{}'.format(h, w)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + pipeline_ins = pipeline(task=Tasks.image_denoising, model=model) + denoise_img = pipeline_ins( + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] + print('pipeline: the shape of output_img is {}x{}'.format(h, w)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.image_denoising, model=self.model_id) + denoise_img = pipeline_ins( + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] + print('pipeline: the shape of output_img is {}x{}'.format(h, w)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.image_denoising) + denoise_img = pipeline_ins( + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] + print('pipeline: the shape of output_img is {}x{}'.format(h, w)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_inpainting.py b/tests/pipelines/test_image_inpainting.py new file mode 100644 index 00000000..a8b704b7 --- /dev/null +++ b/tests/pipelines/test_image_inpainting.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +import torch +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ImageInpaintingTest(unittest.TestCase): + + def setUp(self) -> None: + self.input_location = 'data/test/images/image_inpainting/image_inpainting.png' + self.input_mask_location = 'data/test/images/image_inpainting/image_inpainting_mask.png' + self.model_id = 'damo/cv_fft_inpainting_lama' + self.input = { + 'img': self.input_location, + 'mask': self.input_mask_location + } + + def save_result(self, result): + vis_img = result[OutputKeys.OUTPUT_IMG] + cv2.imwrite('result.png', vis_img) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_inpainting(self): + inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) + result = inpainting(self.input) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') + def test_inpainting_with_refinement(self): + # if input image is HR, set refine=True is more better + inpainting = pipeline( + Tasks.image_inpainting, model=self.model_id, refine=True) + result = inpainting(self.input) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_inpainting_with_image(self): + inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) + img = Image.open(self.input_location).convert('RGB') + mask = Image.open(self.input_mask_location).convert('RGB') + result = inpainting({'img': img, 'mask': mask}) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_inpainting_with_default_task(self): + inpainting = pipeline(Tasks.image_inpainting) + result = inpainting(self.input) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_instance_segmentation.py b/tests/pipelines/test_image_instance_segmentation.py new file mode 100644 index 00000000..2ba0724a --- /dev/null +++ b/tests/pipelines/test_image_instance_segmentation.py @@ -0,0 +1,70 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.cv.image_instance_segmentation import \ + CascadeMaskRCNNSwinModel +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.cv import ImageInstanceSegmentationPipeline +from modelscope.preprocessors import build_preprocessor +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile, Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageInstanceSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_segmentation + self.model_id = 'damo/cv_swin-b_image-instance-segmentation_coco' + + image = 'data/test/images/image_instance_segmentation.jpg' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + config_path = os.path.join(model.model_dir, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) + pipeline_ins = pipeline( + task=Tasks.image_segmentation, + model=model, + preprocessor=preprocessor) + print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.image_segmentation, model=self.model_id) + print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.image_segmentation) + print(pipeline_ins(input=self.image)[OutputKeys.LABELS]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + preprocessor = build_preprocessor(cfg.preprocessor, Fields.cv) + model = CascadeMaskRCNNSwinModel(cache_path) + pipeline1 = ImageInstanceSegmentationPipeline( + model, preprocessor=preprocessor) + pipeline2 = pipeline( + Tasks.image_segmentation, model=model, preprocessor=preprocessor) + print(f'pipeline1:{pipeline1(input=self.image)[OutputKeys.LABELS]}') + print(f'pipeline2: {pipeline2(input=self.image)[OutputKeys.LABELS]}') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py new file mode 100644 index 00000000..a3edb705 --- /dev/null +++ b/tests/pipelines/test_image_matting.py @@ -0,0 +1,70 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageMattingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_unet_image-matting' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_dataset(self): + input_location = ['data/test/images/image_matting.png'] + # alternatively: + # input_location = '/dir/to/images' + + dataset = MsDataset.load(input_location, target='image') + img_matting = pipeline(Tasks.portrait_matting, model=self.model_id) + # note that for dataset output, the inference-output is a Generator that can be iterated. + result = img_matting(dataset) + cv2.imwrite('result.png', next(result)[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + img_matting = pipeline(Tasks.portrait_matting, model=self.model_id) + + result = img_matting('data/test/images/image_matting.png') + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + img_matting = pipeline(Tasks.portrait_matting) + + result = img_matting('data/test/images/image_matting.png') + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_modelscope_dataset(self): + dataset = MsDataset.load( + 'fixtures_image_utils', + namespace='damotest', + split='test', + target='file') + img_matting = pipeline(Tasks.portrait_matting, model=self.model_id) + result = img_matting(dataset) + for i in range(2): + cv2.imwrite(f'result_{i}.png', next(result)[OutputKeys.OUTPUT_IMG]) + print( + f'Output written to dir: {osp.dirname(osp.abspath("result_0.png"))}' + ) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_panoptic_segmentation.py b/tests/pipelines/test_image_panoptic_segmentation.py new file mode 100644 index 00000000..4f12e6af --- /dev/null +++ b/tests/pipelines/test_image_panoptic_segmentation.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import cv2 +import PIL + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import panoptic_seg_masks_to_image +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImagePanopticSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_segmentation + self.model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_panoptic_segmentation(self): + input_location = 'data/test/images/image_panoptic_segmentation.jpg' + pan_segmentor = pipeline(Tasks.image_segmentation, model=self.model_id) + result = pan_segmentor(input_location) + + draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('print test_image_panoptic_segmentation return success') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_panoptic_segmentation_from_PIL(self): + input_location = 'data/test/images/image_panoptic_segmentation.jpg' + pan_segmentor = pipeline(Tasks.image_segmentation, model=self.model_id) + PIL_array = PIL.Image.open(input_location) + result = pan_segmentor(PIL_array) + + draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('print test_image_panoptic_segmentation from PIL return success') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_portrait_enhancement.py b/tests/pipelines/test_image_portrait_enhancement.py new file mode 100644 index 00000000..1ca97253 --- /dev/null +++ b/tests/pipelines/test_image_portrait_enhancement.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImagePortraitEnhancementTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_portrait_enhancement + self.model_id = 'damo/cv_gpen_image-portrait-enhancement' + self.test_image = 'data/test/images/Solvay_conference_1927.png' + + def pipeline_inference(self, pipeline: Pipeline, test_image: str): + result = pipeline(test_image) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + else: + raise Exception('Testing failed: invalid output') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_enhancement = pipeline( + Tasks.image_portrait_enhancement, model=self.model_id) + self.pipeline_inference(face_enhancement, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_enhancement = pipeline(Tasks.image_portrait_enhancement) + self.pipeline_inference(face_enhancement, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_reid_person.py b/tests/pipelines/test_image_reid_person.py new file mode 100644 index 00000000..310cdd66 --- /dev/null +++ b/tests/pipelines/test_image_reid_person.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageReidPersonTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.input_location = 'data/test/images/image_reid_person.jpg' + self.model_id = 'damo/cv_passvitb_image-reid-person_market' + self.task = Tasks.image_reid_person + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_reid_person(self): + image_reid_person = pipeline( + Tasks.image_reid_person, model=self.model_id) + result = image_reid_person(self.input_location) + assert result and OutputKeys.IMG_EMBEDDING in result + print( + f'The shape of img embedding is: {result[OutputKeys.IMG_EMBEDDING].shape}' + ) + print(f'The img embedding is: {result[OutputKeys.IMG_EMBEDDING]}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_image_reid_person_with_image(self): + image_reid_person = pipeline( + Tasks.image_reid_person, model=self.model_id) + img = Image.open(self.input_location) + result = image_reid_person(img) + assert result and OutputKeys.IMG_EMBEDDING in result + print( + f'The shape of img embedding is: {result[OutputKeys.IMG_EMBEDDING].shape}' + ) + print(f'The img embedding is: {result[OutputKeys.IMG_EMBEDDING]}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_image_reid_person_with_default_model(self): + image_reid_person = pipeline(Tasks.image_reid_person) + result = image_reid_person(self.input_location) + assert result and OutputKeys.IMG_EMBEDDING in result + print( + f'The shape of img embedding is: {result[OutputKeys.IMG_EMBEDDING].shape}' + ) + print(f'The img embedding is: {result[OutputKeys.IMG_EMBEDDING]}') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_semantic_segmentation.py b/tests/pipelines/test_image_semantic_segmentation.py new file mode 100644 index 00000000..286d317a --- /dev/null +++ b/tests/pipelines/test_image_semantic_segmentation.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import cv2 +import PIL + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import semantic_seg_masks_to_image +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageSemanticSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = 'image-segmentation' + self.model_id = 'damo/cv_swinL_semantic-segmentation_cocopanmerge' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_semantic_segmentation_panmerge(self): + input_location = 'data/test/images/image_semantic_segmentation.jpg' + segmenter = pipeline(Tasks.image_segmentation, model=self.model_id) + result = segmenter(input_location) + + draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test_image_semantic_segmentation_panmerge DONE') + + PIL_array = PIL.Image.open(input_location) + result = segmenter(PIL_array) + + draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test_image_semantic_segmentation_panmerge_from_PIL DONE') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_semantic_segmentation_vitadapter(self): + input_location = 'data/test/images/image_semantic_segmentation.jpg' + segmenter = pipeline(Tasks.image_segmentation, model=self.model_id) + result = segmenter(input_location) + + draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test_image_semantic_segmentation_vitadapter DONE') + + PIL_array = PIL.Image.open(input_location) + result = segmenter(PIL_array) + + draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) + cv2.imwrite('result.jpg', draw_img) + print('test_image_semantic_segmentation_vitadapter_from_PIL DONE') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_style_transfer.py b/tests/pipelines/test_image_style_transfer.py new file mode 100644 index 00000000..5f37f204 --- /dev/null +++ b/tests/pipelines/test_image_style_transfer.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_style_transfer + self.model_id = 'damo/cv_aams_style-transfer_damo' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + snapshot_path = snapshot_download(self.model_id) + print('snapshot_path: {}'.format(snapshot_path)) + image_style_transfer = pipeline( + Tasks.image_style_transfer, model=snapshot_path) + + result = image_style_transfer( + dict( + content='data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg')) + cv2.imwrite('result_styletransfer1.png', result[OutputKeys.OUTPUT_IMG]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + image_style_transfer = pipeline( + Tasks.image_style_transfer, model=self.model_id) + + result = image_style_transfer( + dict( + content='data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg')) + cv2.imwrite('result_styletransfer2.png', result[OutputKeys.OUTPUT_IMG]) + print('style_transfer.test_run_modelhub done') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + image_style_transfer = pipeline(Tasks.image_style_transfer) + + result = image_style_transfer( + dict( + content='data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg')) + cv2.imwrite('result_styletransfer3.png', result[OutputKeys.OUTPUT_IMG]) + print('style_transfer.test_run_modelhub_default_model done') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_super_resolution.py b/tests/pipelines/test_image_super_resolution.py new file mode 100644 index 00000000..d5cbebe8 --- /dev/null +++ b/tests/pipelines/test_image_super_resolution.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageSuperResolutionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_rrdb_image-super-resolution' + self.img = 'data/test/images/dogs.jpg' + self.task = Tasks.image_super_resolution + + def pipeline_inference(self, pipeline: Pipeline, img: str): + result = pipeline(img) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + super_resolution = pipeline( + Tasks.image_super_resolution, model=self.model_id) + + self.pipeline_inference(super_resolution, self.img) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + super_resolution = pipeline(Tasks.image_super_resolution) + self.pipeline_inference(super_resolution, self.img) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py new file mode 100644 index 00000000..f31d212b --- /dev/null +++ b/tests/pipelines/test_key_word_spotting.py @@ -0,0 +1,305 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest +from typing import Any, Dict, List, Union + +import numpy as np +import soundfile + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ColorCodes, Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import download_and_untar, test_level + +logger = get_logger() + +POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' +BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' +URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/20200707_xiaoyun.wav' + +POS_TESTSETS_FILE = 'pos_testsets.tar.gz' +POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' + +NEG_TESTSETS_FILE = 'neg_testsets.tar.gz' +NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz' + + +class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): + action_info = { + 'test_run_with_wav': { + 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], + 'checking_value': '小云小云', + 'example': { + 'wav_count': + 1, + 'kws_type': + 'wav', + 'kws_list': [{ + 'keyword': '小云小云', + 'offset': 5.76, + 'length': 9.132938, + 'confidence': 0.990368 + }] + } + }, + 'test_run_with_pcm': { + 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], + 'checking_value': '小云小云', + 'example': { + 'wav_count': + 1, + 'kws_type': + 'pcm', + 'kws_list': [{ + 'keyword': '小云小云', + 'offset': 5.76, + 'length': 9.132938, + 'confidence': 0.990368 + }] + } + }, + 'test_run_with_wav_by_customized_keywords': { + 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], + 'checking_value': '播放音乐', + 'example': { + 'wav_count': + 1, + 'kws_type': + 'wav', + 'kws_list': [{ + 'keyword': '播放音乐', + 'offset': 0.87, + 'length': 2.158313, + 'confidence': 0.646237 + }] + } + }, + 'test_run_with_url': { + 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], + 'checking_value': '小云小云', + 'example': { + 'wav_count': + 1, + 'kws_type': + 'pcm', + 'kws_list': [{ + 'keyword': '小云小云', + 'offset': 0.69, + 'length': 1.67, + 'confidence': 0.996023 + }] + } + }, + 'test_run_with_pos_testsets': { + 'checking_item': ['recall'], + 'example': { + 'wav_count': 450, + 'kws_type': 'pos_testsets', + 'wav_time': 3013.75925, + 'keywords': ['小云小云'], + 'recall': 0.953333, + 'detected_count': 429, + 'rejected_count': 21, + 'rejected': ['yyy.wav', 'zzz.wav'] + } + }, + 'test_run_with_neg_testsets': { + 'checking_item': ['fa_rate'], + 'example': { + 'wav_count': + 751, + 'kws_type': + 'neg_testsets', + 'wav_time': + 3572.180813, + 'keywords': ['小云小云'], + 'fa_rate': + 0.001332, + 'fa_per_hour': + 1.007788, + 'detected_count': + 1, + 'rejected_count': + 750, + 'detected': [{ + '6.wav': { + 'confidence': '0.321170', + 'keyword': '小云小云' + } + }] + } + }, + 'test_run_with_roc': { + 'checking_item': ['keywords', 0], + 'checking_value': '小云小云', + 'example': { + 'kws_type': + 'roc', + 'keywords': ['小云小云'], + '小云小云': [{ + 'threshold': 0.0, + 'recall': 0.953333, + 'fa_per_hour': 1.007788 + }, { + 'threshold': 0.001, + 'recall': 0.953333, + 'fa_per_hour': 1.007788 + }, { + 'threshold': 0.999, + 'recall': 0.004444, + 'fa_per_hour': 0.0 + }] + } + } + } + + def setUp(self) -> None: + self.model_id = 'damo/speech_charctc_kws_phone-xiaoyun' + self.workspace = os.path.join(os.getcwd(), '.tmp') + if not os.path.exists(self.workspace): + os.mkdir(self.workspace) + + def tearDown(self) -> None: + # remove workspace dir (.tmp) + shutil.rmtree(self.workspace, ignore_errors=True) + + def run_pipeline(self, + model_id: str, + audio_in: Union[List[str], str, bytes], + keywords: List[str] = None) -> Dict[str, Any]: + kwsbp_16k_pipline = pipeline( + task=Tasks.keyword_spotting, model=model_id) + + kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords) + + return kws_result + + def log_error(self, functions: str, result: Dict[str, Any]) -> None: + logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' + + ColorCodes.END) + logger.error(ColorCodes.MAGENTA + functions + + ' correct result example: ' + ColorCodes.YELLOW + + str(self.action_info[functions]['example']) + + ColorCodes.END) + + raise ValueError('kws result is mismatched') + + def check_result(self, functions: str, result: Dict[str, Any]) -> None: + result_item = result + check_list = self.action_info[functions]['checking_item'] + for check_item in check_list: + result_item = result_item[check_item] + if result_item is None or result_item == 'None': + self.log_error(functions, result) + + if self.action_info[functions].__contains__('checking_value'): + check_value = self.action_info[functions]['checking_value'] + if result_item != check_value: + self.log_error(functions, result) + + logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' + + ColorCodes.END) + if functions == 'test_run_with_roc': + find_keyword = result['keywords'][0] + keyword_list = result[find_keyword] + for item in iter(keyword_list): + threshold: float = item['threshold'] + recall: float = item['recall'] + fa_per_hour: float = item['fa_per_hour'] + logger.info(ColorCodes.YELLOW + ' threshold:' + str(threshold) + + ' recall:' + str(recall) + ' fa_per_hour:' + + str(fa_per_hour) + ColorCodes.END) + else: + logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END) + + def wav2bytes(self, wav_file) -> bytes: + audio, fs = soundfile.read(wav_file) + + # float32 -> int16 + audio = np.asarray(audio) + dtype = np.dtype('int16') + i = np.iinfo(dtype) + abs_max = 2**(i.bits - 1) + offset = i.min + abs_max + audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) + + # int16(PCM_16) -> byte + audio = audio.tobytes() + return audio + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav(self): + kws_result = self.run_pipeline( + model_id=self.model_id, audio_in=POS_WAV_FILE) + self.check_result('test_run_with_wav', kws_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_pcm(self): + audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE)) + + kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio) + self.check_result('test_run_with_pcm', kws_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav_by_customized_keywords(self): + keywords = '播放音乐' + + kws_result = self.run_pipeline( + model_id=self.model_id, + audio_in=BOFANGYINYUE_WAV_FILE, + keywords=keywords) + self.check_result('test_run_with_wav_by_customized_keywords', + kws_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_url(self): + kws_result = self.run_pipeline( + model_id=self.model_id, audio_in=URL_FILE) + self.check_result('test_run_with_url', kws_result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_pos_testsets(self): + wav_file_path = download_and_untar( + os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, + self.workspace) + audio_list = [wav_file_path, None] + + kws_result = self.run_pipeline( + model_id=self.model_id, audio_in=audio_list) + self.check_result('test_run_with_pos_testsets', kws_result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_neg_testsets(self): + wav_file_path = download_and_untar( + os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, + self.workspace) + audio_list = [None, wav_file_path] + + kws_result = self.run_pipeline( + model_id=self.model_id, audio_in=audio_list) + self.check_result('test_run_with_neg_testsets', kws_result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_roc(self): + pos_file_path = download_and_untar( + os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, + self.workspace) + neg_file_path = download_and_untar( + os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, + self.workspace) + audio_list = [pos_file_path, neg_file_path] + + kws_result = self.run_pipeline( + model_id=self.model_id, audio_in=audio_list) + self.check_result('test_run_with_roc', kws_result) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py new file mode 100644 index 00000000..69d6a953 --- /dev/null +++ b/tests/pipelines/test_key_word_spotting_farfield.py @@ -0,0 +1,53 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + +TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' +TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav' +TEST_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ + 'test/audios/3ch_nihaomiya.wav' + + +class KWSFarfieldTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_normal(self): + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + result = kws(os.path.join(os.getcwd(), TEST_SPEECH_FILE)) + self.assertEqual(len(result['kws_list']), 5) + print(result['kws_list'][-1]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_mono(self): + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + result = kws(os.path.join(os.getcwd(), TEST_SPEECH_FILE_MONO)) + self.assertEqual(len(result['kws_list']), 5) + print(result['kws_list'][-1]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_url(self): + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + result = kws(TEST_SPEECH_URL) + self.assertEqual(len(result['kws_list']), 5) + print(result['kws_list'][-1]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_input_bytes(self): + with open(os.path.join(os.getcwd(), TEST_SPEECH_FILE), 'rb') as f: + data = f.read() + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + result = kws(data) + self.assertEqual(len(result['kws_list']), 5) + print(result['kws_list'][-1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_live_category.py b/tests/pipelines/test_live_category.py new file mode 100644 index 00000000..391ed283 --- /dev/null +++ b/tests/pipelines/test_live_category.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class LiveCategoryTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.live_category + self.model_id = 'damo/cv_resnet50_live-category' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + category_pipeline = pipeline(Tasks.live_category, self.model_id) + result = category_pipeline( + 'data/test/videos/live_category_test_video.mp4') + + print(f'live category output: {result}.') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_mglm_text_summarization.py b/tests/pipelines/test_mglm_text_summarization.py new file mode 100644 index 00000000..47abc741 --- /dev/null +++ b/tests/pipelines/test_mglm_text_summarization.py @@ -0,0 +1,47 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.preprocessors import MGLMSummarizationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class mGLMTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.output_dir = 'unittest_output' + os.makedirs(self.output_dir, exist_ok=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_mglm_with_name(self): + model = 'ZhipuAI/Multilingual-GLM-Summarization-zh' + preprocessor = MGLMSummarizationPreprocessor() + pipe = pipeline( + task=Tasks.text_summarization, + model=model, + preprocessor=preprocessor, + ) + result = pipe( + '据中国载人航天工程办公室消息,北京时间2022年10月25日,梦天实验舱与长征五号B遥四运载火箭组合体已转运至发射区。后续将按计划开展发射前各项功能检查和联合测试等工作,计划于近日择机实施发射。目前,文昌航天发射场设施设备状态良好,参试各单位正在加紧开展任务准备,全力以赴确保空间站建造任务决战决胜。' # noqa + ) + print(result) + + model = 'ZhipuAI/Multilingual-GLM-Summarization-en' + preprocessor = MGLMSummarizationPreprocessor() + pipe = pipeline( + task=Tasks.text_summarization, + model=model, + preprocessor=preprocessor, + ) + result = pipe( + '据中国载人航天工程办公室消息,北京时间2022年10月25日,梦天实验舱与长征五号B遥四运载火箭组合体已转运至发射区。后续将按计划开展发射前各项功能检查和联合测试等工作,计划于近日择机实施发射。目前,文昌航天发射场设施设备状态良好,参试各单位正在加紧开展任务准备,全力以赴确保空间站建造任务决战决胜。' # noqa + ) + print(result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_mog_face_detection.py b/tests/pipelines/test_mog_face_detection.py new file mode 100644 index 00000000..5c6d97c2 --- /dev/null +++ b/tests/pipelines/test_mog_face_detection.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result +from modelscope.utils.test_utils import test_level + + +class MogFaceDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet101_face-detection_cvpr22papermogface' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_no_lm_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/mog_face_detection.jpg' + + result = face_detection(img_path) + self.show_result(img_path, result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_movie_scene_segmentation.py b/tests/pipelines/test_movie_scene_segmentation.py new file mode 100644 index 00000000..affd5140 --- /dev/null +++ b/tests/pipelines/test_movie_scene_segmentation.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class MovieSceneSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.movie_scene_segmentation + self.model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_movie_scene_segmentation(self): + input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4' + movie_scene_segmentation_pipeline = pipeline( + Tasks.movie_scene_segmentation, model=self.model_id) + result = movie_scene_segmentation_pipeline(input_location) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_movie_scene_segmentation_with_default_task(self): + input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4' + movie_scene_segmentation_pipeline = pipeline( + Tasks.movie_scene_segmentation) + result = movie_scene_segmentation_pipeline(input_location) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_mplug_tasks.py b/tests/pipelines/test_mplug_tasks.py new file mode 100644 index 00000000..21439ce2 --- /dev/null +++ b/tests/pipelines/test_mplug_tasks.py @@ -0,0 +1,104 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from PIL import Image + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_captioning_with_model(self): + model = Model.from_pretrained( + 'damo/mplug_image-captioning_coco_base_en') + pipeline_caption = pipeline( + task=Tasks.image_captioning, + model=model, + ) + image = Image.open('data/test/images/image_mplug_vqa.jpg') + result = pipeline_caption(image) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_image_captioning_with_name(self): + pipeline_caption = pipeline( + Tasks.image_captioning, + model='damo/mplug_image-captioning_coco_base_en') + image = Image.open('data/test/images/image_mplug_vqa.jpg') + result = pipeline_caption(image) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_question_answering_with_model(self): + model = Model.from_pretrained( + 'damo/mplug_visual-question-answering_coco_large_en') + pipeline_vqa = pipeline(Tasks.visual_question_answering, model=model) + image = Image.open('data/test/images/image_mplug_vqa.jpg') + text = 'What is the woman doing?' + input = {'image': image, 'text': text} + result = pipeline_vqa(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_visual_question_answering_with_name(self): + model = 'damo/mplug_visual-question-answering_coco_large_en' + pipeline_vqa = pipeline(Tasks.visual_question_answering, model=model) + image = Image.open('data/test/images/image_mplug_vqa.jpg') + text = 'What is the woman doing?' + input = {'image': image, 'text': text} + result = pipeline_vqa(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_text_retrieval_with_model(self): + model = Model.from_pretrained( + 'damo/mplug_image-text-retrieval_flickr30k_large_en') + pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model) + image = Image.open('data/test/images/image-text-retrieval.jpg') + text = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.' + input = {'image': image, 'text': text} + result = pipeline_retrieval(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_image_text_retrieval_with_name(self): + model = 'damo/mplug_image-text-retrieval_flickr30k_large_en' + pipeline_retrieval = pipeline(Tasks.image_text_retrieval, model=model) + image = Image.open('data/test/images/image-text-retrieval.jpg') + text = 'Two young guys with shaggy hair look at their hands while hanging out in the yard.' + input = {'image': image, 'text': text} + result = pipeline_retrieval(input) + print(result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_image_captioning_zh_base_with_name(self): + pipeline_caption = pipeline( + Tasks.image_captioning, + model='damo/mplug_image-captioning_coco_base_zh') + image = Image.open('data/test/images/image_mplug_vqa.jpg') + result = pipeline_caption(image) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_visual_question_answering_zh_base_with_name(self): + model = 'damo/mplug_visual-question-answering_coco_base_zh' + pipeline_vqa = pipeline(Tasks.visual_question_answering, model=model) + image = Image.open('data/test/images/image_mplug_vqa.jpg') + text = '这个女人在做什么?' + input = {'image': image, 'text': text} + result = pipeline_vqa(input) + print(result) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_mtcnn_face_detection.py b/tests/pipelines/test_mtcnn_face_detection.py new file mode 100644 index 00000000..5afb5588 --- /dev/null +++ b/tests/pipelines/test_mtcnn_face_detection.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 +from PIL import Image + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_result +from modelscope.utils.test_utils import test_level + + +class MtcnnFaceDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_manual_face-detection_mtcnn' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/mtcnn_face_detection.jpg' + img = Image.open(img_path) + + result_1 = face_detection(img_path) + self.show_result(img_path, result_1) + + result_2 = face_detection(img) + self.show_result(img_path, result_2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_multi_modal_embedding.py b/tests/pipelines/test_multi_modal_embedding.py new file mode 100644 index 00000000..7eddc690 --- /dev/null +++ b/tests/pipelines/test_multi_modal_embedding.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import torch + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class MultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.multi_modal_embedding + self.model_id = 'damo/multi-modal_clip-vit-base-patch16_zh' + + test_input = {'text': '皮卡丘'} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + pipeline_multi_modal_embedding = pipeline( + Tasks.multi_modal_embedding, model=self.model_id) + text_embedding = pipeline_multi_modal_embedding.forward( + self.test_input)[OutputKeys.TEXT_EMBEDDING] + print('l1-norm: {}'.format( + torch.norm(text_embedding, p=1, dim=-1).item())) + print('l2-norm: {}'.format(torch.norm(text_embedding, + dim=-1).item())) # should be 1.0 + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + pipeline_multi_modal_embedding = pipeline( + task=Tasks.multi_modal_embedding, model=model) + text_embedding = pipeline_multi_modal_embedding.forward( + self.test_input)[OutputKeys.TEXT_EMBEDDING] + print('l1-norm: {}'.format( + torch.norm(text_embedding, p=1, dim=-1).item())) + print('l2-norm: {}'.format(torch.norm(text_embedding, + dim=-1).item())) # should be 1.0 + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_multi_modal_embedding = pipeline( + task=Tasks.multi_modal_embedding) + text_embedding = pipeline_multi_modal_embedding.forward( + self.test_input)[OutputKeys.TEXT_EMBEDDING] + print('l1-norm: {}'.format( + torch.norm(text_embedding, p=1, dim=-1).item())) + print('l2-norm: {}'.format(torch.norm(text_embedding, + dim=-1).item())) # should be 1.0 + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_multi_modal_similarity.py b/tests/pipelines/test_multi_modal_similarity.py new file mode 100644 index 00000000..a54fbcf0 --- /dev/null +++ b/tests/pipelines/test_multi_modal_similarity.py @@ -0,0 +1,48 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import unittest + +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MultiModalSimilarityTest(unittest.TestCase): + model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' + test_img = 'data/test/images/multimodal_similarity.jpg' + test_str1 = '一个上了年纪的女人在城镇中骑着自行车一个黄色出租车正要从她身边驶过' + test_str2 = '穿着蓝色连衣裙的那个女人正冲着行来的车辆伸出她的手' + + def infer_pipeline(self, multi_modal_similarity_pipeline): + test_input1 = {'img': self.test_img, 'text': self.test_str1} + test_input2 = {'img': self.test_img, 'text': self.test_str2} + output1 = multi_modal_similarity_pipeline(test_input1) + output2 = multi_modal_similarity_pipeline(test_input2) + print('image: {}, text: {}, similarity: {}'.format( + self.test_img, self.test_str1, output1['scores'])) + print('image: {}, text: {}, similarity: {}'.format( + self.test_img, self.test_str2, output2['scores'])) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + multi_modal_similarity_pipeline = pipeline( + Tasks.multi_modal_similarity, model=self.model_id) + self.infer_pipeline(multi_modal_similarity_pipeline) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + multi_modal_similarity_pipeline = pipeline( + task=Tasks.multi_modal_similarity) + self.infer_pipeline(multi_modal_similarity_pipeline) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + multi_modal_similarity_pipeline = pipeline( + task=Tasks.multi_modal_similarity, model=model) + self.infer_pipeline(multi_modal_similarity_pipeline) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_multi_stage_diffusion.py b/tests/pipelines/test_multi_stage_diffusion.py new file mode 100644 index 00000000..f4e63ce0 --- /dev/null +++ b/tests/pipelines/test_multi_stage_diffusion.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np +import torch + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MultiStageDiffusionTest(unittest.TestCase): + model_id = 'damo/cv_diffusion_text-to-image-synthesis' + test_text = {'text': 'Photograph of a baby chicken wearing sunglasses'} + + @unittest.skip( + 'skip test since the pretrained model is not publicly available') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis, model=model) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + @unittest.skip( + 'skip test since the pretrained model is not publicly available') + def test_run_with_model_name(self): + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis, model=self.model_id) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_multilingual_named_entity_recognition.py b/tests/pipelines/test_multilingual_named_entity_recognition.py new file mode 100644 index 00000000..cb2b32d6 --- /dev/null +++ b/tests/pipelines/test_multilingual_named_entity_recognition.py @@ -0,0 +1,112 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition, + TransformerCRFForNamedEntityRecognition) +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import (NamedEntityRecognitionThaiPipeline, + NamedEntityRecognitionVietPipeline) +from modelscope.preprocessors import NERPreprocessorThai, NERPreprocessorViet +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class MultilingualNamedEntityRecognitionTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.named_entity_recognition + self.model_id = 'damo/nlp_xlmr_named-entity-recognition_thai-ecommerce-title' + + thai_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_thai-ecommerce-title' + thai_sentence = 'เครื่องชั่งดิจิตอลแบบตั้งพื้น150kg.' + + viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title' + viet_sentence = 'Nón vành dễ thương cho bé gái' + + multilingual_model_id = 'damo/nlp_raner_named-entity-recognition_multilingual-large-generic' + ml_stc = 'সমস্ত বেতন নিলামের সাধারণ ব্যবহারিক উদাহরণ বিভিন্ন পেনি নিলাম / বিডিং ফি নিলাম ওয়েবসাইটে পাওয়া যাবে।' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_tcrf_by_direct_model_download_thai(self): + cache_path = snapshot_download(self.thai_tcrf_model_id) + tokenizer = NERPreprocessorThai(cache_path) + model = TransformerCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionThaiPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'thai_sentence: {self.thai_sentence}\n' + f'pipeline1:{pipeline1(input=self.thai_sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.thai_sentence)}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_tcrf_with_model_from_modelhub_thai(self): + model = Model.from_pretrained(self.thai_tcrf_model_id) + tokenizer = NERPreprocessorThai(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.thai_sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_tcrf_with_model_name_thai(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id) + print(pipeline_ins(input=self.thai_sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_tcrf_with_model_name_multilingual(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=self.multilingual_model_id) + print(pipeline_ins(input=self.ml_stc)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_tcrf_by_direct_model_download_viet(self): + cache_path = snapshot_download(self.viet_tcrf_model_id) + tokenizer = NERPreprocessorViet(cache_path) + model = TransformerCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionVietPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'viet_sentence: {self.viet_sentence}\n' + f'pipeline1:{pipeline1(input=self.viet_sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.viet_sentence)}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_tcrf_with_model_from_modelhub_viet(self): + model = Model.from_pretrained(self.viet_tcrf_model_id) + tokenizer = NERPreprocessorViet(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.viet_sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_tcrf_with_model_name_viet(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.viet_tcrf_model_id) + print(pipeline_ins(input=self.viet_sentence)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_multilingual_word_segmentation.py b/tests/pipelines/test_multilingual_word_segmentation.py new file mode 100644 index 00000000..25b4b241 --- /dev/null +++ b/tests/pipelines/test_multilingual_word_segmentation.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import TransformerCRFForWordSegmentation +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import WordSegmentationThaiPipeline +from modelscope.preprocessors import WordSegmentationPreprocessorThai +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.regress_test_utils import MsRegressTool +from modelscope.utils.test_utils import test_level + + +class WordSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.word_segmentation + self.model_id = 'damo/nlp_xlmr_word-segmentation_thai' + + sentence = 'รถคันเก่าก็ยังเก็บเอาไว้ยังไม่ได้ขาย' + regress_tool = MsRegressTool(baseline=False) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = WordSegmentationPreprocessorThai(cache_path) + model = TransformerCRFForWordSegmentation.from_pretrained(cache_path) + pipeline1 = WordSegmentationThaiPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = WordSegmentationPreprocessorThai(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=self.model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_named_entity_recognition.py b/tests/pipelines/test_named_entity_recognition.py new file mode 100644 index 00000000..0df44f5b --- /dev/null +++ b/tests/pipelines/test_named_entity_recognition.py @@ -0,0 +1,119 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition, + TransformerCRFForNamedEntityRecognition) +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import NamedEntityRecognitionPipeline +from modelscope.preprocessors import TokenClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.named_entity_recognition + self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' + + english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom' + chinese_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-large-generic' + tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' + lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' + sentence = '这与温岭市新河镇的一个神秘的传说有关。' + sentence_en = 'pizza shovel' + sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_tcrf_by_direct_model_download(self): + cache_path = snapshot_download(self.tcrf_model_id) + tokenizer = TokenClassificationPreprocessor(cache_path) + model = TransformerCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_lcrf_by_direct_model_download(self): + cache_path = snapshot_download(self.lcrf_model_id) + tokenizer = TokenClassificationPreprocessor(cache_path) + model = LSTMCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_tcrf_with_model_from_modelhub(self): + model = Model.from_pretrained(self.tcrf_model_id) + tokenizer = TokenClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_lcrf_with_model_from_modelhub(self): + model = Model.from_pretrained(self.lcrf_model_id) + tokenizer = TokenClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_tcrf_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.tcrf_model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_lcrf_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.lcrf_model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_lcrf_with_chinese_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.chinese_model_id) + print(pipeline_ins(input=self.sentence_zh)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_english_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.english_model_id) + print(pipeline_ins(input=self.sentence_en)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.named_entity_recognition) + print(pipeline_ins(input=self.sentence)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_nli.py b/tests/pipelines/test_nli.py new file mode 100644 index 00000000..9e9fefea --- /dev/null +++ b/tests/pipelines/test_nli.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextClassificationPipeline +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool +from modelscope.utils.test_utils import test_level + + +class NLITest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.nli + self.model_id = 'damo/nlp_structbert_nli_chinese-base' + + sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' + sentence2 = '四川商务职业学院商务管理在哪个校区?' + regress_tool = MsRegressTool(baseline=False) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_direct_file_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = SequenceClassificationPreprocessor(cache_path) + model = Model.from_pretrained(cache_path) + pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer) + print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' + f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') + print( + f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' + f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = SequenceClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.nli, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, + 'sbert_nli', + compare_fn=IgnoreKeyFn('.*intermediate_act_fn')): + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.nli) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_object_detection.py b/tests/pipelines/test_object_detection.py new file mode 100644 index 00000000..64766c77 --- /dev/null +++ b/tests/pipelines/test_object_detection.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.human_detection + self.model_id = 'damo/cv_resnet18_human-detection' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_object_detection(self): + input_location = 'data/test/images/image_detection.jpg' + model_id = 'damo/cv_vit_object-detection_coco' + object_detect = pipeline(Tasks.image_object_detection, model=model_id) + result = object_detect(input_location) + print(result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_object_detection_with_default_task(self): + input_location = 'data/test/images/image_detection.jpg' + object_detect = pipeline(Tasks.image_object_detection) + result = object_detect(input_location) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_human_detection(self): + input_location = 'data/test/images/image_detection.jpg' + model_id = 'damo/cv_resnet18_human-detection' + human_detect = pipeline(Tasks.human_detection, model=model_id) + result = human_detect(input_location) + print(result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_human_detection_with_default_task(self): + input_location = 'data/test/images/image_detection.jpg' + human_detect = pipeline(Tasks.human_detection) + result = human_detect(input_location) + print(result) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_object_detection_auto_pipeline(self): + model_id = 'damo/cv_yolox_image-object-detection-auto' + test_image = 'data/test/images/auto_demo.jpg' + + image_object_detection_auto = pipeline( + Tasks.image_object_detection, model=model_id) + + result = image_object_detection_auto(test_image) + image_object_detection_auto.show_result(test_image, result, + 'auto_demo_ret.jpg') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_ocr_detection.py b/tests/pipelines/test_ocr_detection.py new file mode 100644 index 00000000..e0591496 --- /dev/null +++ b/tests/pipelines/test_ocr_detection.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class OCRDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet18_ocr-detection-line-level_damo' + self.test_image = 'data/test/images/ocr_detection.jpg' + self.task = Tasks.ocr_detection + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + print('ocr detection results: ') + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + ocr_detection = pipeline(Tasks.ocr_detection, model=self.model_id) + self.pipeline_inference(ocr_detection, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + ocr_detection = pipeline(Tasks.ocr_detection) + self.pipeline_inference(ocr_detection, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_ocr_recognition.py b/tests/pipelines/test_ocr_recognition.py new file mode 100644 index 00000000..8d48dd7a --- /dev/null +++ b/tests/pipelines/test_ocr_recognition.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import PIL + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class OCRRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_convnextTiny_ocr-recognition-general_damo' + self.test_image = 'data/test/images/ocr_recognition.jpg' + self.task = Tasks.ocr_recognition + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + print('ocr recognition results: ', result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + ocr_recognition = pipeline(Tasks.ocr_recognition, model=self.model_id) + self.pipeline_inference(ocr_recognition, self.test_image) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub_PILinput(self): + ocr_recognition = pipeline(Tasks.ocr_recognition, model=self.model_id) + imagePIL = PIL.Image.open(self.test_image) + self.pipeline_inference(ocr_recognition, imagePIL) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + ocr_recognition = pipeline(Tasks.ocr_recognition) + self.pipeline_inference(ocr_recognition, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py new file mode 100644 index 00000000..6be70468 --- /dev/null +++ b/tests/pipelines/test_ofa_tasks.py @@ -0,0 +1,269 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest +from os import path as osp + +import cv2 +from PIL import Image + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import created_boxed_image +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.output_dir = 'unittest_output' + os.makedirs(self.output_dir, exist_ok=True) + + def save_img(self, image_in, box, image_out): + cv2.imwrite( + osp.join(self.output_dir, image_out), + created_boxed_image(image_in, box)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_captioning_with_model(self): + model = Model.from_pretrained('damo/ofa_image-caption_coco_large_en') + img_captioning = pipeline( + task=Tasks.image_captioning, + model=model, + ) + image = 'data/test/images/image_captioning.png' + result = img_captioning(image) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_image_captioning_with_name(self): + img_captioning = pipeline( + Tasks.image_captioning, + model='damo/ofa_image-caption_coco_large_en') + result = img_captioning('data/test/images/image_captioning.png') + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_ocr_recognize_with_name(self): + ocr_recognize = pipeline( + Tasks.ocr_recognition, + model='damo/ofa_ocr-recognition_scene_base_zh') + result = ocr_recognize('data/test/images/image_ocr_recognition.jpg') + print(result[OutputKeys.TEXT]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_classification_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_image-classification_imagenet_large_en') + ofa_pipe = pipeline(Tasks.image_classification, model=model) + image = 'data/test/images/image_classification.png' + result = ofa_pipe(image) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_image_classification_with_name(self): + ofa_pipe = pipeline( + Tasks.image_classification, + model='damo/ofa_image-classification_imagenet_large_en') + image = 'data/test/images/image_classification.png' + result = ofa_pipe(image) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_summarization_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_summarization_gigaword_large_en') + ofa_pipe = pipeline(Tasks.text_summarization, model=model) + text = 'five-time world champion michelle kwan withdrew' + \ + 'from the #### us figure skating championships on wednesday ,' + \ + ' but will petition us skating officials for the chance to ' + \ + 'compete at the #### turin olympics .' + input = {'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_summarization_with_name(self): + ofa_pipe = pipeline( + Tasks.text_summarization, + model='damo/ofa_summarization_gigaword_large_en') + text = 'five-time world champion michelle kwan withdrew' + \ + 'from the #### us figure skating championships on wednesday ,' + \ + ' but will petition us skating officials for the chance to ' +\ + 'compete at the #### turin olympics .' + input = {'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_text_classification_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_text-classification_mnli_large_en') + ofa_pipe = pipeline(Tasks.text_classification, model=model) + text = 'One of our number will carry out your instructions minutely.' + text2 = 'A member of my team will execute your orders with immense precision.' + result = ofa_pipe((text, text2)) + result = ofa_pipe({'text': text, 'text2': text2}) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_text_classification_with_name(self): + ofa_pipe = pipeline( + Tasks.text_classification, + model='damo/ofa_text-classification_mnli_large_en') + text = 'One of our number will carry out your instructions minutely.' + text2 = 'A member of my team will execute your orders with immense precision.' + result = ofa_pipe((text, text2)) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_entailment_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_visual-entailment_snli-ve_large_en') + ofa_pipe = pipeline(Tasks.visual_entailment, model=model) + image = 'data/test/images/dogs.jpg' + text = 'there are two birds.' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_visual_entailment_with_name(self): + ofa_pipe = pipeline( + Tasks.visual_entailment, + model='damo/ofa_visual-entailment_snli-ve_large_en') + image = 'data/test/images/dogs.jpg' + text = 'there are two birds.' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_grounding_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_visual-grounding_refcoco_large_en') + ofa_pipe = pipeline(Tasks.visual_grounding, model=model) + image = 'data/test/images/visual_grounding.png' + text = 'a blue turtle-like pokemon with round head' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + image_name = image.split('/')[-2] + self.save_img( + image, + result[OutputKeys.BOXES][0], # just one box + osp.join('large_en_model_' + image_name + '.png')) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_grounding_with_name(self): + ofa_pipe = pipeline( + Tasks.visual_grounding, + model='damo/ofa_visual-grounding_refcoco_large_en') + image = 'data/test/images/visual_grounding.png' + text = 'a blue turtle-like pokemon with round head' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + image_name = image.split('/')[-2] + self.save_img(image, result[OutputKeys.BOXES][0], + osp.join('large_en_name_' + image_name + '.png')) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_visual_grounding_zh_with_name(self): + model = 'damo/ofa_visual-grounding_refcoco_large_zh' + ofa_pipe = pipeline(Tasks.visual_grounding, model=model) + image = 'data/test/images/visual_grounding.png' + text = '一个圆头的蓝色宝可梦' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + image_name = image.split('/')[-1] + self.save_img(image, result[OutputKeys.BOXES][0], + osp.join('large_zh_name_' + image_name)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_question_answering_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_visual-question-answering_pretrain_large_en') + ofa_pipe = pipeline(Tasks.visual_question_answering, model=model) + image = 'data/test/images/visual_question_answering.png' + text = 'what is grown on the plant?' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_visual_question_answering_with_name(self): + model = 'damo/ofa_visual-question-answering_pretrain_large_en' + ofa_pipe = pipeline(Tasks.visual_question_answering, model=model) + image = 'data/test/images/visual_question_answering.png' + text = 'what is grown on the plant?' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_captioning_distilled_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_image-caption_coco_distilled_en') + img_captioning = pipeline( + task=Tasks.image_captioning, + model=model, + ) + image_path = 'data/test/images/image_captioning.png' + image = Image.open(image_path) + result = img_captioning(image) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_visual_entailment_distilled_model_with_name(self): + ofa_pipe = pipeline( + Tasks.visual_entailment, + model='damo/ofa_visual-entailment_snli-ve_distilled_v2_en') + image = 'data/test/images/dogs.jpg' + text = 'there are two birds.' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_grounding_distilled_model_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_visual-grounding_refcoco_distilled_en') + ofa_pipe = pipeline(Tasks.visual_grounding, model=model) + image = 'data/test/images/visual_grounding.png' + text = 'a blue turtle-like pokemon with round head' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_text_to_image_synthesis_with_name(self): + model = 'damo/ofa_text-to-image-synthesis_coco_large_en' + ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) + ofa_pipe.model.generator.beam_size = 2 + example = {'text': 'a bear in the water.'} + result = ofa_pipe(example) + result[OutputKeys.OUTPUT_IMG].save('result.png') + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_text_to_image_synthesis_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_text-to-image-synthesis_coco_large_en') + ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) + ofa_pipe.model.generator.beam_size = 2 + example = {'text': 'a bear in the water.'} + result = ofa_pipe(example) + result[OutputKeys.OUTPUT_IMG].save('result.png') + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_part_of_speech.py b/tests/pipelines/test_part_of_speech.py new file mode 100644 index 00000000..038a90f0 --- /dev/null +++ b/tests/pipelines/test_part_of_speech.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import shutil +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import TokenClassificationModel +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TokenClassificationPipeline +from modelscope.preprocessors import TokenClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class PartOfSpeechTest(unittest.TestCase): + model_id = 'damo/nlp_structbert_part-of-speech_chinese-lite' + sentence = '今天天气不错,适合出去游玩' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = TokenClassificationPreprocessor(cache_path) + model = TokenClassificationModel.from_pretrained(cache_path) + pipeline1 = TokenClassificationPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.part_of_speech, model=model, preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = TokenClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.part_of_speech, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline(task=Tasks.part_of_speech, model=self.model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.part_of_speech) + print(pipeline_ins(input=self.sentence)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_person_image_cartoon.py b/tests/pipelines/test_person_image_cartoon.py new file mode 100644 index 00000000..b8549f4f --- /dev/null +++ b/tests/pipelines/test_person_image_cartoon.py @@ -0,0 +1,73 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageCartoonTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models' + self.model_id_3d = 'damo/cv_unet_person-image-cartoon-3d_compound-models' + self.model_id_handdrawn = 'damo/cv_unet_person-image-cartoon-handdrawn_compound-models' + self.model_id_sketch = 'damo/cv_unet_person-image-cartoon-sketch_compound-models' + self.model_id_artstyle = 'damo/cv_unet_person-image-cartoon-artstyle_compound-models' + self.task = Tasks.image_portrait_stylization + self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id) + self.pipeline_inference(img_cartoon, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_3d(self): + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id_3d) + self.pipeline_inference(img_cartoon, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_handdrawn(self): + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id_handdrawn) + self.pipeline_inference(img_cartoon, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_sketch(self): + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id_sketch) + self.pipeline_inference(img_cartoon, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_artstyle(self): + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id_artstyle) + self.pipeline_inference(img_cartoon, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + img_cartoon = pipeline(Tasks.image_portrait_stylization) + self.pipeline_inference(img_cartoon, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_plug_text_generation.py b/tests/pipelines/test_plug_text_generation.py new file mode 100644 index 00000000..90b48efa --- /dev/null +++ b/tests/pipelines/test_plug_text_generation.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks + + +class TextPlugGenerationTest(unittest.TestCase): + + def setUp(self) -> None: + # please make sure this local path exists. + self.model_id = 'damo/nlp_plug_text-generation_27B' + self.model_dir = snapshot_download(self.model_id) + self.plug_input = '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。"' + + @unittest.skip('distributed plug, skipped') + def test_plug(self): + """ The model can be downloaded from the link on + https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. + After downloading, you should have a plug model structure like this: + nlp_plug_text-generation_27B + |_ config.json + |_ configuration.json + |_ ds_zero-offload_10B_config.json + |_ vocab.txt + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + """ + # download model binaries to /model + pipe = pipeline(Tasks.text_generation, model=self.model_id) + print( + f'input: {self.plug_input}\noutput: {pipe(self.plug_input, out_length=256)}' + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_product_retrieval_embedding.py b/tests/pipelines/test_product_retrieval_embedding.py new file mode 100644 index 00000000..2483d53a --- /dev/null +++ b/tests/pipelines/test_product_retrieval_embedding.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ProductRetrievalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.product_retrieval_embedding + self.model_id = 'damo/cv_resnet50_product-bag-embedding-models' + + img_input = 'data/test/images/product_embed_bag.jpg' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + product_embed = pipeline(Tasks.product_retrieval_embedding, + self.model_id) + result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] + print('abs sum value is: {}'.format(np.sum(np.abs(result)))) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + product_embed = pipeline( + task=Tasks.product_retrieval_embedding, model=model) + result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] + print('abs sum value is: {}'.format(np.sum(np.abs(result)))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + product_embed = pipeline(task=Tasks.product_retrieval_embedding) + result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] + print('abs sum value is: {}'.format(np.sum(np.abs(result)))) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_product_segmentation.py b/tests/pipelines/test_product_segmentation.py new file mode 100644 index 00000000..8f41c13c --- /dev/null +++ b/tests/pipelines/test_product_segmentation.py @@ -0,0 +1,43 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ProductSegmentationTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_F3Net_product-segmentation' + self.input = { + 'input_path': 'data/test/images/product_segmentation.jpg' + } + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + cv2.imwrite('test_product_segmentation_mask.jpg', + result[OutputKeys.MASKS]) + logger.info('test done') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + product_segmentation = pipeline( + Tasks.product_segmentation, model=self.model_id) + self.pipeline_inference(product_segmentation, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + product_segmentation = pipeline(Tasks.product_segmentation) + self.pipeline_inference(product_segmentation, self.input) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_realtime_object_detection.py b/tests/pipelines/test_realtime_object_detection.py new file mode 100644 index 00000000..e04f6b5c --- /dev/null +++ b/tests/pipelines/test_realtime_object_detection.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import realtime_object_detection_bbox_vis +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class RealtimeObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_cspnet_image-object-detection_yolox' + self.model_nano_id = 'damo/cv_cspnet_image-object-detection_yolox_nano_coco' + self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' + self.task = Tasks.image_object_detection + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + realtime_object_detection = pipeline( + Tasks.image_object_detection, model=self.model_id) + + image = cv2.imread(self.test_image) + result = realtime_object_detection(image) + if result: + bboxes = result[OutputKeys.BOXES].astype(int) + image = realtime_object_detection_bbox_vis(image, bboxes) + cv2.imwrite('rt_obj_out.jpg', image) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_nano(self): + realtime_object_detection = pipeline( + Tasks.image_object_detection, model=self.model_nano_id) + + image = cv2.imread(self.test_image) + result = realtime_object_detection(image) + if result: + bboxes = result[OutputKeys.BOXES].astype(int) + image = realtime_object_detection_bbox_vis(image, bboxes) + cv2.imwrite('rtnano_obj_out.jpg', image) + else: + raise ValueError('process error') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_realtime_video_object_detection.py b/tests/pipelines/test_realtime_video_object_detection.py new file mode 100644 index 00000000..d65313a3 --- /dev/null +++ b/tests/pipelines/test_realtime_video_object_detection.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +import numpy as np + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import show_video_object_detection_result +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class RealtimeVideoObjectDetectionTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_cspnet_video-object-detection_streamyolo' + self.test_video = 'data/test/videos/test_realtime_vod.mp4' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + realtime_video_object_detection = pipeline( + Tasks.video_object_detection, model=self.model_id) + result = realtime_video_object_detection(self.test_video) + if result: + logger.info('Video output to test_vod_results.avi') + show_video_object_detection_result(self.test_video, + result[OutputKeys.BOXES], + result[OutputKeys.LABELS], + 'test_vod_results.avi') + else: + raise ValueError('process error') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_referring_video_object_segmentation.py b/tests/pipelines/test_referring_video_object_segmentation.py new file mode 100644 index 00000000..4d8206b3 --- /dev/null +++ b/tests/pipelines/test_referring_video_object_segmentation.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ReferringVideoObjectSegmentationTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.referring_video_object_segmentation + self.model_id = 'damo/cv_swin-t_referring_video-object-segmentation' + + @unittest.skip('skip since the model is set to private for now') + def test_referring_video_object_segmentation(self): + input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' + text_queries = [ + 'guy in black performing tricks on a bike', + 'a black bike used to perform tricks' + ] + start_pt, end_pt = 4, 14 + input_tuple = (input_location, text_queries, start_pt, end_pt) + pp = pipeline( + Tasks.referring_video_object_segmentation, model=self.model_id) + result = pp(input_tuple) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skip('skip since the model is set to private for now') + def test_referring_video_object_segmentation_with_default_task(self): + input_location = 'data/test/videos/referring_video_object_segmentation_test_video.mp4' + text_queries = [ + 'guy in black performing tricks on a bike', + 'a black bike used to perform tricks' + ] + start_pt, end_pt = 4, 14 + input_tuple = (input_location, text_queries, start_pt, end_pt) + pp = pipeline(Tasks.referring_video_object_segmentation) + result = pp(input_tuple) + if result: + print(result) + else: + raise ValueError('process error') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_relation_extraction.py b/tests/pipelines/test_relation_extraction.py new file mode 100644 index 00000000..561eaf21 --- /dev/null +++ b/tests/pipelines/test_relation_extraction.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import InformationExtractionModel +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import InformationExtractionPipeline +from modelscope.preprocessors import RelationExtractionPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.relation_extraction + self.model_id = 'damo/nlp_bert_relation-extraction_chinese-base' + + sentence = '高捷,祖籍江苏,本科毕业于东南大学' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = RelationExtractionPreprocessor(cache_path) + model = InformationExtractionModel.from_pretrained(cache_path) + pipeline1 = InformationExtractionPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.relation_extraction, model=model, preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = RelationExtractionPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.relation_extraction, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.relation_extraction, model=self.model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.relation_extraction) + print(pipeline_ins(input=self.sentence)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_retina_face_detection.py b/tests/pipelines/test_retina_face_detection.py new file mode 100644 index 00000000..343e1c91 --- /dev/null +++ b/tests/pipelines/test_retina_face_detection.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_result +from modelscope.utils.test_utils import test_level + + +class RetinaFaceDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet50_face-detection_retinaface' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/retina_face_detection.jpg' + + result = face_detection(img_path) + self.show_result(img_path, result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_salient_detection.py b/tests/pipelines/test_salient_detection.py new file mode 100644 index 00000000..bcb904e6 --- /dev/null +++ b/tests/pipelines/test_salient_detection.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.semantic_segmentation + self.model_id = 'damo/cv_u2net_salient-detection' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_salient_detection(self): + input_location = 'data/test/images/image_salient_detection.jpg' + model_id = 'damo/cv_u2net_salient-detection' + salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id) + result = salient_detect(input_location) + import cv2 + cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS]) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_sentence_embedding.py b/tests/pipelines/test_sentence_embedding.py new file mode 100644 index 00000000..e96724a8 --- /dev/null +++ b/tests/pipelines/test_sentence_embedding.py @@ -0,0 +1,82 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import shutil +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import BertForSentenceEmbedding +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import SentenceEmbeddingPipeline +from modelscope.preprocessors import SentenceEmbeddingPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class SentenceEmbeddingTest(unittest.TestCase): + model_id = 'damo/nlp_corom_sentence-embedding_english-base' + inputs = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [ + "On average, students take about 18 to 24 months to complete a master's degree.", + 'On the other hand, some students prefer to go at a slower pace and choose to take ', + 'several years to complete their studies.', + 'It can take anywhere from two semesters' + ] + } + + inputs2 = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [ + "On average, students take about 18 to 24 months to complete a master's degree." + ] + } + + inputs3 = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [] + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = SentenceEmbeddingPreprocessor(cache_path) + model = BertForSentenceEmbedding.from_pretrained(cache_path) + pipeline1 = SentenceEmbeddingPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.sentence_embedding, model=model, preprocessor=tokenizer) + print(f'inputs: {self.inputs}\n' + f'pipeline1:{pipeline1(input=self.inputs)}') + print() + print(f'pipeline2: {pipeline2(input=self.inputs)}') + print() + print(f'inputs: {self.inputs2}\n' + f'pipeline1:{pipeline1(input=self.inputs2)}') + print() + print(f'pipeline2: {pipeline2(input=self.inputs2)}') + print(f'inputs: {self.inputs3}\n' + f'pipeline1:{pipeline1(input=self.inputs3)}') + print() + print(f'pipeline2: {pipeline2(input=self.inputs3)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = SentenceEmbeddingPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.sentence_embedding, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.sentence_embedding, model=self.model_id) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.sentence_embedding) + print(pipeline_ins(input=self.inputs)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_sentence_similarity.py b/tests/pipelines/test_sentence_similarity.py new file mode 100644 index 00000000..904caea3 --- /dev/null +++ b/tests/pipelines/test_sentence_similarity.py @@ -0,0 +1,73 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForSequenceClassification +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextClassificationPipeline +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool +from modelscope.utils.test_utils import test_level + + +class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.sentence_similarity + self.model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + + sentence1 = '今天气温比昨天高么?' + sentence2 = '今天湿度比昨天高么?' + regress_tool = MsRegressTool(baseline=False) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run(self): + cache_path = snapshot_download(self.model_id) + tokenizer = SequenceClassificationPreprocessor(cache_path) + model = SbertForSequenceClassification.from_pretrained(cache_path) + pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.sentence_similarity, model=model, preprocessor=tokenizer) + print('test1') + print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' + f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') + print() + print( + f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' + f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = SequenceClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, model=self.model_id) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, + 'sbert_sen_sim', + compare_fn=IgnoreKeyFn('.*intermediate_act_fn')): + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.sentence_similarity) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_sentiment_classification.py b/tests/pipelines/test_sentiment_classification.py new file mode 100644 index 00000000..5c8d4e93 --- /dev/null +++ b/tests/pipelines/test_sentiment_classification.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp.task_models.sequence_classification import \ + SequenceClassificationModel +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextClassificationPipeline +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class SentimentClassificationTaskModelTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_classification + self.model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' + + sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_direct_file_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = SequenceClassificationPreprocessor(cache_path) + model = SequenceClassificationModel.from_pretrained( + self.model_id, num_labels=2) + pipeline1 = TextClassificationPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.text_classification, model=model, preprocessor=tokenizer) + print(f'sentence1: {self.sentence1}\n' + f'pipeline1:{pipeline1(input=self.sentence1)}') + print(f'sentence1: {self.sentence1}\n' + f'pipeline1: {pipeline2(input=self.sentence1)}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = SequenceClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.text_classification, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence1)) + self.assertTrue( + isinstance(pipeline_ins.model, SequenceClassificationModel)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.text_classification, model=self.model_id) + print(pipeline_ins(input=self.sentence1)) + self.assertTrue( + isinstance(pipeline_ins.model, SequenceClassificationModel)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.text_classification) + print(pipeline_ins(input=self.sentence1)) + self.assertTrue( + isinstance(pipeline_ins.model, SequenceClassificationModel)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_shop_segmentation.py b/tests/pipelines/test_shop_segmentation.py new file mode 100644 index 00000000..58c56dd7 --- /dev/null +++ b/tests/pipelines/test_shop_segmentation.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ShopSegmentationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_shop_segmentation(self): + input_location = 'data/test/images/shop_segmentation.jpg' + model_id = 'damo/cv_vitb16_segmentation_shop-seg' + shop_seg = pipeline(Tasks.shop_segmentation, model=model_id) + result = shop_seg(input_location) + import cv2 + # result[OutputKeys.MASKS] is segment map result,other keys are not used + cv2.imwrite(input_location + '_shopseg.jpg', result[OutputKeys.MASKS]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_skin_retouching.py b/tests/pipelines/test_skin_retouching.py new file mode 100644 index 00000000..db8d89ed --- /dev/null +++ b/tests/pipelines/test_skin_retouching.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class SkinRetouchingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.skin_retouching + self.model_id = 'damo/cv_unet_skin-retouching' + self.test_image = 'data/test/images/skin_retouching.png' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + cv2.imwrite('result_skinretouching.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result_skinretouching.png")}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id) + skin_retouching = pipeline(Tasks.skin_retouching, model=model_dir) + self.pipeline_inference(skin_retouching, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + skin_retouching = pipeline(Tasks.skin_retouching, model=self.model_id) + self.pipeline_inference(skin_retouching, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + skin_retouching = pipeline(Tasks.skin_retouching) + self.pipeline_inference(skin_retouching, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py new file mode 100644 index 00000000..2916d31a --- /dev/null +++ b/tests/pipelines/test_speech_signal_process.py @@ -0,0 +1,121 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path +import unittest + +from modelscope.metainfo import Pipelines +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + +NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav' +FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav' +NEAREND_MIC_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ + 'test/audios/nearend_mic.wav' +FAREND_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ + 'test/audios/farend_speech.wav' + +NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav' +NOISE_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ + 'test/audios/speech_with_noise.wav' + + +class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + pass + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_aec(self): + model_id = 'damo/speech_dfsmn_aec_psm_16k' + input = { + 'nearend_mic': os.path.join(os.getcwd(), NEAREND_MIC_FILE), + 'farend_speech': os.path.join(os.getcwd(), FAREND_SPEECH_FILE) + } + aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id) + output_path = os.path.abspath('output.wav') + aec(input, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_aec_url(self): + model_id = 'damo/speech_dfsmn_aec_psm_16k' + input = { + 'nearend_mic': NEAREND_MIC_URL, + 'farend_speech': FAREND_SPEECH_URL + } + aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id) + output_path = os.path.abspath('output.wav') + aec(input, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_aec_bytes(self): + model_id = 'damo/speech_dfsmn_aec_psm_16k' + input = {} + with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f: + input['nearend_mic'] = f.read() + with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f: + input['farend_speech'] = f.read() + aec = pipeline( + Tasks.acoustic_echo_cancellation, + model=model_id, + pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k) + output_path = os.path.abspath('output.wav') + aec(input, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_aec_tuple_bytes(self): + model_id = 'damo/speech_dfsmn_aec_psm_16k' + with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f: + nearend_bytes = f.read() + with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f: + farend_bytes = f.read() + inputs = (nearend_bytes, farend_bytes) + aec = pipeline( + Tasks.acoustic_echo_cancellation, + model=model_id, + pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k) + output_path = os.path.abspath('output.wav') + aec(inputs, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ans(self): + model_id = 'damo/speech_frcrn_ans_cirm_16k' + ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) + output_path = os.path.abspath('output.wav') + ans(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), + output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_ans_url(self): + model_id = 'damo/speech_frcrn_ans_cirm_16k' + ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) + output_path = os.path.abspath('output.wav') + ans(NOISE_SPEECH_URL, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ans_bytes(self): + model_id = 'damo/speech_frcrn_ans_cirm_16k' + ans = pipeline( + Tasks.acoustic_noise_suppression, + model=model_id, + pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) + output_path = os.path.abspath('output.wav') + with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), 'rb') as f: + data = f.read() + ans(data, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_table_question_answering.py b/tests/pipelines/test_table_question_answering.py new file mode 100644 index 00000000..825d8f23 --- /dev/null +++ b/tests/pipelines/test_table_question_answering.py @@ -0,0 +1,199 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest +from threading import Thread +from typing import List + +import json +from transformers import BertTokenizer + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline +from modelscope.preprocessors import TableQuestionAnsweringPreprocessor +from modelscope.preprocessors.nlp.space_T_cn.fields.database import Database +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +def tableqa_tracking_and_print_results_with_history( + pipelines: List[TableQuestionAnsweringPipeline]): + test_case = { + 'utterance': [ + '有哪些风险类型?', + '风险类型有多少种?', + '珠江流域的小(2)型水库的库容总量是多少?', + '那平均值是多少?', + '那水库的名称呢?', + '换成中型的呢?', + '枣庄营业厅的电话', + '那地址呢?', + '枣庄营业厅的电话和地址', + ] + } + for p in pipelines: + historical_queries = None + for question in test_case['utterance']: + output_dict = p({ + 'question': question, + 'history_sql': historical_queries + })[OutputKeys.OUTPUT] + print('question', question) + print('sql text:', output_dict[OutputKeys.SQL_STRING]) + print('sql query:', output_dict[OutputKeys.SQL_QUERY]) + print('query result:', output_dict[OutputKeys.QUERT_RESULT]) + print('json dumps', json.dumps(output_dict, ensure_ascii=False)) + print() + historical_queries = output_dict[OutputKeys.HISTORY] + + +def tableqa_tracking_and_print_results_without_history( + pipelines: List[TableQuestionAnsweringPipeline]): + test_case = { + 'utterance': [ + '有哪些风险类型?', + '风险类型有多少种?', + '珠江流域的小(2)型水库的库容总量是多少?', + '枣庄营业厅的电话', + '枣庄营业厅的电话和地址', + ] + } + for p in pipelines: + for question in test_case['utterance']: + output_dict = p({'question': question})[OutputKeys.OUTPUT] + print('question', question) + print('sql text:', output_dict[OutputKeys.SQL_STRING]) + print('sql query:', output_dict[OutputKeys.SQL_QUERY]) + print('query result:', output_dict[OutputKeys.QUERT_RESULT]) + print('json dumps', json.dumps(output_dict, ensure_ascii=False)) + print() + + +def tableqa_tracking_and_print_results_with_tableid( + pipelines: List[TableQuestionAnsweringPipeline]): + test_case = { + 'utterance': [ + ['有哪些风险类型?', 'fund'], + ['风险类型有多少种?', 'reservoir'], + ['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], + ['那平均值是多少?', 'reservoir'], + ['那水库的名称呢?', 'reservoir'], + ['换成中型的呢?', 'reservoir'], + ['枣庄营业厅的电话', 'business'], + ['那地址呢?', 'business'], + ['枣庄营业厅的电话和地址', 'business'], + ], + } + for p in pipelines: + historical_queries = None + for question, table_id in test_case['utterance']: + output_dict = p({ + 'question': question, + 'table_id': table_id, + 'history_sql': historical_queries + })[OutputKeys.OUTPUT] + print('question', question) + print('sql text:', output_dict[OutputKeys.SQL_STRING]) + print('sql query:', output_dict[OutputKeys.SQL_QUERY]) + print('query result:', output_dict[OutputKeys.QUERT_RESULT]) + print('json dumps', json.dumps(output_dict, ensure_ascii=False)) + print() + historical_queries = output_dict[OutputKeys.HISTORY] + + +class TableQuestionAnswering(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.table_question_answering + self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) + pipelines = [ + pipeline( + Tasks.table_question_answering, + model=cache_path, + preprocessor=preprocessor) + ] + tableqa_tracking_and_print_results_with_history(pipelines) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download_with_multithreads(self): + cache_path = snapshot_download(self.model_id) + pl = pipeline(Tasks.table_question_answering, model=cache_path) + + def print_func(pl, i): + result = pl({ + 'question': '上个月收益从低到高排前七的基金的名称和风险等级是什么', + 'table_id': 'fund', + 'history_sql': None + }) + print(i, result[OutputKeys.OUTPUT][OutputKeys.SQL_QUERY], + result[OutputKeys.OUTPUT][OutputKeys.QUERT_RESULT], + json.dumps(result)) + + procs = [] + for i in range(5): + proc = Thread(target=print_func, args=(pl, i)) + procs.append(proc) + proc.start() + for proc in procs: + proc.join() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + self.tokenizer = BertTokenizer( + os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) + db = Database( + tokenizer=self.tokenizer, + table_file_path=[ + os.path.join(model.model_dir, 'databases', fname) + for fname in os.listdir( + os.path.join(model.model_dir, 'databases')) + ], + syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), + is_use_sqlite=False) + preprocessor = TableQuestionAnsweringPreprocessor( + model_dir=model.model_dir, db=db) + pipelines = [ + pipeline( + Tasks.table_question_answering, + model=model, + preprocessor=preprocessor, + db=db) + ] + tableqa_tracking_and_print_results_with_tableid(pipelines) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub_with_other_classes(self): + model = Model.from_pretrained(self.model_id) + self.tokenizer = BertTokenizer( + os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) + db = Database( + tokenizer=self.tokenizer, + table_file_path=[ + os.path.join(model.model_dir, 'databases', fname) + for fname in os.listdir( + os.path.join(model.model_dir, 'databases')) + ], + syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), + is_use_sqlite=True) + preprocessor = TableQuestionAnsweringPreprocessor( + model_dir=model.model_dir, db=db) + pipelines = [ + pipeline( + Tasks.table_question_answering, + model=model, + preprocessor=preprocessor, + db=db) + ] + tableqa_tracking_and_print_results_without_history(pipelines) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text2text_generation.py b/tests/pipelines/test_text2text_generation.py new file mode 100644 index 00000000..d90263c4 --- /dev/null +++ b/tests/pipelines/test_text2text_generation.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import T5ForConditionalGeneration +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import Text2TextGenerationPipeline +from modelscope.preprocessors import Text2TextGenerationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id_generate = 'damo/t5-cn-base-test' + self.input_generate = '中国的首都位于。' + self.model_id_translate = 'damo/t5-translate-base-test' + self.input_translate = 'My name is Wolfgang and I live in Berlin' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_T5(self): + cache_path = snapshot_download(self.model_id_generate) + model = T5ForConditionalGeneration.from_pretrained(cache_path) + preprocessor = Text2TextGenerationPreprocessor(cache_path) + pipeline1 = Text2TextGenerationPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.text2text_generation, model=model, preprocessor=preprocessor) + print( + f'pipeline1: {pipeline1(self.input_generate)}\npipeline2: {pipeline2(self.input_generate)}' + ) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_pipeline_with_model_instance(self): + model = Model.from_pretrained(self.model_id_translate) + preprocessor = Text2TextGenerationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.text2text_generation, + model=model, + preprocessor=preprocessor) + print(pipeline_ins(self.input_translate)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_pipeline_with_model_id(self): + pipeline_ins = pipeline( + task=Tasks.text2text_generation, model=self.model_id_translate) + print(pipeline_ins(self.input_translate)) + + @unittest.skip( + 'only for test cases, there is no default official model yet') + def test_run_pipeline_without_model_id(self): + pipeline_ins = pipeline(task=Tasks.text2text_generation) + print(pipeline_ins(self.input_generate)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py new file mode 100644 index 00000000..5b38e116 --- /dev/null +++ b/tests/pipelines/test_text_classification.py @@ -0,0 +1,100 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.models import Model +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextClassificationPipeline +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class SequenceClassificationTest(unittest.TestCase, DemoCompatibilityCheck): + sentence1 = 'i like this wonderful place' + + def setUp(self) -> None: + self.model_id = 'damo/bert-base-sst2' + self.task = Tasks.text_classification + + def predict(self, pipeline_ins: TextClassificationPipeline): + from easynlp.appzoo import load_dataset + + set = load_dataset('glue', 'sst2') + data = set['test']['sentence'][:3] + + results = pipeline_ins(data[0]) + print(results) + results = pipeline_ins(data[1]) + print(results) + + print(data) + + def printDataset(self, dataset: MsDataset): + for i, r in enumerate(dataset): + if i > 10: + break + print(r) + + # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skip('nlp model does not support tensor input, skipped') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = SequenceClassificationPreprocessor( + model.model_dir, first_sequence='sentence', second_sequence=None) + pipeline_ins = pipeline( + task=Tasks.text_classification, + model=model, + preprocessor=preprocessor) + print(f'sentence1: {self.sentence1}\n' + f'pipeline1:{pipeline_ins(input=self.sentence1)}') + + # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skip('nlp model does not support tensor input, skipped') + def test_run_with_model_name(self): + text_classification = pipeline( + task=Tasks.text_classification, model=self.model_id) + result = text_classification( + MsDataset.load( + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test', + target='premise')) + self.printDataset(result) + + # @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skip('nlp model does not support tensor input, skipped') + def test_run_with_default_model(self): + text_classification = pipeline(task=Tasks.text_classification) + result = text_classification( + MsDataset.load( + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test', + target='premise')) + self.printDataset(result) + + # @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skip('nlp model does not support tensor input, skipped') + def test_run_with_modelscope_dataset(self): + text_classification = pipeline(task=Tasks.text_classification) + # loaded from modelscope dataset + dataset = MsDataset.load( + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test', + target='premise') + result = text_classification(dataset) + self.printDataset(result) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_driven_segmentation.py b/tests/pipelines/test_text_driven_segmentation.py new file mode 100644 index 00000000..a67729ff --- /dev/null +++ b/tests/pipelines/test_text_driven_segmentation.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextDrivenSegmentationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_text_driven_segmentation(self): + input_location = 'data/test/images/text_driven_segmentation.jpg' + test_input = { + 'image': input_location, + 'text': 'bear', + } + model_id = 'damo/cv_vitl16_segmentation_text-driven-seg' + shop_seg = pipeline(Tasks.text_driven_segmentation, model=model_id) + result = shop_seg(test_input) + import cv2 + # result[OutputKeys.MASKS] is segment map result,other keys are not used + cv2.imwrite(input_location + '_lseg.jpg', result[OutputKeys.MASKS]) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.test_demo() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_error_correction.py b/tests/pipelines/test_text_error_correction.py new file mode 100644 index 00000000..a714d3d0 --- /dev/null +++ b/tests/pipelines/test_text_error_correction.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import BartForTextErrorCorrection +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextErrorCorrectionPipeline +from modelscope.preprocessors import TextErrorCorrectionPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_error_correction + self.model_id = 'damo/nlp_bart_text-error-correction_chinese' + + input = '随着中国经济突飞猛近,建造工业与日俱增' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_direct_download(self): + cache_path = snapshot_download(self.model_id) + model = BartForTextErrorCorrection(cache_path) + preprocessor = TextErrorCorrectionPreprocessor(cache_path) + pipeline1 = TextErrorCorrectionPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.text_error_correction, + model=model, + preprocessor=preprocessor) + print( + f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' + ) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.text_error_correction, + model=model, + preprocessor=preprocessor) + print(pipeline_ins(self.input)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.text_error_correction, model=self.model_id) + print(pipeline_ins(self.input)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.text_error_correction) + print(pipeline_ins(self.input)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py new file mode 100644 index 00000000..ddb77eeb --- /dev/null +++ b/tests/pipelines/test_text_generation.py @@ -0,0 +1,205 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import GPT3ForTextGeneration, PalmForTextGeneration +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextGenerationPipeline +from modelscope.preprocessors import TextGenerationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.palm_model_id_zh_base = 'damo/nlp_palm2.0_text-generation_chinese-base' + self.palm_model_id_zh_large = 'damo/nlp_palm2.0_text-generation_chinese-large' + self.palm_model_id_zh_commodity = 'damo/nlp_palm2.0_text-generation_commodity_chinese-base' + self.palm_model_id_zh_weather = 'damo/nlp_palm2.0_text-generation_weather_chinese-base' + self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' + self.palm_input_zh = """ + 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: + 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 + """ + self.palm_input_commodity = '垃圾桶,双层,可拆卸,加高,加高双层,把手,垃圾桶,内附,万向轮' + self.palm_input_weather = "今日天气类型='浮尘'&空气质量等级='重度污染'&紫外线强度指数='中等'" + self.palm_input_en = """ + The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started + her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , + 54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges + of paedophilia against nine children because he has dementia . Today , newly-released documents + revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years . + And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her + pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 . + """ + + self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base' + self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' + self.gpt3_poetry_large_model_id = 'damo/nlp_gpt3_poetry-generation_chinese-large' + self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,' + self.gpt3_poetry_input = '天生我材必有用,' + + def run_pipeline_with_model_instance(self, model_id, input): + model = Model.from_pretrained(model_id) + preprocessor = TextGenerationPreprocessor( + model.model_dir, + model.tokenizer, + first_sequence='sentence', + second_sequence=None) + pipeline_ins = pipeline( + task=Tasks.text_generation, model=model, preprocessor=preprocessor) + print(pipeline_ins(input)) + + def run_pipeline_with_model_id(self, model_id, input): + pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id) + print(pipeline_ins(input)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_palm_zh_base_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_base, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_palm_en_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_en, + self.palm_input_en) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_gpt_base_with_model_name(self): + self.run_pipeline_with_model_id(self.gpt3_base_model_id, + self.gpt3_input) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_gpt_large_with_model_name(self): + self.run_pipeline_with_model_id(self.gpt3_large_model_id, + self.gpt3_input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_large_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_large, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_commodity_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_commodity, + self.palm_input_commodity) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_weather_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_weather, + self.palm_input_weather) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_base_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_base, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_large_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_large, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_commodity_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_commodity, + self.palm_input_commodity) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_weather_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_weather, + self.palm_input_weather) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_en_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_en, + self.palm_input_en) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_poetry_large_with_model_name(self): + self.run_pipeline_with_model_id(self.gpt3_poetry_large_model_id, + self.gpt3_poetry_input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_base_with_model_instance(self): + self.run_pipeline_with_model_instance(self.gpt3_base_model_id, + self.gpt3_input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_large_with_model_instance(self): + self.run_pipeline_with_model_instance(self.gpt3_large_model_id, + self.gpt3_input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_poetry_large_with_model_instance(self): + self.run_pipeline_with_model_instance(self.gpt3_poetry_large_model_id, + self.gpt3_poetry_input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_palm(self): + for model_id, input in ((self.palm_model_id_zh_base, + self.palm_input_zh), (self.palm_model_id_en, + self.palm_input_en)): + cache_path = snapshot_download(model_id) + model = PalmForTextGeneration.from_pretrained(cache_path) + preprocessor = TextGenerationPreprocessor( + cache_path, + model.tokenizer, + first_sequence='sentence', + second_sequence=None) + pipeline1 = TextGenerationPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.text_generation, model=model, preprocessor=preprocessor) + print( + f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}' + ) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_gpt3(self): + cache_path = snapshot_download(self.gpt3_base_model_id) + model = GPT3ForTextGeneration(cache_path) + preprocessor = TextGenerationPreprocessor( + cache_path, + model.tokenizer, + first_sequence='sentence', + second_sequence=None) + pipeline1 = TextGenerationPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.text_generation, model=model, preprocessor=preprocessor) + print( + f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}' + ) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.text_generation) + print(pipeline_ins(self.palm_input_zh)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_bloom(self): + pipe = pipeline( + task=Tasks.text_generation, model='langboat/bloom-1b4-zh') + print(pipe('中国的首都是')) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_gpt_neo(self): + pipe = pipeline( + task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base') + print( + pipe( + '我是', + do_sample=True, + top_k=5, + top_p=1, + max_length=20, + repetition_penalty=0.5)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_ranking.py b/tests/pipelines/test_text_ranking.py new file mode 100644 index 00000000..0b43e8b4 --- /dev/null +++ b/tests/pipelines/test_text_ranking.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import shutil +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import BertForTextRanking +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextRankingPipeline +from modelscope.preprocessors import TextRankingPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextRankingTest(unittest.TestCase): + models = [ + 'damo/nlp_corom_passage-ranking_english-base', + 'damo/nlp_rom_passage-ranking_chinese-base' + ] + + inputs = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [ + "On average, students take about 18 to 24 months to complete a master's degree.", + 'On the other hand, some students prefer to go at a slower pace and choose to take ' + 'several years to complete their studies.', + 'It can take anywhere from two semesters' + ] + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + for model_id in self.models: + cache_path = snapshot_download(model_id) + tokenizer = TextRankingPreprocessor(cache_path) + model = BertForTextRanking.from_pretrained(cache_path) + pipeline1 = TextRankingPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.text_ranking, model=model, preprocessor=tokenizer) + print(f'sentence: {self.inputs}\n' + f'pipeline1:{pipeline1(input=self.inputs)}') + print() + print(f'pipeline2: {pipeline2(input=self.inputs)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + for model_id in self.models: + model = Model.from_pretrained(model_id) + tokenizer = TextRankingPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.text_ranking, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + for model_id in self.models: + pipeline_ins = pipeline(task=Tasks.text_ranking, model=model_id) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.text_ranking) + print(pipeline_ins(input=self.inputs)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_to_image_synthesis.py b/tests/pipelines/test_text_to_image_synthesis.py new file mode 100644 index 00000000..0da6768a --- /dev/null +++ b/tests/pipelines/test_text_to_image_synthesis.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TextToImageSynthesisTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_to_image_synthesis + self.model_id = 'damo/cv_diffusion_text-to-image-synthesis_tiny' + + test_text = { + 'text': '宇航员', + 'generator_ddim_timesteps': 2, + 'upsampler_256_ddim_timesteps': 2, + 'upsampler_1024_ddim_timesteps': 2, + 'debug': True + } + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis, model=model) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_name(self): + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis, model=self.model_id) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_to_speech.py b/tests/pipelines/test_text_to_speech.py new file mode 100644 index 00000000..50807e23 --- /dev/null +++ b/tests/pipelines/test_text_to_speech.py @@ -0,0 +1,78 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +# NOTICE: Tensorflow 1.15 seems not so compatible with pytorch. +# A segmentation fault may be raise by pytorch cpp library +# if 'import tensorflow' in front of 'import torch'. +# Puting a 'import torch' here can bypass this incompatibility. +import torch +from scipy.io.wavfile import write + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +import tensorflow as tf # isort:skip + +logger = get_logger() + + +class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_to_speech + self.zhcn_text = '今天北京天气怎么样' + self.en_text = 'How is the weather in Beijing?' + self.zhcn_voices = [ + 'zhitian_emo', 'zhizhe_emo', 'zhiyan_emo', 'zhibei_emo', 'zhcn' + ] + self.zhcn_models = [ + 'damo/speech_sambert-hifigan_tts_zhitian_emo_zh-cn_16k', + 'damo/speech_sambert-hifigan_tts_zhizhe_emo_zh-cn_16k', + 'damo/speech_sambert-hifigan_tts_zhiyan_emo_zh-cn_16k', + 'damo/speech_sambert-hifigan_tts_zhibei_emo_zh-cn_16k', + 'damo/speech_sambert-hifigan_tts_zh-cn_16k' + ] + self.en_voices = ['luca', 'luna', 'andy', 'annie', 'engb', 'enus'] + self.en_models = [ + 'damo/speech_sambert-hifigan_tts_luca_en-gb_16k', + 'damo/speech_sambert-hifigan_tts_luna_en-gb_16k', + 'damo/speech_sambert-hifigan_tts_andy_en-us_16k', + 'damo/speech_sambert-hifigan_tts_annie_en-us_16k', + 'damo/speech_sambert-hifigan_tts_en-gb_16k', + 'damo/speech_sambert-hifigan_tts_en-us_16k' + ] + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_pipeline(self): + for i in range(len(self.zhcn_voices)): + logger.info('test %s' % self.zhcn_voices[i]) + sambert_hifigan_tts = pipeline( + task=self.task, model=self.zhcn_models[i]) + self.assertTrue(sambert_hifigan_tts is not None) + output = sambert_hifigan_tts(input=self.zhcn_text) + self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) + pcm = output[OutputKeys.OUTPUT_PCM] + write('output_%s.wav' % self.zhcn_voices[i], 16000, pcm) + for i in range(len(self.en_voices)): + logger.info('test %s' % self.en_voices[i]) + sambert_hifigan_tts = pipeline( + task=self.task, model=self.en_models[i]) + self.assertTrue(sambert_hifigan_tts is not None) + output = sambert_hifigan_tts(input=self.en_text) + self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM]) + pcm = output[OutputKeys.OUTPUT_PCM] + write('output_%s.wav' % self.en_voices[i], 16000, pcm) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_tinynas_classification.py b/tests/pipelines/test_tinynas_classification.py new file mode 100644 index 00000000..ebc6b722 --- /dev/null +++ b/tests/pipelines/test_tinynas_classification.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TinyNASClassificationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_classification + self.model_id = 'damo/cv_tinynas_classification' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + tinynas_classification = pipeline( + Tasks.image_classification, model='damo/cv_tinynas_classification') + result = tinynas_classification('data/test/images/image_wolf.jpeg') + print(result) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_tinynas_detection.py b/tests/pipelines/test_tinynas_detection.py new file mode 100644 index 00000000..c92b5568 --- /dev/null +++ b/tests/pipelines/test_tinynas_detection.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_object_detection + self.model_id = 'damo/cv_tinynas_object-detection_damoyolo' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_airdet(self): + tinynas_object_detection = pipeline( + Tasks.image_object_detection, model='damo/cv_tinynas_detection') + result = tinynas_object_detection( + 'data/test/images/image_detection.jpg') + print('airdet', result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_damoyolo(self): + tinynas_object_detection = pipeline( + Tasks.image_object_detection, + model='damo/cv_tinynas_object-detection_damoyolo') + result = tinynas_object_detection( + 'data/test/images/image_detection.jpg') + print('damoyolo', result) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_object_detection_auto_pipeline(self): + test_image = 'data/test/images/image_detection.jpg' + tinynas_object_detection = pipeline( + Tasks.image_object_detection, + model='damo/cv_tinynas_object-detection_damoyolo') + result = tinynas_object_detection(test_image) + tinynas_object_detection.show_result(test_image, result, + 'demo_ret.jpg') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_translation_quality_estimation.py b/tests/pipelines/test_translation_quality_estimation.py new file mode 100644 index 00000000..315fa72b --- /dev/null +++ b/tests/pipelines/test_translation_quality_estimation.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TranslationQualityEstimationTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.sentence_similarity + self.model_id = 'damo/nlp_translation_quality_estimation_multilingual' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2zh(self): + inputs = { + 'source_text': 'Love is a losing game', + 'target_text': '宝贝,人和人一场游戏' + } + pipeline_ins = pipeline(self.task, model=self.model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_ulfd_face_detection.py b/tests/pipelines/test_ulfd_face_detection.py new file mode 100644 index 00000000..0ffa688c --- /dev/null +++ b/tests/pipelines/test_ulfd_face_detection.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 +import numpy as np + +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import draw_face_detection_no_lm_result +from modelscope.utils.test_utils import test_level + + +class UlfdFaceDetectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_manual_face-detection_ulfd' + + def show_result(self, img_path, detection_result): + img = draw_face_detection_no_lm_result(img_path, detection_result) + cv2.imwrite('result.png', img) + print(f'output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + face_detection = pipeline(Tasks.face_detection, model=self.model_id) + img_path = 'data/test/images/ulfd_face_detection.jpg' + + result = face_detection(img_path) + self.show_result(img_path, result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_unifold.py b/tests/pipelines/test_unifold.py new file mode 100644 index 00000000..47bb7874 --- /dev/null +++ b/tests/pipelines/test_unifold.py @@ -0,0 +1,34 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.protein_structure + self.model_id = 'DPTech/uni-fold-monomer' + self.model_id_multimer = 'DPTech/uni-fold-multimer' + + self.protein = 'MGLPKKALKESQLQFLTAGTAVSDSSHQTYKVSFIENGVIKNAFYKKLDPKNHYPELLAKISVAVSLFKRIFQGRRSAEERLVFDD' + self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ + 'NIAALKNHIDKIKPIAMQIYKKYSKNIP' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id) + mono_pipeline_ins = pipeline(task=self.task, model=model_dir) + _ = mono_pipeline_ins(self.protein) + + model_dir1 = snapshot_download(self.model_id_multimer) + multi_pipeline_ins = pipeline(task=self.task, model=model_dir1) + _ = multi_pipeline_ins(self.protein_multimer) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_video_category.py b/tests/pipelines/test_video_category.py new file mode 100644 index 00000000..660196b8 --- /dev/null +++ b/tests/pipelines/test_video_category.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VideoCategoryTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.video_category + self.model_id = 'damo/cv_resnet50_video-category' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + category_pipeline = pipeline(Tasks.video_category, self.model_id) + result = category_pipeline( + 'data/test/videos/video_category_test_video.mp4') + + print(f'video category output: {result}.') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_video_inpainting.py b/tests/pipelines/test_video_inpainting.py new file mode 100644 index 00000000..8364b1b3 --- /dev/null +++ b/tests/pipelines/test_video_inpainting.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class VideoInpaintingTest(unittest.TestCase): + + def setUp(self) -> None: + self.model = 'damo/cv_video-inpainting' + self.mask_dir = 'data/test/videos/mask_dir' + self.video_in = 'data/test/videos/video_inpainting_test.mp4' + self.video_out = 'out.mp4' + self.input = { + 'video_input_path': self.video_in, + 'video_output_path': self.video_out, + 'mask_path': self.mask_dir + } + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + video_inpainting = pipeline(Tasks.video_inpainting, model=self.model) + self.pipeline_inference(video_inpainting, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + video_inpainting = pipeline(Tasks.video_inpainting) + self.pipeline_inference(video_inpainting, self.input) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_video_multi_modal_embedding.py b/tests/pipelines/test_video_multi_modal_embedding.py new file mode 100644 index 00000000..afe5940d --- /dev/null +++ b/tests/pipelines/test_video_multi_modal_embedding.py @@ -0,0 +1,50 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class VideoMultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.video_multi_modal_embedding + self.model_id = 'damo/multi_modal_clip_vtretrival_msrvtt_53' + + video_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/multi_modal_test_video_9770.mp4' + caption = 'a person is connecting something to system' + _input = {'video': video_path, 'text': caption} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + pipeline_video_multi_modal_embedding = pipeline( + Tasks.video_multi_modal_embedding, model=self.model_id) + output = pipeline_video_multi_modal_embedding(self._input) + logger.info('text feature: {}'.format( + output['text_embedding'][0][0][0])) + logger.info('video feature: {}'.format( + output['video_embedding'][0][0][0])) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_video_multi_modal_embedding = pipeline( + task=Tasks.video_multi_modal_embedding) + output = pipeline_video_multi_modal_embedding(self._input) + logger.info('text feature: {}'.format( + output['text_embedding'][0][0][0])) + logger.info('video feature: {}'.format( + output['video_embedding'][0][0][0])) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_video_single_object_tracking.py b/tests/pipelines/test_video_single_object_tracking.py new file mode 100644 index 00000000..7f3a9226 --- /dev/null +++ b/tests/pipelines/test_video_single_object_tracking.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import show_video_tracking_result +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class SingleObjectTracking(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.video_single_object_tracking + self.model_id = 'damo/cv_vitb_video-single-object-tracking_ostrack' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_end2end(self): + video_single_object_tracking = pipeline( + Tasks.video_single_object_tracking, model=self.model_id) + video_path = 'data/test/videos/dog.avi' + init_bbox = [414, 343, 514, 449] # [x1, y1, x2, y2] + result = video_single_object_tracking((video_path, init_bbox)) + print('result is : ', result[OutputKeys.BOXES]) + show_video_tracking_result(video_path, result[OutputKeys.BOXES], + './tracking_result.avi') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub_default_model(self): + video_single_object_tracking = pipeline( + Tasks.video_single_object_tracking) + video_path = 'data/test/videos/dog.avi' + init_bbox = [414, 343, 514, 449] # [x1, y1, x2, y2] + result = video_single_object_tracking((video_path, init_bbox)) + print('result is : ', result[OutputKeys.BOXES]) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_video_summarization.py b/tests/pipelines/test_video_summarization.py new file mode 100644 index 00000000..1f965c53 --- /dev/null +++ b/tests/pipelines/test_video_summarization.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VideoSummarizationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.video_summarization + self.model_id = 'damo/cv_googlenet_pgl-video-summarization' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + video_path = 'data/test/videos/video_category_test_video.mp4' + summarization_pipeline = pipeline( + Tasks.video_summarization, model=self.model_id) + result = summarization_pipeline(video_path) + + print(f'video summarization output: \n{result}.') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub_default_model(self): + video_path = 'data/test/videos/video_category_test_video.mp4' + summarization_pipeline = pipeline(Tasks.video_summarization) + result = summarization_pipeline(video_path) + + print(f'video summarization output:\n {result}.') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_virtual_try_on.py b/tests/pipelines/test_virtual_try_on.py new file mode 100644 index 00000000..5c18dcc4 --- /dev/null +++ b/tests/pipelines/test_virtual_try_on.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import cv2 +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VirtualTryonTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.virtual_try_on + self.model_id = 'damo/cv_daflow_virtual-try-on_base' + + masked_model = Image.open('data/test/images/virtual_tryon_model.jpg') + pose = Image.open('data/test/images/virtual_tryon_pose.jpg') + cloth = Image.open('data/test/images/virtual_tryon_cloth.jpg') + input_imgs = (masked_model, pose, cloth) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_virtual_try_on = pipeline( + task=Tasks.virtual_try_on, model=self.model_id) + img = pipeline_virtual_try_on(self.input_imgs)[OutputKeys.OUTPUT_IMG] + cv2.imwrite('demo.jpg', img[:, :, ::-1]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_name_default_model(self): + pipeline_virtual_tryon = pipeline(task=Tasks.virtual_try_on) + img = pipeline_virtual_tryon(self.input_imgs)[OutputKeys.OUTPUT_IMG] + cv2.imwrite('demo.jpg', img[:, :, ::-1]) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_word_segmentation.py b/tests/pipelines/test_word_segmentation.py new file mode 100644 index 00000000..6969c0e6 --- /dev/null +++ b/tests/pipelines/test_word_segmentation.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForTokenClassification +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import WordSegmentationPipeline +from modelscope.preprocessors import TokenClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool +from modelscope.utils.test_utils import test_level + + +class WordSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.word_segmentation + self.model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' + + sentence = '今天天气不错,适合出去游玩' + sentence_eng = 'I am a program.' + regress_tool = MsRegressTool(baseline=False) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = TokenClassificationPreprocessor(cache_path) + model = SbertForTokenClassification.from_pretrained(cache_path) + pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = TokenClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=self.model_id) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, + 'sbert_ws_zh', + compare_fn=IgnoreKeyFn('.*intermediate_act_fn')): + print(pipeline_ins(input=self.sentence)) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, + 'sbert_ws_en', + compare_fn=IgnoreKeyFn('.*intermediate_act_fn')): + print(pipeline_ins(input=self.sentence_eng)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.word_segmentation) + print(pipeline_ins(input=self.sentence)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_zero_shot_classification.py b/tests/pipelines/test_zero_shot_classification.py new file mode 100644 index 00000000..00789707 --- /dev/null +++ b/tests/pipelines/test_zero_shot_classification.py @@ -0,0 +1,86 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForSequenceClassification +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import ZeroShotClassificationPipeline +from modelscope.preprocessors import ZeroShotClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool +from modelscope.utils.test_utils import test_level + + +class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.zero_shot_classification + self.model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base' + + sentence = '全新突破 解放军运20版空中加油机曝光' + labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] + labels_str = '文化, 体育, 娱乐, 财经, 家居, 汽车, 教育, 科技, 军事' + template = '这篇文章的标题是{}' + regress_tool = MsRegressTool(baseline=False) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_direct_file_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = ZeroShotClassificationPreprocessor(cache_path) + model = SbertForSequenceClassification.from_pretrained(cache_path) + pipeline1 = ZeroShotClassificationPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.zero_shot_classification, + model=model, + preprocessor=tokenizer) + + print( + f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' + ) + print( + f'sentence: {self.sentence}\n' + f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels_str,hypothesis_template=self.template)}' + ) + print( + f'sentence: {self.sentence}\n' + f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' + ) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = ZeroShotClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.zero_shot_classification, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.zero_shot_classification, model=self.model_id) + with self.regress_tool.monitor_module_single_forward( + pipeline_ins.model, + 'sbert_zero_shot', + compare_fn=IgnoreKeyFn('.*intermediate_act_fn')): + print( + pipeline_ins( + input=self.sentence, candidate_labels=self.labels)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.zero_shot_classification) + print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/preprocessors/__init__.py b/tests/preprocessors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/preprocessors/test_common.py b/tests/preprocessors/test_common.py new file mode 100644 index 00000000..714b8588 --- /dev/null +++ b/tests/preprocessors/test_common.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import torch + +from modelscope.preprocessors import (PREPROCESSORS, Compose, Filter, + Preprocessor, ToTensor) + + +class ComposeTest(unittest.TestCase): + + def test_compose(self): + + @PREPROCESSORS.register_module() + class Tmp1(Preprocessor): + + def __call__(self, input): + input['tmp1'] = 'tmp1' + return input + + @PREPROCESSORS.register_module() + class Tmp2(Preprocessor): + + def __call__(self, input): + input['tmp2'] = 'tmp2' + return input + + pipeline = [ + dict(type='Tmp1'), + dict(type='Tmp2'), + ] + trans = Compose(pipeline) + + input = {} + output = trans(input) + self.assertEqual(output['tmp1'], 'tmp1') + self.assertEqual(output['tmp2'], 'tmp2') + + +class ToTensorTest(unittest.TestCase): + + def test_totensor(self): + to_tensor_op = ToTensor(keys=['img']) + inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'} + inputs = to_tensor_op(inputs) + self.assertIsInstance(inputs['img'], torch.Tensor) + self.assertEqual(inputs['label'], 1) + self.assertEqual(inputs['path'], 'test.jpg') + + +class FilterTest(unittest.TestCase): + + def test_filter(self): + filter_op = Filter(reserved_keys=['img', 'label']) + inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'} + inputs = filter_op(inputs) + self.assertIn('img', inputs) + self.assertIn('label', inputs) + self.assertNotIn('path', inputs) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/preprocessors/test_image.py b/tests/preprocessors/test_image.py new file mode 100644 index 00000000..a912b4b1 --- /dev/null +++ b/tests/preprocessors/test_image.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from PIL import Image + +from modelscope.preprocessors import load_image + + +class ImagePreprocessorTest(unittest.TestCase): + + def test_load(self): + img = load_image('data/test/images/image_matting.png') + self.assertTrue(isinstance(img, Image.Image)) + self.assertEqual(img.size, (948, 533)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/preprocessors/test_nlp.py b/tests/preprocessors/test_nlp.py new file mode 100644 index 00000000..f9f4d93f --- /dev/null +++ b/tests/preprocessors/test_nlp.py @@ -0,0 +1,113 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.preprocessors import build_preprocessor, nlp +from modelscope.utils.constant import Fields, InputFields +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class NLPPreprocessorTest(unittest.TestCase): + + def test_tokenize(self): + cfg = dict(type='Tokenize', tokenizer_name='bert-base-cased') + preprocessor = build_preprocessor(cfg, Fields.nlp) + input = { + InputFields.text: + 'Do not meddle in the affairs of wizards, ' + 'for they are subtle and quick to anger.' + } + output = preprocessor(input) + self.assertTrue(InputFields.text in output) + self.assertEqual(output['input_ids'], [ + 101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, 1116, + 117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, 119, 102 + ]) + self.assertEqual( + output['token_type_ids'], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + self.assertEqual( + output['attention_mask'], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + + def test_token_classification_tokenize(self): + with self.subTest(tokenizer_type='bert'): + cfg = dict( + type='token-cls-tokenizer', + model_dir='bert-base-cased', + label2id={ + 'O': 0, + 'B': 1, + 'I': 2 + }) + preprocessor = build_preprocessor(cfg, Fields.nlp) + input = 'Do not meddle in the affairs of wizards, ' \ + 'for they are subtle and quick to anger.' + output = preprocessor(input) + self.assertTrue(InputFields.text in output) + self.assertEqual(output['input_ids'].tolist()[0], [ + 101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, + 1116, 117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, + 119, 102 + ]) + self.assertEqual(output['attention_mask'].tolist()[0], [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1 + ]) + self.assertEqual(output['label_mask'].tolist()[0], [ + False, True, True, True, False, True, True, True, True, True, + False, True, True, True, True, True, True, True, True, True, + True, False + ]) + self.assertEqual(output['offset_mapping'], [(0, 2), (3, 6), + (7, 13), (14, 16), + (17, 20), (21, 28), + (29, 31), (32, 39), + (39, 40), (41, 44), + (45, 49), (50, 53), + (54, 60), (61, 64), + (65, 70), (71, 73), + (74, 79), (79, 80)]) + + with self.subTest(tokenizer_type='roberta'): + cfg = dict( + type='token-cls-tokenizer', + model_dir='xlm-roberta-base', + label2id={ + 'O': 0, + 'B': 1, + 'I': 2 + }) + preprocessor = build_preprocessor(cfg, Fields.nlp) + input = 'Do not meddle in the affairs of wizards, ' \ + 'for they are subtle and quick to anger.' + output = preprocessor(input) + self.assertTrue(InputFields.text in output) + self.assertEqual(output['input_ids'].tolist()[0], [ + 0, 984, 959, 128, 19298, 23, 70, 103086, 7, 111, 6, 44239, + 99397, 4, 100, 1836, 621, 1614, 17991, 136, 63773, 47, 348, 56, + 5, 2 + ]) + self.assertEqual(output['attention_mask'].tolist()[0], [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1 + ]) + self.assertEqual(output['label_mask'].tolist()[0], [ + False, True, True, True, False, True, True, True, False, True, + True, False, False, False, True, True, True, True, False, True, + True, True, True, False, False, False + ]) + self.assertEqual(output['offset_mapping'], [(0, 2), (3, 6), + (7, 13), (14, 16), + (17, 20), (21, 28), + (29, 31), (32, 40), + (41, 44), (45, 49), + (50, 53), (54, 60), + (61, 64), (65, 70), + (71, 73), (74, 80)]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run.py b/tests/run.py new file mode 100644 index 00000000..b286ecb5 --- /dev/null +++ b/tests/run.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python +# Copyright (c) Alibaba, Inc. and its affiliates. + +import argparse +import datetime +import multiprocessing +import os +import subprocess +import sys +import tempfile +import unittest +from fnmatch import fnmatch +from multiprocessing.managers import BaseManager +from pathlib import Path +from turtle import shape +from unittest import TestResult, TextTestResult + +import pandas +# NOTICE: Tensorflow 1.15 seems not so compatible with pytorch. +# A segmentation fault may be raise by pytorch cpp library +# if 'import tensorflow' in front of 'import torch'. +# Puting a 'import torch' here can bypass this incompatibility. +import torch +import yaml + +from modelscope.utils.logger import get_logger +from modelscope.utils.model_tag import ModelTag, commit_model_ut_result +from modelscope.utils.test_utils import (get_case_model_info, set_test_level, + test_level) + +logger = get_logger() + + +def test_cases_result_to_df(result_list): + table_header = [ + 'Name', 'Result', 'Info', 'Start time', 'Stop time', + 'Time cost(seconds)' + ] + df = pandas.DataFrame( + result_list, columns=table_header).sort_values( + by=['Start time'], ascending=True) + return df + + +def statistics_test_result(df): + total_cases = df.shape[0] + # yapf: disable + success_cases = df.loc[df['Result'] == 'Success'].shape[0] + error_cases = df.loc[df['Result'] == 'Error'].shape[0] + failures_cases = df.loc[df['Result'] == 'Failures'].shape[0] + expected_failure_cases = df.loc[df['Result'] == 'ExpectedFailures'].shape[0] + unexpected_success_cases = df.loc[df['Result'] == 'UnexpectedSuccesses'].shape[0] + skipped_cases = df.loc[df['Result'] == 'Skipped'].shape[0] + # yapf: enable + + if failures_cases > 0 or \ + error_cases > 0 or \ + unexpected_success_cases > 0: + final_result = 'FAILED' + else: + final_result = 'SUCCESS' + result_msg = '%s (Runs=%s,success=%s,failures=%s,errors=%s,\ + skipped=%s,expected failures=%s,unexpected successes=%s)' % ( + final_result, total_cases, success_cases, failures_cases, error_cases, + skipped_cases, expected_failure_cases, unexpected_success_cases) + + model_cases = get_case_model_info() + for model_name, case_info in model_cases.items(): + cases = df.loc[df['Name'].str.contains('|'.join(list(case_info)))] + results = cases['Result'] + result = None + if any(results == 'Error') or any(results == 'Failures') or any( + results == 'UnexpectedSuccesses'): + result = ModelTag.MODEL_FAIL + elif any(results == 'Success'): + result = ModelTag.MODEL_PASS + elif all(results == 'Skipped'): + result = ModelTag.MODEL_SKIP + else: + print(f'invalid results for {model_name} \n{result}') + + if result is not None: + commit_model_ut_result(model_name, result) + print('Testing result summary.') + print(result_msg) + if final_result == 'FAILED': + sys.exit(1) + + +def gather_test_suites_in_files(test_dir, case_file_list, list_tests): + test_suite = unittest.TestSuite() + for case in case_file_list: + test_case = unittest.defaultTestLoader.discover( + start_dir=test_dir, pattern=case) + test_suite.addTest(test_case) + if hasattr(test_case, '__iter__'): + for subcase in test_case: + if list_tests: + print(subcase) + else: + if list_tests: + print(test_case) + return test_suite + + +def gather_test_suites_files(test_dir, pattern): + case_file_list = [] + for dirpath, dirnames, filenames in os.walk(test_dir): + for file in filenames: + if fnmatch(file, pattern): + case_file_list.append(file) + + return case_file_list + + +def collect_test_results(case_results): + result_list = [ + ] # each item is Case, Result, Start time, Stop time, Time cost + for case_result in case_results.successes: + result_list.append( + (case_result.test_full_name, 'Success', '', case_result.start_time, + case_result.stop_time, case_result.time_cost)) + for case_result in case_results.errors: + result_list.append( + (case_result[0].test_full_name, 'Error', case_result[1], + case_result[0].start_time, case_result[0].stop_time, + case_result[0].time_cost)) + for case_result in case_results.skipped: + result_list.append( + (case_result[0].test_full_name, 'Skipped', case_result[1], + case_result[0].start_time, case_result[0].stop_time, + case_result[0].time_cost)) + for case_result in case_results.expectedFailures: + result_list.append( + (case_result[0].test_full_name, 'ExpectedFailures', case_result[1], + case_result[0].start_time, case_result[0].stop_time, + case_result[0].time_cost)) + for case_result in case_results.failures: + result_list.append( + (case_result[0].test_full_name, 'Failures', case_result[1], + case_result[0].start_time, case_result[0].stop_time, + case_result[0].time_cost)) + for case_result in case_results.unexpectedSuccesses: + result_list.append((case_result.test_full_name, 'UnexpectedSuccesses', + '', case_result.start_time, case_result.stop_time, + case_result.time_cost)) + return result_list + + +def run_command_with_popen(cmd): + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=1, + encoding='utf8') as sub_process: + for line in iter(sub_process.stdout.readline, ''): + sys.stdout.write(line) + + +def save_test_result(df, args): + if args.result_dir is not None: + file_name = str(int(datetime.datetime.now().timestamp() * 1000)) + os.umask(0) + Path(args.result_dir).mkdir(mode=0o777, parents=True, exist_ok=True) + Path(os.path.join(args.result_dir, file_name)).touch( + mode=0o666, exist_ok=True) + df.to_pickle(os.path.join(args.result_dir, file_name)) + + +def run_command(cmd): + logger.info('Running command: %s' % ' '.join(cmd)) + response = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + try: + response.check_returncode() + logger.info(response.stdout.decode('utf8')) + except subprocess.CalledProcessError as error: + logger.error( + 'stdout: %s, stderr: %s' % + (response.stdout.decode('utf8'), error.stderr.decode('utf8'))) + + +def install_packages(pkgs): + cmd = [sys.executable, '-m', 'pip', 'install'] + for pkg in pkgs: + cmd.append(pkg) + + run_command(cmd) + + +def install_requirements(requirements): + for req in requirements: + cmd = [ + sys.executable, '-m', 'pip', 'install', '-r', + 'requirements/%s' % req, '-f', + 'https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html' + ] + run_command(cmd) + + +def run_case_in_env(env_name, env, test_suite_env_map, isolated_cases, + result_dir): + # install requirements and deps # run_config['envs'][env] + if 'requirements' in env: + install_requirements(env['requirements']) + if 'dependencies' in env: + install_packages(env['dependencies']) + + for test_suite_file in isolated_cases: # run case in subprocess + if test_suite_file in test_suite_env_map and test_suite_env_map[ + test_suite_file] == env_name: + cmd = [ + 'python', + 'tests/run.py', + '--pattern', + test_suite_file, + '--result_dir', + result_dir, + ] + run_command_with_popen(cmd) + else: + pass # case not in run list. + + # run remain cases in a process. + remain_suite_files = [] + for k, v in test_suite_env_map.items(): + if k not in isolated_cases and v == env_name: + remain_suite_files.append(k) + if len(remain_suite_files) == 0: + return + cmd = ['python', 'tests/run.py', '--result_dir', result_dir, '--suites'] + for suite in remain_suite_files: + cmd.append(suite) + run_command_with_popen(cmd) + + +def run_in_subprocess(args): + # only case args.isolated_cases run in subporcess, all other run in a subprocess + test_suite_files = gather_test_suites_files( + os.path.abspath(args.test_dir), args.pattern) + run_config = None + isolated_cases = [] + test_suite_env_map = {} + # put all the case in default env. + for test_suite_file in test_suite_files: + test_suite_env_map[test_suite_file] = 'default' + + if args.run_config is not None and Path(args.run_config).exists(): + with open(args.run_config) as f: + run_config = yaml.load(f, Loader=yaml.FullLoader) + if 'isolated' in run_config: + isolated_cases = run_config['isolated'] + + if 'envs' in run_config: + for env in run_config['envs']: + if env != 'default': + for test_suite in run_config['envs'][env]['tests']: + if test_suite in test_suite_env_map: + test_suite_env_map[test_suite] = env + + if args.subprocess: # run all case in subprocess + isolated_cases = test_suite_files + + with tempfile.TemporaryDirectory() as temp_result_dir: + for env in set(test_suite_env_map.values()): + run_case_in_env(env, run_config['envs'][env], test_suite_env_map, + isolated_cases, temp_result_dir) + + result_dfs = [] + result_path = Path(temp_result_dir) + for result in result_path.iterdir(): + if Path.is_file(result): + df = pandas.read_pickle(result) + result_dfs.append(df) + result_pd = pandas.concat( + result_dfs) # merge result of every test suite. + print_table_result(result_pd) + print_abnormal_case_info(result_pd) + statistics_test_result(result_pd) + + +def get_object_full_name(obj): + klass = obj.__class__ + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ + return module + '.' + klass.__qualname__ + + +class TimeCostTextTestResult(TextTestResult): + """Record test case time used!""" + + def __init__(self, stream, descriptions, verbosity): + self.successes = [] + return super(TimeCostTextTestResult, + self).__init__(stream, descriptions, verbosity) + + def startTest(self, test): + test.start_time = datetime.datetime.now() + test.test_full_name = get_object_full_name( + test) + '.' + test._testMethodName + self.stream.writeln('Test case: %s start at: %s' % + (test.test_full_name, test.start_time)) + + return super(TimeCostTextTestResult, self).startTest(test) + + def stopTest(self, test): + TextTestResult.stopTest(self, test) + test.stop_time = datetime.datetime.now() + test.time_cost = (test.stop_time - test.start_time).total_seconds() + self.stream.writeln( + 'Test case: %s stop at: %s, cost time: %s(seconds)' % + (test.test_full_name, test.stop_time, test.time_cost)) + super(TimeCostTextTestResult, self).stopTest(test) + + def addSuccess(self, test): + self.successes.append(test) + super(TextTestResult, self).addSuccess(test) + + +class TimeCostTextTestRunner(unittest.runner.TextTestRunner): + resultclass = TimeCostTextTestResult + + def run(self, test): + return super(TimeCostTextTestRunner, self).run(test) + + def _makeResult(self): + result = super(TimeCostTextTestRunner, self)._makeResult() + return result + + +def gather_test_cases(test_dir, pattern, list_tests): + case_list = [] + for dirpath, dirnames, filenames in os.walk(test_dir): + for file in filenames: + if fnmatch(file, pattern): + case_list.append(file) + + test_suite = unittest.TestSuite() + + for case in case_list: + test_case = unittest.defaultTestLoader.discover( + start_dir=test_dir, pattern=case) + test_suite.addTest(test_case) + if hasattr(test_case, '__iter__'): + for subcase in test_case: + if list_tests: + print(subcase) + else: + if list_tests: + print(test_case) + return test_suite + + +def print_abnormal_case_info(df): + df = df.loc[(df['Result'] == 'Error') | (df['Result'] == 'Failures')] + for _, row in df.iterrows(): + print('Case %s run result: %s, msg:\n%s' % + (row['Name'], row['Result'], row['Info'])) + + +def print_table_result(df): + df = df.loc[df['Result'] != 'Skipped'] + df = df.drop('Info', axis=1) + formatters = { + 'Name': '{{:<{}s}}'.format(df['Name'].str.len().max()).format, + 'Result': '{{:<{}s}}'.format(df['Result'].str.len().max()).format, + } + with pandas.option_context('display.max_rows', None, 'display.max_columns', + None, 'display.width', None): + print(df.to_string(justify='left', formatters=formatters, index=False)) + + +def main(args): + runner = TimeCostTextTestRunner() + if args.suites is not None and len(args.suites) > 0: + logger.info('Running: %s' % ' '.join(args.suites)) + test_suite = gather_test_suites_in_files(args.test_dir, args.suites, + args.list_tests) + else: + test_suite = gather_test_cases( + os.path.abspath(args.test_dir), args.pattern, args.list_tests) + if not args.list_tests: + result = runner.run(test_suite) + result = collect_test_results(result) + df = test_cases_result_to_df(result) + if args.result_dir is not None: + save_test_result(df, args) + else: + print_table_result(df) + print_abnormal_case_info(df) + statistics_test_result(df) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('test runner') + parser.add_argument( + '--list_tests', action='store_true', help='list all tests') + parser.add_argument( + '--pattern', default='test_*.py', help='test file pattern') + parser.add_argument( + '--test_dir', default='tests', help='directory to be tested') + parser.add_argument( + '--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0') + parser.add_argument( + '--disable_profile', action='store_true', help='disable profiling') + parser.add_argument( + '--run_config', + default=None, + help='specified case run config file(yaml file)') + parser.add_argument( + '--subprocess', + action='store_true', + help='run all test suite in subprocess') + parser.add_argument( + '--result_dir', + default=None, + help='Save result to directory, internal use only') + parser.add_argument( + '--suites', + nargs='*', + help='Run specified test suites(test suite files list split by space)') + args = parser.parse_args() + set_test_level(args.level) + os.environ['REGRESSION_BASELINE'] = '1' + logger.info(f'TEST LEVEL: {test_level()}') + if not args.disable_profile: + from utils import profiler + logger.info('enable profile ...') + profiler.enable() + if args.run_config is not None or args.subprocess: + run_in_subprocess(args) + else: + main(args) diff --git a/tests/run_config.yaml b/tests/run_config.yaml new file mode 100644 index 00000000..d51e2606 --- /dev/null +++ b/tests/run_config.yaml @@ -0,0 +1,34 @@ +# isolate cases in env, we can install different dependencies in each env. +isolated: # test cases that may require excessive anmount of GPU memory, which will be executed in dedicagted process. + - test_text_to_speech.py + - test_multi_modal_embedding.py + - test_ofa_tasks.py + - test_video_summarization.py + - test_dialog_modeling.py + - test_csanmt_translation.py + - test_image_super_resolution.py + - test_easycv_trainer.py + - test_segformer.py + - test_segmentation_pipeline.py + - test_movie_scene_segmentation.py + - test_image_inpainting.py + +envs: + default: # default env, case not in other env will in default, pytorch. + dependencies: # requirement packages,pip install before test case run. + - numpy>=1.20 + tensorflow1x: # cases excuted tensorflow1.x framework. + requirements: # requirements files run before test case run. + - tensorflow1x.txt + dependencies: # requirement packages,pip install before test case run. + - numpy==1.18.5 + tests: + - test_text_to_speech.py + - test_csanmt_translation.py + - test_translation_trainer.py + - test_ocr_detection.py + - test_automatic_speech_recognition.py + - test_image_matting.py + - test_person_image_cartoon.py + - test_skin_retouching.py + - test_image_style_transfer.py diff --git a/tests/taskdataset/test_veco_dataset.py b/tests/taskdataset/test_veco_dataset.py new file mode 100644 index 00000000..76da1681 --- /dev/null +++ b/tests/taskdataset/test_veco_dataset.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.msdatasets.task_datasets.veco_dataset import VecoDataset +from modelscope.utils.test_utils import test_level + + +class TestVecoDataset(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_veco_dataset_train(self): + from datasets import Dataset + d0 = Dataset.from_dict({'a': [0, 1, 2]}) + d1 = Dataset.from_dict({'a': [10, 11, 12, 13, 14]}) + d2 = Dataset.from_dict({'a': [21, 22, 23, 24, 25, 26, 27]}) + dataset = VecoDataset([d0, d1, d2], mode='train') + self.assertEqual(len(dataset), 15) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_veco_dataset_eval(self): + from datasets import Dataset + d0 = Dataset.from_dict({'a': [0, 1, 2]}) + d1 = Dataset.from_dict({'a': [10, 11, 12, 13, 14]}) + d2 = Dataset.from_dict({'a': [21, 22, 23, 24, 25, 26, 27]}) + dataset = VecoDataset([d0, d1, d2], mode='eval') + self.assertEqual(len(dataset), 3) + dataset.switch_dataset(1) + self.assertEqual(len(dataset), 5) + dataset.switch_dataset(2) + self.assertEqual(len(dataset), 7) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/__init__.py b/tests/trainers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/audio/__init__.py b/tests/trainers/audio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/audio/test_ans_trainer.py b/tests/trainers/audio/test_ans_trainer.py new file mode 100644 index 00000000..d897e6a9 --- /dev/null +++ b/tests/trainers/audio/test_ans_trainer.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +import tempfile +import unittest +from functools import partial + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.audio.audio_utils import to_segment +from modelscope.utils.hub import read_config +from modelscope.utils.test_utils import test_level + +SEGMENT_LENGTH_TEST = 640 + + +class TestANSTrainer(unittest.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/speech_frcrn_ans_cirm_16k' + cfg = read_config(self.model_id) + cfg.train.max_epochs = 2 + cfg.train.dataloader.batch_size_per_gpu = 1 + self.cfg_file = os.path.join(self.tmp_dir, 'train_config.json') + cfg.dump(self.cfg_file) + + hf_ds = MsDataset.load( + 'ICASSP_2021_DNS_Challenge', split='test').to_hf_dataset() + mapped_ds = hf_ds.map( + partial(to_segment, segment_length=SEGMENT_LENGTH_TEST), + remove_columns=['duration'], + batched=True, + batch_size=2) + self.dataset = MsDataset.from_hf_dataset(mapped_ds) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + max_epochs=2, + train_iters_per_epoch=2, + val_iters_per_epoch=1, + cfg_file=self.cfg_file, + work_dir=self.tmp_dir) + + trainer = build_trainer( + Trainers.speech_frcrn_ans_cirm_16k, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i + 1}.pth', results_files) diff --git a/tests/trainers/audio/test_kws_farfield_trainer.py b/tests/trainers/audio/test_kws_farfield_trainer.py new file mode 100644 index 00000000..70b68a11 --- /dev/null +++ b/tests/trainers/audio/test_kws_farfield_trainer.py @@ -0,0 +1,83 @@ +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.trainers import build_trainer +from modelscope.utils.test_utils import test_level + +POS_FILE = 'data/test/audios/wake_word_with_label_xyxy.wav' +NEG_FILE = 'data/test/audios/speech_with_noise.wav' +NOISE_FILE = 'data/test/audios/speech_with_noise.wav' +INTERF_FILE = 'data/test/audios/speech_with_noise.wav' +REF_FILE = 'data/test/audios/farend_speech.wav' +NOISE_2CH_FILE = 'data/test/audios/noise_2ch.wav' + + +class TestKwsFarfieldTrainer(unittest.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.TemporaryDirectory().name + print(f'tmp dir: {self.tmp_dir}') + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' + + train_pos_list = self.create_list('pos.list', POS_FILE) + train_neg_list = self.create_list('neg.list', NEG_FILE) + train_noise1_list = self.create_list('noise.list', NOISE_FILE) + train_noise2_list = self.create_list('noise_2ch.list', NOISE_2CH_FILE) + train_interf_list = self.create_list('interf.list', INTERF_FILE) + train_ref_list = self.create_list('ref.list', REF_FILE) + + base_dict = dict( + train_pos_list=train_pos_list, + train_neg_list=train_neg_list, + train_noise1_list=train_noise1_list) + fintune_dict = dict( + train_pos_list=train_pos_list, + train_neg_list=train_neg_list, + train_noise1_list=train_noise1_list, + train_noise2_type='1', + train_noise1_ratio='0.2', + train_noise2_list=train_noise2_list, + train_interf_list=train_interf_list, + train_ref_list=train_ref_list) + self.custom_conf = dict( + basetrain_easy=base_dict, + basetrain_normal=base_dict, + basetrain_hard=base_dict, + finetune_easy=fintune_dict, + finetune_normal=fintune_dict, + finetune_hard=fintune_dict) + + def create_list(self, list_name, audio_file): + pos_list_file = os.path.join(self.tmp_dir, list_name) + with open(pos_list_file, 'w') as f: + for i in range(10): + f.write(f'{os.path.join(os.getcwd(), audio_file)}\n') + train_pos_list = f'{pos_list_file}, 1.0' + return train_pos_list + + def tearDown(self) -> None: + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_normal(self): + kwargs = dict( + model=self.model_id, + work_dir=self.tmp_dir, + workers=2, + max_epochs=2, + train_iters_per_epoch=2, + val_iters_per_epoch=1, + custom_conf=self.custom_conf) + + trainer = build_trainer( + Trainers.speech_dfsmn_kws_char_farfield, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files, + f'work_dir:{self.tmp_dir}') diff --git a/tests/trainers/easycv/__init__.py b/tests/trainers/easycv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/easycv/test_easycv_trainer.py b/tests/trainers/easycv/test_easycv_trainer.py new file mode 100644 index 00000000..4bd63c55 --- /dev/null +++ b/tests/trainers/easycv/test_easycv_trainer.py @@ -0,0 +1,237 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import json +import torch + +from modelscope.metainfo import Models, Pipelines, Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import LogKeys, ModeKeys, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import DistributedTestCase, test_level +from modelscope.utils.torch_utils import is_master + + +def train_func(work_dir, dist=False, log_interval=3, imgs_per_gpu=4): + import easycv + config_path = os.path.join( + os.path.dirname(easycv.__file__), + 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py') + + cfg = Config.from_file(config_path) + + cfg.log_config.update( + dict(hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ])) # not support TensorboardLoggerHookV2 + + ms_cfg_file = os.path.join(work_dir, 'ms_yolox_s_8xb16_300e_coco.json') + from easycv.utils.ms_utils import to_ms_config + + if is_master(): + to_ms_config( + cfg, + dump=True, + task=Tasks.image_object_detection, + ms_model_name=Models.yolox, + pipeline_name=Pipelines.easycv_detection, + save_path=ms_cfg_file) + + trainer_name = Trainers.easycv + train_dataset = MsDataset.load( + dataset_name='small_coco_for_test', namespace='EasyCV', split='train') + eval_dataset = MsDataset.load( + dataset_name='small_coco_for_test', + namespace='EasyCV', + split='validation') + + cfg_options = { + 'train.max_epochs': + 2, + 'train.dataloader.batch_size_per_gpu': + imgs_per_gpu, + 'evaluation.dataloader.batch_size_per_gpu': + 2, + 'train.hooks': [ + { + 'type': 'CheckpointHook', + 'interval': 1 + }, + { + 'type': 'EvaluationHook', + 'interval': 1 + }, + { + 'type': 'TextLoggerHook', + 'interval': log_interval + }, + ] + } + kwargs = dict( + cfg_file=ms_cfg_file, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=work_dir, + cfg_options=cfg_options, + launcher='pytorch' if dist else None) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + + +@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') +class EasyCVTrainerTestSingleGpu(unittest.TestCase): + + def setUp(self): + self.logger = get_logger() + self.logger.info(('Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_single_gpu(self): + train_func(self.tmp_dir) + + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + + with open(json_files[0], 'r') as f: + lines = [i.strip() for i in f.readlines()] + + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 1, + LogKeys.ITER: 3, + LogKeys.LR: 0.00013 + }, json.loads(lines[0])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 1, + LogKeys.ITER: 10 + }, json.loads(lines[1])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 2, + LogKeys.ITER: 3, + LogKeys.LR: 0.00157 + }, json.loads(lines[2])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 2, + LogKeys.ITER: 10 + }, json.loads(lines[3])) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + for i in [0, 2]: + self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) + self.assertIn(LogKeys.ITER_TIME, lines[i]) + self.assertIn(LogKeys.MEMORY, lines[i]) + self.assertIn('total_loss', lines[i]) + for i in [1, 3]: + self.assertIn( + 'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP', + lines[i]) + self.assertIn('DetectionBoxes_Precision/mAP', lines[i]) + self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i]) + self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i]) + self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i]) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, 'distributed unittest') +class EasyCVTrainerTestMultiGpus(DistributedTestCase): + + def setUp(self): + self.logger = get_logger() + self.logger.info(('Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_multi_gpus(self): + self.start( + train_func, + num_gpus=2, + work_dir=self.tmp_dir, + dist=True, + log_interval=2, + imgs_per_gpu=5) + + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + + with open(json_files[0], 'r') as f: + lines = [i.strip() for i in f.readlines()] + + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 1, + LogKeys.ITER: 2, + LogKeys.LR: 0.0002 + }, json.loads(lines[0])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 1, + LogKeys.ITER: 5 + }, json.loads(lines[1])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 2, + LogKeys.ITER: 2, + LogKeys.LR: 0.0018 + }, json.loads(lines[2])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 2, + LogKeys.ITER: 5 + }, json.loads(lines[3])) + + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + + for i in [0, 2]: + self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) + self.assertIn(LogKeys.ITER_TIME, lines[i]) + self.assertIn(LogKeys.MEMORY, lines[i]) + self.assertIn('total_loss', lines[i]) + for i in [1, 3]: + self.assertIn( + 'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP', + lines[i]) + self.assertIn('DetectionBoxes_Precision/mAP', lines[i]) + self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i]) + self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i]) + self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py b/tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py new file mode 100644 index 00000000..e4f0c57e --- /dev/null +++ b/tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, LogKeys, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + + +@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') +class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase): + model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' + + def setUp(self): + self.logger = get_logger() + self.logger.info(('Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + + def _train(self, tmp_dir): + cfg_options = {'train.max_epochs': 2} + + trainer_name = Trainers.easycv + + train_dataset = MsDataset.load( + dataset_name='face_2d_keypoints_dataset', + namespace='modelscope', + split='train', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + eval_dataset = MsDataset.load( + dataset_name='face_2d_keypoints_dataset', + namespace='modelscope', + split='train', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + + kwargs = dict( + model=self.model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=tmp_dir, + cfg_options=cfg_options) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + + @unittest.skip( + 'skip since face_2d_keypoints_dataset is set to private for now') + def test_trainer_single_gpu(self): + temp_file_dir = tempfile.TemporaryDirectory() + tmp_dir = temp_file_dir.name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + self._train(tmp_dir) + + results_files = os.listdir(tmp_dir) + json_files = glob.glob(os.path.join(tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + + temp_file_dir.cleanup() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py b/tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py new file mode 100644 index 00000000..270ecbc4 --- /dev/null +++ b/tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, LogKeys, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + + +@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') +class EasyCVTrainerTestHand2dKeypoints(unittest.TestCase): + model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody' + + def setUp(self): + self.logger = get_logger() + self.logger.info(('Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def _train(self): + cfg_options = {'train.max_epochs': 20} + + trainer_name = Trainers.easycv + + train_dataset = MsDataset.load( + dataset_name='cv_hand_2d_keypoints_coco_wholebody', + namespace='chenhyer', + split='subtrain', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + eval_dataset = MsDataset.load( + dataset_name='cv_hand_2d_keypoints_coco_wholebody', + namespace='chenhyer', + split='subtrain', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + + kwargs = dict( + model=self.model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=self.tmp_dir, + cfg_options=cfg_options) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_single_gpu(self): + self._train() + + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + self.assertIn(f'{LogKeys.EPOCH}_10.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_20.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/easycv/test_segformer.py b/tests/trainers/easycv/test_segformer.py new file mode 100644 index 00000000..90a66635 --- /dev/null +++ b/tests/trainers/easycv/test_segformer.py @@ -0,0 +1,72 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import LogKeys, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + + +@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') +class EasyCVTrainerTestSegformer(unittest.TestCase): + + def setUp(self): + self.logger = get_logger() + self.logger.info(('Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def _train(self): + + cfg_options = { + 'train.max_epochs': 2, + 'model.decode_head.norm_cfg.type': 'BN' + } + + trainer_name = Trainers.easycv + train_dataset = MsDataset.load( + dataset_name='small_coco_stuff164k', + namespace='EasyCV', + split='train') + eval_dataset = MsDataset.load( + dataset_name='small_coco_stuff164k', + namespace='EasyCV', + split='validation') + kwargs = dict( + model= + 'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k', + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=self.tmp_dir, + cfg_options=cfg_options) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_single_gpu_segformer(self): + self._train() + + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/hooks/__init__.py b/tests/trainers/hooks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/hooks/compression/__init__.py b/tests/trainers/hooks/compression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/hooks/compression/test_sparsity_hook.py b/tests/trainers/hooks/compression/test_sparsity_hook.py new file mode 100644 index 00000000..4af4dcdb --- /dev/null +++ b/tests/trainers/hooks/compression/test_sparsity_hook.py @@ -0,0 +1,113 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import MultiStepLR + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, TrainerStages +from modelscope.utils.test_utils import create_dummy_test_dataset + +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 10) + self.bn = nn.BatchNorm1d(10) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class SparsityHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_sparsity_hook(self): + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': + self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'SparsityHook', + 'pruning_method': 'pst', + 'config': { + 'weight_rank': 1, + 'mask_rank': 1, + 'final_sparsity': 0.9, + 'frequency': 1, + }, + }], + }, + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4]) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + train_dataset=dummy_dataset, + optimizers=(optimizer, lr_scheduler), + max_epochs=5, + device='cpu', + ) + + trainer = build_trainer(trainer_name, kwargs) + train_dataloader = trainer._build_dataloader_with_dataset( + trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) + trainer.register_optimizers_hook() + trainer.register_hook_from_cfg(trainer.cfg.train.hooks) + trainer.train_dataloader = train_dataloader + trainer.data_loader = train_dataloader + trainer.invoke_hook(TrainerStages.before_run) + for i in range(trainer._epoch, trainer._max_epochs): + trainer.invoke_hook(TrainerStages.before_train_epoch) + for _, data_batch in enumerate(train_dataloader): + trainer.invoke_hook(TrainerStages.before_train_iter) + trainer.train_step(trainer.model, data_batch) + trainer.invoke_hook(TrainerStages.after_train_iter) + trainer.invoke_hook(TrainerStages.after_train_epoch) + trainer.invoke_hook(TrainerStages.after_run) + + self.assertEqual( + torch.mean(1.0 * (trainer.model.linear.weight == 0)), 0.9) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/hooks/logger/__init__.py b/tests/trainers/hooks/logger/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/hooks/logger/test_tensorboard_hook.py b/tests/trainers/hooks/logger/test_tensorboard_hook.py new file mode 100644 index 00000000..67b1aa63 --- /dev/null +++ b/tests/trainers/hooks/logger/test_tensorboard_hook.py @@ -0,0 +1,108 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.utils.constant import LogKeys, ModelFile +from modelscope.utils.test_utils import create_dummy_test_dataset + +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 4) + self.bn = nn.BatchNorm1d(4) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class TensorboardHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_tensorboard_hook(self): + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'optimizer': { + 'type': 'SGD', + 'lr': 0.01 + }, + 'lr_scheduler': { + 'type': 'StepLR', + 'step_size': 2, + }, + 'hooks': [{ + 'type': 'TensorboardHook', + 'interval': 2 + }] + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=DummyModel(), + data_collator=None, + train_dataset=dummy_dataset, + max_epochs=2) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + tb_out_dir = os.path.join(self.tmp_dir, 'tensorboard_output') + + events_files = glob.glob( + os.path.join(tb_out_dir, 'events.out.tfevents.*')) + self.assertEqual(len(events_files), 1) + + from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + ea = EventAccumulator(events_files[0]) + ea.Reload() + self.assertEqual(len(ea.Scalars(LogKeys.LOSS)), 10) + self.assertEqual(len(ea.Scalars(LogKeys.LR)), 10) + for i in range(5): + self.assertAlmostEqual( + ea.Scalars(LogKeys.LR)[i].value, 0.01, delta=0.001) + for i in range(5, 10): + self.assertAlmostEqual( + ea.Scalars(LogKeys.LR)[i].value, 0.01, delta=0.0001) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/hooks/test_checkpoint_hook.py b/tests/trainers/hooks/test_checkpoint_hook.py new file mode 100644 index 00000000..e7f2d33c --- /dev/null +++ b/tests/trainers/hooks/test_checkpoint_hook.py @@ -0,0 +1,220 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn + +from modelscope.metainfo import Trainers +from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.utils.constant import LogKeys, ModelFile +from modelscope.utils.registry import default_group +from modelscope.utils.test_utils import create_dummy_test_dataset + +SRC_DIR = os.path.dirname(__file__) + + +def create_dummy_metric(): + _global_iter = 0 + + @METRICS.register_module( + group_key=default_group, module_name='DummyMetric', force=True) + class DummyMetric: + + _fake_acc_by_epoch = {1: 0.1, 2: 0.5, 3: 0.2} + + def add(*args, **kwargs): + pass + + def evaluate(self): + global _global_iter + _global_iter += 1 + return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} + + +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 4) + self.bn = nn.BatchNorm1d(4) + self.model_dir = SRC_DIR + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class CheckpointHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + create_dummy_metric() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_checkpoint_hook(self): + global _global_iter + _global_iter = 0 + + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'optimizer': { + 'type': 'SGD', + 'lr': 0.01, + 'options': { + 'grad_clip': { + 'max_norm': 2.0 + } + } + }, + 'lr_scheduler': { + 'type': 'StepLR', + 'step_size': 2, + 'options': { + 'warmup': { + 'type': 'LinearWarmup', + 'warmup_iters': 2 + } + } + }, + 'hooks': [{ + 'type': 'CheckpointHook', + 'interval': 1 + }] + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=DummyModel(), + data_collator=None, + train_dataset=dummy_dataset, + max_epochs=2) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + + output_files = os.listdir( + os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) + self.assertIn(ModelFile.CONFIGURATION, output_files) + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) + copy_src_files = os.listdir(SRC_DIR) + self.assertIn(copy_src_files[0], output_files) + self.assertIn(copy_src_files[-1], output_files) + + +class BestCkptSaverHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + create_dummy_metric() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_best_checkpoint_hook(self): + global _global_iter + _global_iter = 0 + + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': + self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'optimizer': { + 'type': 'SGD', + 'lr': 0.01 + }, + 'lr_scheduler': { + 'type': 'StepLR', + 'step_size': 2 + }, + 'hooks': [{ + 'type': 'BestCkptSaverHook', + 'metric_key': MetricKeys.ACCURACY, + 'rule': 'min' + }, { + 'type': 'EvaluationHook', + 'interval': 1, + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': ['DummyMetric'] + } + } + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=DummyModel(), + data_collator=None, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + max_epochs=3) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', + results_files) + + output_files = os.listdir( + os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) + self.assertIn(ModelFile.CONFIGURATION, output_files) + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) + copy_src_files = os.listdir(SRC_DIR) + self.assertIn(copy_src_files[0], output_files) + self.assertIn(copy_src_files[-1], output_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/hooks/test_evaluation_hook.py b/tests/trainers/hooks/test_evaluation_hook.py new file mode 100644 index 00000000..2c71e790 --- /dev/null +++ b/tests/trainers/hooks/test_evaluation_hook.py @@ -0,0 +1,117 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn + +from modelscope.metainfo import Trainers +from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.registry import default_group +from modelscope.utils.test_utils import create_dummy_test_dataset + + +def create_dummy_metric(): + + @METRICS.register_module( + group_key=default_group, module_name='DummyMetric', force=True) + class DummyMetric: + + def add(*args, **kwargs): + pass + + def evaluate(self): + return {MetricKeys.ACCURACY: 0.5} + + +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 4) + self.bn = nn.BatchNorm1d(4) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class EvaluationHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + create_dummy_metric() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_evaluation_hook(self): + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'optimizer': { + 'type': 'SGD', + 'lr': 0.01, + }, + 'lr_scheduler': { + 'type': 'StepLR', + 'step_size': 2, + }, + 'hooks': [{ + 'type': 'EvaluationHook', + 'interval': 1, + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': ['DummyMetric'] + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=DummyModel(), + data_collator=None, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + max_epochs=1) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + self.assertDictEqual(trainer.metric_values, {'accuracy': 0.5}) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/hooks/test_lr_scheduler_hook.py b/tests/trainers/hooks/test_lr_scheduler_hook.py new file mode 100644 index 00000000..7a1ff220 --- /dev/null +++ b/tests/trainers/hooks/test_lr_scheduler_hook.py @@ -0,0 +1,304 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import MultiStepLR + +from modelscope.metainfo import Trainers +from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages +from modelscope.utils.registry import default_group +from modelscope.utils.test_utils import create_dummy_test_dataset + +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) + + +def create_dummy_metric(): + _global_iter = 0 + + @METRICS.register_module( + group_key=default_group, module_name='DummyMetric', force=True) + class DummyMetric: + + _fake_acc_by_epoch = {1: 0.1, 2: 0.1, 3: 0.1, 4: 0.1, 5: 0.3} + + def add(*args, **kwargs): + pass + + def evaluate(self): + global _global_iter + _global_iter += 1 + return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 4) + self.bn = nn.BatchNorm1d(4) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class LrSchedulerHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + create_dummy_metric() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_lr_scheduler_hook(self): + global _global_iter + _global_iter = 0 + + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + } + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4]) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + train_dataset=dummy_dataset, + optimizers=(optimizer, lr_scheduler), + max_epochs=5, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + train_dataloader = trainer._build_dataloader_with_dataset( + trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) + trainer.register_optimizers_hook() + + trainer.invoke_hook(TrainerStages.before_run) + log_lrs = [] + optim_lrs = [] + for _ in range(trainer._epoch, trainer._max_epochs): + trainer.invoke_hook(TrainerStages.before_train_epoch) + for _, data_batch in enumerate(train_dataloader): + trainer.invoke_hook(TrainerStages.before_train_iter) + trainer.train_step(trainer.model, data_batch) + trainer.invoke_hook(TrainerStages.after_train_iter) + + log_lrs.append(trainer.log_buffer.output[LogKeys.LR]) + optim_lrs.append(optimizer.param_groups[0]['lr']) + + trainer.invoke_hook(TrainerStages.after_train_epoch) + trainer._epoch += 1 + trainer.invoke_hook(TrainerStages.after_run) + + iters = 5 + target_lrs = [0.01] * iters * 2 + [0.001] * iters * 2 + [0.0001 + ] * iters * 1 + self.assertListEqual(log_lrs, target_lrs) + self.assertListEqual(optim_lrs, target_lrs) + + def test_warmup_lr_scheduler_hook(self): + global _global_iter + _global_iter = 0 + + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'optimizer': { + 'type': 'SGD', + 'lr': 0.01 + }, + 'lr_scheduler': { + 'type': 'MultiStepLR', + 'milestones': [4, 6], + 'options': { + 'warmup': { + 'type': 'LinearWarmup', + 'warmup_iters': 3 + } + } + } + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + train_dataset=dummy_dataset, + max_epochs=7, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + train_dataloader = trainer._build_dataloader_with_dataset( + trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) + trainer.register_optimizers_hook() + + trainer.invoke_hook(TrainerStages.before_run) + log_lrs = [] + optim_lrs = [] + for _ in range(trainer._epoch, trainer._max_epochs): + trainer.invoke_hook(TrainerStages.before_train_epoch) + for _, data_batch in enumerate(train_dataloader): + trainer.invoke_hook(TrainerStages.before_train_iter) + trainer.train_step(trainer.model, data_batch) + trainer.invoke_hook(TrainerStages.after_train_iter) + + log_lrs.append(round(trainer.log_buffer.output[LogKeys.LR], 5)) + optim_lrs.append( + round(trainer.optimizer.param_groups[0]['lr'], 5)) + + trainer.invoke_hook(TrainerStages.after_train_epoch) + trainer.invoke_hook(TrainerStages.after_run) + + iters = 5 + target_lrs = [0.001] * iters * 1 + [0.004] * iters * 1 + [ + 0.007 + ] * iters * 1 + [0.01] * iters * 1 + [0.001] * iters * 2 + [ + 0.0001 + ] * iters * 1 + + self.assertListEqual(log_lrs, target_lrs) + self.assertListEqual(optim_lrs, target_lrs) + + +class PlateauLrSchedulerHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + create_dummy_metric() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_plateau_lr_scheduler_hook(self): + global _global_iter + _global_iter = 0 + + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'lr_scheduler': { + 'type': 'ReduceLROnPlateau', + 'mode': 'max', + 'factor': 0.1, + 'patience': 2, + }, + 'lr_scheduler_hook': { + 'type': 'PlateauLrSchedulerHook', + 'metric_key': MetricKeys.ACCURACY + }, + 'hooks': [{ + 'type': 'EvaluationHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': ['DummyMetric'] + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimizer = SGD(model.parameters(), lr=0.01) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + optimizers=(optimizer, None), + max_epochs=5, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + train_dataloader = trainer._build_dataloader_with_dataset( + trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) + trainer.train_dataloader = train_dataloader + trainer.data_loader = train_dataloader + trainer.register_optimizers_hook() + trainer.register_hook_from_cfg(trainer.cfg.train.hooks) + + trainer.invoke_hook(TrainerStages.before_run) + log_lrs = [] + optim_lrs = [] + for _ in range(trainer._epoch, trainer._max_epochs): + trainer.invoke_hook(TrainerStages.before_train_epoch) + for _, data_batch in enumerate(train_dataloader): + trainer.invoke_hook(TrainerStages.before_train_iter) + trainer.train_step(trainer.model, data_batch) + trainer.invoke_hook(TrainerStages.after_train_iter) + + log_lrs.append(trainer.log_buffer.output[LogKeys.LR]) + optim_lrs.append(optimizer.param_groups[0]['lr']) + + trainer.invoke_hook(TrainerStages.after_train_epoch) + trainer._epoch += 1 + trainer.invoke_hook(TrainerStages.after_run) + + iters = 5 + target_lrs = [0.01] * iters * 4 + [0.001] * iters * 1 + self.assertListEqual(log_lrs, target_lrs) + self.assertListEqual(optim_lrs, target_lrs) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/hooks/test_optimizer_hook.py b/tests/trainers/hooks/test_optimizer_hook.py new file mode 100644 index 00000000..84c783b5 --- /dev/null +++ b/tests/trainers/hooks/test_optimizer_hook.py @@ -0,0 +1,181 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import MultiStepLR + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, TrainerStages +from modelscope.utils.test_utils import create_dummy_test_dataset + +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(2, )), np.random.randint(0, 2, (1, )), 10) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + self.bn = nn.BatchNorm1d(2) + + def forward(self, feat, labels): + x = self.linear(feat) + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class OptimizerHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_optimizer_hook(self): + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + } + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = MultiStepLR(optimizer, milestones=[1, 2]) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + train_dataset=dummy_dataset, + optimizers=(optimizer, lr_scheduler), + max_epochs=2, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + train_dataloader = trainer._build_dataloader_with_dataset( + trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) + trainer.register_optimizers_hook() + + trainer.invoke_hook(TrainerStages.before_run) + + for _ in range(trainer._epoch, trainer._max_epochs): + trainer.invoke_hook(TrainerStages.before_train_epoch) + for _, data_batch in enumerate(train_dataloader): + trainer.invoke_hook(TrainerStages.before_train_iter) + trainer.train_step(trainer.model, data_batch) + trainer.invoke_hook(TrainerStages.after_train_iter) + + self.assertEqual( + len(trainer.optimizer.param_groups[0]['params']), 4) + for i in range(4): + self.assertTrue(trainer.optimizer.param_groups[0]['params'] + [i].requires_grad) + + trainer.invoke_hook(TrainerStages.after_train_epoch) + trainer._epoch += 1 + trainer.invoke_hook(TrainerStages.after_run) + + +class TorchAMPOptimizerHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + @unittest.skipIf(not torch.cuda.is_available(), + 'skip this test when cuda is not available') + def test_amp_optimizer_hook(self): + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + } + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel().cuda() + optimizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = MultiStepLR(optimizer, milestones=[1, 2]) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + train_dataset=dummy_dataset, + optimizers=(optimizer, lr_scheduler), + max_epochs=2, + use_fp16=True) + + trainer = build_trainer(trainer_name, kwargs) + train_dataloader = trainer._build_dataloader_with_dataset( + trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) + trainer.register_optimizers_hook() + + trainer.invoke_hook(TrainerStages.before_run) + + for _ in range(trainer._epoch, trainer._max_epochs): + trainer.invoke_hook(TrainerStages.before_train_epoch) + for _, data_batch in enumerate(train_dataloader): + for k, v in data_batch.items(): + data_batch[k] = v.cuda() + trainer.invoke_hook(TrainerStages.before_train_iter) + trainer.train_step(trainer.model, data_batch) + trainer.invoke_hook(TrainerStages.after_train_iter) + + self.assertEqual(trainer.train_outputs['logits'].dtype, + torch.float16) + + # test if `after_train_iter`, whether the model is reset to fp32 + trainer.train_step(trainer.model, data_batch) + self.assertEqual(trainer.train_outputs['logits'].dtype, + torch.float32) + + self.assertEqual( + len(trainer.optimizer.param_groups[0]['params']), 4) + for i in range(4): + self.assertTrue(trainer.optimizer.param_groups[0]['params'] + [i].requires_grad) + + trainer.invoke_hook(TrainerStages.after_train_epoch) + trainer._epoch += 1 + trainer.invoke_hook(TrainerStages.after_run) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/hooks/test_timer_hook.py b/tests/trainers/hooks/test_timer_hook.py new file mode 100644 index 00000000..9fb79c77 --- /dev/null +++ b/tests/trainers/hooks/test_timer_hook.py @@ -0,0 +1,128 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import MultiStepLR + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages +from modelscope.utils.test_utils import create_dummy_test_dataset + +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 4) + self.bn = nn.BatchNorm1d(4) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class IterTimerHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_iter_time_hook(self): + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'IterTimerHook', + }] + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4]) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + train_dataset=dummy_dataset, + optimizers=(optimizer, lr_scheduler), + max_epochs=5, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + train_dataloader = trainer._build_dataloader_with_dataset( + trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) + trainer.register_optimizers_hook() + trainer.register_hook_from_cfg(trainer.cfg.train.hooks) + trainer.train_dataloader = train_dataloader + trainer.data_loader = train_dataloader + trainer.invoke_hook(TrainerStages.before_run) + for i in range(trainer._epoch, trainer._max_epochs): + trainer.invoke_hook(TrainerStages.before_train_epoch) + for _, data_batch in enumerate(train_dataloader): + trainer.invoke_hook(TrainerStages.before_train_iter) + trainer.train_step(trainer.model, data_batch) + trainer.invoke_hook(TrainerStages.after_train_iter) + + self.assertIn(LogKeys.DATA_LOAD_TIME, + trainer.log_buffer.val_history) + self.assertIn(LogKeys.ITER_TIME, + trainer.log_buffer.val_history) + self.assertIn(LogKeys.LOSS, trainer.log_buffer.val_history) + + trainer.invoke_hook(TrainerStages.after_train_epoch) + + target_len = 5 + self.assertEqual( + len(trainer.log_buffer.val_history[LogKeys.DATA_LOAD_TIME]), + target_len) + self.assertEqual( + len(trainer.log_buffer.val_history[LogKeys.ITER_TIME]), + target_len) + self.assertEqual( + len(trainer.log_buffer.val_history[LogKeys.LOSS]), target_len) + + self.assertEqual( + len(trainer.log_buffer.n_history[LogKeys.DATA_LOAD_TIME]), + target_len) + self.assertEqual( + len(trainer.log_buffer.n_history[LogKeys.ITER_TIME]), + target_len) + self.assertEqual( + len(trainer.log_buffer.n_history[LogKeys.LOSS]), target_len) + + trainer.invoke_hook(TrainerStages.after_run) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/lrscheduler/__init__.py b/tests/trainers/lrscheduler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/lrscheduler/warmup/__init__.py b/tests/trainers/lrscheduler/warmup/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/lrscheduler/warmup/test_warmup_base.py b/tests/trainers/lrscheduler/warmup/test_warmup_base.py new file mode 100644 index 00000000..45c9fe2c --- /dev/null +++ b/tests/trainers/lrscheduler/warmup/test_warmup_base.py @@ -0,0 +1,79 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import torch +from torch import nn +from torch.optim.lr_scheduler import MultiStepLR + + +class WarmupTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def test_constant_warmup(self): + from modelscope.trainers.lrscheduler.warmup import ConstantWarmup + + net = nn.Linear(2, 2) + base_lr = 0.02 + warmup_iters = 3 + warmup_ratio = 0.2 + optimizer = torch.optim.SGD(net.parameters(), lr=base_lr, momentum=0.9) + lr_scheduler = MultiStepLR(optimizer, milestones=[7, 9]) + lr_scheduler_with_warmup = ConstantWarmup( + lr_scheduler, warmup_iters=warmup_iters, warmup_ratio=warmup_ratio) + + res = [] + for _ in range(10): + lr_scheduler_with_warmup.step() + for _, group in enumerate(optimizer.param_groups): + res.append(group['lr']) + + base_lrs = [0.02, 0.02, 0.02, 0.002, 0.002, 0.0002, 0.0002] + self.assertListEqual(res, [0.004, 0.004, 0.02] + base_lrs) + + def test_linear_warmup(self): + from modelscope.trainers.lrscheduler.warmup import LinearWarmup + + net = nn.Linear(2, 2) + base_lr = 0.02 + warmup_iters = 3 + warmup_ratio = 0.1 + optimizer = torch.optim.SGD(net.parameters(), lr=base_lr, momentum=0.9) + lr_scheduler = MultiStepLR(optimizer, milestones=[7, 9]) + lr_scheduler_with_warmup = LinearWarmup( + lr_scheduler, warmup_iters=warmup_iters, warmup_ratio=warmup_ratio) + + res = [] + for _ in range(10): + lr_scheduler_with_warmup.step() + for _, group in enumerate(optimizer.param_groups): + res.append(round(group['lr'], 5)) + + base_lrs = [0.02, 0.02, 0.02, 0.002, 0.002, 0.0002, 0.0002] + self.assertListEqual(res, [0.0080, 0.0140, 0.02] + base_lrs) + + def test_exp_warmup(self): + from modelscope.trainers.lrscheduler.warmup import ExponentialWarmup + + net = nn.Linear(2, 2) + base_lr = 0.02 + warmup_iters = 3 + warmup_ratio = 0.1 + optimizer = torch.optim.SGD(net.parameters(), lr=base_lr, momentum=0.9) + lr_scheduler = MultiStepLR(optimizer, milestones=[7, 9]) + lr_scheduler_with_warmup = ExponentialWarmup( + lr_scheduler, warmup_iters=warmup_iters, warmup_ratio=warmup_ratio) + + res = [] + for _ in range(10): + lr_scheduler_with_warmup.step() + for _, group in enumerate(optimizer.param_groups): + res.append(round(group['lr'], 5)) + + base_lrs = [0.02, 0.02, 0.02, 0.002, 0.002, 0.0002, 0.0002] + self.assertListEqual(res, [0.00431, 0.00928, 0.02] + base_lrs) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_card_detection_scrfd_trainer.py b/tests/trainers/test_card_detection_scrfd_trainer.py new file mode 100644 index 00000000..af87000b --- /dev/null +++ b/tests/trainers/test_card_detection_scrfd_trainer.py @@ -0,0 +1,151 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import DistributedTestCase, test_level + + +def _setup(): + model_id = 'damo/cv_resnet_carddetection_scrfd34gkps' + # mini dataset only for unit test, remove '_mini' for full dataset. + ms_ds_syncards = MsDataset.load( + 'SyntheticCards_mini', namespace='shaoxuan') + + data_path = ms_ds_syncards.config_kwargs['split_config'] + train_dir = data_path['train'] + val_dir = data_path['validation'] + train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' + val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' + max_epochs = 1 # run epochs in unit test + + cache_path = snapshot_download(model_id) + + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + return train_root, val_root, max_epochs, cache_path, tmp_dir + + +def train_func(**kwargs): + trainer = build_trainer( + name=Trainers.card_detection_scrfd, default_args=kwargs) + trainer.train() + + +class TestCardDetectionScrfdTrainerSingleGPU(unittest.TestCase): + + def setUp(self): + print(('SingleGPU Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def _cfg_modify_fn(self, cfg): + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 # batch size + return cfg + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_from_scratch(self): + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.card_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.card_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, 'distributed unittest') +class TestCardDetectionScrfdTrainerMultiGpus(DistributedTestCase): + + def setUp(self): + print(('MultiGPUs Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') + cfg = Config.from_file(cfg_file_path) + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 + cfg.dump(cfg_file_path) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_multi_gpus_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + launcher='pytorch') + self.start(train_func, num_gpus=2, **kwargs) + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_clip_trainer.py b/tests/trainers/test_clip_trainer.py new file mode 100644 index 00000000..e460f1ac --- /dev/null +++ b/tests/trainers/test_clip_trainer.py @@ -0,0 +1,83 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest + +import json + +from modelscope.metainfo import Metrics, Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestClipTrainer(unittest.TestCase): + + def setUp(self) -> None: + self.finetune_cfg = \ + {'framework': 'pytorch', + 'task': 'multi-modal-embedding', + 'pipeline': {'type': 'multi-modal-embedding'}, + 'pretrained_model': {'model_name': 'damo/multi-modal_clip-vit-base-patch16_zh'}, + 'dataset': {'column_map': {'img': 'image', 'text': 'query'}}, + 'train': {'work_dir': './workspace/ckpts/clip', + # 'launcher': 'pytorch', + 'max_epochs': 1, + 'use_fp16': True, + 'dataloader': {'batch_size_per_gpu': 8, + 'workers_per_gpu': 0, + 'shuffle': True, + 'drop_last': True}, + 'lr_scheduler': {'name': 'cosine', + 'warmup_proportion': 0.01}, + 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, + 'optimizer': {'type': 'AdamW'}, + 'optimizer_hparams': {'lr': 5e-05, 'weight_decay': 0.01}, + 'optimizer_hook': {'type': 'TorchAMPOptimizerHook', + 'cumulative_iters': 1, + 'loss_keys': 'loss'}, + 'loss_cfg': {'aggregate': True}, + 'hooks': [{'type': 'BestCkptSaverHook', + 'metric_key': 'inbatch_t2i_recall_at_1', + 'interval': 100}, + {'type': 'TextLoggerHook', 'interval': 1}, + {'type': 'IterTimerHook'}, + {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}, + {'type': 'ClipClampLogitScaleHook'}]}, + 'evaluation': {'dataloader': {'batch_size_per_gpu': 8, + 'workers_per_gpu': 0, + 'shuffle': True, + 'drop_last': True}, + 'metrics': [{'type': 'inbatch_recall'}]}, + 'preprocessor': []} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_std(self): + WORKSPACE = './workspace/ckpts/clip' + os.makedirs(WORKSPACE, exist_ok=True) + config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) + with open(config_file, 'w') as writer: + json.dump(self.finetune_cfg, writer) + + pretrained_model = 'damo/multi-modal_clip-vit-base-patch16_zh' + args = dict( + model=pretrained_model, + work_dir=WORKSPACE, + train_dataset=MsDataset.load( + 'muge', namespace='modelscope', split='train[:200]'), + eval_dataset=MsDataset.load( + 'muge', namespace='modelscope', split='validation[:100]'), + metrics=[Metrics.inbatch_recall], + cfg_file=config_file) + trainer = build_trainer( + name=Trainers.clip_multi_modal_embedding, default_args=args) + trainer.train() + + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, + os.listdir(os.path.join(WORKSPACE, 'output'))) + shutil.rmtree(WORKSPACE) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_dialog_intent_trainer.py b/tests/trainers/test_dialog_intent_trainer.py new file mode 100644 index 00000000..207387ac --- /dev/null +++ b/tests/trainers/test_dialog_intent_trainer.py @@ -0,0 +1,103 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +import tempfile +import unittest + +import json + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import DownloadMode, ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class TestDialogIntentTrainer(unittest.TestCase): + + def setUp(self): + self.save_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.save_dir): + os.mkdir(self.save_dir) + + def tearDown(self): + shutil.rmtree(self.save_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + model_id = 'damo/nlp_space_pretrained-dialog-model' + data_banking = MsDataset.load('banking77') + self.data_dir = data_banking._hf_ds.config_kwargs['split_config'][ + 'train'] + self.model_dir = snapshot_download(model_id) + self.debugging = True + kwargs = dict( + model_dir=self.model_dir, + cfg_name='intent_train_config.json', + cfg_modify_fn=self.cfg_modify_fn) + trainer = build_trainer( + name=Trainers.dialog_intent_trainer, default_args=kwargs) + trainer.train() + + def cfg_modify_fn(self, cfg): + config = { + 'num_intent': 77, + 'BPETextField': { + 'vocab_path': '', + 'data_name': 'banking77', + 'data_root': self.data_dir, + 'understand': True, + 'generation': False, + 'max_len': 256 + }, + 'Dataset': { + 'data_dir': self.data_dir, + 'with_contrastive': False, + 'trigger_role': 'user', + 'trigger_data': 'banking' + }, + 'Trainer': { + 'can_norm': True, + 'seed': 11, + 'gpu': 1, + 'save_dir': self.save_dir, + 'batch_size_label': 128, + 'batch_size_nolabel': 0, + 'log_steps': 20 + }, + 'Model': { + 'init_checkpoint': self.model_dir, + 'model': 'IntentUnifiedTransformer', + 'example': False, + 'num_intent': 77, + 'with_rdrop': True, + 'num_turn_embeddings': 21, + 'dropout': 0.25, + 'kl_ratio': 5.0, + 'embed_dropout': 0.25, + 'attn_dropout': 0.25, + 'ff_dropout': 0.25, + 'with_pool': False, + 'warmup_steps': -1 + } + } + cfg.BPETextField.vocab_path = os.path.join(self.model_dir, + ModelFile.VOCAB_FILE) + cfg.num_intent = 77 + cfg.Trainer.update(config['Trainer']) + cfg.BPETextField.update(config['BPETextField']) + cfg.Dataset.update(config['Dataset']) + cfg.Model.update(config['Model']) + if self.debugging: + cfg.Trainer.save_checkpoint = False + cfg.Trainer.num_epochs = 5 + cfg.Trainer.batch_size_label = 64 + return cfg + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_dialog_modeling_trainer.py b/tests/trainers/test_dialog_modeling_trainer.py new file mode 100644 index 00000000..be03db30 --- /dev/null +++ b/tests/trainers/test_dialog_modeling_trainer.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Preprocessors, Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.test_utils import test_level + + +class TestDialogModelingTrainer(unittest.TestCase): + + model_id = 'damo/nlp_space_pretrained-dialog-model' + output_dir = './dialog_fintune_result' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + # download data set + data_multiwoz = MsDataset.load( + 'MultiWoz2.0', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + data_dir = os.path.join( + data_multiwoz._hf_ds.config_kwargs['split_config']['train'], + 'data') + + # download model + model_dir = snapshot_download(self.model_id) + + # dialog finetune config + def cfg_modify_fn(cfg): + config = { + 'seed': 10, + 'gpu': 4, + 'use_data_distributed': False, + 'valid_metric_name': '-loss', + 'num_epochs': 60, + 'save_dir': self.output_dir, + 'token_loss': True, + 'batch_size': 32, + 'log_steps': 10, + 'valid_steps': 0, + 'save_checkpoint': True, + 'save_summary': False, + 'shuffle': True, + 'sort_pool_size': 0 + } + + cfg.Trainer = config + cfg.use_gpu = torch.cuda.is_available() and config['gpu'] >= 1 + return cfg + + # trainer config + kwargs = dict( + model_dir=model_dir, + cfg_name='gen_train_config.json', + data_dir=data_dir, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.dialog_modeling_trainer, default_args=kwargs) + trainer.train() + checkpoint_path = os.path.join(self.output_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + assert os.path.exists(checkpoint_path) + trainer.evaluate(checkpoint_path=checkpoint_path) diff --git a/tests/trainers/test_face_detection_scrfd_trainer.py b/tests/trainers/test_face_detection_scrfd_trainer.py new file mode 100644 index 00000000..97b0eca7 --- /dev/null +++ b/tests/trainers/test_face_detection_scrfd_trainer.py @@ -0,0 +1,150 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import DistributedTestCase, test_level + + +def _setup(): + model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' + # mini dataset only for unit test, remove '_mini' for full dataset. + ms_ds_widerface = MsDataset.load('WIDER_FACE_mini', namespace='shaoxuan') + + data_path = ms_ds_widerface.config_kwargs['split_config'] + train_dir = data_path['train'] + val_dir = data_path['validation'] + train_root = train_dir + '/' + os.listdir(train_dir)[0] + '/' + val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' + max_epochs = 1 # run epochs in unit test + + cache_path = snapshot_download(model_id) + + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + return train_root, val_root, max_epochs, cache_path, tmp_dir + + +def train_func(**kwargs): + trainer = build_trainer( + name=Trainers.face_detection_scrfd, default_args=kwargs) + trainer.train() + + +class TestFaceDetectionScrfdTrainerSingleGPU(unittest.TestCase): + + def setUp(self): + print(('SingleGPU Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def _cfg_modify_fn(self, cfg): + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 # batch size + return cfg + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_from_scratch(self): + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.face_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + cfg_modify_fn=self._cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.face_detection_scrfd, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, 'distributed unittest') +class TestFaceDetectionScrfdTrainerMultiGpus(DistributedTestCase): + + def setUp(self): + print(('MultiGPUs Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + self.train_root, self.val_root, self.max_epochs, self.cache_path, self.tmp_dir = _setup( + ) + cfg_file_path = os.path.join(self.cache_path, 'mmcv_scrfd.py') + cfg = Config.from_file(cfg_file_path) + cfg.checkpoint_config.interval = 1 + cfg.log_config.interval = 10 + cfg.evaluation.interval = 1 + cfg.data.workers_per_gpu = 3 + cfg.data.samples_per_gpu = 4 + cfg.dump(cfg_file_path) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_multi_gpus_finetune(self): + pretrain_epoch = 640 + self.max_epochs += pretrain_epoch + kwargs = dict( + cfg_file=os.path.join(self.cache_path, 'mmcv_scrfd.py'), + work_dir=self.tmp_dir, + train_root=self.train_root, + val_root=self.val_root, + total_epochs=self.max_epochs, + resume_from=os.path.join(self.cache_path, + ModelFile.TORCH_MODEL_BIN_FILE), + launcher='pytorch') + self.start(train_func, num_gpus=2, **kwargs) + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + for i in range(pretrain_epoch, self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_mplug.py b/tests/trainers/test_finetune_mplug.py new file mode 100644 index 00000000..46664114 --- /dev/null +++ b/tests/trainers/test_finetune_mplug.py @@ -0,0 +1,144 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.multi_modal import MPlugForAllTasks +from modelscope.msdatasets import MsDataset +from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestFinetuneMPlug(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + datadict = MsDataset.load('coco_captions_small_slice') + self.train_dataset = MsDataset( + datadict['train'].remap_columns({ + 'image:FILE': 'image', + 'answer:Value': 'answer' + }).map(lambda _: {'question': 'what the picture describes?'})) + self.test_dataset = MsDataset( + datadict['test'].remap_columns({ + 'image:FILE': 'image', + 'answer:Value': 'answer' + }).map(lambda _: {'question': 'what the picture describes?'})) + self.max_epochs = 2 + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_caption(self): + kwargs = dict( + model='damo/mplug_image-captioning_coco_base_en', + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.mplug, default_args=kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_caption_with_model_and_args(self): + cache_path = snapshot_download( + 'damo/mplug_image-captioning_coco_base_en') + model = MPlugForAllTasks.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.mplug, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_vqa(self): + kwargs = dict( + model='damo/mplug_visual-question-answering_coco_large_en', + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.mplug, default_args=kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_vqa_with_model_and_args(self): + cache_path = snapshot_download( + 'damo/mplug_visual-question-answering_coco_large_en') + model = MPlugForAllTasks.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.mplug, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_retrieval(self): + kwargs = dict( + model='damo/mplug_image-text-retrieval_flickr30k_large_en', + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.mplug, default_args=kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_retrieval_with_model_and_args(self): + cache_path = snapshot_download( + 'damo/mplug_image-text-retrieval_flickr30k_large_en') + model = MPlugForAllTasks.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.mplug, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py new file mode 100644 index 00000000..061d37d3 --- /dev/null +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -0,0 +1,535 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Preprocessors, Trainers +from modelscope.models import Model +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.trainers import NlpTrainerArguments, build_trainer +from modelscope.trainers.hooks import Hook +from modelscope.trainers.nlp_trainer import (EpochBasedTrainer, + NlpEpochBasedTrainer) +from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ + calculate_fisher +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.data_utils import to_device +from modelscope.utils.regress_test_utils import (MsRegressTool, + compare_arguments_nested) +from modelscope.utils.test_utils import test_level + + +class TestFinetuneSequenceClassification(unittest.TestCase): + epoch_num = 1 + + sentence1 = '今天气温比昨天高么?' + sentence2 = '今天湿度比昨天高么?' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.regress_tool = MsRegressTool(baseline=False) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skip + def test_trainer_cfg_class(self): + dataset = MsDataset.load('clue', subset_name='tnews') + train_dataset = dataset['train'] + validation_dataset = dataset['validation'] + cfg_modify_fn = NlpTrainerArguments( + task=Tasks.text_classification, + preprocessor_type=Preprocessors.sen_cls_tokenizer, + train_first_sequence='sentence', + train_label='label', + labels=[ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', + '12', '13', '14' + ], + max_epochs=5, + optimizer_args={ + 'lr': 3e-5, + }, + lr_scheduler_args={ + 'total_iters': int(len(train_dataset) / 32) * 5, + }, + checkpoint_saving_type='BestCkptSaverHook', + metric_key='accuracy', + train_batch_size_per_gpu=32, + checkpoint_interval=1, + train_workers_per_gpu=0, + checkpoint_by_epoch=False, + evaluation_interval=1, + evaluation_by_epoch=False, + eval_workers_per_gpu=0, + metrics=['seq-cls-metric'], + ) + + kwargs = dict( + model='damo/nlp_structbert_backbone_base_std', + train_dataset=train_dataset, + eval_dataset=validation_dataset, + work_dir=self.tmp_dir, + seed=42, + cfg_modify_fn=cfg_modify_fn) + + os.environ['LOCAL_RANK'] = '0' + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + + @unittest.skip( + 'Skip testing trainer repeatable, because it\'s unstable in daily UT') + def test_trainer_repeatable(self): + import torch # noqa + + def compare_fn(value1, value2, key, type): + # Ignore the differences between optimizers of two torch versions + if type != 'optimizer': + return None + + match = (value1['type'] == value2['type']) + shared_defaults = set(value1['defaults'].keys()).intersection( + set(value2['defaults'].keys())) + match = all([ + compare_arguments_nested(f'Optimizer defaults {key} not match', + value1['defaults'][key], + value2['defaults'][key]) + for key in shared_defaults + ]) and match + match = (len(value1['state_dict']['param_groups']) == len( + value2['state_dict']['param_groups'])) and match + for group1, group2 in zip(value1['state_dict']['param_groups'], + value2['state_dict']['param_groups']): + shared_keys = set(group1.keys()).intersection( + set(group2.keys())) + match = all([ + compare_arguments_nested( + f'Optimizer param_groups {key} not match', group1[key], + group2[key]) for key in shared_keys + ]) and match + return match + + def cfg_modify_fn(cfg): + cfg.task = 'nli' + cfg['preprocessor'] = {'type': 'nli-tokenizer'} + cfg.train.optimizer.lr = 2e-5 + cfg['dataset'] = { + 'train': { + 'labels': [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', + '11', '12', '13', '14' + ], + 'first_sequence': + 'sentence', + 'label': + 'label', + } + } + cfg.train.max_epochs = 5 + cfg.train.lr_scheduler = { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'total_iters': + int(len(dataset['train']) / 32) * cfg.train.max_epochs, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 100 + }] + return cfg + + dataset = MsDataset.load('clue', subset_name='tnews') + + kwargs = dict( + model='damo/nlp_structbert_backbone_base_std', + train_dataset=dataset['train'], + eval_dataset=dataset['validation'], + work_dir=self.tmp_dir, + seed=42, + cfg_modify_fn=cfg_modify_fn) + + os.environ['LOCAL_RANK'] = '0' + trainer: EpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + + with self.regress_tool.monitor_ms_train( + trainer, 'sbert-base-tnews', level='strict', + compare_fn=compare_fn): + trainer.train() + + def finetune(self, + model_id, + train_dataset, + eval_dataset, + name=Trainers.nlp_base_trainer, + cfg_modify_fn=None, + **kwargs): + kwargs = dict( + model=model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=self.tmp_dir, + cfg_modify_fn=cfg_modify_fn, + **kwargs) + + os.environ['LOCAL_RANK'] = '0' + trainer = build_trainer(name=name, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.epoch_num): + self.assertIn(f'epoch_{i + 1}.pth', results_files) + + output_files = os.listdir( + os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) + self.assertIn(ModelFile.CONFIGURATION, output_files) + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) + copy_src_files = os.listdir(trainer.model_dir) + + print(f'copy_src_files are {copy_src_files}') + print(f'output_files are {output_files}') + for item in copy_src_files: + if not item.startswith('.'): + self.assertIn(item, output_files) + + def pipeline_sentence_similarity(self, model_dir): + model = Model.from_pretrained(model_dir) + pipeline_ins = pipeline(task=Tasks.sentence_similarity, model=model) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skip + def test_finetune_afqmc(self): + """This unittest is used to reproduce the clue:afqmc dataset + structbert model training results. + + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ + + def cfg_modify_fn(cfg): + cfg.task = Tasks.sentence_similarity + cfg['preprocessor'] = {'type': Preprocessors.sen_sim_tokenizer} + cfg.train.optimizer.lr = 2e-5 + cfg['dataset'] = { + 'train': { + 'labels': ['0', '1'], + 'first_sequence': 'sentence1', + 'second_sequence': 'sentence2', + 'label': 'label', + } + } + cfg.train.max_epochs = self.epoch_num + cfg.train.lr_scheduler = { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'total_iters': + int(len(dataset['train']) / 32) * cfg.train.max_epochs, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 100 + }] + return cfg + + dataset = MsDataset.load('clue', subset_name='afqmc') + self.finetune( + model_id='damo/nlp_structbert_backbone_base_std', + train_dataset=dataset['train'], + eval_dataset=dataset['validation'], + cfg_modify_fn=cfg_modify_fn) + + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + self.pipeline_sentence_similarity(output_dir) + + @unittest.skip + def test_finetune_tnews(self): + """This unittest is used to reproduce the clue:tnews dataset + structbert model training results. + + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ + + def cfg_modify_fn(cfg): + # TODO no proper task for tnews + cfg.task = 'nli' + cfg['preprocessor'] = {'type': 'nli-tokenizer'} + cfg.train.optimizer.lr = 2e-5 + cfg['dataset'] = { + 'train': { + 'labels': [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', + '11', '12', '13', '14' + ], + 'first_sequence': + 'sentence', + 'label': + 'label', + } + } + cfg.train.max_epochs = 5 + cfg.train.lr_scheduler = { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'total_iters': + int(len(dataset['train']) / 32) * cfg.train.max_epochs, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 100 + }] + return cfg + + dataset = MsDataset.load('clue', subset_name='tnews') + + self.finetune( + model_id='damo/nlp_structbert_backbone_base_std', + train_dataset=dataset['train'], + eval_dataset=dataset['validation'], + cfg_modify_fn=cfg_modify_fn) + + @unittest.skip + def test_veco_xnli(self): + """This unittest is used to reproduce the xnli dataset + veco model training results. + + Here we follow the training scenario listed in the Alicemind open source project: + https://github.com/alibaba/AliceMind/tree/main/VECO + by training the english language subset. + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ + + from datasets import load_dataset + langs = ['en'] + langs_eval = ['en'] + train_datasets = [] + from datasets import DownloadConfig + dc = DownloadConfig() + dc.local_files_only = False + for lang in langs: + train_datasets.append( + load_dataset('xnli', lang, split='train', download_config=dc)) + eval_datasets = [] + for lang in langs_eval: + eval_datasets.append( + load_dataset( + 'xnli', lang, split='validation', download_config=dc)) + train_len = sum([len(dataset) for dataset in train_datasets]) + labels = ['0', '1', '2'] + + def cfg_modify_fn(cfg): + cfg.task = 'nli' + cfg['preprocessor'] = {'type': 'nli-tokenizer'} + cfg['dataset'] = { + 'train': { + 'first_sequence': 'premise', + 'second_sequence': 'hypothesis', + 'labels': labels, + 'label': 'label', + } + } + cfg['train'] = { + 'work_dir': + '/tmp', + 'max_epochs': + 2, + 'dataloader': { + 'batch_size_per_gpu': 16, + 'workers_per_gpu': 0 + }, + 'optimizer': { + 'type': 'AdamW', + 'lr': 2e-5, + 'options': { + 'cumulative_iters': 8, + } + }, + 'lr_scheduler': { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'total_iters': int(train_len / 16) * 2, + 'options': { + 'by_epoch': False + } + }, + 'hooks': [{ + 'type': 'CheckpointHook', + 'interval': 1, + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 500 + }] + } + cfg['evaluation'] = { + 'dataloader': { + 'batch_size_per_gpu': 128, + 'workers_per_gpu': 0, + 'shuffle': False + } + } + return cfg + + self.finetune( + 'damo/nlp_veco_fill-mask-large', + train_datasets, + eval_datasets, + name=Trainers.nlp_veco_trainer, + cfg_modify_fn=cfg_modify_fn) + + @unittest.skip + def test_finetune_cluewsc(self): + """This unittest is used to reproduce the clue:wsc dataset + structbert model training results. + + A runnable sample of child-tuning is also showed here. + + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ + + child_tuning_type = 'ChildTuning-F' + mode = {} + if child_tuning_type is not None: + mode = {'mode': child_tuning_type, 'reserve_p': 0.2} + + def cfg_modify_fn(cfg): + cfg.task = 'nli' + cfg['preprocessor'] = {'type': 'nli-tokenizer'} + cfg['dataset'] = { + 'train': { + 'labels': ['0', '1'], + 'first_sequence': 'text', + 'second_sequence': 'text2', + 'label': 'label', + } + } + cfg.train.dataloader.batch_size_per_gpu = 16 + cfg.train.max_epochs = 30 + cfg.train.optimizer = { + 'type': + 'AdamW' if child_tuning_type is None else 'ChildTuningAdamW', + 'lr': 1e-5, + 'options': {}, + **mode, + } + cfg.train.lr_scheduler = { + 'type': + 'LinearLR', + 'start_factor': + 1.0, + 'end_factor': + 0.0, + 'total_iters': + int( + len(dataset['train']) + / cfg.train.dataloader.batch_size_per_gpu) + * cfg.train.max_epochs, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 30 + }] + return cfg + + def add_sentence2(features): + return { + 'text2': + features['target']['span2_text'] + '指代' + + features['target']['span1_text'] + } + + dataset = MsDataset.load('clue', subset_name='cluewsc2020') + dataset = { + k: v.to_hf_dataset().map(add_sentence2) + for k, v in dataset.items() + } + + kwargs = dict( + model='damo/nlp_structbert_backbone_base_std', + train_dataset=dataset['train'], + eval_dataset=dataset['validation'], + work_dir=self.tmp_dir, + cfg_modify_fn=cfg_modify_fn) + + os.environ['LOCAL_RANK'] = '0' + trainer: NlpEpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + + class CalculateFisherHook(Hook): + + @staticmethod + def forward_step(model, inputs): + inputs = to_device(inputs, trainer.device) + trainer.train_step(model, inputs) + return trainer.train_outputs['loss'] + + def before_run(self, trainer: NlpEpochBasedTrainer): + v = calculate_fisher(trainer.model, trainer.train_dataloader, + self.forward_step, 0.2) + trainer.optimizer.set_gradient_mask(v) + + if child_tuning_type == 'ChildTuning-D': + trainer.register_hook(CalculateFisherHook()) + trainer.train() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_text_generation.py b/tests/trainers/test_finetune_text_generation.py new file mode 100644 index 00000000..59bef51c --- /dev/null +++ b/tests/trainers/test_finetune_text_generation.py @@ -0,0 +1,173 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.nlp import GPT3ForTextGeneration, PalmForTextGeneration +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestFinetuneTextGeneration(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + from datasets import Dataset + + src_dataset_dict = { + 'src_txt': [ + 'This is test sentence1-1', 'This is test sentence2-1', + 'This is test sentence3-1' + ] + } + src_tgt_dataset_dict = { + 'src_txt': + src_dataset_dict['src_txt'], + 'tgt_txt': [ + 'This is test sentence1-2', 'This is test sentence2-2', + 'This is test sentence3-2' + ] + } + + self.src_dataset = MsDataset(Dataset.from_dict(src_dataset_dict)) + self.src_tgt_dataset = MsDataset( + Dataset.from_dict(src_tgt_dataset_dict)) + + self.max_epochs = 3 + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_palm(self): + + kwargs = dict( + model='damo/nlp_palm2.0_text-generation_english-base', + train_dataset=self.src_tgt_dataset, + eval_dataset=self.src_tgt_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.text_generation_trainer, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_palm_with_model_and_args(self): + + cache_path = snapshot_download( + 'damo/nlp_palm2.0_text-generation_english-base') + model = PalmForTextGeneration.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.src_tgt_dataset, + eval_dataset=self.src_tgt_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_gpt3(self): + + kwargs = dict( + model='damo/nlp_gpt3_text-generation_chinese-base', + train_dataset=self.src_dataset, + eval_dataset=self.src_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.text_generation_trainer, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_gpt3_with_model_and_args(self): + + cache_path = snapshot_download( + 'damo/nlp_gpt3_text-generation_chinese-base') + model = GPT3ForTextGeneration.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.src_dataset, + eval_dataset=self.src_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skip + def test_finetune_cnndm(self): + from modelscope.msdatasets import MsDataset + dataset_dict = MsDataset.load('DuReader_robust-QG') + train_dataset = dataset_dict['train'].remap_columns({ + 'text1': 'src_txt', + 'text2': 'tgt_txt' + }) + eval_dataset = dataset_dict['validation'].remap_columns({ + 'text1': + 'src_txt', + 'text2': + 'tgt_txt' + }) + num_warmup_steps = 200 + os.environ['LOCAL_RANK'] = '0' + + def noam_lambda(current_step: int): + current_step += 1 + return min(current_step**(-0.5), + current_step * num_warmup_steps**(-1.5)) + + def cfg_modify_fn(cfg): + cfg.train.lr_scheduler = { + 'type': 'LambdaLR', + 'lr_lambda': noam_lambda, + 'options': { + 'by_epoch': False + } + } + return cfg + + kwargs = dict( + model='damo/nlp_palm2.0_text-generation_chinese-base', + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=self.tmp_dir, + cfg_modify_fn=cfg_modify_fn) + trainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_text_ranking.py b/tests/trainers/test_finetune_text_ranking.py new file mode 100644 index 00000000..6e97310d --- /dev/null +++ b/tests/trainers/test_finetune_text_ranking.py @@ -0,0 +1,200 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union + +import torch +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from modelscope.metainfo import Trainers +from modelscope.models import Model +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class TestFinetuneSequenceClassification(unittest.TestCase): + inputs = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [ + "On average, students take about 18 to 24 months to complete a master's degree.", + 'On the other hand, some students prefer to go at a slower pace and choose to take ' + 'several years to complete their studies.', + 'It can take anywhere from two semesters' + ] + } + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def finetune(self, + model_id, + train_dataset, + eval_dataset, + name=Trainers.nlp_text_ranking_trainer, + cfg_modify_fn=None, + **kwargs): + kwargs = dict( + model=model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=self.tmp_dir, + cfg_modify_fn=cfg_modify_fn, + **kwargs) + + os.environ['LOCAL_RANK'] = '0' + trainer = build_trainer(name=name, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_finetune_msmarco(self): + + def cfg_modify_fn(cfg): + neg_sample = 4 + cfg.task = 'text-ranking' + cfg['preprocessor'] = {'type': 'text-ranking'} + cfg.train.optimizer.lr = 2e-5 + cfg['dataset'] = { + 'train': { + 'type': 'bert', + 'query_sequence': 'query', + 'pos_sequence': 'positive_passages', + 'neg_sequence': 'negative_passages', + 'text_fileds': ['title', 'text'], + 'qid_field': 'query_id', + 'neg_sample': neg_sample + }, + 'val': { + 'type': 'bert', + 'query_sequence': 'query', + 'pos_sequence': 'positive_passages', + 'neg_sequence': 'negative_passages', + 'text_fileds': ['title', 'text'], + 'qid_field': 'query_id' + }, + } + cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 + cfg.train.max_epochs = 1 + cfg.train.train_batch_size = 4 + cfg.train.lr_scheduler = { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'options': { + 'by_epoch': False + } + } + cfg.model['neg_sample'] = 4 + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 15 + }] + return cfg + + # load dataset + ds = MsDataset.load('passage-ranking-demo', 'zyznull') + train_ds = ds['train'].to_hf_dataset() + dev_ds = ds['dev'].to_hf_dataset() + + model_id = 'damo/nlp_corom_passage-ranking_english-base' + self.finetune( + model_id=model_id, + train_dataset=train_ds, + eval_dataset=dev_ds, + cfg_modify_fn=cfg_modify_fn) + + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + self.pipeline_text_ranking(output_dir) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_finetune_dureader(self): + + def cfg_modify_fn(cfg): + cfg.task = 'text-ranking' + cfg['preprocessor'] = {'type': 'text-ranking'} + cfg.train.optimizer.lr = 2e-5 + cfg['dataset'] = { + 'train': { + 'type': 'bert', + 'query_sequence': 'query', + 'pos_sequence': 'positive_passages', + 'neg_sequence': 'negative_passages', + 'text_fileds': ['text'], + 'qid_field': 'query_id' + }, + 'val': { + 'type': 'bert', + 'query_sequence': 'query', + 'pos_sequence': 'positive_passages', + 'neg_sequence': 'negative_passages', + 'text_fileds': ['text'], + 'qid_field': 'query_id' + }, + } + cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 + cfg.train.max_epochs = 1 + cfg.train.train_batch_size = 4 + cfg.train.lr_scheduler = { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 5000 + }] + return cfg + + # load dataset + ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull') + train_ds = ds['train'].to_hf_dataset().shard(1000, index=0) + dev_ds = ds['dev'].to_hf_dataset() + model_id = 'damo/nlp_rom_passage-ranking_chinese-base' + self.finetune( + model_id=model_id, + train_dataset=train_ds, + eval_dataset=dev_ds, + cfg_modify_fn=cfg_modify_fn) + + def pipeline_text_ranking(self, model_dir): + model = Model.from_pretrained(model_dir) + pipeline_ins = pipeline(task=Tasks.text_ranking, model=model) + print(pipeline_ins(input=self.inputs)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_token_classificatin.py b/tests/trainers/test_finetune_token_classificatin.py new file mode 100644 index 00000000..a92cee7b --- /dev/null +++ b/tests/trainers/test_finetune_token_classificatin.py @@ -0,0 +1,129 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +from functools import reduce + +from modelscope.metainfo import Trainers +from modelscope.trainers import build_trainer +from modelscope.utils.test_utils import test_level + + +class TestFinetuneTokenClassification(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def finetune(self, + model_id, + train_dataset, + eval_dataset, + name=Trainers.nlp_base_trainer, + cfg_modify_fn=None, + **kwargs): + kwargs = dict( + model=model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=self.tmp_dir, + cfg_modify_fn=cfg_modify_fn, + **kwargs) + + os.environ['LOCAL_RANK'] = '0' + trainer = build_trainer(name=name, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(10): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skip + def test_word_segmentation(self): + """This unittest is used to reproduce the icwb2:pku dataset + structbert model training results. + + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ + + os.system( + f'curl http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip > {self.tmp_dir}/icwb2-data.zip' + ) + shutil.unpack_archive(f'{self.tmp_dir}/icwb2-data.zip', self.tmp_dir) + from datasets import load_dataset + from modelscope.preprocessors.nlp import WordSegmentationBlankSetToLabelPreprocessor + preprocessor = WordSegmentationBlankSetToLabelPreprocessor() + dataset = load_dataset( + 'text', + data_files=f'{self.tmp_dir}/icwb2-data/training/pku_training.utf8') + + def split_to_dict(examples): + return preprocessor(examples['text']) + + dataset = dataset.map(split_to_dict, batched=False) + + def reducer(x, y): + x = x.split(' ') if isinstance(x, str) else x + y = y.split(' ') if isinstance(y, str) else y + return x + y + + label_enumerate_values = list( + set(reduce(reducer, dataset['train'][:1000]['labels']))) + label_enumerate_values.sort() + + train_len = int(len(dataset['train']) * 0.7) + train_dataset = dataset['train'].select(range(train_len)) + dev_dataset = dataset['train'].select( + range(train_len, len(dataset['train']))) + + def cfg_modify_fn(cfg): + cfg.task = 'token-classification' + cfg['dataset'] = { + 'train': { + 'labels': label_enumerate_values, + 'first_sequence': 'tokens', + 'label': 'labels', + } + } + cfg['preprocessor'] = {'type': 'token-cls-tokenizer'} + cfg.train.max_epochs = 2 + cfg.train.lr_scheduler = { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'total_iters': + int(len(train_dataset) / 32) * cfg.train.max_epochs, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 50 + }] + return cfg + + self.finetune( + 'damo/nlp_structbert_backbone_base_std', + train_dataset, + dev_dataset, + cfg_modify_fn=cfg_modify_fn) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_color_enhance_trainer.py b/tests/trainers/test_image_color_enhance_trainer.py new file mode 100644 index 00000000..34d84cd2 --- /dev/null +++ b/tests/trainers/test_image_color_enhance_trainer.py @@ -0,0 +1,108 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import shutil +import tempfile +import unittest +from typing import Callable, List, Optional, Tuple, Union + +import cv2 +import torch +from torch.utils import data as data + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.cv.image_color_enhance import ImageColorEnhance +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class PairedImageDataset(data.Dataset): + + def __init__(self, root): + super(PairedImageDataset, self).__init__() + gt_dir = osp.join(root, 'gt') + lq_dir = osp.join(root, 'lq') + self.gt_filelist = os.listdir(gt_dir) + self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4])) + self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist] + self.lq_filelist = os.listdir(lq_dir) + self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4])) + self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist] + + def _img_to_tensor(self, img): + return torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type( + torch.float32) / 255. + + def __getitem__(self, index): + lq = cv2.imread(self.lq_filelist[index]) + gt = cv2.imread(self.gt_filelist[index]) + lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC) + gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC) + return \ + {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} + + def __len__(self): + return len(self.gt_filelist) + + def to_torch_dataset(self, + columns: Union[str, List[str]] = None, + preprocessors: Union[Callable, List[Callable]] = None, + **format_kwargs): + return self + + +class TestImageColorEnhanceTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_csrnet_image-color-enhance-models' + + self.dataset = PairedImageDataset( + './data/test/images/image_color_enhance/') + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(3): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + cache_path = snapshot_download(self.model_id) + model = ImageColorEnhance.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset, + eval_dataset=self.dataset, + max_epochs=2, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_denoise_trainer.py b/tests/trainers/test_image_denoise_trainer.py new file mode 100644 index 00000000..b742dcae --- /dev/null +++ b/tests/trainers/test_image_denoise_trainer.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.cv.image_denoise import NAFNetForImageDenoise +from modelscope.msdatasets import MsDataset +from modelscope.msdatasets.task_datasets.sidd_image_denoising import \ + SiddImageDenoisingDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ImageDenoiseTrainerTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_nafnet_image-denoise_sidd' + self.cache_path = snapshot_download(self.model_id) + self.config = Config.from_file( + os.path.join(self.cache_path, ModelFile.CONFIGURATION)) + dataset_train = MsDataset.load( + 'SIDD', + namespace='huizheng', + subset_name='default', + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds + dataset_val = MsDataset.load( + 'SIDD', + namespace='huizheng', + subset_name='default', + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds + self.dataset_train = SiddImageDenoisingDataset( + dataset_train, self.config.dataset, is_train=True) + self.dataset_val = SiddImageDenoisingDataset( + dataset_val, self.config.dataset, is_train=False) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + model = NAFNetForImageDenoise.from_pretrained(self.cache_path) + kwargs = dict( + cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + max_epochs=2, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_inpainting_trainer.py b/tests/trainers/test_image_inpainting_trainer.py new file mode 100644 index 00000000..807fe64f --- /dev/null +++ b/tests/trainers/test_image_inpainting_trainer.py @@ -0,0 +1,84 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.image_inpainting import FFTInpainting +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ImageInpaintingTrainerTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_fft_inpainting_lama' + self.cache_path = snapshot_download(self.model_id) + cfg = Config.from_file( + os.path.join(self.cache_path, ModelFile.CONFIGURATION)) + + train_data_cfg = ConfigDict( + name='PlacesToydataset', + split='train', + mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, + out_size=cfg.dataset.train_out_size, + test_mode=False) + + test_data_cfg = ConfigDict( + name='PlacesToydataset', + split='test', + mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, + out_size=cfg.dataset.val_out_size, + test_mode=True) + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + mask_gen_kwargs=train_data_cfg.mask_gen_kwargs, + out_size=train_data_cfg.out_size, + test_mode=train_data_cfg.test_mode) + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + mask_gen_kwargs=test_data_cfg.mask_gen_kwargs, + out_size=test_data_cfg.out_size, + test_mode=test_data_cfg.test_mode) + assert next( + iter(self.test_dataset.config_kwargs['split_config'].values())) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset) + + trainer = build_trainer( + name=Trainers.image_inpainting, default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_instance_segmentation_trainer.py b/tests/trainers/test_image_instance_segmentation_trainer.py new file mode 100644 index 00000000..03f7eea3 --- /dev/null +++ b/tests/trainers/test_image_instance_segmentation_trainer.py @@ -0,0 +1,124 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile +from functools import partial + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.image_instance_segmentation import \ + CascadeMaskRCNNSwinModel +from modelscope.msdatasets import MsDataset +from modelscope.msdatasets.task_datasets import \ + ImageInstanceSegmentationCocoDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImageInstanceSegmentationTrainer(unittest.TestCase): + + model_id = 'damo/cv_swin-b_image-instance-segmentation_coco' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + + max_epochs = cfg.train.max_epochs + samples_per_gpu = cfg.train.dataloader.batch_size_per_gpu + try: + train_data_cfg = cfg.dataset.train + val_data_cfg = cfg.dataset.val + except Exception: + train_data_cfg = None + val_data_cfg = None + if train_data_cfg is None: + # use default toy data + train_data_cfg = ConfigDict( + name='pets_small', split='train', test_mode=False) + if val_data_cfg is None: + val_data_cfg = ConfigDict( + name='pets_small', split='validation', test_mode=True) + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + test_mode=train_data_cfg.test_mode, + download_mode=DownloadMode.FORCE_REDOWNLOAD) + assert self.train_dataset.config_kwargs['classes'] + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.eval_dataset = MsDataset.load( + dataset_name=val_data_cfg.name, + split=val_data_cfg.split, + test_mode=val_data_cfg.test_mode, + download_mode=DownloadMode.FORCE_REDOWNLOAD) + assert self.eval_dataset.config_kwargs['classes'] + assert next( + iter(self.eval_dataset.config_kwargs['split_config'].values())) + + from mmcv.parallel import collate + + self.collate_fn = partial(collate, samples_per_gpu=samples_per_gpu) + + self.max_epochs = max_epochs + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + data_collator=self.collate_fn, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.image_instance_segmentation, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download(self.model_id) + model = CascadeMaskRCNNSwinModel.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + data_collator=self.collate_fn, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.image_instance_segmentation, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_portrait_enhancement_trainer.py b/tests/trainers/test_image_portrait_enhancement_trainer.py new file mode 100644 index 00000000..123e0098 --- /dev/null +++ b/tests/trainers/test_image_portrait_enhancement_trainer.py @@ -0,0 +1,93 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import shutil +import tempfile +import unittest +from typing import Callable, List, Optional, Tuple, Union + +import cv2 +import torch +from torch.utils import data as data + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.image_portrait_enhancement import \ + ImagePortraitEnhancement +from modelscope.msdatasets import MsDataset +from modelscope.msdatasets.task_datasets.image_portrait_enhancement import \ + ImagePortraitEnhancementDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImagePortraitEnhancementTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_gpen_image-portrait-enhancement' + + dataset_train = MsDataset.load( + 'image-portrait-enhancement-dataset', + namespace='modelscope', + subset_name='default', + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds + dataset_val = MsDataset.load( + 'image-portrait-enhancement-dataset', + namespace='modelscope', + subset_name='default', + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds + + self.dataset_train = ImagePortraitEnhancementDataset( + dataset_train, is_train=True) + self.dataset_val = ImagePortraitEnhancementDataset( + dataset_val, is_train=False) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + device='gpu', + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.image_portrait_enhancement, default_args=kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download(self.model_id) + model = ImagePortraitEnhancement.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + device='gpu', + max_epochs=2, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.image_portrait_enhancement, default_args=kwargs) + trainer.train() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_movie_scene_segmentation_trainer.py b/tests/trainers/test_movie_scene_segmentation_trainer.py new file mode 100644 index 00000000..f25dc92a --- /dev/null +++ b/tests/trainers/test_movie_scene_segmentation_trainer.py @@ -0,0 +1,109 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.movie_scene_segmentation import \ + MovieSceneSegmentationModel +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImageInstanceSegmentationTrainer(unittest.TestCase): + + model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + + max_epochs = cfg.train.max_epochs + + train_data_cfg = ConfigDict( + name='movie_scene_seg_toydata', + split='train', + cfg=cfg.preprocessor, + test_mode=False) + + test_data_cfg = ConfigDict( + name='movie_scene_seg_toydata', + split='test', + cfg=cfg.preprocessor, + test_mode=True) + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + namespace=train_data_cfg.namespace, + cfg=train_data_cfg.cfg, + test_mode=train_data_cfg.test_mode) + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + namespace=test_data_cfg.namespace, + cfg=test_data_cfg.cfg, + test_mode=test_data_cfg.test_mode) + assert next( + iter(self.test_dataset.config_kwargs['split_config'].values())) + + self.max_epochs = max_epochs + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.movie_scene_segmentation, default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download(self.model_id) + model = MovieSceneSegmentationModel.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir=tmp_dir) + + trainer = build_trainer( + name=Trainers.movie_scene_segmentation, default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py new file mode 100644 index 00000000..0516e569 --- /dev/null +++ b/tests/trainers/test_ofa_trainer.py @@ -0,0 +1,109 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest + +import json + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.hub import read_config +from modelscope.utils.test_utils import test_level + + +class TestOfaTrainer(unittest.TestCase): + + def setUp(self) -> None: + self.finetune_cfg = \ + {'framework': 'pytorch', + 'task': 'ocr-recognition', + 'model': {'type': 'ofa', + 'beam_search': {'beam_size': 5, + 'max_len_b': 64, + 'min_len': 1, + 'no_repeat_ngram_size': 0}, + 'seed': 7, + 'max_src_length': 128, + 'language': 'zh', + 'gen_type': 'generation', + 'patch_image_size': 480, + 'is_document': False, + 'max_image_size': 480, + 'imagenet_default_mean_and_std': False}, + 'pipeline': {'type': 'ofa-ocr-recognition'}, + 'dataset': {'column_map': {'text': 'label'}}, + 'train': {'work_dir': 'work/ckpts/recognition', + # 'launcher': 'pytorch', + 'max_epochs': 1, + 'use_fp16': True, + 'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, + 'lr_scheduler': {'name': 'polynomial_decay', + 'warmup_proportion': 0.01, + 'lr_end': 1e-07}, + 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, + 'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01}, + 'optimizer_hook': {'type': 'TorchAMPOptimizerHook', + 'cumulative_iters': 1, + 'grad_clip': {'max_norm': 1.0, 'norm_type': 2}, + 'loss_keys': 'loss'}, + 'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion', + 'constraint_range': None, + 'drop_worst_after': 0, + 'drop_worst_ratio': 0.0, + 'ignore_eos': False, + 'ignore_prefix_size': 0, + 'label_smoothing': 0.1, + 'reg_alpha': 1.0, + 'report_accuracy': False, + 'sample_patch_num': 196, + 'sentence_avg': False, + 'use_rdrop': True}, + 'hooks': [{'type': 'BestCkptSaverHook', + 'metric_key': 'accuracy', + 'interval': 100}, + {'type': 'TextLoggerHook', 'interval': 1}, + {'type': 'IterTimerHook'}, + {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, + 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, + 'metrics': [{'type': 'accuracy'}]}, + 'preprocessor': []} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_std(self): + WORKSPACE = './workspace/ckpts/recognition' + os.makedirs(WORKSPACE, exist_ok=True) + config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) + with open(config_file, 'w') as writer: + json.dump(self.finetune_cfg, writer) + + pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' + + args = dict( + model=pretrained_model, + work_dir=WORKSPACE, + train_dataset=MsDataset.load( + 'ocr_fudanvi_zh', + subset_name='scene', + namespace='modelscope', + split='train[800:900]', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), + eval_dataset=MsDataset.load( + 'ocr_fudanvi_zh', + subset_name='scene', + namespace='modelscope', + split='test[:20]', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), + cfg_file=config_file) + trainer = build_trainer(name=Trainers.ofa, default_args=args) + trainer.train() + + self.assertIn( + ModelFile.TORCH_MODEL_BIN_FILE, + os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) + shutil.rmtree(WORKSPACE) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_referring_video_object_segmentation_trainer.py b/tests/trainers/test_referring_video_object_segmentation_trainer.py new file mode 100644 index 00000000..7b03eb4d --- /dev/null +++ b/tests/trainers/test_referring_video_object_segmentation_trainer.py @@ -0,0 +1,101 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.movie_scene_segmentation import \ + MovieSceneSegmentationModel +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImageInstanceSegmentationTrainer(unittest.TestCase): + + model_id = 'damo/cv_swin-t_referring_video-object-segmentation' + dataset_name = 'referring_vos_toydata' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + + max_epochs = cfg.train.max_epochs + + train_data_cfg = ConfigDict( + name=self.dataset_name, + split='train', + test_mode=False, + cfg=cfg.dataset) + + test_data_cfg = ConfigDict( + name=self.dataset_name, + split='test', + test_mode=True, + cfg=cfg.dataset) + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + cfg=train_data_cfg.cfg, + namespace='damo', + test_mode=train_data_cfg.test_mode) + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + cfg=test_data_cfg.cfg, + namespace='damo', + test_mode=test_data_cfg.test_mode) + assert next( + iter(self.test_dataset.config_kwargs['split_config'].values())) + + self.max_epochs = max_epochs + + @unittest.skip('skip since the model is set to private for now') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir='./work_dir') + + trainer = build_trainer( + name=Trainers.referring_video_object_segmentation, + default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + @unittest.skip('skip since the model is set to private for now') + def test_trainer_with_model_and_args(self): + + cache_path = snapshot_download(self.model_id) + model = MovieSceneSegmentationModel.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir='./work_dir') + + trainer = build_trainer( + name=Trainers.referring_video_object_segmentation, + default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_team_transfer_trainer.py b/tests/trainers/test_team_transfer_trainer.py new file mode 100644 index 00000000..0f6b88bb --- /dev/null +++ b/tests/trainers/test_team_transfer_trainer.py @@ -0,0 +1,94 @@ +import os +import unittest + +import json +import requests +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.trainers.multi_modal.team.team_trainer_utils import ( + collate_fn, train_mapping, val_mapping) +from modelscope.utils.config import Config +from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +def train_worker(device_id): + model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity' + ckpt_dir = './ckpt' + os.makedirs(ckpt_dir, exist_ok=True) + # Use epoch=1 for faster training here + cfg = Config({ + 'framework': 'pytorch', + 'task': 'multi-modal-similarity', + 'pipeline': { + 'type': 'multi-modal-similarity' + }, + 'model': { + 'type': 'team-multi-modal-similarity' + }, + 'dataset': { + 'name': 'Caltech101', + 'class_num': 101 + }, + 'preprocessor': {}, + 'train': { + 'epoch': 1, + 'batch_size': 32, + 'ckpt_dir': ckpt_dir + }, + 'evaluation': { + 'batch_size': 64 + } + }) + cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION) + cfg.dump(cfg_file) + + train_dataset = MsDataset.load( + cfg.dataset.name, + namespace='modelscope', + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() + train_dataset = train_dataset.with_transform(train_mapping) + val_dataset = MsDataset.load( + cfg.dataset.name, + namespace='modelscope', + split='validation', + download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset() + val_dataset = val_dataset.with_transform(val_mapping) + + default_args = dict( + cfg_file=cfg_file, + model=model_id, + device_id=device_id, + data_collator=collate_fn, + train_dataset=train_dataset, + val_dataset=val_dataset) + + trainer = build_trainer( + name=Trainers.image_classification_team, default_args=default_args) + trainer.train() + trainer.evaluate() + + +class TEAMTransferTrainerTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + if torch.cuda.device_count() > 0: + train_worker(device_id=0) + else: + train_worker(device_id=-1) + logger.info('Training done') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_trainer.py b/tests/trainers/test_trainer.py new file mode 100644 index 00000000..c73a56a3 --- /dev/null +++ b/tests/trainers/test_trainer.py @@ -0,0 +1,454 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import IterableDataset + +from modelscope.metainfo import Metrics, Trainers +from modelscope.metrics.builder import MetricKeys +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.trainers.base import DummyTrainer +from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks +from modelscope.utils.test_utils import create_dummy_test_dataset, test_level + + +class DummyIterableDataset(IterableDataset): + + def __iter__(self): + feat = np.random.random(size=(5, )).astype(np.float32) + labels = np.random.randint(0, 4, (1, )) + iterations = [{'feat': feat, 'labels': labels}] * 500 + return iter(iterations) + + +dummy_dataset_small = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) + +dummy_dataset_big = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 4) + self.bn = nn.BatchNorm1d(4) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class TrainerTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_train_0(self): + json_cfg = { + 'task': Tasks.image_classification, + 'train': { + 'work_dir': + self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'optimizer': { + 'type': 'SGD', + 'lr': 0.01, + 'options': { + 'grad_clip': { + 'max_norm': 2.0 + } + } + }, + 'lr_scheduler': { + 'type': 'StepLR', + 'step_size': 2, + 'options': { + 'warmup': { + 'type': 'LinearWarmup', + 'warmup_iters': 2 + } + } + }, + 'hooks': [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': [Metrics.seq_cls_metric] + } + } + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=DummyModel(), + data_collator=None, + train_dataset=dummy_dataset_small, + eval_dataset=dummy_dataset_small, + max_epochs=3, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_train_1(self): + json_cfg = { + 'task': Tasks.image_classification, + 'train': { + 'work_dir': + self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': [Metrics.seq_cls_metric] + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimmizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = StepLR(optimmizer, 2) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + data_collator=None, + train_dataset=dummy_dataset_small, + eval_dataset=dummy_dataset_small, + optimizers=(optimmizer, lr_scheduler), + max_epochs=3, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_train_with_default_config(self): + json_cfg = { + 'task': Tasks.image_classification, + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'EvaluationHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': [Metrics.seq_cls_metric] + } + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimmizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = StepLR(optimmizer, 2) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + data_collator=None, + train_dataset=dummy_dataset_big, + eval_dataset=dummy_dataset_small, + optimizers=(optimmizer, lr_scheduler), + max_epochs=3, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + + json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json') + with open(json_file, 'r') as f: + lines = [i.strip() for i in f.readlines()] + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 1, + LogKeys.ITER: 10, + LogKeys.LR: 0.01 + }, json.loads(lines[0])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 1, + LogKeys.ITER: 20, + LogKeys.LR: 0.01 + }, json.loads(lines[1])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 1, + LogKeys.ITER: 10 + }, json.loads(lines[2])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 2, + LogKeys.ITER: 10, + LogKeys.LR: 0.01 + }, json.loads(lines[3])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 2, + LogKeys.ITER: 20, + LogKeys.LR: 0.01 + }, json.loads(lines[4])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 2, + LogKeys.ITER: 10 + }, json.loads(lines[5])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 3, + LogKeys.ITER: 10, + LogKeys.LR: 0.001 + }, json.loads(lines[6])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 3, + LogKeys.ITER: 20, + LogKeys.LR: 0.001 + }, json.loads(lines[7])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 3, + LogKeys.ITER: 10 + }, json.loads(lines[8])) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) + for i in [0, 1, 3, 4, 6, 7]: + self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) + self.assertIn(LogKeys.ITER_TIME, lines[i]) + for i in [2, 5, 8]: + self.assertIn(MetricKeys.ACCURACY, lines[i]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_train_with_iters_per_epoch(self): + json_cfg = { + 'task': Tasks.image_classification, + 'train': { + 'work_dir': self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'EvaluationHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': [Metrics.seq_cls_metric] + } + } + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimmizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = StepLR(optimmizer, 2) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + data_collator=None, + optimizers=(optimmizer, lr_scheduler), + train_dataset=DummyIterableDataset(), + eval_dataset=DummyIterableDataset(), + train_iters_per_epoch=20, + val_iters_per_epoch=10, + max_epochs=3, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json') + with open(json_file, 'r') as f: + lines = [i.strip() for i in f.readlines()] + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 1, + LogKeys.ITER: 10, + LogKeys.LR: 0.01 + }, json.loads(lines[0])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 1, + LogKeys.ITER: 20, + LogKeys.LR: 0.01 + }, json.loads(lines[1])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 1, + LogKeys.ITER: 10 + }, json.loads(lines[2])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 2, + LogKeys.ITER: 10, + LogKeys.LR: 0.01 + }, json.loads(lines[3])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 2, + LogKeys.ITER: 20, + LogKeys.LR: 0.01 + }, json.loads(lines[4])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 2, + LogKeys.ITER: 10 + }, json.loads(lines[5])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 3, + LogKeys.ITER: 10, + LogKeys.LR: 0.001 + }, json.loads(lines[6])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 3, + LogKeys.ITER: 20, + LogKeys.LR: 0.001 + }, json.loads(lines[7])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 3, + LogKeys.ITER: 10 + }, json.loads(lines[8])) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) + for i in [0, 1, 3, 4, 6, 7]: + self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) + self.assertIn(LogKeys.ITER_TIME, lines[i]) + for i in [2, 5, 8]: + self.assertIn(MetricKeys.ACCURACY, lines[i]) + + +class DummyTrainerTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_dummy(self): + default_args = dict(cfg_file='configs/examples/train.json') + trainer = build_trainer('dummy', default_args) + + trainer.train() + trainer.evaluate() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_trainer_gpu.py b/tests/trainers/test_trainer_gpu.py new file mode 100644 index 00000000..0176704a --- /dev/null +++ b/tests/trainers/test_trainer_gpu.py @@ -0,0 +1,330 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import IterableDataset + +from modelscope.metainfo import Metrics, Trainers +from modelscope.metrics.builder import MetricKeys +from modelscope.models.base import Model +from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks +from modelscope.utils.test_utils import (DistributedTestCase, + create_dummy_test_dataset, test_level) + + +class DummyIterableDataset(IterableDataset): + + def __iter__(self): + feat = np.random.random(size=(5, )).astype(np.float32) + labels = np.random.randint(0, 4, (1, )) + iterations = [{'feat': feat, 'labels': labels}] * 500 + return iter(iterations) + + +dummy_dataset_small = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) + +dummy_dataset_big = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 4) + self.bn = nn.BatchNorm1d(4) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class DummyModelForwardInputs(DummyModel): + + def forward(self, inputs): + feat, labels = inputs['feat'], inputs['labels'] + return super().forward(feat, labels) + + +def train_func(work_dir, + dist=False, + iterable_dataset=False, + forward_inputs=False, + **kwargs): + json_cfg = { + 'task': Tasks.image_classification, + 'train': { + 'work_dir': work_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'EvaluationHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 1, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': [Metrics.seq_cls_metric] + } + } + + config_path = os.path.join(work_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + if forward_inputs: + model = DummyModelForwardInputs() + else: + model = DummyModel() + optimmizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = StepLR(optimmizer, 2) + trainer_name = Trainers.default + if iterable_dataset: + train_dataset = DummyIterableDataset() + eval_dataset = DummyIterableDataset() + else: + train_dataset = dummy_dataset_big + eval_dataset = dummy_dataset_small + _kwargs = dict( + cfg_file=config_path, + model=model, + data_collator=None, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + optimizers=(optimmizer, lr_scheduler), + max_epochs=3, + device='gpu', + launcher='pytorch' if dist else None, + **kwargs) + + trainer = build_trainer(trainer_name, _kwargs) + trainer.train() + + +@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') +class TrainerTestSingleGpu(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_single_gpu(self): + train_func(self.tmp_dir) + + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + + with open(json_files[0], 'r') as f: + lines = [i.strip() for i in f.readlines()] + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 1, + LogKeys.ITER: 10, + LogKeys.LR: 0.01 + }, json.loads(lines[0])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 1, + LogKeys.ITER: 20, + LogKeys.LR: 0.01 + }, json.loads(lines[1])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 1, + LogKeys.ITER: 20 + }, json.loads(lines[2])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 2, + LogKeys.ITER: 10, + LogKeys.LR: 0.01 + }, json.loads(lines[3])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 2, + LogKeys.ITER: 20, + LogKeys.LR: 0.01 + }, json.loads(lines[4])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 2, + LogKeys.ITER: 20 + }, json.loads(lines[5])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 3, + LogKeys.ITER: 10, + LogKeys.LR: 0.001 + }, json.loads(lines[6])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 3, + LogKeys.ITER: 20, + LogKeys.LR: 0.001 + }, json.loads(lines[7])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 3, + LogKeys.ITER: 20 + }, json.loads(lines[8])) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) + for i in [0, 1, 3, 4, 6, 7]: + self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) + self.assertIn(LogKeys.ITER_TIME, lines[i]) + for i in [2, 5, 8]: + self.assertIn(MetricKeys.ACCURACY, lines[i]) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, 'distributed unittest') +class TrainerTestMultiGpus(DistributedTestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_multi_gpus(self): + self.start(train_func, num_gpus=2, work_dir=self.tmp_dir, dist=True) + + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + + with open(json_files[0], 'r') as f: + lines = [i.strip() for i in f.readlines()] + + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 1, + LogKeys.ITER: 10, + LogKeys.LR: 0.01 + }, json.loads(lines[0])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 1, + LogKeys.ITER: 10 + }, json.loads(lines[1])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 2, + LogKeys.ITER: 10, + LogKeys.LR: 0.01 + }, json.loads(lines[2])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 2, + LogKeys.ITER: 10 + }, json.loads(lines[3])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.TRAIN, + LogKeys.EPOCH: 3, + LogKeys.ITER: 10, + LogKeys.LR: 0.001 + }, json.loads(lines[4])) + self.assertDictContainsSubset( + { + LogKeys.MODE: ModeKeys.EVAL, + LogKeys.EPOCH: 3, + LogKeys.ITER: 10 + }, json.loads(lines[5])) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) + for i in [0, 2, 4]: + self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) + self.assertIn(LogKeys.ITER_TIME, lines[i]) + for i in [1, 3, 5]: + self.assertIn(MetricKeys.ACCURACY, lines[i]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_multi_gpus_forward_inputs(self): + self.start( + train_func, + num_gpus=2, + work_dir=self.tmp_dir, + dist=True, + forward_inputs=True) + + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) + + # TODO: support iters_per_epoch for dist mode + @unittest.skipIf(True, 'need to adapt to DistributedSampler') + def test_multi_gpus_with_iters_per_epoch(self): + self.start( + train_func, + num_gpus=2, + work_dir=self.tmp_dir, + dist=True, + iterable_dataset=True, + train_iters_per_epoch=20, + val_iters_per_epoch=10, + ) + + results_files = os.listdir(self.tmp_dir) + json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + + with open(json_files[0], 'r') as f: + lines = [i.strip() for i in f.readlines()] + + print(results_files, lines) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py new file mode 100644 index 00000000..f1d9e414 --- /dev/null +++ b/tests/trainers/test_trainer_with_nlp.py @@ -0,0 +1,270 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Metrics +from modelscope.models.base import Model +from modelscope.models.nlp import SbertForSequenceClassification +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.hub import read_config +from modelscope.utils.test_utils import test_level + + +class TestTrainerWithNlp(unittest.TestCase): + sentence1 = '今天气温比昨天高么?' + sentence2 = '今天湿度比昨天高么?' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.dataset = MsDataset.load( + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(2)) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' + kwargs = dict( + model=model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(10): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + output_files = os.listdir( + os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) + self.assertIn(ModelFile.CONFIGURATION, output_files) + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) + copy_src_files = os.listdir(trainer.model_dir) + + print(f'copy_src_files are {copy_src_files}') + print(f'output_files are {output_files}') + for item in copy_src_files: + if not item.startswith('.'): + self.assertIn(item, output_files) + + def pipeline_sentence_similarity(model_dir): + model = Model.from_pretrained(model_dir) + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, model=model) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + pipeline_sentence_similarity(output_dir) + + @unittest.skip + def test_trainer_with_backbone_head(self): + model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' + kwargs = dict( + model=model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(10): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + eval_results = trainer.evaluate( + checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) + self.assertTrue(Metrics.accuracy in eval_results) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_user_defined_config(self): + model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' + cfg = read_config(model_id) + cfg.train.max_epochs = 20 + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} + cfg.train.work_dir = self.tmp_dir + cfg_file = os.path.join(self.tmp_dir, 'config.json') + cfg.dump(cfg_file) + kwargs = dict( + model=model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + cfg_file=cfg_file) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(20): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + eval_results = trainer.evaluate( + checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) + self.assertTrue(Metrics.accuracy in eval_results) + + @unittest.skip('skip for now before test is re-configured') + def test_trainer_with_configured_datasets(self): + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + cfg: Config = read_config(model_id) + cfg.train.max_epochs = 20 + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} + cfg.train.work_dir = self.tmp_dir + cfg.dataset = { + 'train': { + 'name': 'clue', + 'subset_name': 'afqmc', + 'split': 'train', + }, + 'val': { + 'name': 'clue', + 'subset_name': 'afqmc', + 'split': 'train', + }, + } + cfg_file = os.path.join(self.tmp_dir, 'config.json') + cfg.dump(cfg_file) + kwargs = dict(model=model_id, cfg_file=cfg_file) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(cfg.train.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + eval_results = trainer.evaluate( + checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) + self.assertTrue(Metrics.accuracy in eval_results) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_continue_train(self): + from modelscope.utils.regress_test_utils import MsRegressTool + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + cfg: Config = read_config(model_id) + cfg.train.max_epochs = 3 + cfg.preprocessor.first_sequence = 'sentence1' + cfg.preprocessor.second_sequence = 'sentence2' + cfg.preprocessor.label = 'label' + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} + cfg.train.dataloader.batch_size_per_gpu = 2 + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 3, + 'by_epoch': False, + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'interval': 1 + }] + cfg.train.work_dir = self.tmp_dir + cfg_file = os.path.join(self.tmp_dir, 'config.json') + cfg.dump(cfg_file) + dataset = MsDataset.load('clue', subset_name='afqmc', split='train') + dataset = dataset.to_hf_dataset().select(range(4)) + kwargs = dict( + model=model_id, + train_dataset=dataset, + eval_dataset=dataset, + cfg_file=cfg_file) + + regress_tool = MsRegressTool(baseline=True) + trainer: EpochBasedTrainer = build_trainer(default_args=kwargs) + + def lazy_stop_callback(): + from modelscope.trainers.hooks.hook import Hook, Priority + + class EarlyStopHook(Hook): + PRIORITY = Priority.VERY_LOW + + def after_iter(self, trainer): + if trainer.iter == 3: + raise MsRegressTool.EarlyStopError('Test finished.') + + if 'EarlyStopHook' not in [ + hook.__class__.__name__ for hook in trainer.hooks + ]: + trainer.register_hook(EarlyStopHook()) + + with regress_tool.monitor_ms_train( + trainer, + 'trainer_continue_train', + level='strict', + lazy_stop_callback=lazy_stop_callback): + trainer.train() + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + trainer = build_trainer(default_args=kwargs) + regress_tool = MsRegressTool(baseline=False) + with regress_tool.monitor_ms_train( + trainer, 'trainer_continue_train', level='strict'): + trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_evaluation(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' + cache_path = snapshot_download(model_id) + model = SbertForSequenceClassification.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + eval_dataset=self.dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + print(trainer.evaluate(cache_path + '/pytorch_model.bin')) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' + cache_path = snapshot_download(model_id) + model = SbertForSequenceClassification.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset, + eval_dataset=self.dataset, + max_epochs=2, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_translation_trainer.py b/tests/trainers/test_translation_trainer.py new file mode 100644 index 00000000..71bed241 --- /dev/null +++ b/tests/trainers/test_translation_trainer.py @@ -0,0 +1,18 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.trainers.nlp import CsanmtTranslationTrainer +from modelscope.utils.test_utils import test_level + + +class TranslationTest(unittest.TestCase): + model_id = 'damo/nlp_csanmt_translation_zh2en' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + trainer = CsanmtTranslationTrainer(model=self.model_id) + trainer.train() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_video_summarization_trainer.py b/tests/trainers/test_video_summarization_trainer.py new file mode 100644 index 00000000..1cea1eea --- /dev/null +++ b/tests/trainers/test_video_summarization_trainer.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.cv.video_summarization import PGLVideoSummarization +from modelscope.msdatasets.task_datasets import VideoSummarizationDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class VideoSummarizationTrainerTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_googlenet_pgl-video-summarization' + self.cache_path = snapshot_download(self.model_id) + self.config = Config.from_file( + os.path.join(self.cache_path, ModelFile.CONFIGURATION)) + self.dataset_train = VideoSummarizationDataset('train', + self.config.dataset, + self.cache_path) + self.dataset_val = VideoSummarizationDataset('test', + self.config.dataset, + self.cache_path) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + model = PGLVideoSummarization.from_pretrained(self.cache_path) + kwargs = dict( + cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + max_epochs=2, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/utils/__init__.py b/tests/trainers/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/utils/test_inference.py b/tests/trainers/utils/test_inference.py new file mode 100644 index 00000000..37e202e3 --- /dev/null +++ b/tests/trainers/utils/test_inference.py @@ -0,0 +1,126 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import torch +from torch import nn +from torch.utils.data import DataLoader + +from modelscope.metrics.builder import MetricKeys +from modelscope.metrics.sequence_classification_metric import \ + SequenceClassificationMetric +from modelscope.models.base import Model +from modelscope.trainers import EpochBasedTrainer +from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test +from modelscope.utils.test_utils import (DistributedTestCase, + create_dummy_test_dataset, test_level) +from modelscope.utils.torch_utils import get_dist_info, init_dist + +dummy_dataset = create_dummy_test_dataset( + torch.rand((5, )), torch.randint(0, 4, (1, )), 20) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 4) + self.bn = nn.BatchNorm1d(4) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class DummyTrainer(EpochBasedTrainer): + + def __init__(self, model): + self.model = model + + +def test_func(dist=False): + dummy_model = DummyModel() + dataset = dummy_dataset.to_torch_dataset() + + dummy_loader = DataLoader( + dataset, + batch_size=2, + ) + + metric_class = SequenceClassificationMetric() + + if dist: + init_dist(launcher='pytorch') + + rank, world_size = get_dist_info() + device = torch.device(f'cuda:{rank}') + dummy_model.cuda() + + if world_size > 1: + from torch.nn.parallel.distributed import DistributedDataParallel + dummy_model = DistributedDataParallel( + dummy_model, device_ids=[torch.cuda.current_device()]) + test_func = multi_gpu_test + else: + test_func = single_gpu_test + + dummy_trainer = DummyTrainer(dummy_model) + + metric_results = test_func( + dummy_trainer, + dummy_loader, + device=device, + metric_classes=[metric_class]) + + return metric_results + + +@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') +class SingleGpuTestTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_single_gpu_test(self): + metric_results = test_func() + self.assertIn(MetricKeys.ACCURACY, metric_results) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, 'distributed unittest') +class MultiGpuTestTest(DistributedTestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_multi_gpu_test(self): + self.start( + test_func, + num_gpus=2, + assert_callback=lambda x: self.assertIn(MetricKeys.ACCURACY, x), + dist=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..f1a50035 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .profiler import * # noqa F403 diff --git a/tests/utils/profiler.py b/tests/utils/profiler.py new file mode 100644 index 00000000..f5a522ef --- /dev/null +++ b/tests/utils/profiler.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import importlib +import sys +from functools import wraps +from typing import Any, Callable, Dict, Tuple, Type + + +def reraise(tp, value, tb): + try: + if value is None: + value = tp() + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + finally: + value = None + tb = None + + +class Profiler: + + def __init__(self) -> None: + import cProfile + self.pr = cProfile.Profile() + + def __enter__(self): + self.pr.enable() + + def __exit__(self, tp, exc, tb): + self.pr.disable() + if tp is not None: + reraise(tp, exc, tb) + + import pstats + ps = pstats.Stats(self.pr, stream=sys.stderr).sort_stats('tottime') + ps.print_stats(20) + + +def wrapper(tp: Type[Profiler]) -> Callable[[], Callable[..., Any]]: + + def _inner(func: Callable[..., Any]) -> Callable[..., Any]: + + @wraps(func) + def executor(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any: + with tp(): + return func(*args, **kwargs) + + return executor + + return _inner + + +PIPELINE_BASE_MODULE = 'modelscope.pipelines.base' +PIPELINE_BASE_CLASS = 'Pipeline' + + +def enable(): + base = importlib.import_module(PIPELINE_BASE_MODULE) + Pipeline = getattr(base, PIPELINE_BASE_CLASS) + Pipeline.__call__ = wrapper(Profiler)(Pipeline.__call__) diff --git a/tests/utils/test_ast.py b/tests/utils/test_ast.py new file mode 100644 index 00000000..0243053e --- /dev/null +++ b/tests/utils/test_ast.py @@ -0,0 +1,95 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +import tempfile +import time +import unittest +from pathlib import Path + +from modelscope.utils.ast_utils import AstScaning, FilesAstScaning, load_index + +p = Path(__file__) + +MODELSCOPE_PATH = p.resolve().parents[2].joinpath('modelscope') + + +class AstScaningTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + self.test_file = os.path.join(self.tmp_dir, 'test.py') + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_ast_scaning_class(self): + astScaner = AstScaning() + pipeline_file = os.path.join(MODELSCOPE_PATH, 'pipelines', 'nlp', + 'text_generation_pipeline.py') + output = astScaner.generate_ast(pipeline_file) + self.assertTrue(output['imports'] is not None) + self.assertTrue(output['from_imports'] is not None) + self.assertTrue(output['decorators'] is not None) + imports, from_imports, decorators = output['imports'], output[ + 'from_imports'], output['decorators'] + self.assertIsInstance(imports, dict) + self.assertIsInstance(from_imports, dict) + self.assertIsInstance(decorators, list) + self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) + self.assertEqual(len(from_imports.keys()), 10) + self.assertTrue(from_imports['modelscope.metainfo'] is not None) + self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) + self.assertEqual(decorators, + [('PIPELINES', 'text-generation', 'text-generation')]) + + def test_files_scaning_method(self): + fileScaner = FilesAstScaning() + output = fileScaner.get_files_scan_results() + self.assertTrue(output['index'] is not None) + self.assertTrue(output['requirements'] is not None) + index, requirements = output['index'], output['requirements'] + self.assertIsInstance(index, dict) + self.assertIsInstance(requirements, dict) + self.assertIsInstance(list(index.keys())[0], tuple) + index_0 = list(index.keys())[0] + self.assertIsInstance(index[index_0], dict) + self.assertTrue(index[index_0]['imports'] is not None) + self.assertIsInstance(index[index_0]['imports'], list) + self.assertTrue(index[index_0]['module'] is not None) + self.assertIsInstance(index[index_0]['module'], str) + index_0 = list(requirements.keys())[0] + self.assertIsInstance(requirements[index_0], list) + + def test_file_mtime_md5_method(self): + fileScaner = FilesAstScaning() + # create first file + with open(self.test_file, 'w', encoding='utf-8') as f: + f.write('This is the new test!') + + md5_1 = fileScaner.files_mtime_md5(self.tmp_dir, []) + md5_2 = fileScaner.files_mtime_md5(self.tmp_dir, []) + self.assertEqual(md5_1, md5_2) + time.sleep(2) + # case of revise + with open(self.test_file, 'w', encoding='utf-8') as f: + f.write('test again') + md5_3 = fileScaner.files_mtime_md5(self.tmp_dir, []) + self.assertNotEqual(md5_1, md5_3) + + # case of create + self.test_file_new = os.path.join(self.tmp_dir, 'test_1.py') + time.sleep(2) + with open(self.test_file_new, 'w', encoding='utf-8') as f: + f.write('test again') + md5_4 = fileScaner.files_mtime_md5(self.tmp_dir, []) + self.assertNotEqual(md5_1, md5_4) + self.assertNotEqual(md5_3, md5_4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_compatibility.py b/tests/utils/test_compatibility.py new file mode 100644 index 00000000..f5222261 --- /dev/null +++ b/tests/utils/test_compatibility.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + + +class CompatibilityTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def tearDown(self): + super().tearDown() + + def test_xtcocotools(self): + from xtcocotools.coco import COCO + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py new file mode 100644 index 00000000..8b89fa68 --- /dev/null +++ b/tests/utils/test_config.py @@ -0,0 +1,232 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import copy +import tempfile +import unittest + +import json + +from modelscope.utils.config import Config, check_config + +obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} + + +class ConfigTest(unittest.TestCase): + + def test_json(self): + config_file = 'configs/examples/configuration.json' + cfg = Config.from_file(config_file) + self.assertEqual(cfg.a, 1) + self.assertEqual(cfg.b, obj['b']) + + def test_yaml(self): + config_file = 'configs/examples/configuration.yaml' + cfg = Config.from_file(config_file) + self.assertEqual(cfg.a, 1) + self.assertEqual(cfg.b, obj['b']) + + def test_py(self): + config_file = 'configs/examples/configuration.py' + cfg = Config.from_file(config_file) + self.assertEqual(cfg.a, 1) + self.assertEqual(cfg.b, obj['b']) + + def test_dump(self): + config_file = 'configs/examples/configuration.py' + cfg = Config.from_file(config_file) + self.assertEqual(cfg.a, 1) + self.assertEqual(cfg.b, obj['b']) + pretty_text = 'a = 1\n' + pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n" + + json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}' + yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n' + with tempfile.NamedTemporaryFile(suffix='.json') as ofile: + self.assertEqual(pretty_text, cfg.dump()) + cfg.dump(ofile.name) + with open(ofile.name, 'r') as infile: + self.assertDictEqual( + json.loads(json_str), json.loads(infile.read())) + + with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile: + cfg.dump(ofile.name) + with open(ofile.name, 'r') as infile: + self.assertEqual(yaml_str, infile.read()) + + def test_to_dict(self): + config_file = 'configs/examples/configuration.json' + cfg = Config.from_file(config_file) + d = cfg.to_dict() + print(d) + self.assertTrue(isinstance(d, dict)) + + def test_to_args(self): + + def parse_fn(args): + parser = argparse.ArgumentParser(prog='PROG') + parser.add_argument('--model-dir', default='') + parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--optimizer', default='') + parser.add_argument('--weight-decay', type=float, default=1e-7) + parser.add_argument( + '--save-checkpoint-epochs', type=int, default=30) + return parser.parse_args(args) + + cfg = Config.from_file('configs/examples/plain_args.yaml') + args = cfg.to_args(parse_fn) + + self.assertEqual(args.model_dir, 'path/to/model') + self.assertAlmostEqual(args.lr, 0.01) + self.assertAlmostEqual(args.weight_decay, 1e-6) + self.assertEqual(args.optimizer, 'Adam') + self.assertEqual(args.save_checkpoint_epochs, 20) + + def test_check_config(self): + check_config('configs/cv/configuration.json') + check_config('configs/nlp/sbert_sentence_similarity.json') + + def test_merge_from_dict(self): + base_cfg = copy.deepcopy(obj) + base_cfg.update({'dict_list': [dict(l1=1), dict(l2=2)]}) + + cfg = Config(base_cfg) + + merge_dict = { + 'a': 2, + 'b.d': 'ee', + 'b.c': [3, 3, 3], + 'dict_list': { + '0': dict(l1=3) + }, + 'c': 'test' + } + + cfg1 = copy.deepcopy(cfg) + cfg1.merge_from_dict(merge_dict) + self.assertDictEqual( + cfg1._cfg_dict, { + 'a': 2, + 'b': { + 'c': [3, 3, 3], + 'd': 'ee' + }, + 'dict_list': [dict(l1=3), dict(l2=2)], + 'c': 'test' + }) + + cfg2 = copy.deepcopy(cfg) + cfg2.merge_from_dict(merge_dict, force=False) + self.assertDictEqual( + cfg2._cfg_dict, { + 'a': 1, + 'b': { + 'c': [1, 2, 3], + 'd': 'dd' + }, + 'dict_list': [dict(l1=1), dict(l2=2)], + 'c': 'test' + }) + + def test_merge_from_dict_with_list(self): + base_cfg = { + 'a': + 1, + 'b': { + 'c': [1, 2, 3], + 'd': 'dd' + }, + 'dict_list': [dict(type='l1', v=1), + dict(type='l2', v=2)], + 'dict_list2': [ + dict( + type='l1', + v=[dict(type='l1_1', v=1), + dict(type='l1_2', v=2)]), + dict(type='l2', v=2) + ] + } + cfg = Config(base_cfg) + + merge_dict_for_list = { + 'a': + 2, + 'b.c': [3, 3, 3], + 'b.d': + 'ee', + 'dict_list': [dict(type='l1', v=8), + dict(type='l3', v=8)], + 'dict_list2': [ + dict( + type='l1', + v=[ + dict(type='l1_1', v=8), + dict(type='l1_2', v=2), + dict(type='l1_3', v=8), + ]), + dict(type='l2', v=8) + ], + 'c': + 'test' + } + + cfg1 = copy.deepcopy(cfg) + cfg1.merge_from_dict(merge_dict_for_list, force=False) + self.assertDictEqual( + cfg1._cfg_dict, { + 'a': + 1, + 'b': { + 'c': [1, 2, 3], + 'd': 'dd' + }, + 'dict_list': [ + dict(type='l1', v=1), + dict(type='l2', v=2), + dict(type='l3', v=8) + ], + 'dict_list2': [ + dict( + type='l1', + v=[ + dict(type='l1_1', v=1), + dict(type='l1_2', v=2), + dict(type='l1_3', v=8), + ]), + dict(type='l2', v=2) + ], + 'c': + 'test' + }) + + cfg2 = copy.deepcopy(cfg) + cfg2.merge_from_dict(merge_dict_for_list, force=True) + self.assertDictEqual( + cfg2._cfg_dict, { + 'a': + 2, + 'b': { + 'c': [3, 3, 3], + 'd': 'ee' + }, + 'dict_list': [ + dict(type='l1', v=8), + dict(type='l2', v=2), + dict(type='l3', v=8) + ], + 'dict_list2': [ + dict( + type='l1', + v=[ + dict(type='l1_1', v=8), + dict(type='l1_2', v=2), + dict(type='l1_3', v=8), + ]), + dict(type='l2', v=8) + ], + 'c': + 'test' + }) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py new file mode 100644 index 00000000..0d334fda --- /dev/null +++ b/tests/utils/test_device.py @@ -0,0 +1,108 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +import tempfile +import time +import unittest + +import torch + +from modelscope.utils.constant import Frameworks +from modelscope.utils.device import (create_device, device_placement, + verify_device) + +# import tensorflow must be imported after torch is imported when using tf1.15 +import tensorflow as tf # isort:skip + + +class DeviceTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def tearDown(self): + super().tearDown() + + def test_verify(self): + device_name, device_id = verify_device('cpu') + self.assertEqual(device_name, 'cpu') + self.assertTrue(device_id is None) + device_name, device_id = verify_device('CPU') + self.assertEqual(device_name, 'cpu') + + device_name, device_id = verify_device('gpu') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 0) + + device_name, device_id = verify_device('cuda') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 0) + + device_name, device_id = verify_device('cuda:0') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 0) + + device_name, device_id = verify_device('gpu:1') + self.assertEqual(device_name, 'gpu') + self.assertTrue(device_id == 1) + + with self.assertRaises(AssertionError): + verify_device('xgu') + + with self.assertRaises(AssertionError): + verify_device('') + + with self.assertRaises(AssertionError): + verify_device(None) + + def test_create_device_torch(self): + if torch.cuda.is_available(): + target_device_type = 'cuda' + target_device_index = 0 + else: + target_device_type = 'cpu' + target_device_index = None + device = create_device('gpu') + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.type == target_device_type) + self.assertTrue(device.index == target_device_index) + + device = create_device('gpu:0') + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.type == target_device_type) + self.assertTrue(device.index == target_device_index) + + device = create_device('cuda') + self.assertTrue(device.type == target_device_type) + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.index == target_device_index) + + device = create_device('cuda:0') + self.assertTrue(isinstance(device, torch.device)) + self.assertTrue(device.type == target_device_type) + self.assertTrue(device.index == target_device_index) + + def test_device_placement_cpu(self): + with device_placement(Frameworks.torch, 'cpu'): + pass + + @unittest.skip('skip this test to avoid debug logging.') + def test_device_placement_tf_gpu(self): + tf.debugging.set_log_device_placement(True) + with device_placement(Frameworks.tf, 'gpu:0'): + a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + c = tf.matmul(a, b) + s = tf.Session() + s.run(c) + tf.debugging.set_log_device_placement(False) + + def test_device_placement_torch_gpu(self): + with device_placement(Frameworks.torch, 'gpu:0'): + if torch.cuda.is_available(): + self.assertEqual(torch.cuda.current_device(), 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_registry.py b/tests/utils/test_registry.py new file mode 100644 index 00000000..0a37101d --- /dev/null +++ b/tests/utils/test_registry.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.utils.constant import Tasks +from modelscope.utils.registry import Registry, build_from_cfg, default_group + + +class RegistryTest(unittest.TestCase): + + def test_register_class_no_task(self): + MODELS = Registry('models') + self.assertTrue(MODELS.name == 'models') + self.assertTrue(default_group in MODELS.modules) + self.assertTrue(MODELS.modules[default_group] == {}) + + self.assertEqual(len(MODELS.modules), 1) + + @MODELS.register_module(module_name='cls-resnet') + class ResNetForCls(object): + pass + + self.assertTrue(default_group in MODELS.modules) + self.assertTrue(MODELS.get('cls-resnet') is ResNetForCls) + + def test_register_class_with_task(self): + MODELS = Registry('models') + + @MODELS.register_module(Tasks.image_classification, 'SwinT') + class SwinTForCls(object): + pass + + self.assertTrue(Tasks.image_classification in MODELS.modules) + self.assertTrue( + MODELS.get('SwinT', Tasks.image_classification) is SwinTForCls) + + @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') + class BertForSentimentAnalysis(object): + pass + + self.assertTrue(Tasks.sentiment_analysis in MODELS.modules) + self.assertTrue( + MODELS.get('Bert', Tasks.sentiment_analysis) is + BertForSentimentAnalysis) + + @MODELS.register_module(Tasks.image_object_detection) + class DETR(object): + pass + + self.assertTrue(Tasks.image_object_detection in MODELS.modules) + self.assertTrue( + MODELS.get('DETR', Tasks.image_object_detection) is DETR) + + self.assertEqual(len(MODELS.modules), 4) + + def test_list(self): + MODELS = Registry('models') + + @MODELS.register_module(Tasks.image_classification, 'SwinT') + class SwinTForCls(object): + pass + + @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') + class BertForSentimentAnalysis(object): + pass + + MODELS.list() + print(MODELS) + + def test_build(self): + MODELS = Registry('models') + + @MODELS.register_module(Tasks.image_classification, 'SwinT') + class SwinTForCls(object): + pass + + @MODELS.register_module(Tasks.sentiment_analysis, 'Bert') + class BertForSentimentAnalysis(object): + pass + + cfg = dict(type='SwinT') + model = build_from_cfg(cfg, MODELS, Tasks.image_classification) + self.assertTrue(isinstance(model, SwinTForCls)) + + cfg = dict(type='Bert') + model = build_from_cfg(cfg, MODELS, Tasks.sentiment_analysis) + self.assertTrue(isinstance(model, BertForSentimentAnalysis)) + + with self.assertRaises(KeyError): + cfg = dict(type='Bert') + model = build_from_cfg(cfg, MODELS, Tasks.image_classification) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_type_assert.py b/tests/utils/test_type_assert.py new file mode 100644 index 00000000..5b62a269 --- /dev/null +++ b/tests/utils/test_type_assert.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest +from typing import List, Union + +from modelscope.utils.type_assert import type_assert + + +class type_assertTest(unittest.TestCase): + + @type_assert(object, list, (int, str)) + def a(self, a: List[int], b: Union[int, str]): + print(a, b) + + def test_type_assert(self): + with self.assertRaises(TypeError): + self.a([1], 2) + self.a(1, [123]) + + +if __name__ == '__main__': + unittest.main()